diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 7a65077..2cfbbfa 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -5,7 +5,7 @@ from pytorch_lightning.metrics.regression import ( MSE, PSNR, RMSE, - RMSLE + RMSLE, ) from pytorch_lightning.metrics.classification import ( Accuracy, @@ -28,30 +28,32 @@ from pytorch_lightning.metrics.sklearns import ( PrecisionRecallCurve, SklearnMetric, ) +from pytorch_lightning.metrics.nlp import BLEUScore __classification_metrics = [ - 'AUC', - 'AUROC', - 'Accuracy', - 'AveragePrecision', - 'ConfusionMatrix', - 'DiceCoefficient', - 'F1', - 'FBeta', - 'MulticlassPrecisionRecall', - 'MulticlassROC', - 'Precision', - 'PrecisionRecall', - 'PrecisionRecallCurve', - 'ROC', - 'Recall', - 'IoU', + "AUC", + "AUROC", + "Accuracy", + "AveragePrecision", + "ConfusionMatrix", + "DiceCoefficient", + "F1", + "FBeta", + "MulticlassPrecisionRecall", + "MulticlassROC", + "Precision", + "PrecisionRecall", + "PrecisionRecallCurve", + "ROC", + "Recall", + "IoU", ] __regression_metrics = [ - 'MAE', - 'MSE', - 'PSNR', - 'RMSE', - 'RMSLE' + "MAE", + "MSE", + "PSNR", + "RMSE", + "RMSLE", ] -__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric'] +__sequence_metrics = ["BLEUScore"] +__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 35cc286..eb92cab 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -25,5 +25,6 @@ from pytorch_lightning.metrics.functional.regression import ( mse, psnr, rmse, - rmsle + rmsle, ) +from pytorch_lightning.metrics.functional.nlp import bleu_score diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py new file mode 100644 index 0000000..e1bb86a --- /dev/null +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -0,0 +1,92 @@ +# referenced from +# Library Name: torchtext +# Authors: torchtext authors and @sluks +# Date: 2020-07-18 +# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +from typing import Sequence, List +from collections import Counter + +import torch + + +def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: + """Counting how many times each word appears in a given text with ngram + + Args: + ngram_input_list: A list of translated text or reference texts + n_gram: gram value ranged 1 to 4 + + Return: + ngram_counter: a collections.Counter object of ngram + """ + + ngram_counter = Counter() + + for i in range(1, n_gram + 1): + for j in range(len(ngram_input_list) - i + 1): + ngram_key = tuple(ngram_input_list[j : i + j]) + ngram_counter[ngram_key] += 1 + + return ngram_counter + + +def bleu_score( + translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False +) -> torch.Tensor: + """Calculate BLEU score of machine translated text with one or more references. + + Args: + translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus + n_gram: Gram value ranged from 1 to 4 (Default 4) + smooth: Whether or not to apply smoothing – Lin et al. 2004 + + Return: + A Tensor with BLEU Score + + Example: + + >>> translate_corpus = ['the cat is on the mat'.split()] + >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> bleu_score(translate_corpus, reference_corpus) + tensor(0.7598) + """ + + assert len(translate_corpus) == len(reference_corpus) + numerator = torch.zeros(n_gram) + denominator = torch.zeros(n_gram) + precision_scores = torch.zeros(n_gram) + c = 0.0 + r = 0.0 + for (translation, references) in zip(translate_corpus, reference_corpus): + c += len(translation) + ref_len_list = [len(ref) for ref in references] + ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] + r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] + translation_counter = _count_ngram(translation, n_gram) + reference_counter = Counter() + for ref in references: + reference_counter |= _count_ngram(ref, n_gram) + + ngram_counter_clip = translation_counter & reference_counter + for counter_clip in ngram_counter_clip: + numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] + + for counter in translation_counter: + denominator[len(counter) - 1] += translation_counter[counter] + + trans_len = torch.tensor(c) + ref_len = torch.tensor(r) + if min(numerator) == 0.0: + return torch.tensor(0.0) + + if smooth: + precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) + else: + precision_scores = numerator / denominator + log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) + geometric_mean = torch.exp(torch.sum(log_precision_scores)) + brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) + bleu = brevity_penalty * geometric_mean + + return bleu diff --git a/pytorch_lightning/metrics/nlp.py b/pytorch_lightning/metrics/nlp.py new file mode 100644 index 0000000..a4284ad --- /dev/null +++ b/pytorch_lightning/metrics/nlp.py @@ -0,0 +1,46 @@ +import torch + +from pytorch_lightning.metrics.functional.nlp import bleu_score +from pytorch_lightning.metrics.metric import Metric + + +class BLEUScore(Metric): + """ + Calculate BLEU score of machine translated text with one or more references. + + Example: + + >>> translate_corpus = ['the cat is on the mat'.split()] + >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> metric = BLEUScore() + >>> metric(translate_corpus, reference_corpus) + tensor(0.7598) + """ + + def __init__(self, n_gram: int = 4, smooth: bool = False): + """ + Args: + n_gram: Gram value ranged from 1 to 4 (Default 4) + smooth: Whether or not to apply smoothing – Lin et al. 2004 + """ + super().__init__(name="bleu") + self.n_gram = n_gram + self.smooth = smooth + + def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor: + """ + Actual metric computation + + Args: + translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus + + Return: + torch.Tensor: BLEU Score + """ + return bleu_score( + translate_corpus=translate_corpus, + reference_corpus=reference_corpus, + n_gram=self.n_gram, + smooth=self.smooth, + ).to(self.device, self.dtype) diff --git a/tests/metrics/functional/test_nlp.py b/tests/metrics/functional/test_nlp.py new file mode 100644 index 0000000..2f16472 --- /dev/null +++ b/tests/metrics/functional/test_nlp.py @@ -0,0 +1,66 @@ +import pytest +import torch +from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu + +from pytorch_lightning.metrics.functional.nlp import bleu_score + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu +HYPOTHESIS1 = tuple( + "It is a guide to action which ensures that the military always obeys the commands of the party".split() +) +REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) +REFERENCE2 = tuple( + "It is a guiding principle which makes the military forces always being under the command of the Party".split() +) +REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split()) + + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu +HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() +HYP2 = "he read the book because he was interested in world history".split() + +REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() +REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split() +REF1C = "It is the practical guide for the army always to heed the directions of the party".split() +REF2A = "he was interested in world history because he read the book".split() + +LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] +HYPOTHESES = [HYP1, HYP2] + +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction +smooth_func = SmoothingFunction().method2 + + +@pytest.mark.parametrize( + ["weights", "n_gram", "smooth_func", "smooth"], + [ + pytest.param([1], 1, None, False), + pytest.param([0.5, 0.5], 2, smooth_func, True), + pytest.param([0.333333, 0.333333, 0.333333], 3, None, False), + pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True), + ], +) +def test_bleu_score(weights, n_gram, smooth_func, smooth): + nltk_output = sentence_bleu( + [REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func + ) + pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) + assert torch.allclose(pl_output, torch.tensor(nltk_output)) + + nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) + pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) + assert torch.allclose(pl_output, torch.tensor(nltk_output)) + + +def test_bleu_empty(): + hyp = [[]] + ref = [[[]]] + assert bleu_score(hyp, ref) == torch.tensor(0.0) + + +def test_no_4_gram(): + hyps = [["My", "full", "pytorch-lightning"]] + refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] + assert bleu_score(hyps, refs) == torch.tensor(0.0) diff --git a/tests/metrics/test_nlp.py b/tests/metrics/test_nlp.py new file mode 100644 index 0000000..e58b1f3 --- /dev/null +++ b/tests/metrics/test_nlp.py @@ -0,0 +1,29 @@ +import pytest +import torch + +from pytorch_lightning.metrics.nlp import BLEUScore + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu +HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() +HYP2 = "he read the book because he was interested in world history".split() + +REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() +REF1B = "It is a guiding principle which makes the military forces always being under the command of the Party".split() +REF1C = "It is the practical guide for the army always to heed the directions of the party".split() +REF2A = "he was interested in world history because he read the book".split() + +LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] +HYPOTHESES = [HYP1, HYP2] + + +@pytest.mark.parametrize( + ["n_gram", "smooth"], + [pytest.param(1, True), pytest.param(2, False), pytest.param(3, True), pytest.param(4, False),], +) +def test_bleu(smooth, n_gram): + bleu = BLEUScore(n_gram=n_gram, smooth=smooth) + assert bleu.name == "bleu" + + pl_output = bleu(HYPOTHESES, LIST_OF_REFERENCES) + assert isinstance(pl_output, torch.Tensor)