mbart.prepare_translation_batch: pass through kwargs (#5581)
This commit is contained in:
Родитель
353b8f1e7a
Коммит
d6eab53058
|
@ -198,6 +198,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||
max_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||
Arguments:
|
||||
|
@ -207,6 +208,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||
tgt_lang: default ro_RO (romanian), the language we are translating to
|
||||
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
|
||||
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
|
||||
**kwargs: passed to self.__call__
|
||||
|
||||
Returns:
|
||||
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
|
||||
|
@ -221,6 +223,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
|
@ -232,6 +235,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
|
|
Загрузка…
Ссылка в новой задаче