Fix tests imports dpr (#5576)
* fix test imports * fix max_length * style * fix tests
This commit is contained in:
Родитель
d4886173b2
Коммит
4fedc1256c
|
@ -157,13 +157,13 @@ CUSTOM_DPR_READER_DOCSTRING = r"""
|
|||
The passages titles to be encoded. This can be a string, a list of strings if there are several passages.
|
||||
texts (:obj:`str`, :obj:`List[str]`):
|
||||
The passages texts to be encoded. This can be a string, a list of strings if there are several passages.
|
||||
padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
|
||||
padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
|
||||
Activate and control padding. Accepts the following values:
|
||||
|
||||
* `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
|
||||
* `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`)
|
||||
* `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths)
|
||||
truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
|
||||
truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
|
||||
Activate and control truncation. Accepts the following values:
|
||||
|
||||
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`).
|
||||
|
@ -203,15 +203,37 @@ class CustomDPRReaderTokenizerMixin:
|
|||
def __call__(
|
||||
self,
|
||||
questions,
|
||||
titles,
|
||||
texts,
|
||||
padding: Union[bool, str] = True,
|
||||
truncation: Union[bool, str] = True,
|
||||
max_length: Optional[int] = 512,
|
||||
titles: Optional[str] = None,
|
||||
texts: Optional[str] = None,
|
||||
padding: Union[bool, str] = False,
|
||||
truncation: Union[bool, str] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
if titles is None and texts is None:
|
||||
return super().__call__(
|
||||
questions,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
return_attention_mask=return_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
elif titles is None or texts is None:
|
||||
text_pair = titles if texts is None else texts
|
||||
return super().__call__(
|
||||
questions,
|
||||
text_pair,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
return_attention_mask=return_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
titles = titles if not isinstance(titles, str) else [titles]
|
||||
texts = texts if not isinstance(texts, str) else [texts]
|
||||
n_passages = len(titles)
|
||||
|
|
|
@ -17,10 +17,10 @@
|
|||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from .utils import require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.tokenization_dpr import (
|
||||
DPRContextEncoderTokenizer,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
|
@ -26,7 +27,6 @@ from transformers.tokenization_dpr import (
|
|||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
from .test_tokenization_bert import BertTokenizationTest
|
||||
from .utils import slow
|
||||
|
||||
|
||||
class DPRContextEncoderTokenizationTest(BertTokenizationTest):
|
||||
|
|
Загрузка…
Ссылка в новой задаче