Merge pull request #2177 from mandubian/issue-2106

:zip: #2106 tokenizer.tokenize speed improvement (3-8x) by caching added_tokens in a Set
This commit is contained in:
Thomas Wolf 2019-12-21 14:31:20 +01:00 коммит произвёл GitHub
Родитель ed9b84816e cc0135134b
Коммит deceb00161
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 7 добавлений и 5 удалений

Просмотреть файл

@ -231,6 +231,7 @@ class PreTrainedTokenizer(object):
# Added tokens
self.added_tokens_encoder = {}
self.unique_added_tokens_encoder = set()
self.added_tokens_decoder = {}
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
@ -554,6 +555,7 @@ class PreTrainedTokenizer(object):
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
self.added_tokens_decoder.update(added_tok_decoder)
return len(to_add_tokens)
@ -631,6 +633,7 @@ class PreTrainedTokenizer(object):
return added_tokens
def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
@ -685,18 +688,17 @@ class PreTrainedTokenizer(object):
for tok in tok_list:
tokenized_text = []
for sub_text in text_list:
if sub_text not in self.added_tokens_encoder \
and sub_text not in all_special_tokens:
if sub_text not in self.unique_added_tokens_encoder:
tokenized_text += split_on_token(tok, sub_text)
else:
tokenized_text += [sub_text]
text_list = tokenized_text
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in all_special_tokens \
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) \
if token not in self.unique_added_tokens_encoder
else [token] for token in tokenized_text)))
added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens
added_tokens = self.unique_added_tokens_encoder
tokenized_text = split_on_tokens(added_tokens, text)
return tokenized_text