Add `dim` to `pytorch_lightning.metrics.PSNR` (PL^5957)
* Add dim to PSNR * Update CHANGELOG.md * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add reduction tests * Recover warnings on reduction and add tests * Add copyright texts * Refactor PSNR * Change warnings * Update pytorch_lightning/metrics/functional/psnr.py Change functional.psnr dim doc Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Change PSNR dim docs * Apply suggestions from code review * tests Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> (cherry picked from commit a3dd545b733f9b0cafbe725a4f2e57bb6bed2759)
This commit is contained in:
Родитель
4c2b7b24ec
Коммит
c998634d40
|
@ -1,27 +1,56 @@
|
||||||
from typing import Optional, Tuple
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning.utilities import rank_zero_warn
|
from pytorch_lightning import utilities
|
||||||
|
from pytorch_lightning.metrics import utils
|
||||||
|
|
||||||
|
|
||||||
def _psnr_compute(
|
def _psnr_compute(
|
||||||
sum_squared_error: torch.Tensor,
|
sum_squared_error: torch.Tensor,
|
||||||
n_obs: int,
|
n_obs: torch.Tensor,
|
||||||
data_range: float,
|
data_range: torch.Tensor,
|
||||||
base: float = 10.0,
|
base: float = 10.0,
|
||||||
reduction: str = 'elementwise_mean',
|
reduction: str = 'elementwise_mean',
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if reduction != 'elementwise_mean':
|
|
||||||
rank_zero_warn(f'The `reduction={reduction}` parameter is unused and will not have any effect.')
|
|
||||||
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
|
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
|
||||||
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
|
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
|
||||||
return psnr
|
return utils.reduce(psnr, reduction=reduction)
|
||||||
|
|
||||||
|
|
||||||
def _psnr_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
def _psnr_update(preds: torch.Tensor,
|
||||||
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
|
target: torch.Tensor,
|
||||||
n_obs = target.numel()
|
dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if dim is None:
|
||||||
|
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
|
||||||
|
n_obs = torch.tensor(target.numel(), device=target.device)
|
||||||
|
return sum_squared_error, n_obs
|
||||||
|
|
||||||
|
sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim)
|
||||||
|
|
||||||
|
if isinstance(dim, int):
|
||||||
|
dim_list = [dim]
|
||||||
|
else:
|
||||||
|
dim_list = list(dim)
|
||||||
|
if not dim_list:
|
||||||
|
n_obs = torch.tensor(target.numel(), device=target.device)
|
||||||
|
else:
|
||||||
|
n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod()
|
||||||
|
n_obs = n_obs.expand_as(sum_squared_error)
|
||||||
|
|
||||||
return sum_squared_error, n_obs
|
return sum_squared_error, n_obs
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +60,7 @@ def psnr(
|
||||||
data_range: Optional[float] = None,
|
data_range: Optional[float] = None,
|
||||||
base: float = 10.0,
|
base: float = 10.0,
|
||||||
reduction: str = 'elementwise_mean',
|
reduction: str = 'elementwise_mean',
|
||||||
|
dim: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Computes the peak signal-to-noise ratio
|
Computes the peak signal-to-noise ratio
|
||||||
|
@ -38,7 +68,9 @@ def psnr(
|
||||||
Args:
|
Args:
|
||||||
preds: estimated signal
|
preds: estimated signal
|
||||||
target: groun truth signal
|
target: groun truth signal
|
||||||
data_range: the range of the data. If None, it is determined from the data (max - min)
|
data_range:
|
||||||
|
the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given
|
||||||
|
when ``dim`` is not None.
|
||||||
base: a base of a logarithm to use (default: 10)
|
base: a base of a logarithm to use (default: 10)
|
||||||
reduction: a method to reduce metric score over labels.
|
reduction: a method to reduce metric score over labels.
|
||||||
|
|
||||||
|
@ -46,6 +78,9 @@ def psnr(
|
||||||
- ``'sum'``: takes the sum
|
- ``'sum'``: takes the sum
|
||||||
- ``'none'``: no reduction will be applied
|
- ``'none'``: no reduction will be applied
|
||||||
|
|
||||||
|
dim:
|
||||||
|
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
|
||||||
|
None meaning scores will be reduced across all dimensions.
|
||||||
Return:
|
Return:
|
||||||
Tensor with PSNR score
|
Tensor with PSNR score
|
||||||
|
|
||||||
|
@ -57,9 +92,17 @@ def psnr(
|
||||||
tensor(2.5527)
|
tensor(2.5527)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if dim is None and reduction != 'elementwise_mean':
|
||||||
|
utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
|
||||||
|
|
||||||
if data_range is None:
|
if data_range is None:
|
||||||
|
if dim is not None:
|
||||||
|
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate
|
||||||
|
# `data_range` in the future.
|
||||||
|
raise ValueError("The `data_range` must be given when `dim` is not None.")
|
||||||
|
|
||||||
data_range = target.max() - target.min()
|
data_range = target.max() - target.min()
|
||||||
else:
|
else:
|
||||||
data_range = torch.tensor(float(data_range))
|
data_range = torch.tensor(float(data_range))
|
||||||
sum_squared_error, n_obs = _psnr_update(preds, target)
|
sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
|
||||||
return _psnr_compute(sum_squared_error, n_obs, data_range, base, reduction)
|
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)
|
||||||
|
|
|
@ -11,10 +11,11 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from pytorch_lightning import utilities
|
||||||
from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update
|
from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update
|
||||||
from pytorch_lightning.metrics.metric import Metric
|
from pytorch_lightning.metrics.metric import Metric
|
||||||
|
|
||||||
|
@ -29,7 +30,9 @@ class PSNR(Metric):
|
||||||
<https://en.wikipedia.org/wiki/Mean_squared_error>`_ function.
|
<https://en.wikipedia.org/wiki/Mean_squared_error>`_ function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_range: the range of the data. If None, it is determined from the data (max - min)
|
data_range:
|
||||||
|
the range of the data. If None, it is determined from the data (max - min).
|
||||||
|
The ``data_range`` must be given when ``dim`` is not None.
|
||||||
base: a base of a logarithm to use (default: 10)
|
base: a base of a logarithm to use (default: 10)
|
||||||
reduction: a method to reduce metric score over labels.
|
reduction: a method to reduce metric score over labels.
|
||||||
|
|
||||||
|
@ -37,6 +40,9 @@ class PSNR(Metric):
|
||||||
- ``'sum'``: takes the sum
|
- ``'sum'``: takes the sum
|
||||||
- ``'none'``: no reduction will be applied
|
- ``'none'``: no reduction will be applied
|
||||||
|
|
||||||
|
dim:
|
||||||
|
Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is
|
||||||
|
None meaning scores will be reduced across all dimensions and all batches.
|
||||||
compute_on_step:
|
compute_on_step:
|
||||||
Forward only calls ``update()`` and return None if this is set to False. default: True
|
Forward only calls ``update()`` and return None if this is set to False. default: True
|
||||||
dist_sync_on_step:
|
dist_sync_on_step:
|
||||||
|
@ -61,6 +67,7 @@ class PSNR(Metric):
|
||||||
data_range: Optional[float] = None,
|
data_range: Optional[float] = None,
|
||||||
base: float = 10.0,
|
base: float = 10.0,
|
||||||
reduction: str = 'elementwise_mean',
|
reduction: str = 'elementwise_mean',
|
||||||
|
dim: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||||
compute_on_step: bool = True,
|
compute_on_step: bool = True,
|
||||||
dist_sync_on_step: bool = False,
|
dist_sync_on_step: bool = False,
|
||||||
process_group: Optional[Any] = None,
|
process_group: Optional[Any] = None,
|
||||||
|
@ -71,9 +78,22 @@ class PSNR(Metric):
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
if dim is None and reduction != 'elementwise_mean':
|
||||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
|
||||||
|
|
||||||
|
if dim is None:
|
||||||
|
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||||
|
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||||
|
else:
|
||||||
|
self.add_state("sum_squared_error", default=[])
|
||||||
|
self.add_state("total", default=[])
|
||||||
|
|
||||||
if data_range is None:
|
if data_range is None:
|
||||||
|
if dim is not None:
|
||||||
|
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to
|
||||||
|
# calculate `data_range` in the future.
|
||||||
|
raise ValueError("The `data_range` must be given when `dim` is not None.")
|
||||||
|
|
||||||
self.data_range = None
|
self.data_range = None
|
||||||
self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min)
|
self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min)
|
||||||
self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max)
|
self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max)
|
||||||
|
@ -81,6 +101,7 @@ class PSNR(Metric):
|
||||||
self.register_buffer("data_range", torch.tensor(float(data_range)))
|
self.register_buffer("data_range", torch.tensor(float(data_range)))
|
||||||
self.base = base
|
self.base = base
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
|
||||||
|
|
||||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
|
@ -90,14 +111,18 @@ class PSNR(Metric):
|
||||||
preds: Predictions from model
|
preds: Predictions from model
|
||||||
target: Ground truth values
|
target: Ground truth values
|
||||||
"""
|
"""
|
||||||
if self.data_range is None:
|
sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim)
|
||||||
# keep track of min and max target values
|
if self.dim is None:
|
||||||
self.min_target = min(target.min(), self.min_target)
|
if self.data_range is None:
|
||||||
self.max_target = max(target.max(), self.max_target)
|
# keep track of min and max target values
|
||||||
|
self.min_target = min(target.min(), self.min_target)
|
||||||
|
self.max_target = max(target.max(), self.max_target)
|
||||||
|
|
||||||
sum_squared_error, n_obs = _psnr_update(preds, target)
|
self.sum_squared_error += sum_squared_error
|
||||||
self.sum_squared_error += sum_squared_error
|
self.total += n_obs
|
||||||
self.total += n_obs
|
else:
|
||||||
|
self.sum_squared_error.append(sum_squared_error)
|
||||||
|
self.total.append(n_obs)
|
||||||
|
|
||||||
def compute(self):
|
def compute(self):
|
||||||
"""
|
"""
|
||||||
|
@ -107,4 +132,11 @@ class PSNR(Metric):
|
||||||
data_range = self.data_range
|
data_range = self.data_range
|
||||||
else:
|
else:
|
||||||
data_range = self.max_target - self.min_target
|
data_range = self.max_target - self.min_target
|
||||||
return _psnr_compute(self.sum_squared_error, self.total, data_range, self.base, self.reduction)
|
|
||||||
|
if self.dim is None:
|
||||||
|
sum_squared_error = self.sum_squared_error
|
||||||
|
total = self.total
|
||||||
|
else:
|
||||||
|
sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error])
|
||||||
|
total = torch.cat([values.flatten() for values in self.total])
|
||||||
|
return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -14,67 +28,106 @@ torch.manual_seed(42)
|
||||||
|
|
||||||
Input = namedtuple('Input', ["preds", "target"])
|
Input = namedtuple('Input', ["preds", "target"])
|
||||||
|
|
||||||
|
_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32)
|
||||||
_inputs = [
|
_inputs = [
|
||||||
Input(
|
Input(
|
||||||
preds=torch.randint(n_cls_pred, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
|
preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float),
|
||||||
target=torch.randint(n_cls_target, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
|
target=torch.randint(n_cls_target, _input_size, dtype=torch.float),
|
||||||
) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
|
) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _sk_metric(preds, target, data_range):
|
def _to_sk_peak_signal_noise_ratio_inputs(value, dim):
|
||||||
sk_preds = preds.view(-1).numpy()
|
value = value.numpy()
|
||||||
sk_target = target.view(-1).numpy()
|
batches = value[None] if value.ndim == len(_input_size) - 1 else value
|
||||||
return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
|
|
||||||
|
if dim is None:
|
||||||
|
return [batches]
|
||||||
|
|
||||||
|
num_dims = np.size(dim)
|
||||||
|
if not num_dims:
|
||||||
|
return batches
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
for batch in batches:
|
||||||
|
batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0))
|
||||||
|
psnr_input_shape = batch.shape[-num_dims:]
|
||||||
|
inputs.extend(batch.reshape(-1, *psnr_input_shape))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
def _base_e_sk_metric(preds, target, data_range):
|
def _sk_psnr(preds, target, data_range, reduction, dim):
|
||||||
sk_preds = preds.view(-1).numpy()
|
sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim)
|
||||||
sk_target = target.view(-1).numpy()
|
sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim)
|
||||||
return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) * np.log(10)
|
np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum}
|
||||||
|
return np_reduce_map[reduction]([
|
||||||
|
peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
|
||||||
|
for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _base_e_sk_psnr(preds, target, data_range, reduction, dim):
|
||||||
|
return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"preds, target, data_range",
|
"preds, target, data_range, reduction, dim",
|
||||||
[
|
[
|
||||||
(_inputs[0].preds, _inputs[0].target, 10),
|
(_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None),
|
||||||
(_inputs[1].preds, _inputs[1].target, 10),
|
(_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None),
|
||||||
(_inputs[2].preds, _inputs[2].target, 5),
|
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None),
|
||||||
|
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1),
|
||||||
|
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)),
|
||||||
|
(_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"base, sk_metric",
|
"base, sk_metric",
|
||||||
[
|
[
|
||||||
(10.0, _sk_metric),
|
(10.0, _sk_psnr),
|
||||||
(2.718281828459045, _base_e_sk_metric),
|
(2.718281828459045, _base_e_sk_psnr),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
class TestPSNR(MetricTester):
|
class TestPSNR(MetricTester):
|
||||||
|
|
||||||
@pytest.mark.parametrize("ddp", [True, False])
|
@pytest.mark.parametrize("ddp", [True, False])
|
||||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||||
def test_psnr(self, preds, target, data_range, base, sk_metric, ddp, dist_sync_on_step):
|
def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step):
|
||||||
|
_args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
|
||||||
self.run_class_metric_test(
|
self.run_class_metric_test(
|
||||||
ddp,
|
ddp,
|
||||||
preds,
|
preds,
|
||||||
target,
|
target,
|
||||||
PSNR,
|
PSNR,
|
||||||
partial(sk_metric, data_range=data_range),
|
partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
|
||||||
metric_args={
|
metric_args=_args,
|
||||||
"data_range": data_range,
|
|
||||||
"base": base
|
|
||||||
},
|
|
||||||
dist_sync_on_step=dist_sync_on_step,
|
dist_sync_on_step=dist_sync_on_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_psnr_functional(self, preds, target, sk_metric, data_range, base):
|
def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim):
|
||||||
|
_args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
|
||||||
self.run_functional_metric_test(
|
self.run_functional_metric_test(
|
||||||
preds,
|
preds,
|
||||||
target,
|
target,
|
||||||
psnr,
|
psnr,
|
||||||
partial(sk_metric, data_range=data_range),
|
partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
|
||||||
metric_args={
|
metric_args=_args,
|
||||||
"data_range": data_range,
|
|
||||||
"base": base
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("reduction", ["none", "sum"])
|
||||||
|
def test_reduction_for_dim_none(reduction):
|
||||||
|
match = f"The `reduction={reduction}` will not have any effect when `dim` is None."
|
||||||
|
with pytest.warns(UserWarning, match=match):
|
||||||
|
PSNR(reduction=reduction, dim=None)
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning, match=match):
|
||||||
|
psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_data_range():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PSNR(data_range=None, dim=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче