R2Score (PL^5241)
* add r2metric * change init * add test * add docs * add math * Apply suggestions from code review Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * changelog * adjusted parameter * add more test * pep8 * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * add warnings for adjusted score Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> (cherry picked from commit 4d44437074051a24366973e1ead758ff2f44c3b2)
This commit is contained in:
Родитель
247e085d5f
Коммит
8736ee01eb
|
@ -34,4 +34,5 @@ from pytorch_lightning.metrics.regression import ( # noqa: F401
|
||||||
ExplainedVariance,
|
ExplainedVariance,
|
||||||
PSNR,
|
PSNR,
|
||||||
SSIM,
|
SSIM,
|
||||||
|
R2Score
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,9 +28,9 @@ from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
|
||||||
to_categorical,
|
to_categorical,
|
||||||
to_onehot,
|
to_onehot,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
|
|
||||||
# TODO: unify metrics between class and functional, add below
|
# TODO: unify metrics between class and functional, add below
|
||||||
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
|
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
|
||||||
|
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
|
||||||
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
|
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
|
||||||
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
|
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
|
||||||
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
|
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
|
||||||
|
@ -40,6 +40,7 @@ from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squ
|
||||||
from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401
|
from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401
|
||||||
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401
|
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401
|
||||||
from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401
|
from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401
|
||||||
|
from pytorch_lightning.metrics.functional.r2score import r2score # noqa: F401
|
||||||
from pytorch_lightning.metrics.functional.roc import roc # noqa: F401
|
from pytorch_lightning.metrics.functional.roc import roc # noqa: F401
|
||||||
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401
|
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401
|
||||||
from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401
|
from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics.utils import _check_same_shape
|
||||||
|
from pytorch_lightning.utilities import rank_zero_warn
|
||||||
|
|
||||||
|
|
||||||
|
def _r2score_update(
|
||||||
|
preds: torch.tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
_check_same_shape(preds, target)
|
||||||
|
if preds.ndim > 2:
|
||||||
|
raise ValueError('Expected both prediction and target to be 1D or 2D tensors,'
|
||||||
|
f' but recevied tensors with dimension {preds.shape}')
|
||||||
|
if len(preds) < 2:
|
||||||
|
raise ValueError('Needs atleast two samples to calculate r2 score.')
|
||||||
|
|
||||||
|
sum_error = torch.sum(target, dim=0)
|
||||||
|
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
|
||||||
|
residual = torch.sum(torch.pow(target - preds, 2.0), dim=0)
|
||||||
|
total = target.size(0)
|
||||||
|
|
||||||
|
return sum_squared_error, sum_error, residual, total
|
||||||
|
|
||||||
|
|
||||||
|
def _r2score_compute(sum_squared_error: torch.Tensor,
|
||||||
|
sum_error: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
total: torch.Tensor,
|
||||||
|
adjusted: int = 0,
|
||||||
|
multioutput: str = "uniform_average") -> torch.Tensor:
|
||||||
|
mean_error = sum_error / total
|
||||||
|
diff = sum_squared_error - sum_error * mean_error
|
||||||
|
raw_scores = 1 - (residual / diff)
|
||||||
|
|
||||||
|
if multioutput == "raw_values":
|
||||||
|
r2score = raw_scores
|
||||||
|
elif multioutput == "uniform_average":
|
||||||
|
r2score = torch.mean(raw_scores)
|
||||||
|
elif multioutput == "variance_weighted":
|
||||||
|
diff_sum = torch.sum(diff)
|
||||||
|
r2score = torch.sum(diff / diff_sum * raw_scores)
|
||||||
|
else:
|
||||||
|
raise ValueError('Argument `multioutput` must be either `raw_values`,'
|
||||||
|
f' `uniform_average` or `variance_weighted`. Received {multioutput}.')
|
||||||
|
|
||||||
|
if adjusted < 0 or not isinstance(adjusted, int):
|
||||||
|
raise ValueError('`adjusted` parameter should be an integer larger or'
|
||||||
|
' equal to 0.')
|
||||||
|
|
||||||
|
if adjusted != 0:
|
||||||
|
if adjusted > total - 1:
|
||||||
|
rank_zero_warn("More independent regressions than datapoints in"
|
||||||
|
" adjusted r2 score. Falls back to standard r2 score.",
|
||||||
|
UserWarning)
|
||||||
|
elif adjusted == total - 1:
|
||||||
|
rank_zero_warn("Division by zero in adjusted r2 score. Falls back to"
|
||||||
|
" standard r2 score.", UserWarning)
|
||||||
|
else:
|
||||||
|
r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1)
|
||||||
|
return r2score
|
||||||
|
|
||||||
|
|
||||||
|
def r2score(
|
||||||
|
preds: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
adjusted: int = 0,
|
||||||
|
multioutput: str = "uniform_average",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Computes r2 score also known as `coefficient of determination
|
||||||
|
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
|
||||||
|
|
||||||
|
.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}
|
||||||
|
|
||||||
|
where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
|
||||||
|
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
|
||||||
|
adjusted r2 score given by
|
||||||
|
|
||||||
|
.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
|
||||||
|
|
||||||
|
where the parameter :math:`k` (the number of independent regressors) should
|
||||||
|
be provided as the ``adjusted`` argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated labels
|
||||||
|
target: ground truth labels
|
||||||
|
adjusted: number of independent regressors for calculating adjusted r2 score.
|
||||||
|
Default 0 (standard r2 score).
|
||||||
|
multioutput: Defines aggregation in the case of multiple output scores. Can be one
|
||||||
|
of the following strings (default is ``'uniform_average'``.):
|
||||||
|
|
||||||
|
* ``'raw_values'`` returns full set of scores
|
||||||
|
* ``'uniform_average'`` scores are uniformly averaged
|
||||||
|
* ``'variance_weighted'`` scores are weighted by their individual variances
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from pytorch_lightning.metrics.functional import r2score
|
||||||
|
>>> target = torch.tensor([3, -0.5, 2, 7])
|
||||||
|
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
|
||||||
|
>>> r2score(preds, target)
|
||||||
|
tensor(0.9486)
|
||||||
|
|
||||||
|
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
|
||||||
|
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
|
||||||
|
>>> r2score(preds, target, multioutput='raw_values')
|
||||||
|
tensor([0.9654, 0.9082])
|
||||||
|
"""
|
||||||
|
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
|
||||||
|
return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput)
|
|
@ -17,3 +17,4 @@ from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSqua
|
||||||
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance # noqa: F401
|
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance # noqa: F401
|
||||||
from pytorch_lightning.metrics.regression.psnr import PSNR # noqa: F401
|
from pytorch_lightning.metrics.regression.psnr import PSNR # noqa: F401
|
||||||
from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401
|
from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401
|
||||||
|
from pytorch_lightning.metrics.regression.r2score import R2Score # noqa: F401
|
||||||
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics.metric import Metric
|
||||||
|
from pytorch_lightning.metrics.functional.r2score import (
|
||||||
|
_r2score_update,
|
||||||
|
_r2score_compute
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class R2Score(Metric):
|
||||||
|
r"""
|
||||||
|
Computes r2 score also known as `coefficient of determination
|
||||||
|
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
|
||||||
|
|
||||||
|
.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}
|
||||||
|
|
||||||
|
where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
|
||||||
|
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
|
||||||
|
adjusted r2 score given by
|
||||||
|
|
||||||
|
.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
|
||||||
|
|
||||||
|
where the parameter :math:`k` (the number of independent regressors) should
|
||||||
|
be provided as the `adjusted` argument.
|
||||||
|
|
||||||
|
Forward accepts
|
||||||
|
|
||||||
|
- ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput)
|
||||||
|
- ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput)
|
||||||
|
|
||||||
|
In the case of multioutput, as default the variances will be uniformly
|
||||||
|
averaged over the additional dimensions. Please see argument `multioutput`
|
||||||
|
for changing this behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_outputs:
|
||||||
|
Number of outputs in multioutput setting (default is 1)
|
||||||
|
adjusted:
|
||||||
|
number of independent regressors for calculating adjusted r2 score.
|
||||||
|
Default 0 (standard r2 score).
|
||||||
|
multioutput:
|
||||||
|
Defines aggregation in the case of multiple output scores. Can be one
|
||||||
|
of the following strings (default is ``'uniform_average'``.):
|
||||||
|
|
||||||
|
* ``'raw_values'`` returns full set of scores
|
||||||
|
* ``'uniform_average'`` scores are uniformly averaged
|
||||||
|
* ``'variance_weighted'`` scores are weighted by their individual variances
|
||||||
|
|
||||||
|
compute_on_step:
|
||||||
|
Forward only calls ``update()`` and return None if this is set to False. default: True
|
||||||
|
dist_sync_on_step:
|
||||||
|
Synchronize metric state across processes at each ``forward()``
|
||||||
|
before returning the value at the step. default: False
|
||||||
|
process_group:
|
||||||
|
Specify the process group on which synchronization is called. default: None (which selects the entire world)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from pytorch_lightning.metrics import R2Score
|
||||||
|
>>> target = torch.tensor([3, -0.5, 2, 7])
|
||||||
|
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
|
||||||
|
>>> r2score = R2Score()
|
||||||
|
>>> r2score(preds, target)
|
||||||
|
tensor(0.9486)
|
||||||
|
|
||||||
|
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
|
||||||
|
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
|
||||||
|
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
|
||||||
|
>>> r2score(preds, target)
|
||||||
|
tensor([0.9654, 0.9082])
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_outputs: int = 1,
|
||||||
|
adjusted: int = 0,
|
||||||
|
multioutput: str = "uniform_average",
|
||||||
|
compute_on_step: bool = True,
|
||||||
|
dist_sync_on_step: bool = False,
|
||||||
|
process_group: Optional[Any] = None,
|
||||||
|
dist_sync_fn: Callable = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
compute_on_step=compute_on_step,
|
||||||
|
dist_sync_on_step=dist_sync_on_step,
|
||||||
|
process_group=process_group,
|
||||||
|
dist_sync_fn=dist_sync_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_outputs = num_outputs
|
||||||
|
|
||||||
|
if adjusted < 0 or not isinstance(adjusted, int):
|
||||||
|
raise ValueError('`adjusted` parameter should be an integer larger or'
|
||||||
|
' equal to 0.')
|
||||||
|
self.adjusted = adjusted
|
||||||
|
|
||||||
|
allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
|
||||||
|
if multioutput not in allowed_multioutput:
|
||||||
|
raise ValueError(
|
||||||
|
f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}'
|
||||||
|
)
|
||||||
|
self.multioutput = multioutput
|
||||||
|
|
||||||
|
self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
|
||||||
|
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
|
||||||
|
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
|
||||||
|
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||||
|
|
||||||
|
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Update state with predictions and targets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: Predictions from model
|
||||||
|
target: Ground truth values
|
||||||
|
"""
|
||||||
|
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
|
||||||
|
|
||||||
|
self.sum_squared_error += sum_squared_error
|
||||||
|
self.sum_error += sum_error
|
||||||
|
self.residual += residual
|
||||||
|
self.total += total
|
||||||
|
|
||||||
|
def compute(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes r2 score over the metric states.
|
||||||
|
"""
|
||||||
|
return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual,
|
||||||
|
self.total, self.adjusted, self.multioutput)
|
|
@ -0,0 +1,109 @@
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from sklearn.metrics import r2_score as sk_r2score
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics.regression import R2Score
|
||||||
|
from pytorch_lightning.metrics.functional import r2score
|
||||||
|
from tests.metrics.utils import BATCH_SIZE, NUM_BATCHES, MetricTester
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
num_targets = 5
|
||||||
|
|
||||||
|
Input = namedtuple('Input', ["preds", "target"])
|
||||||
|
|
||||||
|
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
|
||||||
|
|
||||||
|
_multi_target_inputs = Input(
|
||||||
|
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _single_target_sk_metric(preds, target, adjusted, multioutput):
|
||||||
|
sk_preds = preds.view(-1).numpy()
|
||||||
|
sk_target = target.view(-1).numpy()
|
||||||
|
r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput)
|
||||||
|
if adjusted != 0:
|
||||||
|
r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1)
|
||||||
|
return r2_score
|
||||||
|
|
||||||
|
|
||||||
|
def _multi_target_sk_metric(preds, target, adjusted, multioutput):
|
||||||
|
sk_preds = preds.view(-1, num_targets).numpy()
|
||||||
|
sk_target = target.view(-1, num_targets).numpy()
|
||||||
|
r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput)
|
||||||
|
if adjusted != 0:
|
||||||
|
r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1)
|
||||||
|
return r2_score
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("adjusted", [0, 5, 10])
|
||||||
|
@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted'])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"preds, target, sk_metric, num_outputs",
|
||||||
|
[
|
||||||
|
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric, 1),
|
||||||
|
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
class TestR2Score(MetricTester):
|
||||||
|
@pytest.mark.parametrize("ddp", [True, False])
|
||||||
|
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||||
|
def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step):
|
||||||
|
self.run_class_metric_test(
|
||||||
|
ddp,
|
||||||
|
preds,
|
||||||
|
target,
|
||||||
|
R2Score,
|
||||||
|
partial(sk_metric, adjusted=adjusted, multioutput=multioutput),
|
||||||
|
dist_sync_on_step,
|
||||||
|
metric_args=dict(adjusted=adjusted,
|
||||||
|
multioutput=multioutput,
|
||||||
|
num_outputs=num_outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
|
||||||
|
self.run_functional_metric_test(
|
||||||
|
preds,
|
||||||
|
target,
|
||||||
|
r2score,
|
||||||
|
partial(sk_metric, adjusted=adjusted, multioutput=multioutput),
|
||||||
|
metric_args=dict(adjusted=adjusted,
|
||||||
|
multioutput=multioutput),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_different_shape(metric_class=R2Score):
|
||||||
|
metric = metric_class()
|
||||||
|
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
|
||||||
|
metric(torch.randn(100,), torch.randn(50,))
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_multidim_tensors(metric_class=R2Score):
|
||||||
|
metric = metric_class()
|
||||||
|
with pytest.raises(ValueError, match=r'Expected both prediction and target to be 1D or 2D tensors,'
|
||||||
|
r' but recevied tensors with dimension .'):
|
||||||
|
metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5))
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_too_few_samples(metric_class=R2Score):
|
||||||
|
metric = metric_class()
|
||||||
|
with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'):
|
||||||
|
metric(torch.randn(1,), torch.randn(1,))
|
||||||
|
|
||||||
|
|
||||||
|
def test_warning_on_too_large_adjusted(metric_class=R2Score):
|
||||||
|
metric = metric_class(adjusted=10)
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning,
|
||||||
|
match="More independent regressions than datapoints in"
|
||||||
|
" adjusted r2 score. Falls back to standard r2 score."):
|
||||||
|
metric(torch.randn(10,), torch.randn(10,))
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning,
|
||||||
|
match="Division by zero in adjusted r2 score. Falls back to"
|
||||||
|
" standard r2 score."):
|
||||||
|
metric(torch.randn(11,), torch.randn(11,))
|
Загрузка…
Ссылка в новой задаче