diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index c0e95a1..434b2ae 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -1,27 +1,56 @@ -from typing import Optional, Tuple +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union import torch -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning import utilities +from pytorch_lightning.metrics import utils def _psnr_compute( sum_squared_error: torch.Tensor, - n_obs: int, - data_range: float, + n_obs: torch.Tensor, + data_range: torch.Tensor, base: float = 10.0, reduction: str = 'elementwise_mean', ) -> torch.Tensor: - if reduction != 'elementwise_mean': - rank_zero_warn(f'The `reduction={reduction}` parameter is unused and will not have any effect.') psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) - return psnr + return utils.reduce(psnr, reduction=reduction) -def _psnr_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = target.numel() +def _psnr_update(preds: torch.Tensor, + target: torch.Tensor, + dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if dim is None: + sum_squared_error = torch.sum(torch.pow(preds - target, 2)) + n_obs = torch.tensor(target.numel(), device=target.device) + return sum_squared_error, n_obs + + sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim) + + if isinstance(dim, int): + dim_list = [dim] + else: + dim_list = list(dim) + if not dim_list: + n_obs = torch.tensor(target.numel(), device=target.device) + else: + n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod() + n_obs = n_obs.expand_as(sum_squared_error) + return sum_squared_error, n_obs @@ -31,6 +60,7 @@ def psnr( data_range: Optional[float] = None, base: float = 10.0, reduction: str = 'elementwise_mean', + dim: Optional[Union[int, Tuple[int, ...]]] = None, ) -> torch.Tensor: """ Computes the peak signal-to-noise ratio @@ -38,7 +68,9 @@ def psnr( Args: preds: estimated signal target: groun truth signal - data_range: the range of the data. If None, it is determined from the data (max - min) + data_range: + the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given + when ``dim`` is not None. base: a base of a logarithm to use (default: 10) reduction: a method to reduce metric score over labels. @@ -46,6 +78,9 @@ def psnr( - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied + dim: + Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is + None meaning scores will be reduced across all dimensions. Return: Tensor with PSNR score @@ -57,9 +92,17 @@ def psnr( tensor(2.5527) """ + if dim is None and reduction != 'elementwise_mean': + utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') + if data_range is None: + if dim is not None: + # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate + # `data_range` in the future. + raise ValueError("The `data_range` must be given when `dim` is not None.") + data_range = target.max() - target.min() else: data_range = torch.tensor(float(data_range)) - sum_squared_error, n_obs = _psnr_update(preds, target) - return _psnr_compute(sum_squared_error, n_obs, data_range, base, reduction) + sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim) + return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction) diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index bfadf97..b07941f 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Sequence, Tuple, Union import torch +from pytorch_lightning import utilities from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update from pytorch_lightning.metrics.metric import Metric @@ -29,7 +30,9 @@ class PSNR(Metric): `_ function. Args: - data_range: the range of the data. If None, it is determined from the data (max - min) + data_range: + the range of the data. If None, it is determined from the data (max - min). + The ``data_range`` must be given when ``dim`` is not None. base: a base of a logarithm to use (default: 10) reduction: a method to reduce metric score over labels. @@ -37,6 +40,9 @@ class PSNR(Metric): - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied + dim: + Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is + None meaning scores will be reduced across all dimensions and all batches. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -61,6 +67,7 @@ class PSNR(Metric): data_range: Optional[float] = None, base: float = 10.0, reduction: str = 'elementwise_mean', + dim: Optional[Union[int, Tuple[int, ...]]] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -71,9 +78,22 @@ class PSNR(Metric): process_group=process_group, ) - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + if dim is None and reduction != 'elementwise_mean': + utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') + + if dim is None: + self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + else: + self.add_state("sum_squared_error", default=[]) + self.add_state("total", default=[]) + if data_range is None: + if dim is not None: + # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to + # calculate `data_range` in the future. + raise ValueError("The `data_range` must be given when `dim` is not None.") + self.data_range = None self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min) self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max) @@ -81,6 +101,7 @@ class PSNR(Metric): self.register_buffer("data_range", torch.tensor(float(data_range))) self.base = base self.reduction = reduction + self.dim = tuple(dim) if isinstance(dim, Sequence) else dim def update(self, preds: torch.Tensor, target: torch.Tensor): """ @@ -90,14 +111,18 @@ class PSNR(Metric): preds: Predictions from model target: Ground truth values """ - if self.data_range is None: - # keep track of min and max target values - self.min_target = min(target.min(), self.min_target) - self.max_target = max(target.max(), self.max_target) + sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim) + if self.dim is None: + if self.data_range is None: + # keep track of min and max target values + self.min_target = min(target.min(), self.min_target) + self.max_target = max(target.max(), self.max_target) - sum_squared_error, n_obs = _psnr_update(preds, target) - self.sum_squared_error += sum_squared_error - self.total += n_obs + self.sum_squared_error += sum_squared_error + self.total += n_obs + else: + self.sum_squared_error.append(sum_squared_error) + self.total.append(n_obs) def compute(self): """ @@ -107,4 +132,11 @@ class PSNR(Metric): data_range = self.data_range else: data_range = self.max_target - self.min_target - return _psnr_compute(self.sum_squared_error, self.total, data_range, self.base, self.reduction) + + if self.dim is None: + sum_squared_error = self.sum_squared_error + total = self.total + else: + sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error]) + total = torch.cat([values.flatten() for values in self.total]) + return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction) diff --git a/tests/metrics/regression/test_psnr.py b/tests/metrics/regression/test_psnr.py index bc1c8d9..eb07fff 100644 --- a/tests/metrics/regression/test_psnr.py +++ b/tests/metrics/regression/test_psnr.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import namedtuple from functools import partial @@ -14,67 +28,106 @@ torch.manual_seed(42) Input = namedtuple('Input', ["preds", "target"]) +_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32) _inputs = [ Input( - preds=torch.randint(n_cls_pred, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float), - target=torch.randint(n_cls_target, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float), + preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float), + target=torch.randint(n_cls_target, _input_size, dtype=torch.float), ) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)] ] -def _sk_metric(preds, target, data_range): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) +def _to_sk_peak_signal_noise_ratio_inputs(value, dim): + value = value.numpy() + batches = value[None] if value.ndim == len(_input_size) - 1 else value + + if dim is None: + return [batches] + + num_dims = np.size(dim) + if not num_dims: + return batches + + inputs = [] + for batch in batches: + batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0)) + psnr_input_shape = batch.shape[-num_dims:] + inputs.extend(batch.reshape(-1, *psnr_input_shape)) + return inputs -def _base_e_sk_metric(preds, target, data_range): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) * np.log(10) +def _sk_psnr(preds, target, data_range, reduction, dim): + sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim) + sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim) + np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum} + return np_reduce_map[reduction]([ + peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) + for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists) + ]) + + +def _base_e_sk_psnr(preds, target, data_range, reduction, dim): + return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10) @pytest.mark.parametrize( - "preds, target, data_range", + "preds, target, data_range, reduction, dim", [ - (_inputs[0].preds, _inputs[0].target, 10), - (_inputs[1].preds, _inputs[1].target, 10), - (_inputs[2].preds, _inputs[2].target, 5), + (_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None), + (_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None), + (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None), + (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1), + (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)), + (_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)), ], ) @pytest.mark.parametrize( "base, sk_metric", [ - (10.0, _sk_metric), - (2.718281828459045, _base_e_sk_metric), + (10.0, _sk_psnr), + (2.718281828459045, _base_e_sk_psnr), ], ) class TestPSNR(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_psnr(self, preds, target, data_range, base, sk_metric, ddp, dist_sync_on_step): + def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step): + _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} self.run_class_metric_test( ddp, preds, target, PSNR, - partial(sk_metric, data_range=data_range), - metric_args={ - "data_range": data_range, - "base": base - }, + partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), + metric_args=_args, dist_sync_on_step=dist_sync_on_step, ) - def test_psnr_functional(self, preds, target, sk_metric, data_range, base): + def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim): + _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} self.run_functional_metric_test( preds, target, psnr, - partial(sk_metric, data_range=data_range), - metric_args={ - "data_range": data_range, - "base": base - }, + partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), + metric_args=_args, ) + + +@pytest.mark.parametrize("reduction", ["none", "sum"]) +def test_reduction_for_dim_none(reduction): + match = f"The `reduction={reduction}` will not have any effect when `dim` is None." + with pytest.warns(UserWarning, match=match): + PSNR(reduction=reduction, dim=None) + + with pytest.warns(UserWarning, match=match): + psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None) + + +def test_missing_data_range(): + with pytest.raises(ValueError): + PSNR(data_range=None, dim=0) + + with pytest.raises(ValueError): + psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)