Merge pull request #2177 from mandubian/issue-2106
:zip: #2106 tokenizer.tokenize speed improvement (3-8x) by caching added_tokens in a Set
This commit is contained in:
Коммит
deceb00161
|
@ -231,6 +231,7 @@ class PreTrainedTokenizer(object):
|
|||
|
||||
# Added tokens
|
||||
self.added_tokens_encoder = {}
|
||||
self.unique_added_tokens_encoder = set()
|
||||
self.added_tokens_decoder = {}
|
||||
|
||||
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
|
||||
|
@ -554,6 +555,7 @@ class PreTrainedTokenizer(object):
|
|||
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
|
||||
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||
self.added_tokens_encoder.update(added_tok_encoder)
|
||||
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
|
||||
self.added_tokens_decoder.update(added_tok_decoder)
|
||||
|
||||
return len(to_add_tokens)
|
||||
|
@ -631,6 +633,7 @@ class PreTrainedTokenizer(object):
|
|||
|
||||
return added_tokens
|
||||
|
||||
|
||||
def tokenize(self, text, **kwargs):
|
||||
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||
Split in words for word-based vocabulary or sub-words for sub-word-based
|
||||
|
@ -685,18 +688,17 @@ class PreTrainedTokenizer(object):
|
|||
for tok in tok_list:
|
||||
tokenized_text = []
|
||||
for sub_text in text_list:
|
||||
if sub_text not in self.added_tokens_encoder \
|
||||
and sub_text not in all_special_tokens:
|
||||
if sub_text not in self.unique_added_tokens_encoder:
|
||||
tokenized_text += split_on_token(tok, sub_text)
|
||||
else:
|
||||
tokenized_text += [sub_text]
|
||||
text_list = tokenized_text
|
||||
|
||||
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
|
||||
in self.added_tokens_encoder and token not in all_special_tokens \
|
||||
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) \
|
||||
if token not in self.unique_added_tokens_encoder
|
||||
else [token] for token in tokenized_text)))
|
||||
|
||||
added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens
|
||||
added_tokens = self.unique_added_tokens_encoder
|
||||
tokenized_text = split_on_tokens(added_tokens, text)
|
||||
return tokenized_text
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче