* add r2metric

* change init

* add test

* add docs

* add math

* Apply suggestions from code review

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* changelog

* adjusted parameter

* add more test

* pep8

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* add warnings for adjusted score

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
(cherry picked from commit 4d44437074051a24366973e1ead758ff2f44c3b2)
This commit is contained in:
Nicki Skafte 2021-01-01 12:23:19 +01:00 коммит произвёл Jirka Borovec
Родитель 247e085d5f
Коммит 8736ee01eb
6 изменённых файлов: 382 добавлений и 1 удалений

Просмотреть файл

@ -34,4 +34,5 @@ from pytorch_lightning.metrics.regression import ( # noqa: F401
ExplainedVariance,
PSNR,
SSIM,
R2Score
)

Просмотреть файл

@ -28,9 +28,9 @@ from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
to_categorical,
to_onehot,
)
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
# TODO: unify metrics between class and functional, add below
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
@ -40,6 +40,7 @@ from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squ
from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401
from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401
from pytorch_lightning.metrics.functional.r2score import r2score # noqa: F401
from pytorch_lightning.metrics.functional.roc import roc # noqa: F401
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401
from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401

Просмотреть файл

@ -0,0 +1,126 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
from pytorch_lightning.metrics.utils import _check_same_shape
from pytorch_lightning.utilities import rank_zero_warn
def _r2score_update(
preds: torch.tensor,
target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_check_same_shape(preds, target)
if preds.ndim > 2:
raise ValueError('Expected both prediction and target to be 1D or 2D tensors,'
f' but recevied tensors with dimension {preds.shape}')
if len(preds) < 2:
raise ValueError('Needs atleast two samples to calculate r2 score.')
sum_error = torch.sum(target, dim=0)
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
residual = torch.sum(torch.pow(target - preds, 2.0), dim=0)
total = target.size(0)
return sum_squared_error, sum_error, residual, total
def _r2score_compute(sum_squared_error: torch.Tensor,
sum_error: torch.Tensor,
residual: torch.Tensor,
total: torch.Tensor,
adjusted: int = 0,
multioutput: str = "uniform_average") -> torch.Tensor:
mean_error = sum_error / total
diff = sum_squared_error - sum_error * mean_error
raw_scores = 1 - (residual / diff)
if multioutput == "raw_values":
r2score = raw_scores
elif multioutput == "uniform_average":
r2score = torch.mean(raw_scores)
elif multioutput == "variance_weighted":
diff_sum = torch.sum(diff)
r2score = torch.sum(diff / diff_sum * raw_scores)
else:
raise ValueError('Argument `multioutput` must be either `raw_values`,'
f' `uniform_average` or `variance_weighted`. Received {multioutput}.')
if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError('`adjusted` parameter should be an integer larger or'
' equal to 0.')
if adjusted != 0:
if adjusted > total - 1:
rank_zero_warn("More independent regressions than datapoints in"
" adjusted r2 score. Falls back to standard r2 score.",
UserWarning)
elif adjusted == total - 1:
rank_zero_warn("Division by zero in adjusted r2 score. Falls back to"
" standard r2 score.", UserWarning)
else:
r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1)
return r2score
def r2score(
preds: torch.Tensor,
target: torch.Tensor,
adjusted: int = 0,
multioutput: str = "uniform_average",
) -> torch.Tensor:
r"""
Computes r2 score also known as `coefficient of determination
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}
where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
adjusted r2 score given by
.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
where the parameter :math:`k` (the number of independent regressors) should
be provided as the ``adjusted`` argument.
Args:
pred: estimated labels
target: ground truth labels
adjusted: number of independent regressors for calculating adjusted r2 score.
Default 0 (standard r2 score).
multioutput: Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is ``'uniform_average'``.):
* ``'raw_values'`` returns full set of scores
* ``'uniform_average'`` scores are uniformly averaged
* ``'variance_weighted'`` scores are weighted by their individual variances
Example:
>>> from pytorch_lightning.metrics.functional import r2score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> r2score(preds, target)
tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score(preds, target, multioutput='raw_values')
tensor([0.9654, 0.9082])
"""
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput)

Просмотреть файл

@ -17,3 +17,4 @@ from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSqua
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance # noqa: F401
from pytorch_lightning.metrics.regression.psnr import PSNR # noqa: F401
from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401
from pytorch_lightning.metrics.regression.r2score import R2Score # noqa: F401

Просмотреть файл

@ -0,0 +1,143 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
import torch
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.r2score import (
_r2score_update,
_r2score_compute
)
class R2Score(Metric):
r"""
Computes r2 score also known as `coefficient of determination
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}
where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
adjusted r2 score given by
.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
where the parameter :math:`k` (the number of independent regressors) should
be provided as the `adjusted` argument.
Forward accepts
- ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput)
- ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput)
In the case of multioutput, as default the variances will be uniformly
averaged over the additional dimensions. Please see argument `multioutput`
for changing this behavior.
Args:
num_outputs:
Number of outputs in multioutput setting (default is 1)
adjusted:
number of independent regressors for calculating adjusted r2 score.
Default 0 (standard r2 score).
multioutput:
Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is ``'uniform_average'``.):
* ``'raw_values'`` returns full set of scores
* ``'uniform_average'`` scores are uniformly averaged
* ``'variance_weighted'`` scores are weighted by their individual variances
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example:
>>> from pytorch_lightning.metrics import R2Score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> r2score = R2Score()
>>> r2score(preds, target)
tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
"""
def __init__(
self,
num_outputs: int = 1,
adjusted: int = 0,
multioutput: str = "uniform_average",
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.num_outputs = num_outputs
if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError('`adjusted` parameter should be an integer larger or'
' equal to 0.')
self.adjusted = adjusted
allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
if multioutput not in allowed_multioutput:
raise ValueError(
f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}'
)
self.multioutput = multioutput
self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
self.sum_squared_error += sum_squared_error
self.sum_error += sum_error
self.residual += residual
self.total += total
def compute(self) -> torch.Tensor:
"""
Computes r2 score over the metric states.
"""
return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual,
self.total, self.adjusted, self.multioutput)

Просмотреть файл

@ -0,0 +1,109 @@
from collections import namedtuple
from functools import partial
import pytest
import torch
from sklearn.metrics import r2_score as sk_r2score
from pytorch_lightning.metrics.regression import R2Score
from pytorch_lightning.metrics.functional import r2score
from tests.metrics.utils import BATCH_SIZE, NUM_BATCHES, MetricTester
torch.manual_seed(42)
num_targets = 5
Input = namedtuple('Input', ["preds", "target"])
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)
def _single_target_sk_metric(preds, target, adjusted, multioutput):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput)
if adjusted != 0:
r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1)
return r2_score
def _multi_target_sk_metric(preds, target, adjusted, multioutput):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput)
if adjusted != 0:
r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1)
return r2_score
@pytest.mark.parametrize("adjusted", [0, 5, 10])
@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted'])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_outputs",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric, 1),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets),
],
)
class TestR2Score(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
R2Score,
partial(sk_metric, adjusted=adjusted, multioutput=multioutput),
dist_sync_on_step,
metric_args=dict(adjusted=adjusted,
multioutput=multioutput,
num_outputs=num_outputs),
)
def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
self.run_functional_metric_test(
preds,
target,
r2score,
partial(sk_metric, adjusted=adjusted, multioutput=multioutput),
metric_args=dict(adjusted=adjusted,
multioutput=multioutput),
)
def test_error_on_different_shape(metric_class=R2Score):
metric = metric_class()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))
def test_error_on_multidim_tensors(metric_class=R2Score):
metric = metric_class()
with pytest.raises(ValueError, match=r'Expected both prediction and target to be 1D or 2D tensors,'
r' but recevied tensors with dimension .'):
metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5))
def test_error_on_too_few_samples(metric_class=R2Score):
metric = metric_class()
with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'):
metric(torch.randn(1,), torch.randn(1,))
def test_warning_on_too_large_adjusted(metric_class=R2Score):
metric = metric_class(adjusted=10)
with pytest.warns(UserWarning,
match="More independent regressions than datapoints in"
" adjusted r2 score. Falls back to standard r2 score."):
metric(torch.randn(10,), torch.randn(10,))
with pytest.warns(UserWarning,
match="Division by zero in adjusted r2 score. Falls back to"
" standard r2 score."):
metric(torch.randn(11,), torch.randn(11,))