* fix test imports

* fix max_length

* style

* fix tests
This commit is contained in:
Quentin Lhoest 2020-07-07 16:35:12 +02:00 коммит произвёл GitHub
Родитель d4886173b2
Коммит 4fedc1256c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 31 добавлений и 9 удалений

Просмотреть файл

@ -157,13 +157,13 @@ CUSTOM_DPR_READER_DOCSTRING = r"""
The passages titles to be encoded. This can be a string, a list of strings if there are several passages.
texts (:obj:`str`, :obj:`List[str]`):
The passages texts to be encoded. This can be a string, a list of strings if there are several passages.
padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
Activate and control padding. Accepts the following values:
* `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
* `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`)
* `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths)
truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
Activate and control truncation. Accepts the following values:
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`).
@ -203,15 +203,37 @@ class CustomDPRReaderTokenizerMixin:
def __call__(
self,
questions,
titles,
texts,
padding: Union[bool, str] = True,
truncation: Union[bool, str] = True,
max_length: Optional[int] = 512,
titles: Optional[str] = None,
texts: Optional[str] = None,
padding: Union[bool, str] = False,
truncation: Union[bool, str] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = None,
**kwargs
) -> BatchEncoding:
if titles is None and texts is None:
return super().__call__(
questions,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
**kwargs,
)
elif titles is None or texts is None:
text_pair = titles if texts is None else texts
return super().__call__(
questions,
text_pair,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
**kwargs,
)
titles = titles if not isinstance(titles, str) else [titles]
texts = texts if not isinstance(texts, str) else [texts]
n_passages = len(titles)

Просмотреть файл

@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

Просмотреть файл

@ -14,6 +14,7 @@
# limitations under the License.
from transformers.testing_utils import slow
from transformers.tokenization_dpr import (
DPRContextEncoderTokenizer,
DPRContextEncoderTokenizerFast,
@ -26,7 +27,6 @@ from transformers.tokenization_dpr import (
from transformers.tokenization_utils_base import BatchEncoding
from .test_tokenization_bert import BertTokenizationTest
from .utils import slow
class DPRContextEncoderTokenizationTest(BertTokenizationTest):