Use dataloader in sequence classifier.
This commit is contained in:
Родитель
ed07643e1c
Коммит
90d4b9020b
|
@ -4,15 +4,21 @@
|
|||
# This script reuses some code from
|
||||
# https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py
|
||||
|
||||
import random
|
||||
from collections import namedtuple
|
||||
# import random
|
||||
# from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
RandomSampler,
|
||||
SequentialSampler,
|
||||
TensorDataset,
|
||||
)
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from tqdm import tqdm
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from utils_nlp.models.bert.common import Language
|
||||
from utils_nlp.common.pytorch_utils import get_device, move_to_device
|
||||
|
@ -124,65 +130,55 @@ class BERTSequenceClassifier:
|
|||
|
||||
# define loss function
|
||||
# loss_func = nn.CrossEntropyLoss().to(device)
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
# loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
# train
|
||||
self.model.train() # training mode
|
||||
|
||||
token_type_ids_batch = None
|
||||
for epoch in range(num_epochs):
|
||||
for i in range(num_batches):
|
||||
all_input_ids = torch.tensor(token_ids, dtype=torch.long)
|
||||
all_input_mask = torch.tensor(input_mask, dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(token_type_ids, dtype=torch.long)
|
||||
all_label_ids = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
# get random batch
|
||||
start = int(random.random() * num_examples)
|
||||
end = start + batch_size
|
||||
x_batch = torch.tensor(
|
||||
token_ids[start:end], dtype=torch.long, device=device
|
||||
)
|
||||
y_batch = torch.tensor(
|
||||
labels[start:end], dtype=torch.long, device=device
|
||||
)
|
||||
mask_batch = torch.tensor(
|
||||
input_mask[start:end], dtype=torch.long, device=device
|
||||
train_data = TensorDataset(
|
||||
all_input_ids, all_input_mask, all_segment_ids, all_label_ids
|
||||
)
|
||||
train_sampler = RandomSampler(train_data)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_data, sampler=train_sampler, batch_size=batch_size
|
||||
)
|
||||
|
||||
for _ in trange(int(num_epochs), desc="Epoch"):
|
||||
tr_loss = 0
|
||||
nb_tr_examples, nb_tr_steps = 0, 0
|
||||
for step, batch in enumerate(
|
||||
tqdm(train_dataloader, desc="Iteration")
|
||||
):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
|
||||
# define a new function to compute loss values for both output_modes
|
||||
logits = self.model(
|
||||
input_ids, segment_ids, input_mask, labels=None
|
||||
)
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids_batch = torch.tensor(
|
||||
token_type_ids[start:end],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
|
||||
y_h = self.model(
|
||||
input_ids=x_batch,
|
||||
token_type_ids=token_type_ids_batch,
|
||||
attention_mask=mask_batch,
|
||||
labels=None,
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
logits.view(-1, self.num_labels), label_ids.view(-1)
|
||||
)
|
||||
|
||||
loss = loss_func(
|
||||
y_h.view(-1, self.num_labels), y_batch.view(-1)
|
||||
).mean()
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
|
||||
if verbose:
|
||||
if i % ((num_batches // 10) + 1) == 0:
|
||||
print(
|
||||
"epoch:{}/{}; batch:{}->{}/{}; loss:{:.6f}".format(
|
||||
epoch + 1,
|
||||
num_epochs,
|
||||
i + 1,
|
||||
min(i + 1 + num_batches // 10, num_batches),
|
||||
num_batches,
|
||||
loss.data,
|
||||
)
|
||||
)
|
||||
# empty cache
|
||||
del [x_batch, y_batch, mask_batch, token_type_ids_batch]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(
|
||||
self,
|
||||
token_ids,
|
||||
|
@ -220,41 +216,43 @@ class BERTSequenceClassifier:
|
|||
|
||||
# score
|
||||
self.model.eval()
|
||||
|
||||
all_input_ids = torch.tensor(token_ids, dtype=torch.long)
|
||||
all_input_mask = torch.tensor(input_mask, dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(token_type_ids, dtype=torch.long)
|
||||
|
||||
eval_data = TensorDataset(
|
||||
all_input_ids, all_input_mask, all_segment_ids
|
||||
)
|
||||
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_data, sampler=eval_sampler, batch_size=batch_size
|
||||
)
|
||||
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
with tqdm(total=len(token_ids)) as pbar:
|
||||
for i in range(0, len(token_ids), batch_size):
|
||||
x_batch = token_ids[i : i + batch_size]
|
||||
mask_batch = input_mask[i : i + batch_size]
|
||||
x_batch = torch.tensor(
|
||||
x_batch, dtype=torch.long, device=device
|
||||
)
|
||||
mask_batch = torch.tensor(
|
||||
mask_batch, dtype=torch.long, device=device
|
||||
)
|
||||
token_type_ids_batch = None
|
||||
if token_type_ids is not None:
|
||||
token_type_ids_batch = torch.tensor(
|
||||
token_type_ids[i : i + batch_size],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
p_batch = self.model(
|
||||
input_ids=x_batch,
|
||||
token_type_ids=token_type_ids_batch,
|
||||
attention_mask=mask_batch,
|
||||
labels=None,
|
||||
)
|
||||
preds.append(p_batch.cpu())
|
||||
if i % batch_size == 0:
|
||||
pbar.update(batch_size)
|
||||
for input_ids, input_mask, segment_ids in tqdm(
|
||||
eval_dataloader, desc="Evaluating"
|
||||
):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
||||
preds = np.concatenate(preds)
|
||||
with torch.no_grad():
|
||||
logits = self.model(
|
||||
input_ids, segment_ids, input_mask, labels=None
|
||||
)
|
||||
|
||||
if probabilities:
|
||||
return namedtuple("Predictions", "classes probabilities")(
|
||||
preds.argmax(axis=1),
|
||||
nn.Softmax(dim=1)(torch.Tensor(preds)).numpy(),
|
||||
)
|
||||
else:
|
||||
return preds.argmax(axis=1)
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0
|
||||
)
|
||||
|
||||
preds = preds[0]
|
||||
preds = np.argmax(preds, axis=1)
|
||||
|
||||
return preds
|
||||
|
|
Загрузка…
Ссылка в новой задаче