[pegasus] Faster tokenizer tests (#7672)
This commit is contained in:
Родитель
bc00b37a0d
Коммит
b0f05e0c4c
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# this script builds a small sample spm file tests/fixtures/test_sentencepiece_no_bos.model, with features needed by pegasus
|
||||
|
||||
# 1. pip install sentencepiece
|
||||
#
|
||||
# 2. wget https://raw.githubusercontent.com/google/sentencepiece/master/data/botchan.txt
|
||||
|
||||
# 3. build
|
||||
import sentencepiece as spm
|
||||
|
||||
# pegasus:
|
||||
# 1. no bos
|
||||
# 2. eos_id is 1
|
||||
# 3. unk_id is 2
|
||||
# build a sample spm file accordingly
|
||||
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=test_sentencepiece_no_bos --bos_id=-1 --unk_id=2 --eos_id=1 --vocab_size=1000')
|
||||
|
||||
# 4. now update the fixture
|
||||
# mv test_sentencepiece_no_bos.model ../../tests/fixtures/
|
|
@ -184,13 +184,23 @@ def require_faiss(test_case):
|
|||
return test_case
|
||||
|
||||
|
||||
def get_tests_dir():
|
||||
def get_tests_dir(append_path=None):
|
||||
"""
|
||||
returns the full path to the `tests` dir, so that the tests can be invoked from anywhere
|
||||
Args:
|
||||
append_path: optional path to append to the tests dir path
|
||||
|
||||
Return:
|
||||
The full path to the `tests` dir, so that the tests can be invoked from anywhere.
|
||||
Optionally `append_path` is joined after the `tests` dir the former is provided.
|
||||
|
||||
"""
|
||||
# this function caller's __file__
|
||||
caller__file__ = inspect.stack()[1][1]
|
||||
return os.path.abspath(os.path.dirname(caller__file__))
|
||||
tests_dir = os.path.abspath(os.path.dirname(caller__file__))
|
||||
if append_path:
|
||||
return os.path.join(tests_dir, append_path)
|
||||
else:
|
||||
return tests_dir
|
||||
|
||||
|
||||
#
|
||||
|
|
|
@ -49,7 +49,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Dont use reserved words added_token_encoder, added_tokens_decoder because of
|
||||
# Don't use reserved words added_token_encoder, added_tokens_decoder because of
|
||||
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
|
||||
assert len(self.added_tokens_decoder) == 0
|
||||
self.encoder: Dict[int, str] = {0: self.pad_token, 1: self.eos_token}
|
||||
|
@ -58,7 +58,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
""" Converts a token (str) to an id using the vocab. """
|
||||
if token in self.decoder:
|
||||
return self.decoder[token]
|
||||
elif token in self.added_tokens_decoder:
|
||||
|
@ -67,7 +67,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||
return sp_id + self.offset
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
"""Converts an index (integer) to a token (str) using the vocab."""
|
||||
if index in self.encoder:
|
||||
return self.encoder[index]
|
||||
elif index in self.added_tokens_encoder:
|
||||
|
@ -81,11 +81,6 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||
def vocab_size(self) -> int:
|
||||
return len(self.sp_model) + self.offset
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def num_special_tokens_to_add(self, pair=False):
|
||||
"""Just EOS"""
|
||||
return 1
|
||||
|
@ -109,12 +104,12 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
Build model inputs from a sequence or a pair of sequences for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
A Pegasus sequence has the following format, where ``X`` represents the sequence:
|
||||
|
||||
- single sequence: ``X </s>``
|
||||
- pair of sequences: ``A B </s>`` (not intended use)
|
||||
- pair of sequences: ``A B </s>`` (not intended use)
|
||||
|
||||
BOS is never used.
|
||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import Dict
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
@ -119,7 +120,7 @@ class ReformerTokenizer(PreTrainedTokenizer):
|
|||
def vocab_size(self):
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
def get_vocab(self):
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
|
|
@ -186,7 +186,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||
|
||||
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
||||
print('We have added', num_added_toks, 'tokens')
|
||||
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
"""
|
||||
new_tokens = [str(tok) for tok in new_tokens]
|
||||
|
|
Двоичный файл не отображается.
|
@ -1,13 +1,15 @@
|
|||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
from transformers.tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
|
||||
|
||||
|
||||
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = PegasusTokenizer
|
||||
|
@ -17,11 +19,9 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
save_dir = Path(self.tmpdirname)
|
||||
spm_file = PegasusTokenizer.vocab_files_names["vocab_file"]
|
||||
if not (save_dir / spm_file).exists():
|
||||
tokenizer = self.pegasus_large_tokenizer
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = PegasusTokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
@cached_property
|
||||
def pegasus_large_tokenizer(self):
|
||||
|
@ -32,10 +32,7 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
pass
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
|
||||
if not kwargs:
|
||||
return self.pegasus_large_tokenizer
|
||||
else:
|
||||
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return ("This is a test", "This is a test")
|
||||
|
|
|
@ -14,19 +14,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers import BatchEncoding
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import _torch_available
|
||||
from transformers.testing_utils import _torch_available, get_tests_dir
|
||||
from transformers.tokenization_t5 import T5Tokenizer, T5TokenizerFast
|
||||
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче