added optional prob dist predictions
This commit is contained in:
Родитель
4388e6c588
Коммит
fad2564604
|
@ -22,6 +22,7 @@ class BERTSequenceClassifier:
|
||||||
|
|
||||||
def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."):
|
def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."):
|
||||||
"""Initializes the classifier and the underlying pretrained model.
|
"""Initializes the classifier and the underlying pretrained model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
language (Language, optional): The pretrained model's language.
|
language (Language, optional): The pretrained model's language.
|
||||||
Defaults to Language.ENGLISH.
|
Defaults to Language.ENGLISH.
|
||||||
|
@ -54,6 +55,7 @@ class BERTSequenceClassifier:
|
||||||
verbose=True,
|
verbose=True,
|
||||||
):
|
):
|
||||||
"""Fine-tunes the BERT classifier using the given training data.
|
"""Fine-tunes the BERT classifier using the given training data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids (list): List of training token id lists.
|
token_ids (list): List of training token id lists.
|
||||||
input_mask (list): List of input mask lists.
|
input_mask (list): List of input mask lists.
|
||||||
|
@ -149,8 +151,16 @@ class BERTSequenceClassifier:
|
||||||
del [x_batch, y_batch, mask_batch]
|
del [x_batch, y_batch, mask_batch]
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def predict(self, token_ids, input_mask, num_gpus=None, batch_size=32):
|
def predict(
|
||||||
|
self,
|
||||||
|
token_ids,
|
||||||
|
input_mask,
|
||||||
|
num_gpus=None,
|
||||||
|
batch_size=32,
|
||||||
|
probabilities=False,
|
||||||
|
):
|
||||||
"""Scores the given dataset and returns the predicted classes.
|
"""Scores the given dataset and returns the predicted classes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids (list): List of training token lists.
|
token_ids (list): List of training token lists.
|
||||||
input_mask (list): List of input mask lists.
|
input_mask (list): List of input mask lists.
|
||||||
|
@ -158,6 +168,9 @@ class BERTSequenceClassifier:
|
||||||
If None is specified, all available GPUs
|
If None is specified, all available GPUs
|
||||||
will be used. Defaults to None.
|
will be used. Defaults to None.
|
||||||
batch_size (int, optional): Scoring batch size. Defaults to 32.
|
batch_size (int, optional): Scoring batch size. Defaults to 32.
|
||||||
|
probabilities: If true, the predicted probability distribution
|
||||||
|
is returned; otherwise, the predicted classes are returned.
|
||||||
|
Defaults to False.
|
||||||
Returns:
|
Returns:
|
||||||
[ndarray]: Predicted classes.
|
[ndarray]: Predicted classes.
|
||||||
"""
|
"""
|
||||||
|
@ -188,6 +201,10 @@ class BERTSequenceClassifier:
|
||||||
preds.append(p_batch.cpu().data.numpy())
|
preds.append(p_batch.cpu().data.numpy())
|
||||||
if i % batch_size == 0:
|
if i % batch_size == 0:
|
||||||
pbar.update(batch_size)
|
pbar.update(batch_size)
|
||||||
preds = [x.argmax(1) for x in preds]
|
|
||||||
preds = np.concatenate(preds)
|
preds = np.concatenate(preds)
|
||||||
return preds
|
|
||||||
|
if probabilities:
|
||||||
|
return nn.Softmax(dim=1)(torch.Tensor(preds)).numpy()
|
||||||
|
else:
|
||||||
|
return preds.argmax(axis=1)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче