added optional prob dist predictions
This commit is contained in:
Родитель
4388e6c588
Коммит
fad2564604
|
@ -22,6 +22,7 @@ class BERTSequenceClassifier:
|
|||
|
||||
def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."):
|
||||
"""Initializes the classifier and the underlying pretrained model.
|
||||
|
||||
Args:
|
||||
language (Language, optional): The pretrained model's language.
|
||||
Defaults to Language.ENGLISH.
|
||||
|
@ -54,6 +55,7 @@ class BERTSequenceClassifier:
|
|||
verbose=True,
|
||||
):
|
||||
"""Fine-tunes the BERT classifier using the given training data.
|
||||
|
||||
Args:
|
||||
token_ids (list): List of training token id lists.
|
||||
input_mask (list): List of input mask lists.
|
||||
|
@ -149,8 +151,16 @@ class BERTSequenceClassifier:
|
|||
del [x_batch, y_batch, mask_batch]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, token_ids, input_mask, num_gpus=None, batch_size=32):
|
||||
def predict(
|
||||
self,
|
||||
token_ids,
|
||||
input_mask,
|
||||
num_gpus=None,
|
||||
batch_size=32,
|
||||
probabilities=False,
|
||||
):
|
||||
"""Scores the given dataset and returns the predicted classes.
|
||||
|
||||
Args:
|
||||
token_ids (list): List of training token lists.
|
||||
input_mask (list): List of input mask lists.
|
||||
|
@ -158,6 +168,9 @@ class BERTSequenceClassifier:
|
|||
If None is specified, all available GPUs
|
||||
will be used. Defaults to None.
|
||||
batch_size (int, optional): Scoring batch size. Defaults to 32.
|
||||
probabilities: If true, the predicted probability distribution
|
||||
is returned; otherwise, the predicted classes are returned.
|
||||
Defaults to False.
|
||||
Returns:
|
||||
[ndarray]: Predicted classes.
|
||||
"""
|
||||
|
@ -188,6 +201,10 @@ class BERTSequenceClassifier:
|
|||
preds.append(p_batch.cpu().data.numpy())
|
||||
if i % batch_size == 0:
|
||||
pbar.update(batch_size)
|
||||
preds = [x.argmax(1) for x in preds]
|
||||
|
||||
preds = np.concatenate(preds)
|
||||
return preds
|
||||
|
||||
if probabilities:
|
||||
return nn.Softmax(dim=1)(torch.Tensor(preds)).numpy()
|
||||
else:
|
||||
return preds.argmax(axis=1)
|
||||
|
|
Загрузка…
Ссылка в новой задаче