added optional prob dist predictions

This commit is contained in:
Said Bleik 2019-06-20 14:23:09 -04:00
Родитель 4388e6c588
Коммит fad2564604
1 изменённых файлов: 20 добавлений и 3 удалений

Просмотреть файл

@ -22,6 +22,7 @@ class BERTSequenceClassifier:
def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."):
"""Initializes the classifier and the underlying pretrained model.
Args:
language (Language, optional): The pretrained model's language.
Defaults to Language.ENGLISH.
@ -54,6 +55,7 @@ class BERTSequenceClassifier:
verbose=True,
):
"""Fine-tunes the BERT classifier using the given training data.
Args:
token_ids (list): List of training token id lists.
input_mask (list): List of input mask lists.
@ -149,8 +151,16 @@ class BERTSequenceClassifier:
del [x_batch, y_batch, mask_batch]
torch.cuda.empty_cache()
def predict(self, token_ids, input_mask, num_gpus=None, batch_size=32):
def predict(
self,
token_ids,
input_mask,
num_gpus=None,
batch_size=32,
probabilities=False,
):
"""Scores the given dataset and returns the predicted classes.
Args:
token_ids (list): List of training token lists.
input_mask (list): List of input mask lists.
@ -158,6 +168,9 @@ class BERTSequenceClassifier:
If None is specified, all available GPUs
will be used. Defaults to None.
batch_size (int, optional): Scoring batch size. Defaults to 32.
probabilities: If true, the predicted probability distribution
is returned; otherwise, the predicted classes are returned.
Defaults to False.
Returns:
[ndarray]: Predicted classes.
"""
@ -188,6 +201,10 @@ class BERTSequenceClassifier:
preds.append(p_batch.cpu().data.numpy())
if i % batch_size == 0:
pbar.update(batch_size)
preds = [x.argmax(1) for x in preds]
preds = np.concatenate(preds)
return preds
if probabilities:
return nn.Softmax(dim=1)(torch.Tensor(preds)).numpy()
else:
return preds.argmax(axis=1)