Added warmup and support for two-sequence classification.

This commit is contained in:
hlums 2019-06-17 22:40:24 +00:00
Родитель 42340a7508
Коммит e5b5af4c78
2 изменённых файлов: 87 добавлений и 11 удалений

Просмотреть файл

@ -5,6 +5,7 @@ from pytorch_pretrained_bert.tokenization import BertTokenizer
from enum import Enum
import warnings
import torch
from tqdm import tqdm
from torch.utils.data import (
DataLoader,
@ -54,8 +55,10 @@ class Tokenizer:
Returns:
[list]: [description]
"""
tokens = [self.tokenizer.tokenize(x) for x in text]
return tokens
if isinstance(text[0], str):
return [self.tokenizer.tokenize(x) for x in tqdm(text)]
else:
return [[self.tokenizer.tokenize(x) for x in sentences] for sentences in tqdm(text)]
def preprocess_classification_tokens(self, tokens, max_len=BERT_MAX_LEN):
"""Preprocessing of input tokens:
@ -80,15 +83,59 @@ class Tokenizer:
)
max_len = BERT_MAX_LEN
# truncate and add BERT sentence markers
tokens = [["[CLS]"] + x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
if isinstance(tokens[0], str):
tokens = [x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
token_type_ids = None
else:
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
tokens_a.append("[SEP]")
tokens_b.append("[SEP]")
return [tokens_a, tokens_b]
# print(tokens[:2])
# get tokens for each sentence [[t00, t01, ...] [t10, t11,... ]]
tokens = [_truncate_seq_pair(sentence[0], sentence[1], max_len - 3) # [CLS] + 2x [SEP]
for sentence in tokens]
# construct token_type_ids [[0, 0, 0, 0, ... 0, 1, 1, 1, ... 1], [0, 0, 0, ..., 1, 1, ]
token_type_ids = [
[[i] * len(sentence) for i, sentence in enumerate(example)]
for example in tokens
]
# merge sentences
tokens = [[token for sentence in example for token in sentence]
for example in tokens]
# prefix with [0] for [CLS]
token_type_ids = [[0] + [i for sentence in example for i in sentence]
for example in token_type_ids]
# pad sequence
token_type_ids = [x + [0] * (max_len - len(x))
for x in token_type_ids]
tokens = [["[CLS]"] + x for x in tokens]
# convert tokens to indices
tokens = [self.tokenizer.convert_tokens_to_ids(x) for x in tokens]
# pad sequence
tokens = [x + [0] * (max_len - len(x)) for x in tokens]
# create input mask
input_mask = [[min(1, x) for x in y] for y in tokens]
return tokens, input_mask
return tokens, input_mask, token_type_ids
def preprocess_ner_tokens(
self,

Просмотреть файл

@ -42,10 +42,12 @@ class BERTSequenceClassifier:
token_ids,
input_mask,
labels,
token_type_ids=None,
num_gpus=None,
num_epochs=1,
batch_size=32,
lr=2e-5,
warmup_proportion=None,
verbose=True,
):
"""Fine-tunes the BERT classifier using the given training data.
@ -62,6 +64,9 @@ class BERTSequenceClassifier:
Defaults to 1.
batch_size (int, optional): Training batch size. Defaults to 32.
lr (float): Learning rate of the Adam optimizer. Defaults to 2e-5.
warmup_proportion (float, optional): Proportion of training to
perform linear learning rate warmup for. E.g., 0.1 = 10% of
training. Defaults to None.
verbose (bool, optional): If True, shows the training progress and
loss values. Defaults to True.
"""
@ -90,7 +95,19 @@ class BERTSequenceClassifier:
},
]
opt = BertAdam(optimizer_grouped_parameters, lr=lr)
num_train_optimization_steps = (
int(len(token_ids) / batch_size) * num_epochs
)
if warmup_proportion is None:
opt = BertAdam(optimizer_grouped_parameters, lr=lr)
else:
opt = BertAdam(
optimizer_grouped_parameters,
lr=lr,
t_total=num_train_optimization_steps,
warmup=warmup_proportion,
)
# define loss function
loss_func = nn.CrossEntropyLoss().to(device)
@ -99,7 +116,8 @@ class BERTSequenceClassifier:
self.model.train() # training mode
num_examples = len(token_ids)
num_batches = int(num_examples / batch_size)
token_type_ids_batch = None
for epoch in range(num_epochs):
for i in range(num_batches):
@ -115,12 +133,17 @@ class BERTSequenceClassifier:
mask_batch = torch.tensor(
input_mask[start:end], dtype=torch.long, device=device
)
if token_type_ids is not None:
token_type_ids_batch = torch.tensor(
token_type_ids[start:end], dtype=torch.long, device=device
)
opt.zero_grad()
y_h = self.model(
input_ids=x_batch,
token_type_ids=None,
token_type_ids=token_type_ids_batch,
attention_mask=mask_batch,
labels=None,
)
@ -141,10 +164,10 @@ class BERTSequenceClassifier:
)
)
# empty cache
del [x_batch, y_batch, mask_batch]
del [x_batch, y_batch, mask_batch, token_type_ids_batch]
torch.cuda.empty_cache()
def predict(self, token_ids, input_mask, num_gpus=None, batch_size=32):
def predict(self, token_ids, input_mask, token_type_ids=None, num_gpus=None, batch_size=32):
"""Scores the given dataset and returns the predicted classes.
Args:
token_ids (list): List of training token lists.
@ -173,10 +196,16 @@ class BERTSequenceClassifier:
mask_batch = torch.tensor(
mask_batch, dtype=torch.long, device=device
)
token_type_ids_batch = None
if token_type_ids is not None:
token_type_ids_batch = torch.tensor(
token_type_ids[i : i +
batch_size], dtype=torch.long, device=device
)
with torch.no_grad():
p_batch = self.model(
input_ids=x_batch,
token_type_ids=None,
token_type_ids=token_type_ids_batch,
attention_mask=mask_batch,
labels=None,
)