Add `dim` to `pytorch_lightning.metrics.PSNR` (PL^5957)

* Add dim to PSNR

* Update CHANGELOG.md

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Add reduction tests

* Recover warnings on reduction and add tests

* Add copyright texts

* Refactor PSNR

* Change warnings

* Update pytorch_lightning/metrics/functional/psnr.py

Change functional.psnr dim doc

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Change PSNR dim docs

* Apply suggestions from code review

* tests

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
(cherry picked from commit a3dd545b733f9b0cafbe725a4f2e57bb6bed2759)
This commit is contained in:
manipopopo 2021-02-17 18:55:40 +08:00 коммит произвёл Jirka Borovec
Родитель 4c2b7b24ec
Коммит c998634d40
3 изменённых файлов: 181 добавлений и 53 удалений

Просмотреть файл

@ -1,27 +1,56 @@
from typing import Optional, Tuple # Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch import torch
from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning import utilities
from pytorch_lightning.metrics import utils
def _psnr_compute( def _psnr_compute(
sum_squared_error: torch.Tensor, sum_squared_error: torch.Tensor,
n_obs: int, n_obs: torch.Tensor,
data_range: float, data_range: torch.Tensor,
base: float = 10.0, base: float = 10.0,
reduction: str = 'elementwise_mean', reduction: str = 'elementwise_mean',
) -> torch.Tensor: ) -> torch.Tensor:
if reduction != 'elementwise_mean':
rank_zero_warn(f'The `reduction={reduction}` parameter is unused and will not have any effect.')
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
return psnr return utils.reduce(psnr, reduction=reduction)
def _psnr_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: def _psnr_update(preds: torch.Tensor,
sum_squared_error = torch.sum(torch.pow(preds - target, 2)) target: torch.Tensor,
n_obs = target.numel() dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
if dim is None:
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = torch.tensor(target.numel(), device=target.device)
return sum_squared_error, n_obs
sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim)
if isinstance(dim, int):
dim_list = [dim]
else:
dim_list = list(dim)
if not dim_list:
n_obs = torch.tensor(target.numel(), device=target.device)
else:
n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod()
n_obs = n_obs.expand_as(sum_squared_error)
return sum_squared_error, n_obs return sum_squared_error, n_obs
@ -31,6 +60,7 @@ def psnr(
data_range: Optional[float] = None, data_range: Optional[float] = None,
base: float = 10.0, base: float = 10.0,
reduction: str = 'elementwise_mean', reduction: str = 'elementwise_mean',
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Computes the peak signal-to-noise ratio Computes the peak signal-to-noise ratio
@ -38,7 +68,9 @@ def psnr(
Args: Args:
preds: estimated signal preds: estimated signal
target: groun truth signal target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min) data_range:
the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given
when ``dim`` is not None.
base: a base of a logarithm to use (default: 10) base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels. reduction: a method to reduce metric score over labels.
@ -46,6 +78,9 @@ def psnr(
- ``'sum'``: takes the sum - ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied - ``'none'``: no reduction will be applied
dim:
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
None meaning scores will be reduced across all dimensions.
Return: Return:
Tensor with PSNR score Tensor with PSNR score
@ -57,9 +92,17 @@ def psnr(
tensor(2.5527) tensor(2.5527)
""" """
if dim is None and reduction != 'elementwise_mean':
utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
if data_range is None: if data_range is None:
if dim is not None:
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate
# `data_range` in the future.
raise ValueError("The `data_range` must be given when `dim` is not None.")
data_range = target.max() - target.min() data_range = target.max() - target.min()
else: else:
data_range = torch.tensor(float(data_range)) data_range = torch.tensor(float(data_range))
sum_squared_error, n_obs = _psnr_update(preds, target) sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
return _psnr_compute(sum_squared_error, n_obs, data_range, base, reduction) return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)

Просмотреть файл

@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Optional from typing import Any, Optional, Sequence, Tuple, Union
import torch import torch
from pytorch_lightning import utilities
from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update
from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.metric import Metric
@ -29,7 +30,9 @@ class PSNR(Metric):
<https://en.wikipedia.org/wiki/Mean_squared_error>`_ function. <https://en.wikipedia.org/wiki/Mean_squared_error>`_ function.
Args: Args:
data_range: the range of the data. If None, it is determined from the data (max - min) data_range:
the range of the data. If None, it is determined from the data (max - min).
The ``data_range`` must be given when ``dim`` is not None.
base: a base of a logarithm to use (default: 10) base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels. reduction: a method to reduce metric score over labels.
@ -37,6 +40,9 @@ class PSNR(Metric):
- ``'sum'``: takes the sum - ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied - ``'none'``: no reduction will be applied
dim:
Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is
None meaning scores will be reduced across all dimensions and all batches.
compute_on_step: compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step: dist_sync_on_step:
@ -61,6 +67,7 @@ class PSNR(Metric):
data_range: Optional[float] = None, data_range: Optional[float] = None,
base: float = 10.0, base: float = 10.0,
reduction: str = 'elementwise_mean', reduction: str = 'elementwise_mean',
dim: Optional[Union[int, Tuple[int, ...]]] = None,
compute_on_step: bool = True, compute_on_step: bool = True,
dist_sync_on_step: bool = False, dist_sync_on_step: bool = False,
process_group: Optional[Any] = None, process_group: Optional[Any] = None,
@ -71,9 +78,22 @@ class PSNR(Metric):
process_group=process_group, process_group=process_group,
) )
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") if dim is None and reduction != 'elementwise_mean':
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
if dim is None:
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
else:
self.add_state("sum_squared_error", default=[])
self.add_state("total", default=[])
if data_range is None: if data_range is None:
if dim is not None:
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to
# calculate `data_range` in the future.
raise ValueError("The `data_range` must be given when `dim` is not None.")
self.data_range = None self.data_range = None
self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min) self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min)
self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max) self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max)
@ -81,6 +101,7 @@ class PSNR(Metric):
self.register_buffer("data_range", torch.tensor(float(data_range))) self.register_buffer("data_range", torch.tensor(float(data_range)))
self.base = base self.base = base
self.reduction = reduction self.reduction = reduction
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
def update(self, preds: torch.Tensor, target: torch.Tensor): def update(self, preds: torch.Tensor, target: torch.Tensor):
""" """
@ -90,14 +111,18 @@ class PSNR(Metric):
preds: Predictions from model preds: Predictions from model
target: Ground truth values target: Ground truth values
""" """
if self.data_range is None: sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim)
# keep track of min and max target values if self.dim is None:
self.min_target = min(target.min(), self.min_target) if self.data_range is None:
self.max_target = max(target.max(), self.max_target) # keep track of min and max target values
self.min_target = min(target.min(), self.min_target)
self.max_target = max(target.max(), self.max_target)
sum_squared_error, n_obs = _psnr_update(preds, target) self.sum_squared_error += sum_squared_error
self.sum_squared_error += sum_squared_error self.total += n_obs
self.total += n_obs else:
self.sum_squared_error.append(sum_squared_error)
self.total.append(n_obs)
def compute(self): def compute(self):
""" """
@ -107,4 +132,11 @@ class PSNR(Metric):
data_range = self.data_range data_range = self.data_range
else: else:
data_range = self.max_target - self.min_target data_range = self.max_target - self.min_target
return _psnr_compute(self.sum_squared_error, self.total, data_range, self.base, self.reduction)
if self.dim is None:
sum_squared_error = self.sum_squared_error
total = self.total
else:
sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error])
total = torch.cat([values.flatten() for values in self.total])
return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction)

Просмотреть файл

@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
@ -14,67 +28,106 @@ torch.manual_seed(42)
Input = namedtuple('Input', ["preds", "target"]) Input = namedtuple('Input', ["preds", "target"])
_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32)
_inputs = [ _inputs = [
Input( Input(
preds=torch.randint(n_cls_pred, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float), preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float),
target=torch.randint(n_cls_target, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float), target=torch.randint(n_cls_target, _input_size, dtype=torch.float),
) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)] ) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
] ]
def _sk_metric(preds, target, data_range): def _to_sk_peak_signal_noise_ratio_inputs(value, dim):
sk_preds = preds.view(-1).numpy() value = value.numpy()
sk_target = target.view(-1).numpy() batches = value[None] if value.ndim == len(_input_size) - 1 else value
return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
if dim is None:
return [batches]
num_dims = np.size(dim)
if not num_dims:
return batches
inputs = []
for batch in batches:
batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0))
psnr_input_shape = batch.shape[-num_dims:]
inputs.extend(batch.reshape(-1, *psnr_input_shape))
return inputs
def _base_e_sk_metric(preds, target, data_range): def _sk_psnr(preds, target, data_range, reduction, dim):
sk_preds = preds.view(-1).numpy() sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim)
sk_target = target.view(-1).numpy() sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim)
return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) * np.log(10) np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum}
return np_reduce_map[reduction]([
peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists)
])
def _base_e_sk_psnr(preds, target, data_range, reduction, dim):
return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"preds, target, data_range", "preds, target, data_range, reduction, dim",
[ [
(_inputs[0].preds, _inputs[0].target, 10), (_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None),
(_inputs[1].preds, _inputs[1].target, 10), (_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None),
(_inputs[2].preds, _inputs[2].target, 5), (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None),
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1),
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)),
(_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"base, sk_metric", "base, sk_metric",
[ [
(10.0, _sk_metric), (10.0, _sk_psnr),
(2.718281828459045, _base_e_sk_metric), (2.718281828459045, _base_e_sk_psnr),
], ],
) )
class TestPSNR(MetricTester): class TestPSNR(MetricTester):
@pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_psnr(self, preds, target, data_range, base, sk_metric, ddp, dist_sync_on_step): def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step):
_args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
self.run_class_metric_test( self.run_class_metric_test(
ddp, ddp,
preds, preds,
target, target,
PSNR, PSNR,
partial(sk_metric, data_range=data_range), partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
metric_args={ metric_args=_args,
"data_range": data_range,
"base": base
},
dist_sync_on_step=dist_sync_on_step, dist_sync_on_step=dist_sync_on_step,
) )
def test_psnr_functional(self, preds, target, sk_metric, data_range, base): def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim):
_args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
self.run_functional_metric_test( self.run_functional_metric_test(
preds, preds,
target, target,
psnr, psnr,
partial(sk_metric, data_range=data_range), partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
metric_args={ metric_args=_args,
"data_range": data_range,
"base": base
},
) )
@pytest.mark.parametrize("reduction", ["none", "sum"])
def test_reduction_for_dim_none(reduction):
match = f"The `reduction={reduction}` will not have any effect when `dim` is None."
with pytest.warns(UserWarning, match=match):
PSNR(reduction=reduction, dim=None)
with pytest.warns(UserWarning, match=match):
psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None)
def test_missing_data_range():
with pytest.raises(ValueError):
PSNR(data_range=None, dim=0)
with pytest.raises(ValueError):
psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)