diff --git a/utils_nlp/models/bert/sequence_classification.py b/utils_nlp/models/bert/sequence_classification.py index 9828764..3c1f34d 100644 --- a/utils_nlp/models/bert/sequence_classification.py +++ b/utils_nlp/models/bert/sequence_classification.py @@ -4,15 +4,21 @@ # This script reuses some code from # https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py -import random -from collections import namedtuple +# import random +# from collections import namedtuple import numpy as np import torch import torch.nn as nn +from torch.utils.data import ( + DataLoader, + RandomSampler, + SequentialSampler, + TensorDataset, +) from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.optimization import BertAdam -from tqdm import tqdm +from tqdm import tqdm, trange from utils_nlp.models.bert.common import Language from utils_nlp.common.pytorch_utils import get_device, move_to_device @@ -124,65 +130,55 @@ class BERTSequenceClassifier: # define loss function # loss_func = nn.CrossEntropyLoss().to(device) - loss_func = nn.CrossEntropyLoss() + # loss_func = nn.CrossEntropyLoss() # train self.model.train() # training mode - token_type_ids_batch = None - for epoch in range(num_epochs): - for i in range(num_batches): + all_input_ids = torch.tensor(token_ids, dtype=torch.long) + all_input_mask = torch.tensor(input_mask, dtype=torch.long) + all_segment_ids = torch.tensor(token_type_ids, dtype=torch.long) + all_label_ids = torch.tensor(labels, dtype=torch.long) - # get random batch - start = int(random.random() * num_examples) - end = start + batch_size - x_batch = torch.tensor( - token_ids[start:end], dtype=torch.long, device=device - ) - y_batch = torch.tensor( - labels[start:end], dtype=torch.long, device=device - ) - mask_batch = torch.tensor( - input_mask[start:end], dtype=torch.long, device=device + train_data = TensorDataset( + all_input_ids, all_input_mask, all_segment_ids, all_label_ids + ) + train_sampler = RandomSampler(train_data) + + train_dataloader = DataLoader( + train_data, sampler=train_sampler, batch_size=batch_size + ) + + for _ in trange(int(num_epochs), desc="Epoch"): + tr_loss = 0 + nb_tr_examples, nb_tr_steps = 0, 0 + for step, batch in enumerate( + tqdm(train_dataloader, desc="Iteration") + ): + batch = tuple(t.to(device) for t in batch) + input_ids, input_mask, segment_ids, label_ids = batch + + # define a new function to compute loss values for both output_modes + logits = self.model( + input_ids, segment_ids, input_mask, labels=None ) - if token_type_ids is not None: - token_type_ids_batch = torch.tensor( - token_type_ids[start:end], - dtype=torch.long, - device=device, - ) - - y_h = self.model( - input_ids=x_batch, - token_type_ids=token_type_ids_batch, - attention_mask=mask_batch, - labels=None, + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), label_ids.view(-1) ) - loss = loss_func( - y_h.view(-1, self.num_labels), y_batch.view(-1) - ).mean() + loss = loss.mean() # mean() to average on multi-gpu. + loss.backward() + + tr_loss += loss.item() + nb_tr_examples += input_ids.size(0) + nb_tr_steps += 1 + opt.step() opt.zero_grad() - if verbose: - if i % ((num_batches // 10) + 1) == 0: - print( - "epoch:{}/{}; batch:{}->{}/{}; loss:{:.6f}".format( - epoch + 1, - num_epochs, - i + 1, - min(i + 1 + num_batches // 10, num_batches), - num_batches, - loss.data, - ) - ) - # empty cache - del [x_batch, y_batch, mask_batch, token_type_ids_batch] - torch.cuda.empty_cache() - def predict( self, token_ids, @@ -220,41 +216,43 @@ class BERTSequenceClassifier: # score self.model.eval() + + all_input_ids = torch.tensor(token_ids, dtype=torch.long) + all_input_mask = torch.tensor(input_mask, dtype=torch.long) + all_segment_ids = torch.tensor(token_type_ids, dtype=torch.long) + + eval_data = TensorDataset( + all_input_ids, all_input_mask, all_segment_ids + ) + + eval_sampler = SequentialSampler(eval_data) + eval_dataloader = DataLoader( + eval_data, sampler=eval_sampler, batch_size=batch_size + ) + + nb_eval_steps = 0 preds = [] - with tqdm(total=len(token_ids)) as pbar: - for i in range(0, len(token_ids), batch_size): - x_batch = token_ids[i : i + batch_size] - mask_batch = input_mask[i : i + batch_size] - x_batch = torch.tensor( - x_batch, dtype=torch.long, device=device - ) - mask_batch = torch.tensor( - mask_batch, dtype=torch.long, device=device - ) - token_type_ids_batch = None - if token_type_ids is not None: - token_type_ids_batch = torch.tensor( - token_type_ids[i : i + batch_size], - dtype=torch.long, - device=device, - ) - with torch.no_grad(): - p_batch = self.model( - input_ids=x_batch, - token_type_ids=token_type_ids_batch, - attention_mask=mask_batch, - labels=None, - ) - preds.append(p_batch.cpu()) - if i % batch_size == 0: - pbar.update(batch_size) + for input_ids, input_mask, segment_ids in tqdm( + eval_dataloader, desc="Evaluating" + ): + input_ids = input_ids.to(device) + input_mask = input_mask.to(device) + segment_ids = segment_ids.to(device) - preds = np.concatenate(preds) + with torch.no_grad(): + logits = self.model( + input_ids, segment_ids, input_mask, labels=None + ) - if probabilities: - return namedtuple("Predictions", "classes probabilities")( - preds.argmax(axis=1), - nn.Softmax(dim=1)(torch.Tensor(preds)).numpy(), - ) - else: - return preds.argmax(axis=1) + nb_eval_steps += 1 + if len(preds) == 0: + preds.append(logits.detach().cpu().numpy()) + else: + preds[0] = np.append( + preds[0], logits.detach().cpu().numpy(), axis=0 + ) + + preds = preds[0] + preds = np.argmax(preds, axis=1) + + return preds