Moved _truncate_seq_pair outside of if else block.
This commit is contained in:
Родитель
b0bfcba2ea
Коммит
fb3f7ddef6
|
@ -53,15 +53,20 @@ class Tokenizer:
|
|||
def tokenize(self, text):
|
||||
"""Tokenizes a list of documents using a BERT tokenizer
|
||||
Args:
|
||||
text (list): List of strings (one sequence) or tuples (two sequences).
|
||||
text (list): List of strings (one sequence) or
|
||||
tuples (two sequences).
|
||||
|
||||
Returns:
|
||||
[list]: List of lists. Each sublist contains WordPiece tokens of the input sequence(s).
|
||||
[list]: List of lists. Each sublist contains WordPiece tokens
|
||||
of the input sequence(s).
|
||||
"""
|
||||
if isinstance(text[0], str):
|
||||
return [self.tokenizer.tokenize(x) for x in tqdm(text)]
|
||||
else:
|
||||
return [[self.tokenizer.tokenize(x) for x in sentences] for sentences in tqdm(text)]
|
||||
return [
|
||||
[self.tokenizer.tokenize(x) for x in sentences]
|
||||
for sentences in tqdm(text)
|
||||
]
|
||||
|
||||
def preprocess_classification_tokens(self, tokens, max_len=BERT_MAX_LEN):
|
||||
"""Preprocessing of input tokens:
|
||||
|
@ -75,7 +80,7 @@ class Tokenizer:
|
|||
max_len (int, optional): Maximum number of tokens
|
||||
(documents will be truncated or padded).
|
||||
Defaults to 512.
|
||||
Returns:
|
||||
Returns:
|
||||
tuple: A tuple containing the following three lists
|
||||
list of preprocesssed token lists
|
||||
list of input mask lists
|
||||
|
@ -89,50 +94,58 @@ class Tokenizer:
|
|||
)
|
||||
max_len = BERT_MAX_LEN
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
# This is a simple heuristic which will always truncate the longer
|
||||
# sequence one token at a time. This makes more sense than
|
||||
# truncating an equal percent of tokens from each, since if one
|
||||
# sequence is very short then each token that's truncated likely
|
||||
# contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
tokens_a.append("[SEP]")
|
||||
tokens_b.append("[SEP]")
|
||||
|
||||
return [tokens_a, tokens_b]
|
||||
|
||||
if isinstance(tokens[0], str):
|
||||
tokens = [x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
|
||||
token_type_ids = None
|
||||
else:
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
tokens_a.append("[SEP]")
|
||||
tokens_b.append("[SEP]")
|
||||
|
||||
return [tokens_a, tokens_b]
|
||||
|
||||
# print(tokens[:2])
|
||||
# get tokens for each sentence [[t00, t01, ...] [t10, t11,... ]]
|
||||
tokens = [_truncate_seq_pair(sentence[0], sentence[1], max_len - 3) # [CLS] + 2x [SEP]
|
||||
for sentence in tokens]
|
||||
tokens = [
|
||||
_truncate_seq_pair(sentence[0], sentence[1], max_len - 3)
|
||||
for sentence in tokens
|
||||
]
|
||||
|
||||
# construct token_type_ids [[0, 0, 0, 0, ... 0, 1, 1, 1, ... 1], [0, 0, 0, ..., 1, 1, ]
|
||||
# construct token_type_ids
|
||||
# [[0, 0, 0, 0, ... 0, 1, 1, 1, ... 1], [0, 0, 0, ..., 1, 1, ]
|
||||
token_type_ids = [
|
||||
[[i] * len(sentence) for i, sentence in enumerate(example)]
|
||||
for example in tokens
|
||||
]
|
||||
# merge sentences
|
||||
tokens = [[token for sentence in example for token in sentence]
|
||||
for example in tokens]
|
||||
tokens = [
|
||||
[token for sentence in example for token in sentence]
|
||||
for example in tokens
|
||||
]
|
||||
# prefix with [0] for [CLS]
|
||||
token_type_ids = [[0] + [i for sentence in example for i in sentence]
|
||||
for example in token_type_ids]
|
||||
token_type_ids = [
|
||||
[0] + [i for sentence in example for i in sentence]
|
||||
for example in token_type_ids
|
||||
]
|
||||
# pad sequence
|
||||
token_type_ids = [x + [0] * (max_len - len(x))
|
||||
for x in token_type_ids]
|
||||
token_type_ids = [
|
||||
x + [0] * (max_len - len(x)) for x in token_type_ids
|
||||
]
|
||||
|
||||
tokens = [["[CLS]"] + x for x in tokens]
|
||||
# convert tokens to indices
|
||||
|
|
Загрузка…
Ссылка в новой задаче