diff --git a/utils_nlp/bert/sequence_classification.py b/utils_nlp/bert/sequence_classification.py index c5d4614..7ba75ee 100644 --- a/utils_nlp/bert/sequence_classification.py +++ b/utils_nlp/bert/sequence_classification.py @@ -22,6 +22,7 @@ class BERTSequenceClassifier: def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."): """Initializes the classifier and the underlying pretrained model. + Args: language (Language, optional): The pretrained model's language. Defaults to Language.ENGLISH. @@ -54,6 +55,7 @@ class BERTSequenceClassifier: verbose=True, ): """Fine-tunes the BERT classifier using the given training data. + Args: token_ids (list): List of training token id lists. input_mask (list): List of input mask lists. @@ -149,8 +151,16 @@ class BERTSequenceClassifier: del [x_batch, y_batch, mask_batch] torch.cuda.empty_cache() - def predict(self, token_ids, input_mask, num_gpus=None, batch_size=32): + def predict( + self, + token_ids, + input_mask, + num_gpus=None, + batch_size=32, + probabilities=False, + ): """Scores the given dataset and returns the predicted classes. + Args: token_ids (list): List of training token lists. input_mask (list): List of input mask lists. @@ -158,6 +168,9 @@ class BERTSequenceClassifier: If None is specified, all available GPUs will be used. Defaults to None. batch_size (int, optional): Scoring batch size. Defaults to 32. + probabilities: If true, the predicted probability distribution + is returned; otherwise, the predicted classes are returned. + Defaults to False. Returns: [ndarray]: Predicted classes. """ @@ -188,6 +201,10 @@ class BERTSequenceClassifier: preds.append(p_batch.cpu().data.numpy()) if i % batch_size == 0: pbar.update(batch_size) - preds = [x.argmax(1) for x in preds] + preds = np.concatenate(preds) - return preds + + if probabilities: + return nn.Softmax(dim=1)(torch.Tensor(preds)).numpy() + else: + return preds.argmax(axis=1)