* grouping in Chlog

* implemention

* init files

* requirements

* change to optional testing

* update

* working test

* docs

* Update tests/image/test_lpips.py

* fix suggestions

* fix

* lower cpu load

* fix

* fix docs

* fix docs

* skip testing

* add seed

* fix

* diable scripting

* pep8

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Nicki Skafte 2021-08-18 10:00:12 +02:00 коммит произвёл GitHub
Родитель 00f0256db3
Коммит db281f7188
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 304 добавлений и 20 удалений

Просмотреть файл

@ -10,8 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))
- Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437))
### Changed

Просмотреть файл

@ -38,3 +38,4 @@
.. _MAPE implementation returns: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html
.. _mean squared logarithmic error: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-squared-log-error
.. _Mean Reciprocal Rank: https://en.wikipedia.org/wiki/Mean_reciprocal_rank
.. _LPIPS: https://arxiv.org/abs/1801.03924

Просмотреть файл

@ -335,6 +335,11 @@ KID
.. autoclass:: torchmetrics.KID
:noindex:
LPIPS
~~~~~
.. autoclass:: torchmetrics.LPIPS
:noindex:
PSNR
~~~~
@ -342,7 +347,6 @@ PSNR
.. autoclass:: torchmetrics.PSNR
:noindex:
SSIM
~~~~

Просмотреть файл

@ -1,3 +1,4 @@
scipy
torchvision # this is needed to internally set TV version according installed PT
torch-fidelity
lpips

Просмотреть файл

@ -15,7 +15,7 @@ import os
import pickle
import sys
from functools import partial
from typing import Any, Callable, Sequence
from typing import Any, Callable, Optional, Sequence
import numpy as np
import pytest
@ -61,7 +61,7 @@ def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
"""Utility function for recursively asserting that two results are within a certain tolerance."""
# single output compare
if isinstance(pl_result, Tensor):
assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True)
assert np.allclose(pl_result.detach().cpu().numpy(), sk_result, atol=atol, equal_nan=True)
# multi output compare
elif isinstance(pl_result, Sequence):
for pl_res, sk_res in zip(pl_result, sk_result):
@ -127,6 +127,9 @@ def _class_test(
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
assert preds.shape[0] == target.shape[0]
num_batches = preds.shape[0]
if not metric_args:
metric_args = {}
@ -149,7 +152,7 @@ def _class_test(
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)
for i in range(rank, NUM_BATCHES, worldsize):
for i in range(rank, num_batches, worldsize):
batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
batch_result = metric(preds[i], target[i], **batch_kwargs_update)
@ -177,10 +180,10 @@ def _class_test(
result = metric.compute()
_assert_tensor(result)
total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]).cpu()
total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]).cpu()
total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu()
total_target = torch.cat([target[i] for i in range(num_batches)]).cpu()
total_kwargs_update = {
k: torch.cat([v[i] for i in range(NUM_BATCHES)]).cpu() if isinstance(v, Tensor) else v
k: torch.cat([v[i] for i in range(num_batches)]).cpu() if isinstance(v, Tensor) else v
for k, v in kwargs_update.items()
}
sk_result = sk_metric(total_preds, total_target, **total_kwargs_update)
@ -213,6 +216,9 @@ def _functional_test(
kwargs_update: Additional keyword arguments that will be passed with preds and
target when running update on the metric.
"""
assert preds.shape[0] == target.shape[0]
num_batches = preds.shape[0]
if not metric_args:
metric_args = {}
@ -223,7 +229,7 @@ def _functional_test(
target = target.to(device)
kwargs_update = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
for i in range(NUM_BATCHES):
for i in range(num_batches):
extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
lightning_result = metric(preds[i], target[i], **extra_kwargs)
extra_kwargs = {
@ -238,7 +244,7 @@ def _functional_test(
def _assert_half_support(
metric_module: Metric,
metric_functional: Callable,
metric_functional: Optional[Callable],
preds: Tensor,
target: Tensor,
device: str = "cpu",
@ -263,7 +269,8 @@ def _assert_half_support(
}
metric_module = metric_module.to(device)
_assert_tensor(metric_module(y_hat, y, **kwargs_update))
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))
if metric_functional is not None:
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))
class MetricTester:
@ -411,8 +418,8 @@ class MetricTester:
preds: Tensor,
target: Tensor,
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = None,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
**kwargs_update,
):
"""Test if a metric can be used with half precision tensors on cpu
@ -435,8 +442,8 @@ class MetricTester:
preds: Tensor,
target: Tensor,
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = None,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
**kwargs_update,
):
"""Test if a metric can be used with half precision tensors on gpu
@ -459,8 +466,8 @@ class MetricTester:
preds: Tensor,
target: Tensor,
metric_module: Metric,
metric_functional: Callable,
metric_args: dict = None,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
):
"""Test if a metric is differentiable or not.
@ -480,7 +487,7 @@ class MetricTester:
# Check if requires_grad matches is_differentiable attribute
_assert_requires_grad(metric, out)
if metric.is_differentiable:
if metric.is_differentiable and metric_functional is not None:
# check for numerical correctness
assert torch.autograd.gradcheck(
partial(metric_functional, **metric_args), (preds[0].double(), target[0])

103
tests/image/test_lpips.py Normal file
Просмотреть файл

@ -0,0 +1,103 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial
import pytest
import torch
from lpips import LPIPS as reference_LPIPS
from torch import Tensor
from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.image.lpip_similarity import LPIPS
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE
seed_all(42)
Input = namedtuple("Input", ["img1", "img2"])
_inputs = Input(
img1=torch.rand(int(NUM_BATCHES * 0.4), int(BATCH_SIZE / 16), 3, 100, 100),
img2=torch.rand(int(NUM_BATCHES * 0.4), int(BATCH_SIZE / 16), 3, 100, 100),
)
def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, reduction: str = "mean") -> Tensor:
"""comparison function for tm implementation."""
ref = reference_LPIPS(net=net_type)
res = ref(img1, img2).detach().cpu().numpy()
if reduction == "mean":
return res.mean()
return res.sum()
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
@pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"])
class TestLPIPS(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
def test_lpips(self, net_type, ddp):
"""test modular implementation for correctness."""
self.run_class_metric_test(
ddp=ddp,
preds=_inputs.img1,
target=_inputs.img2,
metric_class=LPIPS,
sk_metric=partial(_compare_fn, net_type=net_type),
dist_sync_on_step=False,
check_scriptable=False,
metric_args={"net_type": net_type},
)
def test_lpips_differentiability(self, net_type):
"""test for differentiability of LPIPS metric."""
self.run_differentiability_test(preds=_inputs.img1, target=_inputs.img2, metric_module=LPIPS)
# LPIPS half + cpu does not work due to missing support in torch.min
@pytest.mark.xfail(reason="PearsonCorrcoef metric does not support cpu + half precision")
def test_lpips_half_cpu(self, net_type):
"""test for half + cpu support."""
self.run_precision_test_cpu(_inputs.img1, _inputs.img2, LPIPS)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_lpips_half_gpu(self, net_type):
"""test for half + gpu support."""
self.run_precision_test_gpu(_inputs.img1, _inputs.img2, LPIPS)
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
with pytest.raises(ValueError, match="Argument `net_type` must be one .*"):
LPIPS(net_type="resnet")
with pytest.raises(ValueError, match="Argument `reduction` must be one .*"):
LPIPS(reduction=None)
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
@pytest.mark.parametrize(
"inp1, inp2",
[
(torch.rand(1, 1, 28, 28), torch.rand(1, 3, 28, 28)), # wrong number of channels
(torch.rand(1, 3, 28, 28), torch.rand(1, 1, 28, 28)), # wrong number of channels
(torch.randn(1, 3, 28, 28), torch.rand(1, 3, 28, 28)), # non-normalized input
(torch.rand(1, 3, 28, 28), torch.randn(1, 3, 28, 28)), # non-normalized input
],
)
def test_error_on_wrong_update(inp1, inp2):
"""test error is raised on wrong input to update method."""
metric = LPIPS()
with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"):
metric(inp1, inp2)

Просмотреть файл

@ -19,11 +19,14 @@ import torch
from sklearn.metrics import precision_score, recall_score
from torch import Tensor
from tests.helpers import seed_all
from torchmetrics.classification import Precision, Recall
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler
seed_all(42)
_preds = torch.randint(10, (10, 32))
_target = torch.randint(10, (10, 32))
@ -55,10 +58,10 @@ def _sample_checker(old_samples, new_samples, op: operator, threshold: int):
@pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"])
def test_bootstrap_sampler(sampling_strategy):
"""make sure that the bootstrap sampler works as intended."""
old_samples = torch.randn(10, 2)
old_samples = torch.randn(20, 2)
# make sure that the new samples are only made up of old samples
idx = _bootstrap_sampler(10, sampling_strategy=sampling_strategy)
idx = _bootstrap_sampler(20, sampling_strategy=sampling_strategy)
new_samples = old_samples[idx]
for ns in new_samples:
assert ns in old_samples

Просмотреть файл

@ -40,7 +40,7 @@ from torchmetrics.classification import ( # noqa: E402
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.image import FID, IS, KID, PSNR, SSIM # noqa: E402
from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.regression import ( # noqa: E402
CosineSimilarity,
@ -92,6 +92,7 @@ __all__ = [
"IS",
"KID",
"KLDivergence",
"LPIPS",
"MatthewsCorrcoef",
"MeanAbsoluteError",
"MeanAbsolutePercentageError",

Просмотреть файл

@ -14,5 +14,6 @@
from torchmetrics.image.fid import FID # noqa: F401
from torchmetrics.image.inception import IS # noqa: F401
from torchmetrics.image.kid import KID # noqa: F401
from torchmetrics.image.lpip_similarity import LPIPS # noqa: F401
from torchmetrics.image.psnr import PSNR # noqa: F401
from torchmetrics.image.ssim import SSIM # noqa: F401

Просмотреть файл

@ -0,0 +1,159 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional
import torch
from torch import Tensor
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE
if _LPIPS_AVAILABLE:
from lpips import LPIPS as Lpips_backbone
else:
class Lpips_backbone(torch.nn.Module): # type: ignore
pass
class NoTrainLpips(Lpips_backbone):
def train(self, mode: bool) -> "NoTrainLpips":
"""the network should not be able to be switched away from evaluation mode."""
return super().train(False)
def _valid_img(img: Tensor) -> bool:
"""check that input is a valid image to the network."""
return img.ndim == 4 and img.shape[1] == 3 and img.min() >= -1.0 and img.max() <= 1.0
class LPIPS(Metric):
"""The Learned Perceptual Image Patch Similarity (`LPIPS_`) is used to judge the perceptual similarity between
two images. LPIPS essentially computes the similarity between the activations of two image patches for some
pre-defined network. This measure have been shown to match human perseption well. A low LPIPS score means that
image patches are perceptual similar.
Both input image patches are expected to have shape `[N, 3, H, W]` and be normalized to the [-1,1]
range. The minimum size of `H, W` depends on the chosen backbone (see `net_type` arg).
.. note:: using this metrics requires you to have ``lpips`` package installed. Either install
as ``pip install torchmetrics[image]`` or ``pip install lpips``
.. note:: this metric is not scriptable when using ``torch<1.8``. Please update your pytorch installation
if this is a issue.
Args:
net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'`
reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'` or `'mean'`.
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Raises:
ValueError:
If ``lpips`` package is not installed
ValueError:
If ``net_type`` is not one of ``"vgg"``, ``"alex"`` or ``"squeeze"``
ValueError:
If ``reduction`` is not one of ``"mean"`` or ``"sum"``
Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics import LPIPS
>>> lpips = LPIPS(net_type='vgg')
>>> img1 = torch.rand(10, 3, 100, 100)
>>> img2 = torch.rand(10, 3, 100, 100)
>>> lpips(img1, img2)
tensor([0.3566], grad_fn=<DivBackward0>)
"""
real_features: List[Tensor]
fake_features: List[Tensor]
# due to the use of named tuple in the backbone the net variable cannot be scriptet
__jit_ignored_attributes__ = ["net"]
def __init__(
self,
net_type: str = "alex",
reduction: str = "mean",
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable[[Tensor], List[Tensor]] = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
if not _LPIPS_AVAILABLE:
raise ValueError(
"LPIPS metric requires that lpips is installed."
"Either install as `pip install torchmetrics[image]` or `pip install lpips`"
)
valid_net_type = ("vgg", "alex", "squeeze")
if net_type not in valid_net_type:
raise ValueError(f"Argument `net_type` must be one of {valid_net_type}, but got {net_type}.")
self.net = NoTrainLpips(net=net_type, verbose=False)
valid_reduction = ("mean", "sum")
if reduction not in valid_reduction:
raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
self.reduction = reduction
self.add_state("sum_scores", torch.zeros(1), dist_reduce_fx="sum")
self.add_state("total", torch.zeros(1), dist_reduce_fx="sum")
def update(self, img1: Tensor, img2: Tensor) -> None: # type: ignore
"""Update internal states with lpips score.
Args:
img1: tensor with images of shape [N, 3, H, W]
img2: tensor with images of shape [N, 3, H, W]
"""
if not (_valid_img(img1) and _valid_img(img2)):
raise ValueError(
"Expected both input arguments to be normalized tensors (all values in range [-1,1])"
f" and to have shape [N, 3, H, W] but `img1` have shape {img1.shape} with values in"
f" range {[img1.min(), img1.max()]} and `img2` have shape {img2.shape} with value"
f" in range {[img2.min(), img2.max()]}"
)
loss = self.net(img1, img2).squeeze()
self.sum_scores += loss.sum()
self.total += img1.shape[0]
def compute(self) -> Tensor:
"""Compute final perceptual similarity metric."""
if self.reduction == "mean":
return self.sum_scores / self.total
elif self.reduction == "sum":
return self.sum_scores
@property
def is_differentiable(self) -> bool:
return True

Просмотреть файл

@ -79,3 +79,4 @@ _ROUGE_SCORE_AVAILABLE: bool = _module_available("rouge_score")
_BERTSCORE_AVAILABLE: bool = _module_available("bert_score")
_SCIPY_AVAILABLE: bool = _module_available("scipy")
_TORCH_FIDELITY_AVAILABLE: bool = _module_available("torch_fidelity")
_LPIPS_AVAILABLE: bool = _module_available("lpips")