[breaking|pipelines|tokenizers] Adding slow-fast tokenizers equivalence tests pipelines - Removing sentencepiece as a required dependency (#8073)

* Fixing roberta for slow-fast tests

* WIP getting equivalence on pipelines

* slow-to-fast equivalence - working on question-answering pipeline

* optional FAISS tests

* Pipeline Q&A

* Move pipeline tests to their own test job again

* update tokenizer to add sequence id methods

* update to tokenizers 0.9.4

* set sentencepiecce as optional

* clean up squad

* clean up pipelines to use sequence_ids

* style/quality

* wording

* Switch to use_fast = True by default

* update tests for use_fast at True by default

* fix rag tokenizer test

* removing protobuf from required dependencies

* fix NER test for use_fast = True by default

* fixing example tests (Q&A examples use slow tokenizers for now)

* protobuf in main deps extras["sentencepiece"] and example deps

* fix protobug install test

* try to fix seq2seq by switching to slow tokenizers for now

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Thomas Wolf 2020-11-15 22:50:59 +01:00 коммит произвёл GitHub
Родитель 24184e73c4
Коммит f4e04cd2c6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
23 изменённых файлов: 689 добавлений и 262 удалений

Просмотреть файл

@ -736,6 +736,7 @@ def main():
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
use_fast=False, # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
)
model = AutoModelForQuestionAnswering.from_pretrained(
args.model_name_or_path,
@ -784,7 +785,10 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
# SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
# So we use use_fast=False here for now until Fast-tokenizer-compatible-examples are out
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case, use_fast=False)
model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory

Просмотреть файл

@ -114,6 +114,7 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False, # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
)
model = AutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,

Просмотреть файл

@ -18,3 +18,4 @@ fire
pytest
conllu
sentencepiece != 0.1.92
protobuf

Просмотреть файл

@ -197,7 +197,7 @@ class TestAll(TestCasePlus):
)
@require_torch_non_multi_gpu_but_fix_me
def test_dataset_kwargs(self, tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=False)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset(
tokenizer,

Просмотреть файл

@ -96,13 +96,13 @@ else:
extras["retrieval"] = ["faiss-cpu", "datasets"]
extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"]
extras["tokenizers"] = ["tokenizers==0.9.2"]
extras["tokenizers"] = ["tokenizers==0.9.4"]
extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
extras["modelcreation"] = ["cookiecutter==1.7.2"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
extras["sentencepiece"] = ["sentencepiece==0.1.91"]
extras["sentencepiece"] = ["sentencepiece==0.1.91", "protobuf"]
extras["retrieval"] = ["faiss-cpu", "datasets"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil"] + extras["retrieval"] + extras["modelcreation"]
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
@ -130,7 +130,7 @@ setup(
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.9.3",
"tokenizers == 0.9.4",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions
@ -143,9 +143,6 @@ setup(
"tqdm >= 4.27",
# for OpenAI GPT
"regex != 2019.12.17",
# for SentencePiece models
"sentencepiece == 0.1.91",
"protobuf",
# for XLM
"sacremoses",
],

Просмотреть файл

@ -24,10 +24,7 @@ from typing import Dict, List, Tuple
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece
# from transformers.tokenization_openai import OpenAIGPTTokenizer
from transformers.utils import sentencepiece_model_pb2 as model
from .file_utils import requires_sentencepiece
from .file_utils import requires_protobuf, requires_sentencepiece
class SentencePieceExtractor:
@ -64,12 +61,6 @@ def check_number_comma(piece: str) -> bool:
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
def get_proto(filename: str):
m = model.ModelProto()
m.ParseFromString(open(filename, "rb").read())
return m
class Converter:
def __init__(self, original_tokenizer):
self.original_tokenizer = original_tokenizer
@ -292,8 +283,15 @@ class RobertaConverter(Converter):
class SpmConverter(Converter):
def __init__(self, *args):
requires_protobuf(self)
super().__init__(*args)
self.proto = get_proto(self.original_tokenizer.vocab_file)
from .utils import sentencepiece_model_pb2 as model_pb2
m = model_pb2.ModelProto()
m.ParseFromString(open(self.original_tokenizer.vocab_file, "rb").read())
self.proto = m
def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces]

Просмотреть файл

@ -8,7 +8,7 @@ from tqdm import tqdm
from ...file_utils import is_tf_available, is_torch_available
from ...tokenization_bert import whitespace_tokenize
from ...tokenization_utils_base import PreTrainedTokenizerBase, TruncationStrategy
from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
from ...utils import logging
from .utils import DataProcessor
@ -765,6 +765,7 @@ class SquadFeatures:
token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
start_position: start of the answer token index
end_position: end of the answer token index
encoding: optionally store the BatchEncoding with the fast-tokenizer alignement methods.
"""
def __init__(
@ -784,6 +785,7 @@ class SquadFeatures:
end_position,
is_impossible,
qas_id: str = None,
encoding: BatchEncoding = None,
):
self.input_ids = input_ids
self.attention_mask = attention_mask
@ -803,6 +805,8 @@ class SquadFeatures:
self.is_impossible = is_impossible
self.qas_id = qas_id
self.encoding = encoding
class SquadResult:
"""

Просмотреть файл

@ -185,6 +185,15 @@ except ImportError:
_sentencepiece_available = False
try:
import google.protobuf # noqa: F401
_protobuf_available = True
except ImportError:
_protobuf_available = False
try:
import tokenizers # noqa: F401
@ -270,6 +279,10 @@ def is_sentencepiece_available():
return _sentencepiece_available
def is_protobuf_available():
return _protobuf_available
def is_tokenizers_available():
return _tokenizers_available
@ -330,6 +343,14 @@ that match your environment.
"""
# docstyle-ignore
PROTOBUF_IMPORT_ERROR = """
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
that match your environment.
"""
# docstyle-ignore
FAISS_IMPORT_ERROR = """
{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
@ -420,6 +441,12 @@ def requires_sentencepiece(obj):
raise ImportError(SENTENCEPIECE_IMPORT_ERROR.format(name))
def requires_protobuf(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_protobuf_available():
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

Просмотреть файл

@ -32,7 +32,7 @@ import numpy as np
from .configuration_auto import AutoConfig
from .configuration_utils import PretrainedConfig
from .data import SquadExample, squad_convert_examples_to_features
from .data import SquadExample, SquadFeatures, squad_convert_examples_to_features
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available
from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
@ -1758,6 +1758,7 @@ class QuestionAnsweringPipeline(Pipeline):
- **answer** (:obj:`str`) -- The answer to the question.
"""
# Set defaults values
kwargs.setdefault("padding", "longest")
kwargs.setdefault("topk", 1)
kwargs.setdefault("doc_stride", 128)
kwargs.setdefault("max_answer_len", 15)
@ -1773,19 +1774,87 @@ class QuestionAnsweringPipeline(Pipeline):
# Convert inputs to features
examples = self._args_parser(*args, **kwargs)
features_list = [
squad_convert_examples_to_features(
examples=[example],
tokenizer=self.tokenizer,
max_seq_length=kwargs["max_seq_len"],
doc_stride=kwargs["doc_stride"],
max_query_length=kwargs["max_question_len"],
padding_strategy=PaddingStrategy.MAX_LENGTH.value,
is_training=False,
tqdm_enabled=False,
)
for example in examples
]
if not self.tokenizer.is_fast:
features_list = [
squad_convert_examples_to_features(
examples=[example],
tokenizer=self.tokenizer,
max_seq_length=kwargs["max_seq_len"],
doc_stride=kwargs["doc_stride"],
max_query_length=kwargs["max_question_len"],
padding_strategy=PaddingStrategy.MAX_LENGTH.value,
is_training=False,
tqdm_enabled=False,
)
for example in examples
]
else:
features_list = []
for example in examples:
# Define the side we want to truncate / pad and the text/pair sorting
question_first = bool(self.tokenizer.padding_side == "right")
encoded_inputs = self.tokenizer(
text=example.question_text if question_first else example.context_text,
text_pair=example.context_text if question_first else example.question_text,
padding=kwargs["padding"],
truncation="only_second" if question_first else "only_first",
max_length=kwargs["max_seq_len"],
stride=kwargs["doc_stride"],
return_tensors="np",
return_token_type_ids=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_special_tokens_mask=True,
)
# When the input is too long, it's converted in a batch of inputs with overflowing tokens
# and a stride of overlap between the inputs. If a batch of inputs is given, a special output
# "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
# Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
# "num_span" is the number of output samples generated from the overflowing tokens.
num_spans = len(encoded_inputs["input_ids"])
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
p_mask = np.asarray(
[
[tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
for span_id in range(num_spans)
]
)
# keep the cls_token unmasked (some models use it to indicate unanswerable questions)
if self.tokenizer.cls_token_id:
cls_index = np.nonzero(encoded_inputs["input_ids"] == self.tokenizer.cls_token_id)
p_mask[cls_index] = 0
features = []
for span_idx in range(num_spans):
features.append(
SquadFeatures(
input_ids=encoded_inputs["input_ids"][span_idx],
attention_mask=encoded_inputs["attention_mask"][span_idx],
token_type_ids=encoded_inputs["token_type_ids"][span_idx],
p_mask=p_mask[span_idx].tolist(),
encoding=encoded_inputs[span_idx],
# We don't use the rest of the values - and actually
# for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
cls_index=None,
token_to_orig_map={},
example_index=0,
unique_id=0,
paragraph_len=0,
token_is_max_context=0,
tokens=[],
start_position=0,
end_position=0,
is_impossible=False,
qas_id=None,
)
)
features_list.append(features)
all_answers = []
for features, example in zip(features_list, examples):
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
@ -1828,20 +1897,56 @@ class QuestionAnsweringPipeline(Pipeline):
start_[0] = end_[0] = 0.0
starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
char_to_word = np.array(example.char_to_word_offset)
if not self.tokenizer.is_fast:
char_to_word = np.array(example.char_to_word_offset)
# Convert the answer (tokens) back to the original text
answers += [
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
for s, e, score in zip(starts, ends, scores)
]
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
answers += [
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
for s, e, score in zip(starts, ends, scores)
]
else:
# Convert the answer (tokens) back to the original text
# Score: score from the model
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
question_first = bool(self.tokenizer.padding_side == "right")
enc = feature.encoding
# Sometimes the max probability token is in the middle of a word so:
# - we start by finding the right word containing the token with `token_to_word`
# - then we convert this word in a character span with `word_to_chars`
answers += [
{
"score": score.item(),
"start": enc.word_to_chars(
enc.token_to_word(s), sequence_index=1 if question_first else 0
)[0],
"end": enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[
1
],
"answer": example.context_text[
enc.word_to_chars(enc.token_to_word(s), sequence_index=1 if question_first else 0)[
0
] : enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[
1
]
],
}
for s, e, score in zip(starts, ends, scores)
]
if kwargs["handle_impossible_answer"]:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
@ -2735,7 +2840,7 @@ def pipeline(
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = False,
use_fast: bool = True,
**kwargs
) -> Pipeline:
"""
@ -2795,7 +2900,7 @@ def pipeline(
When passing a task name or a string model identifier: The specific model version to use. It can be a
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
kwargs:
Additional keyword arguments passed along to the specific pipeline init (see the documentation for the

Просмотреть файл

@ -280,7 +280,7 @@ class AutoTokenizer:
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to try to load the fast version of the tokenizer.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
@ -308,7 +308,7 @@ class AutoTokenizer:
if "bert-base-japanese" in str(pretrained_model_name_or_path):
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
use_fast = kwargs.pop("use_fast", False)
use_fast = kwargs.pop("use_fast", True)
if config.tokenizer_class is not None:
if use_fast and not config.tokenizer_class.endswith("Fast"):

Просмотреть файл

@ -18,6 +18,7 @@ from typing import List, Optional
from .tokenization_gpt2_fast import GPT2TokenizerFast
from .tokenization_roberta import RobertaTokenizer
from .tokenization_utils_base import AddedToken
from .utils import logging
@ -172,6 +173,32 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
**kwargs,
)
@property
def mask_token(self) -> str:
"""
:obj:`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while
not having been set.
Roberta tokenizer has a special mask token to be usble in the fill-mask pipeline. The mask token will greedily
comprise the space before the `<mask>`.
"""
if self._mask_token is None and self.verbose:
logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
@mask_token.setter
def mask_token(self, value):
"""
Overriding the default behavior of the mask token to have it eat the space before it.
This is needed to preserve backward compatibility with all the previously used models based on Roberta.
"""
# Mask token behave like a normal word, i.e. include the space before it
# So we set lstrip to True
value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
self._mask_token = value
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:

Просмотреть файл

@ -182,7 +182,9 @@ def to_py_obj(obj):
"""
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
"""
if isinstance(obj, (list, tuple)):
if isinstance(obj, (dict, BatchEncoding)):
return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj]
elif is_tf_available() and isinstance(obj, tf.Tensor):
return obj.numpy().tolist()
@ -216,6 +218,9 @@ class BatchEncoding(UserDict):
initialization.
prepend_batch_axis (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to add a batch axis when converting to tensors (see :obj:`tensor_type` above).
n_sequences (:obj:`Optional[int]`, `optional`):
You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
initialization.
"""
def __init__(
@ -224,6 +229,7 @@ class BatchEncoding(UserDict):
encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None,
tensor_type: Union[None, str, TensorType] = None,
prepend_batch_axis: bool = False,
n_sequences: Optional[int] = None,
):
super().__init__(data)
@ -232,8 +238,22 @@ class BatchEncoding(UserDict):
self._encodings = encoding
if n_sequences is None and encoding is not None and len(encoding):
n_sequences = encoding[0].n_sequences
self._n_sequences = n_sequences
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
@property
def n_sequences(self) -> Optional[int]:
"""
:obj:`Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this
:class:`~transformers.BatchEncoding`. Currently can be one of :obj:`None` (unknown), :obj:`1` (a single
sentence) or :obj:`2` (a pair of sentences)
"""
return self.n_sequences
@property
def is_fast(self) -> bool:
"""
@ -311,6 +331,27 @@ class BatchEncoding(UserDict):
raise ValueError("tokens() is not available when using Python-based tokenizers")
return self._encodings[batch_index].tokens
def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]:
"""
Return a list mapping the tokens to the id of their original sentences:
- :obj:`None` for special tokens added around or between sequences,
- :obj:`0` for tokens corresponding to words in the first sequence,
- :obj:`1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly
encoded.
Args:
batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.
Returns:
:obj:`List[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens
added by the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their
corresponding sequence.
"""
if not self._encodings:
raise ValueError("sequence_ids() is not available when using Python-based tokenizers")
return self._encodings[batch_index].sequence_ids
def words(self, batch_index: int = 0) -> List[Optional[int]]:
"""
Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
@ -325,7 +366,67 @@ class BatchEncoding(UserDict):
"""
if not self._encodings:
raise ValueError("words() is not available when using Python-based tokenizers")
return self._encodings[batch_index].words
warnings.warn(
"`BatchEncoding.words()` property is deprecated and should be replaced with the identical, "
"but more self-explanatory `BatchEncoding.word_ids()` property.",
FutureWarning,
)
return self.word_ids(batch_index)
def word_ids(self, batch_index: int = 0) -> List[Optional[int]]:
"""
Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
Args:
batch_index (:obj:`int`, `optional`, defaults to 0): The index to access in the batch.
Returns:
:obj:`List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by
the tokenizer are mapped to :obj:`None` and other tokens are mapped to the index of their corresponding
word (several tokens will be mapped to the same word index if they are parts of that word).
"""
if not self._encodings:
raise ValueError("word_ids() is not available when using Python-based tokenizers")
return self._encodings[batch_index].word_ids
def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
"""
Get the index of the sequence represented by the given token. In the general use case, this method returns
:obj:`0` for a single sequence or the first sequence of a pair, and :obj:`1` for the second sequence of a pair
Can be called as:
- ``self.token_to_sequence(token_index)`` if batch size is 1
- ``self.token_to_sequence(batch_index, token_index)`` if batch size is greater than 1
This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
words are defined by the user). In this case it allows to easily associate encoded tokens with provided
tokenized words.
Args:
batch_or_token_index (:obj:`int`):
Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
the token in the sequence.
token_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the token in the
sequence.
Returns:
:obj:`int`: Index of the word in the input sequence.
"""
if not self._encodings:
raise ValueError("token_to_sequence() is not available when using Python based tokenizers")
if token_index is not None:
batch_index = batch_or_token_index
else:
batch_index = 0
token_index = batch_or_token_index
if batch_index < 0:
batch_index = self._batch_size + batch_index
if token_index < 0:
token_index = self._seq_len + token_index
return self._encodings[batch_index].token_to_sequence(token_index)
def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
"""
@ -365,9 +466,11 @@ class BatchEncoding(UserDict):
token_index = self._seq_len + token_index
return self._encodings[batch_index].token_to_word(token_index)
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> Optional[TokenSpan]:
def word_to_tokens(
self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
) -> Optional[TokenSpan]:
"""
Get the encoded token span corresponding to a word in the sequence of the batch.
Get the encoded token span corresponding to a word in a sequence of the batch.
Token spans are returned as a :class:`~transformers.tokenization_utils_base.TokenSpan` with:
@ -376,8 +479,9 @@ class BatchEncoding(UserDict):
Can be called as:
- ``self.word_to_tokens(word_index)`` if batch size is 1
- ``self.word_to_tokens(batch_index, word_index)`` if batch size is greater or equal to 1
- ``self.word_to_tokens(word_index, sequence_index: int = 0)`` if batch size is 1
- ``self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)`` if batch size is greater or equal
to 1
This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
@ -390,6 +494,9 @@ class BatchEncoding(UserDict):
word_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the
sequence.
sequence_index (:obj:`int`, `optional`, defaults to 0):
If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
or 1) the provided word index belongs to.
Returns:
Optional :class:`~transformers.tokenization_utils_base.TokenSpan` Span of tokens in the encoded sequence.
@ -407,7 +514,7 @@ class BatchEncoding(UserDict):
batch_index = self._batch_size + batch_index
if word_index < 0:
word_index = self._seq_len + word_index
span = self._encodings[batch_index].word_to_tokens(word_index)
span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index)
return TokenSpan(*span) if span is not None else None
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
@ -446,7 +553,9 @@ class BatchEncoding(UserDict):
token_index = batch_or_token_index
return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index)))
def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int:
def char_to_token(
self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0
) -> int:
"""
Get the index of the token in the encoded output comprising a character in the original string for a sequence
of the batch.
@ -467,6 +576,9 @@ class BatchEncoding(UserDict):
char_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the
sequence.
sequence_index (:obj:`int`, `optional`, defaults to 0):
If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
or 1) the provided character index belongs to.
Returns:
@ -480,9 +592,11 @@ class BatchEncoding(UserDict):
else:
batch_index = 0
char_index = batch_or_char_index
return self._encodings[batch_index].char_to_token(char_index)
return self._encodings[batch_index].char_to_token(char_index, sequence_index)
def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = None) -> CharSpan:
def word_to_chars(
self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
) -> CharSpan:
"""
Get the character span in the original string corresponding to given word in a sequence of the batch.
@ -503,6 +617,9 @@ class BatchEncoding(UserDict):
word_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the word in the
sequence.
sequence_index (:obj:`int`, `optional`, defaults to 0):
If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
or 1) the provided word index belongs to.
Returns:
:obj:`CharSpan` or :obj:`List[CharSpan]`: Span(s) of the associated character or characters in the string.
@ -520,9 +637,9 @@ class BatchEncoding(UserDict):
else:
batch_index = 0
word_index = batch_or_word_index
return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index)))
return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index)))
def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int:
def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int:
"""
Get the word in the original string corresponding to a character in the original string of a sequence of the
batch.
@ -543,6 +660,9 @@ class BatchEncoding(UserDict):
char_index (:obj:`int`, `optional`):
If a batch index is provided in `batch_or_token_index`, this can be the index of the character in the
original string.
sequence_index (:obj:`int`, `optional`, defaults to 0):
If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
or 1) the provided character index belongs to.
Returns:
@ -556,7 +676,7 @@ class BatchEncoding(UserDict):
else:
batch_index = 0
char_index = batch_or_char_index
return self._encodings[batch_index].char_to_word(char_index)
return self._encodings[batch_index].char_to_word(char_index, sequence_index)
def convert_to_tensors(
self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
@ -1872,6 +1992,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
"Only fast tokenizers (instances of PretrainedTokenizerFast) can be saved in non legacy format."
)
save_directory = str(save_directory)
added_tokens_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
)

Просмотреть файл

@ -169,9 +169,10 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
) -> Dict[str, Any]:
) -> Tuple[Dict[str, Any], List[EncodingFast]]:
"""
Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict.
Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list
of encodings, take care of building a batch from overflowing tokens.
Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are
lists (overflows) of lists (tokens).
@ -203,7 +204,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
if return_length:
encoding_dict["length"].append(len(e.ids))
return encoding_dict
return encoding_dict, encodings
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
"""
@ -390,9 +391,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
)
# Convert encoding to dict
# `Tokens` has type: List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]]
# `Tokens` has type: Tuple[
# List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
# List[EncodingFast]
# ]
# with nested dimensions corresponding to batch, overflows, sequence length
tokens = [
tokens_and_encodings = [
self._convert_encoding(
encoding=encoding,
return_token_type_ids=return_token_type_ids,
@ -406,22 +410,27 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
for encoding in encodings
]
# Convert the output to have dict[list] from list[dict]
sanitized = {}
for key in tokens[0].keys():
# To List[List[List[int]]] of shape (batch, overflows, sequence length)
stack = [e for item in tokens for e in item[key]]
sanitized[key] = stack
# Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
# From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
# (we say ~ because the number of overflow varies with the example in the batch)
#
# To match each overflowing sample with the original sample in the batch
# we add an overflow_to_sample_mapping array (see below)
sanitized_tokens = {}
for key in tokens_and_encodings[0][0].keys():
stack = [e for item, _ in tokens_and_encodings for e in item[key]]
sanitized_tokens[key] = stack
sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
# If returning overflowing tokens, we need to return a mapping
# from the batch idx to the original sample
if return_overflowing_tokens:
overflow_to_sample_mapping = []
for i, enc in enumerate(tokens):
overflow_to_sample_mapping += [i] * len(enc["input_ids"])
sanitized["overflow_to_sample_mapping"] = overflow_to_sample_mapping
for i, (toks, _) in enumerate(tokens_and_encodings):
overflow_to_sample_mapping += [i] * len(toks["input_ids"])
sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
return BatchEncoding(sanitized, encodings, tensor_type=return_tensors)
return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
def _encode_plus(
self,
@ -518,6 +527,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the
specific :meth:`~transformers.PreTrainedTokenizerFast._save_pretrained`
"""
save_directory = str(save_directory)
if legacy_format:
added_tokens_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE

Просмотреть файл

@ -1,10 +1,10 @@
from typing import List, Optional
from unittest import mock
from transformers import is_tf_available, is_torch_available, pipeline
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
from transformers.tokenization_utils_base import to_py_obj
VALID_INPUTS = ["A simple string", ["list of strings"]]
@ -13,9 +13,11 @@ VALID_INPUTS = ["A simple string", ["list of strings"]]
@is_pipeline_test
class CustomInputPipelineCommonMixin:
pipeline_task = None
pipeline_loading_kwargs = {}
small_models = None # Models tested without the @slow decorator
large_models = None # Models tested with the @slow decorator
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
small_models = [] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator
valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers
def setUp(self) -> None:
if not is_tf_available() and not is_torch_available():
@ -47,73 +49,11 @@ class CustomInputPipelineCommonMixin:
@require_torch
@slow
def test_pt_defaults(self):
pipeline(self.pipeline_task, framework="pt")
@require_tf
@slow
def test_tf_defaults(self):
pipeline(self.pipeline_task, framework="tf")
@require_torch
def test_torch_small(self):
for model_name in self.small_models:
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt")
self._test_pipeline(nlp)
@require_tf
def test_tf_small(self):
for model_name in self.small_models:
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf")
self._test_pipeline(nlp)
@require_torch
@slow
def test_torch_large(self):
for model_name in self.large_models:
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt")
self._test_pipeline(nlp)
@require_tf
@slow
def test_tf_large(self):
for model_name in self.large_models:
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf")
self._test_pipeline(nlp)
def _test_pipeline(self, nlp: Pipeline):
raise NotImplementedError
@is_pipeline_test
class MonoInputPipelineCommonMixin:
pipeline_task = None
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
small_models = [] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator
mandatory_keys = {} # Keys which should be in the output
valid_inputs = VALID_INPUTS # inputs which are valid
invalid_inputs = [None] # inputs which are not allowed
expected_multi_result: Optional[List] = None
expected_check_keys: Optional[List[str]] = None
def setUp(self) -> None:
if not is_tf_available() and not is_torch_available():
return # Currently no JAX pipelines
for model_name in self.small_models:
pipeline(self.pipeline_task, model=model_name, tokenizer=model_name, **self.pipeline_loading_kwargs)
for model_name in self.large_models:
pipeline(self.pipeline_task, model=model_name, tokenizer=model_name, **self.pipeline_loading_kwargs)
@require_torch
@slow
def test_pt_defaults_loads(self):
pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs)
@require_tf
@slow
def test_tf_defaults_loads(self):
def test_tf_defaults(self):
pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs)
@require_torch
@ -166,6 +106,95 @@ class MonoInputPipelineCommonMixin:
)
self._test_pipeline(nlp)
def _test_pipeline(self, nlp: Pipeline):
raise NotImplementedError
@require_torch
def test_compare_slow_fast_torch(self):
for model_name in self.small_models:
nlp_slow = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
use_fast=False,
**self.pipeline_loading_kwargs,
)
nlp_fast = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
use_fast=True,
**self.pipeline_loading_kwargs,
)
self._compare_slow_fast_pipelines(nlp_slow, nlp_fast, method="forward")
@require_tf
def test_compare_slow_fast_tf(self):
for model_name in self.small_models:
nlp_slow = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
use_fast=False,
**self.pipeline_loading_kwargs,
)
nlp_fast = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
use_fast=True,
**self.pipeline_loading_kwargs,
)
self._compare_slow_fast_pipelines(nlp_slow, nlp_fast, method="call")
def _compare_slow_fast_pipelines(self, nlp_slow: Pipeline, nlp_fast: Pipeline, method: str):
"""We check that the inputs to the models forward passes are identical for
slow and fast tokenizers.
"""
with mock.patch.object(
nlp_slow.model, method, wraps=getattr(nlp_slow.model, method)
) as mock_slow, mock.patch.object(nlp_fast.model, method, wraps=getattr(nlp_fast.model, method)) as mock_fast:
for inputs in self.valid_inputs:
if isinstance(inputs, dict):
inputs.update(self.pipeline_running_kwargs)
_ = nlp_slow(**inputs)
_ = nlp_fast(**inputs)
else:
_ = nlp_slow(inputs, **self.pipeline_running_kwargs)
_ = nlp_fast(inputs, **self.pipeline_running_kwargs)
mock_slow.assert_called()
mock_fast.assert_called()
self.assertEqual(len(mock_slow.call_args_list), len(mock_fast.call_args_list))
for mock_slow_call_args, mock_fast_call_args in zip(
mock_slow.call_args_list, mock_slow.call_args_list
):
slow_call_args, slow_call_kwargs = mock_slow_call_args
fast_call_args, fast_call_kwargs = mock_fast_call_args
slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs)
fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs)
self.assertEqual(slow_call_args, fast_call_args)
self.assertDictEqual(slow_call_kwargs, fast_call_kwargs)
@is_pipeline_test
class MonoInputPipelineCommonMixin(CustomInputPipelineCommonMixin):
"""A version of the CustomInputPipelineCommonMixin
with a predefined `_test_pipeline` method.
"""
mandatory_keys = {} # Keys which should be in the output
invalid_inputs = [None] # inputs which are not allowed
expected_multi_result: Optional[List] = None
expected_check_keys: Optional[List[str]] = None
def _test_pipeline(self, nlp: Pipeline):
self.assertIsNotNone(nlp)
@ -199,76 +228,3 @@ class MonoInputPipelineCommonMixin:
self.assertIn(key, result)
self.assertRaises(Exception, nlp, self.invalid_inputs)
# @is_pipeline_test
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
# def setUp(self) -> None:
# self.handler = DefaultArgumentHandler()
#
# def test_kwargs_x(self):
# mono_data = {"X": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_kwargs_data(self):
# mono_data = {"data": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_multi_kwargs(self):
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 2)
#
# multi_data = {
# "data": ["This is a sample input", "This is a second sample input"],
# "test": ["This is a sample input 2", "This is a second sample input 2"],
# }
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 4)
#
# def test_args(self):
# mono_data = "This is a sample input"
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# mono_data = ["This is a sample input"]
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(*multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)

Просмотреть файл

@ -1,29 +0,0 @@
import unittest
from transformers.pipelines import Conversation, Pipeline
from .test_pipelines_common import CustomInputPipelineCommonMixin
class DialoguePipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "conversational"
small_models = [] # Default model - Models tested without the @slow decorator
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
def _test_pipeline(self, nlp: Pipeline):
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
invalid_inputs = ["Hi there!", Conversation()]
self.assertIsNotNone(nlp)
mono_result = nlp(valid_inputs[0])
self.assertIsInstance(mono_result, Conversation)
multi_result = nlp(valid_inputs[1])
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], Conversation)
# Inactive conversations passed to the pipeline raise a ValueError
self.assertRaises(ValueError, nlp, valid_inputs[1])
for bad_input in invalid_inputs:
self.assertRaises(Exception, nlp, bad_input)
self.assertRaises(Exception, nlp, invalid_inputs)

Просмотреть файл

@ -146,10 +146,10 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
@require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
with self.assertRaises(ValueError):
pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True)
pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True, use_fast=False)
@require_torch
def test_pt_defaults_slow_tokenizer(self):

Просмотреть файл

@ -8,10 +8,22 @@ from .test_pipelines_common import CustomInputPipelineCommonMixin
class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "question-answering"
pipeline_running_kwargs = {
"padding": "max_length",
"max_seq_len": 25,
"doc_stride": 5,
} # Default is 'longest' but we use 'max_length' to test equivalence between slow/fast tokenizers
small_models = [
"sshleifer/tiny-distilbert-base-cased-distilled-squad"
] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator
valid_inputs = [
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
{
"question": "In what field is HuggingFace working ?",
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
},
]
def _test_pipeline(self, nlp: Pipeline):
output_keys = {"score", "answer", "start", "end"}

Просмотреть файл

@ -12,6 +12,18 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
] # Models tested without the @slow decorator
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
valid_inputs = [
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics"]},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics, public health"},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics", "public health"]},
{"sequences": ["Who are you voting for in 2020?"], "candidate_labels": "politics"},
{
"sequences": "Who are you voting for in 2020?",
"candidate_labels": "politics",
"hypothesis_template": "This text is about {}",
},
]
def _test_scores_sum_to_one(self, result):
sum = 0.0

Просмотреть файл

@ -9,7 +9,7 @@ from unittest.mock import patch
import numpy as np
from datasets import Dataset
import faiss
from transformers import is_faiss_available
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.configuration_rag import RagConfig
@ -27,6 +27,10 @@ from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
if is_faiss_available():
import faiss
@require_faiss
@require_datasets
class RagRetrieverTest(TestCase):

Просмотреть файл

@ -116,5 +116,5 @@ class AutoTokenizerTest(unittest.TestCase):
@require_tokenizers
def test_from_pretrained_use_fast_toggle(self):
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizer)
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=True), BertTokenizerFast)
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer)
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast)

Просмотреть файл

@ -576,6 +576,42 @@ class TokenizerTesterMixin:
sequences, mask = information["input_ids"], information["token_type_ids"]
self.assertEqual(len(sequences), len(mask))
def test_token_type_ids(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
seq_0 = "Test this method."
# We want to have sequence 0 and sequence 1 are tagged
# respectively with 0 and 1 token_ids
# (regardeless of weither the model use token type ids)
# We use this assumption in the QA pipeline among other place
output = tokenizer(seq_0, return_token_type_ids=True)
self.assertIn(0, output["token_type_ids"])
def test_sequence_ids(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
if not tokenizer.is_fast:
continue
with self.subTest(f"{tokenizer.__class__.__name__}"):
seq_0 = "Test this method."
seq_1 = "With these inputs."
# We want to have sequence 0 and sequence 1 are tagged
# respectively with 0 and 1 token_ids
# (regardeless of weither the model use token type ids)
# We use this assumption in the QA pipeline among other place
output = tokenizer(seq_0)
self.assertIn(0, output.sequence_ids())
output = tokenizer(seq_0, seq_1)
self.assertIn(0, output.sequence_ids())
self.assertIn(1, output.sequence_ids())
if tokenizer.num_special_tokens_to_add(pair=True):
self.assertIn(None, output.sequence_ids())
def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
@ -1878,6 +1914,144 @@ class TokenizerTesterMixin:
batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1
)
# Assert token_to_sequence
self.assertEqual(encoding.token_to_sequence(num_tokens // 2), 0)
self.assertEqual(encoding.token_to_sequence(0, num_tokens // 2), 0)
self.assertEqual(batch_encoding.token_to_sequence(1, num_tokens // 2), 0)
self.assertEqual(batch_encoding.token_to_sequence(0, num_tokens // 2), 0)
self.assertEqual(batch_encoding.token_to_sequence(last_batch_index, num_tokens // 2), 0)
# Pair of input sequences
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
text = " ".join(words)
pair_words = ["Amazing", "example", "full", "of", "inspiration"]
pair_text = " ".join(pair_words)
batch_size = 3
index_word_in_first_seq = words.index("inspiration")
index_word_in_pair_seq = pair_words.index("inspiration")
index_char_in_first_seq = text.find("inspiration")
index_char_in_pair_seq = pair_text.find("inspiration")
pair_encoding = tokenizer_r.encode_plus(text, pair_text, add_special_tokens=False)
pair_batch_encoding = tokenizer_r.batch_encode_plus(
[(text, pair_text)] * batch_size, add_special_tokens=False
)
num_tokens = len(encoding["input_ids"])
last_word_index = len(words) - 1
last_token_index = num_tokens - 1
last_batch_index = batch_size - 1
last_char_index = len(text) - 1
# Assert word_to_tokens
self.assertNotEqual(
pair_encoding.word_to_tokens(index_word_in_first_seq, sequence_index=0).start,
pair_encoding.word_to_tokens(index_word_in_pair_seq, sequence_index=1).start,
)
self.assertEqual(
pair_encoding["input_ids"][
pair_encoding.word_to_tokens(index_word_in_first_seq, sequence_index=0).start
],
pair_encoding["input_ids"][
pair_encoding.word_to_tokens(index_word_in_pair_seq, sequence_index=1).start
],
)
self.assertNotEqual(
pair_batch_encoding.word_to_tokens(1, index_word_in_first_seq, sequence_index=0).start,
pair_batch_encoding.word_to_tokens(1, index_word_in_pair_seq, sequence_index=1).start,
)
self.assertEqual(
pair_batch_encoding["input_ids"][1][
pair_batch_encoding.word_to_tokens(1, index_word_in_first_seq, sequence_index=0).start
],
pair_batch_encoding["input_ids"][1][
pair_batch_encoding.word_to_tokens(1, index_word_in_pair_seq, sequence_index=1).start
],
)
# Assert char_to_token
self.assertNotEqual(
pair_encoding.char_to_token(index_char_in_first_seq, sequence_index=0),
pair_encoding.char_to_token(index_char_in_pair_seq, sequence_index=1),
)
self.assertEqual(
pair_encoding["input_ids"][pair_encoding.char_to_token(index_char_in_first_seq, sequence_index=0)],
pair_encoding["input_ids"][pair_encoding.char_to_token(index_char_in_pair_seq, sequence_index=1)],
)
self.assertNotEqual(
pair_batch_encoding.char_to_token(1, index_char_in_first_seq, sequence_index=0),
pair_batch_encoding.char_to_token(1, index_char_in_pair_seq, sequence_index=1),
)
self.assertEqual(
pair_batch_encoding["input_ids"][1][
pair_batch_encoding.char_to_token(1, index_char_in_first_seq, sequence_index=0)
],
pair_batch_encoding["input_ids"][1][
pair_batch_encoding.char_to_token(1, index_char_in_pair_seq, sequence_index=1)
],
)
# Assert char_to_word
self.assertNotEqual(
pair_encoding.char_to_word(index_char_in_first_seq, sequence_index=0),
pair_encoding.char_to_word(index_char_in_pair_seq, sequence_index=1),
)
self.assertEqual(
words[pair_encoding.char_to_word(index_char_in_first_seq, sequence_index=0)],
pair_words[pair_encoding.char_to_word(index_char_in_pair_seq, sequence_index=1)],
)
self.assertNotEqual(
pair_batch_encoding.char_to_word(1, index_char_in_first_seq, sequence_index=0),
pair_batch_encoding.char_to_word(1, index_char_in_pair_seq, sequence_index=1),
)
self.assertEqual(
words[pair_batch_encoding.char_to_word(1, index_char_in_first_seq, sequence_index=0)],
pair_words[pair_batch_encoding.char_to_word(1, index_char_in_pair_seq, sequence_index=1)],
)
# Assert word_to_chars
self.assertNotEqual(
pair_encoding.word_to_chars(index_word_in_first_seq, sequence_index=0).start,
pair_encoding.word_to_chars(index_word_in_pair_seq, sequence_index=1).start,
)
self.assertEqual(
text[pair_encoding.word_to_chars(index_word_in_first_seq, sequence_index=0).start],
pair_text[pair_encoding.word_to_chars(index_word_in_pair_seq, sequence_index=1).start],
)
self.assertNotEqual(
pair_batch_encoding.word_to_chars(1, index_word_in_first_seq, sequence_index=0).start,
pair_batch_encoding.word_to_chars(1, index_word_in_pair_seq, sequence_index=1).start,
)
self.assertEqual(
text[pair_batch_encoding.word_to_chars(1, index_word_in_first_seq, sequence_index=0).start],
pair_text[pair_batch_encoding.word_to_chars(1, index_word_in_pair_seq, sequence_index=1).start],
)
# Assert token_to_sequence
pair_encoding = tokenizer_r.encode_plus(text, pair_text, add_special_tokens=True)
pair_sequence_ids = [
pair_encoding.token_to_sequence(i) for i in range(len(pair_encoding["input_ids"]))
]
self.assertIn(0, pair_sequence_ids)
self.assertIn(1, pair_sequence_ids)
if tokenizer_r.num_special_tokens_to_add(pair=True):
self.assertIn(None, pair_sequence_ids)
pair_batch_encoding = tokenizer_r.batch_encode_plus(
[(text, pair_text)] * batch_size, add_special_tokens=True
)
pair_batch_sequence_ids = [
pair_batch_encoding.token_to_sequence(1, i)
for i in range(len(pair_batch_encoding["input_ids"][0]))
]
self.assertIn(0, pair_batch_sequence_ids)
self.assertIn(1, pair_batch_sequence_ids)
if tokenizer_r.num_special_tokens_to_add(pair=True):
self.assertIn(None, pair_batch_sequence_ids)
def test_tokenization_python_rust_equals(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)):

Просмотреть файл

@ -4,13 +4,12 @@ import shutil
import tempfile
from unittest import TestCase
from transformers import BartTokenizer, BartTokenizerFast, DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast
from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig
from transformers.file_utils import is_datasets_available, is_faiss_available, is_torch_available
from transformers.testing_utils import require_datasets, require_faiss, require_torch, slow
from transformers.tokenization_bart import BartTokenizer
from transformers.testing_utils import require_datasets, require_faiss, require_tokenizers, require_torch, slow
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
@ -96,6 +95,7 @@ class RagTokenizerTest(TestCase):
def tearDown(self):
shutil.rmtree(self.tmpdirname)
@require_tokenizers
def test_save_load_pretrained_with_saved_config(self):
save_dir = os.path.join(self.tmpdirname, "rag_tokenizer")
@ -104,10 +104,10 @@ class RagTokenizerTest(TestCase):
rag_config.save_pretrained(save_dir)
rag_tokenizer.save_pretrained(save_dir)
new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config)
self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer)
self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab)
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer)
self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)
self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizerFast)
self.assertEqual(new_rag_tokenizer.question_encoder.get_vocab(), rag_tokenizer.question_encoder.get_vocab())
self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizerFast)
self.assertEqual(new_rag_tokenizer.generator.get_vocab(), rag_tokenizer.generator.get_vocab())
@slow
def test_pretrained_token_nq_tokenizer(self):

Просмотреть файл

@ -18,7 +18,7 @@ import os
import unittest
from transformers.file_utils import cached_property
from transformers.testing_utils import slow
from transformers.testing_utils import require_sentencepiece, slow
from transformers.tokenization_xlm_prophetnet import SPIECE_UNDERLINE, XLMProphetNetTokenizer
from .test_tokenization_common import TokenizerTesterMixin
@ -27,6 +27,7 @@ from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
@require_sentencepiece
class XLMProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XLMProphetNetTokenizer