Moved _truncate_seq_pair outside of if else block.

This commit is contained in:
Hong Lu 2019-06-21 14:48:10 -04:00
Родитель b0bfcba2ea
Коммит fb3f7ddef6
1 изменённых файлов: 47 добавлений и 34 удалений

Просмотреть файл

@ -53,15 +53,20 @@ class Tokenizer:
def tokenize(self, text):
"""Tokenizes a list of documents using a BERT tokenizer
Args:
text (list): List of strings (one sequence) or tuples (two sequences).
text (list): List of strings (one sequence) or
tuples (two sequences).
Returns:
[list]: List of lists. Each sublist contains WordPiece tokens of the input sequence(s).
[list]: List of lists. Each sublist contains WordPiece tokens
of the input sequence(s).
"""
if isinstance(text[0], str):
return [self.tokenizer.tokenize(x) for x in tqdm(text)]
else:
return [[self.tokenizer.tokenize(x) for x in sentences] for sentences in tqdm(text)]
return [
[self.tokenizer.tokenize(x) for x in sentences]
for sentences in tqdm(text)
]
def preprocess_classification_tokens(self, tokens, max_len=BERT_MAX_LEN):
"""Preprocessing of input tokens:
@ -75,7 +80,7 @@ class Tokenizer:
max_len (int, optional): Maximum number of tokens
(documents will be truncated or padded).
Defaults to 512.
Returns:
Returns:
tuple: A tuple containing the following three lists
list of preprocesssed token lists
list of input mask lists
@ -89,50 +94,58 @@ class Tokenizer:
)
max_len = BERT_MAX_LEN
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer
# sequence one token at a time. This makes more sense than
# truncating an equal percent of tokens from each, since if one
# sequence is very short then each token that's truncated likely
# contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
tokens_a.append("[SEP]")
tokens_b.append("[SEP]")
return [tokens_a, tokens_b]
if isinstance(tokens[0], str):
tokens = [x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
token_type_ids = None
else:
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
tokens_a.append("[SEP]")
tokens_b.append("[SEP]")
return [tokens_a, tokens_b]
# print(tokens[:2])
# get tokens for each sentence [[t00, t01, ...] [t10, t11,... ]]
tokens = [_truncate_seq_pair(sentence[0], sentence[1], max_len - 3) # [CLS] + 2x [SEP]
for sentence in tokens]
tokens = [
_truncate_seq_pair(sentence[0], sentence[1], max_len - 3)
for sentence in tokens
]
# construct token_type_ids [[0, 0, 0, 0, ... 0, 1, 1, 1, ... 1], [0, 0, 0, ..., 1, 1, ]
# construct token_type_ids
# [[0, 0, 0, 0, ... 0, 1, 1, 1, ... 1], [0, 0, 0, ..., 1, 1, ]
token_type_ids = [
[[i] * len(sentence) for i, sentence in enumerate(example)]
for example in tokens
]
# merge sentences
tokens = [[token for sentence in example for token in sentence]
for example in tokens]
tokens = [
[token for sentence in example for token in sentence]
for example in tokens
]
# prefix with [0] for [CLS]
token_type_ids = [[0] + [i for sentence in example for i in sentence]
for example in token_type_ids]
token_type_ids = [
[0] + [i for sentence in example for i in sentence]
for example in token_type_ids
]
# pad sequence
token_type_ids = [x + [0] * (max_len - len(x))
for x in token_type_ids]
token_type_ids = [
x + [0] * (max_len - len(x)) for x in token_type_ids
]
tokens = [["[CLS]"] + x for x in tokens]
# convert tokens to indices