add BLEU (PL^2535)
* metrics: added bleu score and test bleu * metrics: fixed type hints in bleu * bleu score moved to metrics/functional/nlp.py * refactor with torch.Tensor * Update test_sequence.py * refactor as Borda requests and nltk==3.2 * locked nltk==3.3 * nltk>=3.3, parametrized smooth argument for test * fix bleu_score example * added class BLEUScore metrics and test * added class BLEUScore metrics and test * update CHANGELOG * refactor with torchtext * torchtext changed to optional import * fix E501 line too long * add else: in optional import * remove pragma: no-cover * constants changed to CAPITALS * remove class in tests * List -> Sequence, conda -> pip, cast with tensor * add torchtext in test.txt * remove torchtext from test.txt * bump torchtext to 0.5.0 * bump torchtext to 0.5.0 * Apply suggestions from code review * ignore bleu score in doctest, renamed to nlp.py * back to implementation with torch * remove --ignore in CI test, proper reference format * apply justus comment Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Родитель
277dbb4252
Коммит
ee4ef0ea95
|
@ -5,7 +5,7 @@ from pytorch_lightning.metrics.regression import (
|
|||
MSE,
|
||||
PSNR,
|
||||
RMSE,
|
||||
RMSLE
|
||||
RMSLE,
|
||||
)
|
||||
from pytorch_lightning.metrics.classification import (
|
||||
Accuracy,
|
||||
|
@ -28,30 +28,32 @@ from pytorch_lightning.metrics.sklearns import (
|
|||
PrecisionRecallCurve,
|
||||
SklearnMetric,
|
||||
)
|
||||
from pytorch_lightning.metrics.nlp import BLEUScore
|
||||
|
||||
__classification_metrics = [
|
||||
'AUC',
|
||||
'AUROC',
|
||||
'Accuracy',
|
||||
'AveragePrecision',
|
||||
'ConfusionMatrix',
|
||||
'DiceCoefficient',
|
||||
'F1',
|
||||
'FBeta',
|
||||
'MulticlassPrecisionRecall',
|
||||
'MulticlassROC',
|
||||
'Precision',
|
||||
'PrecisionRecall',
|
||||
'PrecisionRecallCurve',
|
||||
'ROC',
|
||||
'Recall',
|
||||
'IoU',
|
||||
"AUC",
|
||||
"AUROC",
|
||||
"Accuracy",
|
||||
"AveragePrecision",
|
||||
"ConfusionMatrix",
|
||||
"DiceCoefficient",
|
||||
"F1",
|
||||
"FBeta",
|
||||
"MulticlassPrecisionRecall",
|
||||
"MulticlassROC",
|
||||
"Precision",
|
||||
"PrecisionRecall",
|
||||
"PrecisionRecallCurve",
|
||||
"ROC",
|
||||
"Recall",
|
||||
"IoU",
|
||||
]
|
||||
__regression_metrics = [
|
||||
'MAE',
|
||||
'MSE',
|
||||
'PSNR',
|
||||
'RMSE',
|
||||
'RMSLE'
|
||||
"MAE",
|
||||
"MSE",
|
||||
"PSNR",
|
||||
"RMSE",
|
||||
"RMSLE",
|
||||
]
|
||||
__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric']
|
||||
__sequence_metrics = ["BLEUScore"]
|
||||
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics
|
||||
|
|
|
@ -25,5 +25,6 @@ from pytorch_lightning.metrics.functional.regression import (
|
|||
mse,
|
||||
psnr,
|
||||
rmse,
|
||||
rmsle
|
||||
rmsle,
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# referenced from
|
||||
# Library Name: torchtext
|
||||
# Authors: torchtext authors and @sluks
|
||||
# Date: 2020-07-18
|
||||
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
|
||||
from typing import Sequence, List
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
|
||||
"""Counting how many times each word appears in a given text with ngram
|
||||
|
||||
Args:
|
||||
ngram_input_list: A list of translated text or reference texts
|
||||
n_gram: gram value ranged 1 to 4
|
||||
|
||||
Return:
|
||||
ngram_counter: a collections.Counter object of ngram
|
||||
"""
|
||||
|
||||
ngram_counter = Counter()
|
||||
|
||||
for i in range(1, n_gram + 1):
|
||||
for j in range(len(ngram_input_list) - i + 1):
|
||||
ngram_key = tuple(ngram_input_list[j : i + j])
|
||||
ngram_counter[ngram_key] += 1
|
||||
|
||||
return ngram_counter
|
||||
|
||||
|
||||
def bleu_score(
|
||||
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""Calculate BLEU score of machine translated text with one or more references.
|
||||
|
||||
Args:
|
||||
translate_corpus: An iterable of machine translated corpus
|
||||
reference_corpus: An iterable of iterables of reference corpus
|
||||
n_gram: Gram value ranged from 1 to 4 (Default 4)
|
||||
smooth: Whether or not to apply smoothing – Lin et al. 2004
|
||||
|
||||
Return:
|
||||
A Tensor with BLEU Score
|
||||
|
||||
Example:
|
||||
|
||||
>>> translate_corpus = ['the cat is on the mat'.split()]
|
||||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
|
||||
>>> bleu_score(translate_corpus, reference_corpus)
|
||||
tensor(0.7598)
|
||||
"""
|
||||
|
||||
assert len(translate_corpus) == len(reference_corpus)
|
||||
numerator = torch.zeros(n_gram)
|
||||
denominator = torch.zeros(n_gram)
|
||||
precision_scores = torch.zeros(n_gram)
|
||||
c = 0.0
|
||||
r = 0.0
|
||||
for (translation, references) in zip(translate_corpus, reference_corpus):
|
||||
c += len(translation)
|
||||
ref_len_list = [len(ref) for ref in references]
|
||||
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
|
||||
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
|
||||
translation_counter = _count_ngram(translation, n_gram)
|
||||
reference_counter = Counter()
|
||||
for ref in references:
|
||||
reference_counter |= _count_ngram(ref, n_gram)
|
||||
|
||||
ngram_counter_clip = translation_counter & reference_counter
|
||||
for counter_clip in ngram_counter_clip:
|
||||
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
|
||||
|
||||
for counter in translation_counter:
|
||||
denominator[len(counter) - 1] += translation_counter[counter]
|
||||
|
||||
trans_len = torch.tensor(c)
|
||||
ref_len = torch.tensor(r)
|
||||
if min(numerator) == 0.0:
|
||||
return torch.tensor(0.0)
|
||||
|
||||
if smooth:
|
||||
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
|
||||
else:
|
||||
precision_scores = numerator / denominator
|
||||
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
|
||||
geometric_mean = torch.exp(torch.sum(log_precision_scores))
|
||||
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
|
||||
bleu = brevity_penalty * geometric_mean
|
||||
|
||||
return bleu
|
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class BLEUScore(Metric):
|
||||
"""
|
||||
Calculate BLEU score of machine translated text with one or more references.
|
||||
|
||||
Example:
|
||||
|
||||
>>> translate_corpus = ['the cat is on the mat'.split()]
|
||||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
|
||||
>>> metric = BLEUScore()
|
||||
>>> metric(translate_corpus, reference_corpus)
|
||||
tensor(0.7598)
|
||||
"""
|
||||
|
||||
def __init__(self, n_gram: int = 4, smooth: bool = False):
|
||||
"""
|
||||
Args:
|
||||
n_gram: Gram value ranged from 1 to 4 (Default 4)
|
||||
smooth: Whether or not to apply smoothing – Lin et al. 2004
|
||||
"""
|
||||
super().__init__(name="bleu")
|
||||
self.n_gram = n_gram
|
||||
self.smooth = smooth
|
||||
|
||||
def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
translate_corpus: An iterable of machine translated corpus
|
||||
reference_corpus: An iterable of iterables of reference corpus
|
||||
|
||||
Return:
|
||||
torch.Tensor: BLEU Score
|
||||
"""
|
||||
return bleu_score(
|
||||
translate_corpus=translate_corpus,
|
||||
reference_corpus=reference_corpus,
|
||||
n_gram=self.n_gram,
|
||||
smooth=self.smooth,
|
||||
).to(self.device, self.dtype)
|
|
@ -0,0 +1,66 @@
|
|||
import pytest
|
||||
import torch
|
||||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
|
||||
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
|
||||
HYPOTHESIS1 = tuple(
|
||||
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
)
|
||||
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
|
||||
REFERENCE2 = tuple(
|
||||
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
|
||||
)
|
||||
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())
|
||||
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
|
||||
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
HYP2 = "he read the book because he was interested in world history".split()
|
||||
|
||||
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
|
||||
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split()
|
||||
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
|
||||
REF2A = "he was interested in world history because he read the book".split()
|
||||
|
||||
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
|
||||
HYPOTHESES = [HYP1, HYP2]
|
||||
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
|
||||
smooth_func = SmoothingFunction().method2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["weights", "n_gram", "smooth_func", "smooth"],
|
||||
[
|
||||
pytest.param([1], 1, None, False),
|
||||
pytest.param([0.5, 0.5], 2, smooth_func, True),
|
||||
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
|
||||
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
|
||||
],
|
||||
)
|
||||
def test_bleu_score(weights, n_gram, smooth_func, smooth):
|
||||
nltk_output = sentence_bleu(
|
||||
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
|
||||
)
|
||||
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
|
||||
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
|
||||
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
|
||||
|
||||
def test_bleu_empty():
|
||||
hyp = [[]]
|
||||
ref = [[[]]]
|
||||
assert bleu_score(hyp, ref) == torch.tensor(0.0)
|
||||
|
||||
|
||||
def test_no_4_gram():
|
||||
hyps = [["My", "full", "pytorch-lightning"]]
|
||||
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
|
||||
assert bleu_score(hyps, refs) == torch.tensor(0.0)
|
|
@ -0,0 +1,29 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.nlp import BLEUScore
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
|
||||
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
HYP2 = "he read the book because he was interested in world history".split()
|
||||
|
||||
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
|
||||
REF1B = "It is a guiding principle which makes the military forces always being under the command of the Party".split()
|
||||
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
|
||||
REF2A = "he was interested in world history because he read the book".split()
|
||||
|
||||
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
|
||||
HYPOTHESES = [HYP1, HYP2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["n_gram", "smooth"],
|
||||
[pytest.param(1, True), pytest.param(2, False), pytest.param(3, True), pytest.param(4, False),],
|
||||
)
|
||||
def test_bleu(smooth, n_gram):
|
||||
bleu = BLEUScore(n_gram=n_gram, smooth=smooth)
|
||||
assert bleu.name == "bleu"
|
||||
|
||||
pl_output = bleu(HYPOTHESES, LIST_OF_REFERENCES)
|
||||
assert isinstance(pl_output, torch.Tensor)
|
Загрузка…
Ссылка в новой задаче