Add `dim` to `pytorch_lightning.metrics.PSNR` (PL^5957)
* Add dim to PSNR * Update CHANGELOG.md * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add reduction tests * Recover warnings on reduction and add tests * Add copyright texts * Refactor PSNR * Change warnings * Update pytorch_lightning/metrics/functional/psnr.py Change functional.psnr dim doc Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Change PSNR dim docs * Apply suggestions from code review * tests Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> (cherry picked from commit a3dd545b733f9b0cafbe725a4f2e57bb6bed2759)
This commit is contained in:
Родитель
4c2b7b24ec
Коммит
c998634d40
|
@ -1,27 +1,56 @@
|
|||
from typing import Optional, Tuple
|
||||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning import utilities
|
||||
from pytorch_lightning.metrics import utils
|
||||
|
||||
|
||||
def _psnr_compute(
|
||||
sum_squared_error: torch.Tensor,
|
||||
n_obs: int,
|
||||
data_range: float,
|
||||
n_obs: torch.Tensor,
|
||||
data_range: torch.Tensor,
|
||||
base: float = 10.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
if reduction != 'elementwise_mean':
|
||||
rank_zero_warn(f'The `reduction={reduction}` parameter is unused and will not have any effect.')
|
||||
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
|
||||
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
|
||||
return psnr
|
||||
return utils.reduce(psnr, reduction=reduction)
|
||||
|
||||
|
||||
def _psnr_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
||||
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
|
||||
n_obs = target.numel()
|
||||
def _psnr_update(preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if dim is None:
|
||||
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
|
||||
n_obs = torch.tensor(target.numel(), device=target.device)
|
||||
return sum_squared_error, n_obs
|
||||
|
||||
sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim)
|
||||
|
||||
if isinstance(dim, int):
|
||||
dim_list = [dim]
|
||||
else:
|
||||
dim_list = list(dim)
|
||||
if not dim_list:
|
||||
n_obs = torch.tensor(target.numel(), device=target.device)
|
||||
else:
|
||||
n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod()
|
||||
n_obs = n_obs.expand_as(sum_squared_error)
|
||||
|
||||
return sum_squared_error, n_obs
|
||||
|
||||
|
||||
|
@ -31,6 +60,7 @@ def psnr(
|
|||
data_range: Optional[float] = None,
|
||||
base: float = 10.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
dim: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the peak signal-to-noise ratio
|
||||
|
@ -38,7 +68,9 @@ def psnr(
|
|||
Args:
|
||||
preds: estimated signal
|
||||
target: groun truth signal
|
||||
data_range: the range of the data. If None, it is determined from the data (max - min)
|
||||
data_range:
|
||||
the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given
|
||||
when ``dim`` is not None.
|
||||
base: a base of a logarithm to use (default: 10)
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
|
@ -46,6 +78,9 @@ def psnr(
|
|||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
|
||||
dim:
|
||||
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
|
||||
None meaning scores will be reduced across all dimensions.
|
||||
Return:
|
||||
Tensor with PSNR score
|
||||
|
||||
|
@ -57,9 +92,17 @@ def psnr(
|
|||
tensor(2.5527)
|
||||
|
||||
"""
|
||||
if dim is None and reduction != 'elementwise_mean':
|
||||
utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
|
||||
|
||||
if data_range is None:
|
||||
if dim is not None:
|
||||
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate
|
||||
# `data_range` in the future.
|
||||
raise ValueError("The `data_range` must be given when `dim` is not None.")
|
||||
|
||||
data_range = target.max() - target.min()
|
||||
else:
|
||||
data_range = torch.tensor(float(data_range))
|
||||
sum_squared_error, n_obs = _psnr_update(preds, target)
|
||||
return _psnr_compute(sum_squared_error, n_obs, data_range, base, reduction)
|
||||
sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
|
||||
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)
|
||||
|
|
|
@ -11,10 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import utilities
|
||||
from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
@ -29,7 +30,9 @@ class PSNR(Metric):
|
|||
<https://en.wikipedia.org/wiki/Mean_squared_error>`_ function.
|
||||
|
||||
Args:
|
||||
data_range: the range of the data. If None, it is determined from the data (max - min)
|
||||
data_range:
|
||||
the range of the data. If None, it is determined from the data (max - min).
|
||||
The ``data_range`` must be given when ``dim`` is not None.
|
||||
base: a base of a logarithm to use (default: 10)
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
|
@ -37,6 +40,9 @@ class PSNR(Metric):
|
|||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
|
||||
dim:
|
||||
Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is
|
||||
None meaning scores will be reduced across all dimensions and all batches.
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return None if this is set to False. default: True
|
||||
dist_sync_on_step:
|
||||
|
@ -61,6 +67,7 @@ class PSNR(Metric):
|
|||
data_range: Optional[float] = None,
|
||||
base: float = 10.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
dim: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
compute_on_step: bool = True,
|
||||
dist_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
|
@ -71,9 +78,22 @@ class PSNR(Metric):
|
|||
process_group=process_group,
|
||||
)
|
||||
|
||||
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
if dim is None and reduction != 'elementwise_mean':
|
||||
utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
|
||||
|
||||
if dim is None:
|
||||
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
else:
|
||||
self.add_state("sum_squared_error", default=[])
|
||||
self.add_state("total", default=[])
|
||||
|
||||
if data_range is None:
|
||||
if dim is not None:
|
||||
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to
|
||||
# calculate `data_range` in the future.
|
||||
raise ValueError("The `data_range` must be given when `dim` is not None.")
|
||||
|
||||
self.data_range = None
|
||||
self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min)
|
||||
self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max)
|
||||
|
@ -81,6 +101,7 @@ class PSNR(Metric):
|
|||
self.register_buffer("data_range", torch.tensor(float(data_range)))
|
||||
self.base = base
|
||||
self.reduction = reduction
|
||||
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
|
@ -90,14 +111,18 @@ class PSNR(Metric):
|
|||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
if self.data_range is None:
|
||||
# keep track of min and max target values
|
||||
self.min_target = min(target.min(), self.min_target)
|
||||
self.max_target = max(target.max(), self.max_target)
|
||||
sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim)
|
||||
if self.dim is None:
|
||||
if self.data_range is None:
|
||||
# keep track of min and max target values
|
||||
self.min_target = min(target.min(), self.min_target)
|
||||
self.max_target = max(target.max(), self.max_target)
|
||||
|
||||
sum_squared_error, n_obs = _psnr_update(preds, target)
|
||||
self.sum_squared_error += sum_squared_error
|
||||
self.total += n_obs
|
||||
self.sum_squared_error += sum_squared_error
|
||||
self.total += n_obs
|
||||
else:
|
||||
self.sum_squared_error.append(sum_squared_error)
|
||||
self.total.append(n_obs)
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
|
@ -107,4 +132,11 @@ class PSNR(Metric):
|
|||
data_range = self.data_range
|
||||
else:
|
||||
data_range = self.max_target - self.min_target
|
||||
return _psnr_compute(self.sum_squared_error, self.total, data_range, self.base, self.reduction)
|
||||
|
||||
if self.dim is None:
|
||||
sum_squared_error = self.sum_squared_error
|
||||
total = self.total
|
||||
else:
|
||||
sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error])
|
||||
total = torch.cat([values.flatten() for values in self.total])
|
||||
return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction)
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
|
@ -14,67 +28,106 @@ torch.manual_seed(42)
|
|||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32)
|
||||
_inputs = [
|
||||
Input(
|
||||
preds=torch.randint(n_cls_pred, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
|
||||
target=torch.randint(n_cls_target, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
|
||||
preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float),
|
||||
target=torch.randint(n_cls_target, _input_size, dtype=torch.float),
|
||||
) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
|
||||
]
|
||||
|
||||
|
||||
def _sk_metric(preds, target, data_range):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
|
||||
def _to_sk_peak_signal_noise_ratio_inputs(value, dim):
|
||||
value = value.numpy()
|
||||
batches = value[None] if value.ndim == len(_input_size) - 1 else value
|
||||
|
||||
if dim is None:
|
||||
return [batches]
|
||||
|
||||
num_dims = np.size(dim)
|
||||
if not num_dims:
|
||||
return batches
|
||||
|
||||
inputs = []
|
||||
for batch in batches:
|
||||
batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0))
|
||||
psnr_input_shape = batch.shape[-num_dims:]
|
||||
inputs.extend(batch.reshape(-1, *psnr_input_shape))
|
||||
return inputs
|
||||
|
||||
|
||||
def _base_e_sk_metric(preds, target, data_range):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) * np.log(10)
|
||||
def _sk_psnr(preds, target, data_range, reduction, dim):
|
||||
sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim)
|
||||
sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim)
|
||||
np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum}
|
||||
return np_reduce_map[reduction]([
|
||||
peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
|
||||
for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists)
|
||||
])
|
||||
|
||||
|
||||
def _base_e_sk_psnr(preds, target, data_range, reduction, dim):
|
||||
return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, data_range",
|
||||
"preds, target, data_range, reduction, dim",
|
||||
[
|
||||
(_inputs[0].preds, _inputs[0].target, 10),
|
||||
(_inputs[1].preds, _inputs[1].target, 10),
|
||||
(_inputs[2].preds, _inputs[2].target, 5),
|
||||
(_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None),
|
||||
(_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None),
|
||||
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None),
|
||||
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1),
|
||||
(_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)),
|
||||
(_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"base, sk_metric",
|
||||
[
|
||||
(10.0, _sk_metric),
|
||||
(2.718281828459045, _base_e_sk_metric),
|
||||
(10.0, _sk_psnr),
|
||||
(2.718281828459045, _base_e_sk_psnr),
|
||||
],
|
||||
)
|
||||
class TestPSNR(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_psnr(self, preds, target, data_range, base, sk_metric, ddp, dist_sync_on_step):
|
||||
def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step):
|
||||
_args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
|
||||
self.run_class_metric_test(
|
||||
ddp,
|
||||
preds,
|
||||
target,
|
||||
PSNR,
|
||||
partial(sk_metric, data_range=data_range),
|
||||
metric_args={
|
||||
"data_range": data_range,
|
||||
"base": base
|
||||
},
|
||||
partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
|
||||
metric_args=_args,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
)
|
||||
|
||||
def test_psnr_functional(self, preds, target, sk_metric, data_range, base):
|
||||
def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim):
|
||||
_args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
|
||||
self.run_functional_metric_test(
|
||||
preds,
|
||||
target,
|
||||
psnr,
|
||||
partial(sk_metric, data_range=data_range),
|
||||
metric_args={
|
||||
"data_range": data_range,
|
||||
"base": base
|
||||
},
|
||||
partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
|
||||
metric_args=_args,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reduction", ["none", "sum"])
|
||||
def test_reduction_for_dim_none(reduction):
|
||||
match = f"The `reduction={reduction}` will not have any effect when `dim` is None."
|
||||
with pytest.warns(UserWarning, match=match):
|
||||
PSNR(reduction=reduction, dim=None)
|
||||
|
||||
with pytest.warns(UserWarning, match=match):
|
||||
psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None)
|
||||
|
||||
|
||||
def test_missing_data_range():
|
||||
with pytest.raises(ValueError):
|
||||
PSNR(data_range=None, dim=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)
|
||||
|
|
Загрузка…
Ссылка в новой задаче