From 4fedc1256c1102e7dc28ca47033b5d9105a33a9c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Tue, 7 Jul 2020 16:35:12 +0200 Subject: [PATCH] Fix tests imports dpr (#5576) * fix test imports * fix max_length * style * fix tests --- src/transformers/tokenization_dpr.py | 36 ++++++++++++++++++++++------ tests/test_modeling_dpr.py | 2 +- tests/test_tokenization_dpr.py | 2 +- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/transformers/tokenization_dpr.py b/src/transformers/tokenization_dpr.py index 33458c485..0c7b70eff 100644 --- a/src/transformers/tokenization_dpr.py +++ b/src/transformers/tokenization_dpr.py @@ -157,13 +157,13 @@ CUSTOM_DPR_READER_DOCSTRING = r""" The passages titles to be encoded. This can be a string, a list of strings if there are several passages. texts (:obj:`str`, :obj:`List[str]`): The passages texts to be encoded. This can be a string, a list of strings if there are several passages. - padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`): + padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`): Activate and control padding. Accepts the following values: * `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided), * `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`) * `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths) - truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`): + truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`): Activate and control truncation. Accepts the following values: * `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). @@ -203,15 +203,37 @@ class CustomDPRReaderTokenizerMixin: def __call__( self, questions, - titles, - texts, - padding: Union[bool, str] = True, - truncation: Union[bool, str] = True, - max_length: Optional[int] = 512, + titles: Optional[str] = None, + texts: Optional[str] = None, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: Optional[bool] = None, **kwargs ) -> BatchEncoding: + if titles is None and texts is None: + return super().__call__( + questions, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + **kwargs, + ) + elif titles is None or texts is None: + text_pair = titles if texts is None else texts + return super().__call__( + questions, + text_pair, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + **kwargs, + ) titles = titles if not isinstance(titles, str) else [titles] texts = texts if not isinstance(texts, str) else [texts] n_passages = len(titles) diff --git a/tests/test_modeling_dpr.py b/tests/test_modeling_dpr.py index 9bcd85873..42883a404 100644 --- a/tests/test_modeling_dpr.py +++ b/tests/test_modeling_dpr.py @@ -17,10 +17,10 @@ import unittest from transformers import is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor -from .utils import require_torch, slow, torch_device if is_torch_available(): diff --git a/tests/test_tokenization_dpr.py b/tests/test_tokenization_dpr.py index f825b0429..2043d4e9f 100644 --- a/tests/test_tokenization_dpr.py +++ b/tests/test_tokenization_dpr.py @@ -14,6 +14,7 @@ # limitations under the License. +from transformers.testing_utils import slow from transformers.tokenization_dpr import ( DPRContextEncoderTokenizer, DPRContextEncoderTokenizerFast, @@ -26,7 +27,6 @@ from transformers.tokenization_dpr import ( from transformers.tokenization_utils_base import BatchEncoding from .test_tokenization_bert import BertTokenizationTest -from .utils import slow class DPRContextEncoderTokenizationTest(BertTokenizationTest):