Added warmup and support for two-sequence classification.
This commit is contained in:
Родитель
42340a7508
Коммит
e5b5af4c78
|
@ -5,6 +5,7 @@ from pytorch_pretrained_bert.tokenization import BertTokenizer
|
|||
from enum import Enum
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
|
@ -54,8 +55,10 @@ class Tokenizer:
|
|||
Returns:
|
||||
[list]: [description]
|
||||
"""
|
||||
tokens = [self.tokenizer.tokenize(x) for x in text]
|
||||
return tokens
|
||||
if isinstance(text[0], str):
|
||||
return [self.tokenizer.tokenize(x) for x in tqdm(text)]
|
||||
else:
|
||||
return [[self.tokenizer.tokenize(x) for x in sentences] for sentences in tqdm(text)]
|
||||
|
||||
def preprocess_classification_tokens(self, tokens, max_len=BERT_MAX_LEN):
|
||||
"""Preprocessing of input tokens:
|
||||
|
@ -80,15 +83,59 @@ class Tokenizer:
|
|||
)
|
||||
max_len = BERT_MAX_LEN
|
||||
|
||||
# truncate and add BERT sentence markers
|
||||
tokens = [["[CLS]"] + x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
|
||||
if isinstance(tokens[0], str):
|
||||
tokens = [x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
|
||||
token_type_ids = None
|
||||
else:
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
tokens_a.append("[SEP]")
|
||||
tokens_b.append("[SEP]")
|
||||
|
||||
return [tokens_a, tokens_b]
|
||||
|
||||
# print(tokens[:2])
|
||||
# get tokens for each sentence [[t00, t01, ...] [t10, t11,... ]]
|
||||
tokens = [_truncate_seq_pair(sentence[0], sentence[1], max_len - 3) # [CLS] + 2x [SEP]
|
||||
for sentence in tokens]
|
||||
|
||||
# construct token_type_ids [[0, 0, 0, 0, ... 0, 1, 1, 1, ... 1], [0, 0, 0, ..., 1, 1, ]
|
||||
token_type_ids = [
|
||||
[[i] * len(sentence) for i, sentence in enumerate(example)]
|
||||
for example in tokens
|
||||
]
|
||||
# merge sentences
|
||||
tokens = [[token for sentence in example for token in sentence]
|
||||
for example in tokens]
|
||||
# prefix with [0] for [CLS]
|
||||
token_type_ids = [[0] + [i for sentence in example for i in sentence]
|
||||
for example in token_type_ids]
|
||||
# pad sequence
|
||||
token_type_ids = [x + [0] * (max_len - len(x))
|
||||
for x in token_type_ids]
|
||||
|
||||
tokens = [["[CLS]"] + x for x in tokens]
|
||||
# convert tokens to indices
|
||||
tokens = [self.tokenizer.convert_tokens_to_ids(x) for x in tokens]
|
||||
# pad sequence
|
||||
tokens = [x + [0] * (max_len - len(x)) for x in tokens]
|
||||
# create input mask
|
||||
input_mask = [[min(1, x) for x in y] for y in tokens]
|
||||
return tokens, input_mask
|
||||
return tokens, input_mask, token_type_ids
|
||||
|
||||
def preprocess_ner_tokens(
|
||||
self,
|
||||
|
|
|
@ -42,10 +42,12 @@ class BERTSequenceClassifier:
|
|||
token_ids,
|
||||
input_mask,
|
||||
labels,
|
||||
token_type_ids=None,
|
||||
num_gpus=None,
|
||||
num_epochs=1,
|
||||
batch_size=32,
|
||||
lr=2e-5,
|
||||
warmup_proportion=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""Fine-tunes the BERT classifier using the given training data.
|
||||
|
@ -62,6 +64,9 @@ class BERTSequenceClassifier:
|
|||
Defaults to 1.
|
||||
batch_size (int, optional): Training batch size. Defaults to 32.
|
||||
lr (float): Learning rate of the Adam optimizer. Defaults to 2e-5.
|
||||
warmup_proportion (float, optional): Proportion of training to
|
||||
perform linear learning rate warmup for. E.g., 0.1 = 10% of
|
||||
training. Defaults to None.
|
||||
verbose (bool, optional): If True, shows the training progress and
|
||||
loss values. Defaults to True.
|
||||
"""
|
||||
|
@ -90,7 +95,19 @@ class BERTSequenceClassifier:
|
|||
},
|
||||
]
|
||||
|
||||
opt = BertAdam(optimizer_grouped_parameters, lr=lr)
|
||||
num_train_optimization_steps = (
|
||||
int(len(token_ids) / batch_size) * num_epochs
|
||||
)
|
||||
|
||||
if warmup_proportion is None:
|
||||
opt = BertAdam(optimizer_grouped_parameters, lr=lr)
|
||||
else:
|
||||
opt = BertAdam(
|
||||
optimizer_grouped_parameters,
|
||||
lr=lr,
|
||||
t_total=num_train_optimization_steps,
|
||||
warmup=warmup_proportion,
|
||||
)
|
||||
|
||||
# define loss function
|
||||
loss_func = nn.CrossEntropyLoss().to(device)
|
||||
|
@ -99,7 +116,8 @@ class BERTSequenceClassifier:
|
|||
self.model.train() # training mode
|
||||
num_examples = len(token_ids)
|
||||
num_batches = int(num_examples / batch_size)
|
||||
|
||||
|
||||
token_type_ids_batch = None
|
||||
for epoch in range(num_epochs):
|
||||
for i in range(num_batches):
|
||||
|
||||
|
@ -115,12 +133,17 @@ class BERTSequenceClassifier:
|
|||
mask_batch = torch.tensor(
|
||||
input_mask[start:end], dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids_batch = torch.tensor(
|
||||
token_type_ids[start:end], dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
y_h = self.model(
|
||||
input_ids=x_batch,
|
||||
token_type_ids=None,
|
||||
token_type_ids=token_type_ids_batch,
|
||||
attention_mask=mask_batch,
|
||||
labels=None,
|
||||
)
|
||||
|
@ -141,10 +164,10 @@ class BERTSequenceClassifier:
|
|||
)
|
||||
)
|
||||
# empty cache
|
||||
del [x_batch, y_batch, mask_batch]
|
||||
del [x_batch, y_batch, mask_batch, token_type_ids_batch]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, token_ids, input_mask, num_gpus=None, batch_size=32):
|
||||
def predict(self, token_ids, input_mask, token_type_ids=None, num_gpus=None, batch_size=32):
|
||||
"""Scores the given dataset and returns the predicted classes.
|
||||
Args:
|
||||
token_ids (list): List of training token lists.
|
||||
|
@ -173,10 +196,16 @@ class BERTSequenceClassifier:
|
|||
mask_batch = torch.tensor(
|
||||
mask_batch, dtype=torch.long, device=device
|
||||
)
|
||||
token_type_ids_batch = None
|
||||
if token_type_ids is not None:
|
||||
token_type_ids_batch = torch.tensor(
|
||||
token_type_ids[i : i +
|
||||
batch_size], dtype=torch.long, device=device
|
||||
)
|
||||
with torch.no_grad():
|
||||
p_batch = self.model(
|
||||
input_ids=x_batch,
|
||||
token_type_ids=None,
|
||||
token_type_ids=token_type_ids_batch,
|
||||
attention_mask=mask_batch,
|
||||
labels=None,
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче