mbart.prepare_translation_batch: pass through kwargs (#5581)

This commit is contained in:
Sam Shleifer 2020-07-07 13:46:05 -04:00 коммит произвёл GitHub
Родитель 353b8f1e7a
Коммит d6eab53058
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 4 добавлений и 0 удалений

Просмотреть файл

@ -198,6 +198,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "pt",
**kwargs,
) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of MBartModel.
Arguments:
@ -207,6 +208,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
tgt_lang: default ro_RO (romanian), the language we are translating to
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
**kwargs: passed to self.__call__
Returns:
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
@ -221,6 +223,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length=max_length,
padding=padding,
truncation=True,
**kwargs,
)
if tgt_texts is None:
return model_inputs
@ -232,6 +235,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
padding=padding,
max_length=max_length,
truncation=True,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v