diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 7378da968..2b6f7240c 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -198,6 +198,7 @@ class MBartTokenizer(XLMRobertaTokenizer): max_length: Optional[int] = None, padding: str = "longest", return_tensors: str = "pt", + **kwargs, ) -> BatchEncoding: """Prepare a batch that can be passed directly to an instance of MBartModel. Arguments: @@ -207,6 +208,7 @@ class MBartTokenizer(XLMRobertaTokenizer): tgt_lang: default ro_RO (romanian), the language we are translating to max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large* padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest. + **kwargs: passed to self.__call__ Returns: :obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. @@ -221,6 +223,7 @@ class MBartTokenizer(XLMRobertaTokenizer): max_length=max_length, padding=padding, truncation=True, + **kwargs, ) if tgt_texts is None: return model_inputs @@ -232,6 +235,7 @@ class MBartTokenizer(XLMRobertaTokenizer): padding=padding, max_length=max_length, truncation=True, + **kwargs, ) for k, v in decoder_inputs.items(): model_inputs[f"decoder_{k}"] = v