diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 53f6b5b..68268c6 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -34,4 +34,5 @@ from pytorch_lightning.metrics.regression import ( # noqa: F401 ExplainedVariance, PSNR, SSIM, + R2Score ) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 3aa5335..b4cd125 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -28,9 +28,9 @@ from pytorch_lightning.metrics.functional.classification import ( # noqa: F401 to_categorical, to_onehot, ) -from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401 # TODO: unify metrics between class and functional, add below from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401 +from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401 from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401 from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401 from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401 @@ -40,6 +40,7 @@ from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squ from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401 from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401 from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401 +from pytorch_lightning.metrics.functional.r2score import r2score # noqa: F401 from pytorch_lightning.metrics.functional.roc import roc # noqa: F401 from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401 from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py new file mode 100644 index 0000000..f689e3a --- /dev/null +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -0,0 +1,126 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch + +from pytorch_lightning.metrics.utils import _check_same_shape +from pytorch_lightning.utilities import rank_zero_warn + + +def _r2score_update( + preds: torch.tensor, + target: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + _check_same_shape(preds, target) + if preds.ndim > 2: + raise ValueError('Expected both prediction and target to be 1D or 2D tensors,' + f' but recevied tensors with dimension {preds.shape}') + if len(preds) < 2: + raise ValueError('Needs atleast two samples to calculate r2 score.') + + sum_error = torch.sum(target, dim=0) + sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) + residual = torch.sum(torch.pow(target - preds, 2.0), dim=0) + total = target.size(0) + + return sum_squared_error, sum_error, residual, total + + +def _r2score_compute(sum_squared_error: torch.Tensor, + sum_error: torch.Tensor, + residual: torch.Tensor, + total: torch.Tensor, + adjusted: int = 0, + multioutput: str = "uniform_average") -> torch.Tensor: + mean_error = sum_error / total + diff = sum_squared_error - sum_error * mean_error + raw_scores = 1 - (residual / diff) + + if multioutput == "raw_values": + r2score = raw_scores + elif multioutput == "uniform_average": + r2score = torch.mean(raw_scores) + elif multioutput == "variance_weighted": + diff_sum = torch.sum(diff) + r2score = torch.sum(diff / diff_sum * raw_scores) + else: + raise ValueError('Argument `multioutput` must be either `raw_values`,' + f' `uniform_average` or `variance_weighted`. Received {multioutput}.') + + if adjusted < 0 or not isinstance(adjusted, int): + raise ValueError('`adjusted` parameter should be an integer larger or' + ' equal to 0.') + + if adjusted != 0: + if adjusted > total - 1: + rank_zero_warn("More independent regressions than datapoints in" + " adjusted r2 score. Falls back to standard r2 score.", + UserWarning) + elif adjusted == total - 1: + rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" + " standard r2 score.", UserWarning) + else: + r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) + return r2score + + +def r2score( + preds: torch.Tensor, + target: torch.Tensor, + adjusted: int = 0, + multioutput: str = "uniform_average", +) -> torch.Tensor: + r""" + Computes r2 score also known as `coefficient of determination + `_: + + .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} + + where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and + :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate + adjusted r2 score given by + + .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} + + where the parameter :math:`k` (the number of independent regressors) should + be provided as the ``adjusted`` argument. + + Args: + pred: estimated labels + target: ground truth labels + adjusted: number of independent regressors for calculating adjusted r2 score. + Default 0 (standard r2 score). + multioutput: Defines aggregation in the case of multiple output scores. Can be one + of the following strings (default is ``'uniform_average'``.): + + * ``'raw_values'`` returns full set of scores + * ``'uniform_average'`` scores are uniformly averaged + * ``'variance_weighted'`` scores are weighted by their individual variances + + Example: + + >>> from pytorch_lightning.metrics.functional import r2score + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> r2score(preds, target) + tensor(0.9486) + + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score(preds, target, multioutput='raw_values') + tensor([0.9654, 0.9082]) + """ + sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) + return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput) diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py index 8f48189..3e2fed1 100644 --- a/pytorch_lightning/metrics/regression/__init__.py +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -17,3 +17,4 @@ from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSqua from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance # noqa: F401 from pytorch_lightning.metrics.regression.psnr import PSNR # noqa: F401 from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401 +from pytorch_lightning.metrics.regression.r2score import R2Score # noqa: F401 diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py new file mode 100644 index 0000000..f8f6e98 --- /dev/null +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -0,0 +1,143 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch + +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.functional.r2score import ( + _r2score_update, + _r2score_compute +) + + +class R2Score(Metric): + r""" + Computes r2 score also known as `coefficient of determination + `_: + + .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} + + where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and + :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate + adjusted r2 score given by + + .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} + + where the parameter :math:`k` (the number of independent regressors) should + be provided as the `adjusted` argument. + + Forward accepts + + - ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) + - ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) + + In the case of multioutput, as default the variances will be uniformly + averaged over the additional dimensions. Please see argument `multioutput` + for changing this behavior. + + Args: + num_outputs: + Number of outputs in multioutput setting (default is 1) + adjusted: + number of independent regressors for calculating adjusted r2 score. + Default 0 (standard r2 score). + multioutput: + Defines aggregation in the case of multiple output scores. Can be one + of the following strings (default is ``'uniform_average'``.): + + * ``'raw_values'`` returns full set of scores + * ``'uniform_average'`` scores are uniformly averaged + * ``'variance_weighted'`` scores are weighted by their individual variances + + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import R2Score + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> r2score = R2Score() + >>> r2score(preds, target) + tensor(0.9486) + + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') + >>> r2score(preds, target) + tensor([0.9654, 0.9082]) + """ + def __init__( + self, + num_outputs: int = 1, + adjusted: int = 0, + multioutput: str = "uniform_average", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.num_outputs = num_outputs + + if adjusted < 0 or not isinstance(adjusted, int): + raise ValueError('`adjusted` parameter should be an integer larger or' + ' equal to 0.') + self.adjusted = adjusted + + allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') + if multioutput not in allowed_multioutput: + raise ValueError( + f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' + ) + self.multioutput = multioutput + + self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) + + self.sum_squared_error += sum_squared_error + self.sum_error += sum_error + self.residual += residual + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes r2 score over the metric states. + """ + return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual, + self.total, self.adjusted, self.multioutput) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py new file mode 100644 index 0000000..ef3ec89 --- /dev/null +++ b/tests/metrics/regression/test_r2score.py @@ -0,0 +1,109 @@ +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import r2_score as sk_r2score + +from pytorch_lightning.metrics.regression import R2Score +from pytorch_lightning.metrics.functional import r2score +from tests.metrics.utils import BATCH_SIZE, NUM_BATCHES, MetricTester + +torch.manual_seed(42) + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _single_target_sk_metric(preds, target, adjusted, multioutput): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) + if adjusted != 0: + r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) + return r2_score + + +def _multi_target_sk_metric(preds, target, adjusted, multioutput): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) + if adjusted != 0: + r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) + return r2_score + + +@pytest.mark.parametrize("adjusted", [0, 5, 10]) +@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) +@pytest.mark.parametrize( + "preds, target, sk_metric, num_outputs", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric, 1), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets), + ], +) +class TestR2Score(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + R2Score, + partial(sk_metric, adjusted=adjusted, multioutput=multioutput), + dist_sync_on_step, + metric_args=dict(adjusted=adjusted, + multioutput=multioutput, + num_outputs=num_outputs), + ) + + def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs): + self.run_functional_metric_test( + preds, + target, + r2score, + partial(sk_metric, adjusted=adjusted, multioutput=multioutput), + metric_args=dict(adjusted=adjusted, + multioutput=multioutput), + ) + + +def test_error_on_different_shape(metric_class=R2Score): + metric = metric_class() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100,), torch.randn(50,)) + + +def test_error_on_multidim_tensors(metric_class=R2Score): + metric = metric_class() + with pytest.raises(ValueError, match=r'Expected both prediction and target to be 1D or 2D tensors,' + r' but recevied tensors with dimension .'): + metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5)) + + +def test_error_on_too_few_samples(metric_class=R2Score): + metric = metric_class() + with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'): + metric(torch.randn(1,), torch.randn(1,)) + + +def test_warning_on_too_large_adjusted(metric_class=R2Score): + metric = metric_class(adjusted=10) + + with pytest.warns(UserWarning, + match="More independent regressions than datapoints in" + " adjusted r2 score. Falls back to standard r2 score."): + metric(torch.randn(10,), torch.randn(10,)) + + with pytest.warns(UserWarning, + match="Division by zero in adjusted r2 score. Falls back to" + " standard r2 score."): + metric(torch.randn(11,), torch.randn(11,))