add LPIPS (#431)
* grouping in Chlog * implemention * init files * requirements * change to optional testing * update * working test * docs * Update tests/image/test_lpips.py * fix suggestions * fix * lower cpu load * fix * fix docs * fix docs * skip testing * add seed * fix * diable scripting * pep8 Co-authored-by: Jirka <jirka.borovec@seznam.cz> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Родитель
00f0256db3
Коммит
db281f7188
|
@ -10,8 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
### Added
|
||||
|
||||
- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))
|
||||
|
||||
- Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
|
|
|
@ -38,3 +38,4 @@
|
|||
.. _MAPE implementation returns: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html
|
||||
.. _mean squared logarithmic error: https://scikit-learn.org/stable/modules/model_evaluation.html#mean-squared-log-error
|
||||
.. _Mean Reciprocal Rank: https://en.wikipedia.org/wiki/Mean_reciprocal_rank
|
||||
.. _LPIPS: https://arxiv.org/abs/1801.03924
|
||||
|
|
|
@ -335,6 +335,11 @@ KID
|
|||
.. autoclass:: torchmetrics.KID
|
||||
:noindex:
|
||||
|
||||
LPIPS
|
||||
~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.LPIPS
|
||||
:noindex:
|
||||
|
||||
PSNR
|
||||
~~~~
|
||||
|
@ -342,7 +347,6 @@ PSNR
|
|||
.. autoclass:: torchmetrics.PSNR
|
||||
:noindex:
|
||||
|
||||
|
||||
SSIM
|
||||
~~~~
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
scipy
|
||||
torchvision # this is needed to internally set TV version according installed PT
|
||||
torch-fidelity
|
||||
lpips
|
||||
|
|
|
@ -15,7 +15,7 @@ import os
|
|||
import pickle
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Sequence
|
||||
from typing import Any, Callable, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -61,7 +61,7 @@ def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
|
|||
"""Utility function for recursively asserting that two results are within a certain tolerance."""
|
||||
# single output compare
|
||||
if isinstance(pl_result, Tensor):
|
||||
assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True)
|
||||
assert np.allclose(pl_result.detach().cpu().numpy(), sk_result, atol=atol, equal_nan=True)
|
||||
# multi output compare
|
||||
elif isinstance(pl_result, Sequence):
|
||||
for pl_res, sk_res in zip(pl_result, sk_result):
|
||||
|
@ -127,6 +127,9 @@ def _class_test(
|
|||
kwargs_update: Additional keyword arguments that will be passed with preds and
|
||||
target when running update on the metric.
|
||||
"""
|
||||
assert preds.shape[0] == target.shape[0]
|
||||
num_batches = preds.shape[0]
|
||||
|
||||
if not metric_args:
|
||||
metric_args = {}
|
||||
|
||||
|
@ -149,7 +152,7 @@ def _class_test(
|
|||
pickled_metric = pickle.dumps(metric)
|
||||
metric = pickle.loads(pickled_metric)
|
||||
|
||||
for i in range(rank, NUM_BATCHES, worldsize):
|
||||
for i in range(rank, num_batches, worldsize):
|
||||
batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
|
||||
|
||||
batch_result = metric(preds[i], target[i], **batch_kwargs_update)
|
||||
|
@ -177,10 +180,10 @@ def _class_test(
|
|||
result = metric.compute()
|
||||
_assert_tensor(result)
|
||||
|
||||
total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]).cpu()
|
||||
total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]).cpu()
|
||||
total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu()
|
||||
total_target = torch.cat([target[i] for i in range(num_batches)]).cpu()
|
||||
total_kwargs_update = {
|
||||
k: torch.cat([v[i] for i in range(NUM_BATCHES)]).cpu() if isinstance(v, Tensor) else v
|
||||
k: torch.cat([v[i] for i in range(num_batches)]).cpu() if isinstance(v, Tensor) else v
|
||||
for k, v in kwargs_update.items()
|
||||
}
|
||||
sk_result = sk_metric(total_preds, total_target, **total_kwargs_update)
|
||||
|
@ -213,6 +216,9 @@ def _functional_test(
|
|||
kwargs_update: Additional keyword arguments that will be passed with preds and
|
||||
target when running update on the metric.
|
||||
"""
|
||||
assert preds.shape[0] == target.shape[0]
|
||||
num_batches = preds.shape[0]
|
||||
|
||||
if not metric_args:
|
||||
metric_args = {}
|
||||
|
||||
|
@ -223,7 +229,7 @@ def _functional_test(
|
|||
target = target.to(device)
|
||||
kwargs_update = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
|
||||
|
||||
for i in range(NUM_BATCHES):
|
||||
for i in range(num_batches):
|
||||
extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
|
||||
lightning_result = metric(preds[i], target[i], **extra_kwargs)
|
||||
extra_kwargs = {
|
||||
|
@ -238,7 +244,7 @@ def _functional_test(
|
|||
|
||||
def _assert_half_support(
|
||||
metric_module: Metric,
|
||||
metric_functional: Callable,
|
||||
metric_functional: Optional[Callable],
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
device: str = "cpu",
|
||||
|
@ -263,7 +269,8 @@ def _assert_half_support(
|
|||
}
|
||||
metric_module = metric_module.to(device)
|
||||
_assert_tensor(metric_module(y_hat, y, **kwargs_update))
|
||||
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))
|
||||
if metric_functional is not None:
|
||||
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))
|
||||
|
||||
|
||||
class MetricTester:
|
||||
|
@ -411,8 +418,8 @@ class MetricTester:
|
|||
preds: Tensor,
|
||||
target: Tensor,
|
||||
metric_module: Metric,
|
||||
metric_functional: Callable,
|
||||
metric_args: dict = None,
|
||||
metric_functional: Optional[Callable] = None,
|
||||
metric_args: Optional[dict] = None,
|
||||
**kwargs_update,
|
||||
):
|
||||
"""Test if a metric can be used with half precision tensors on cpu
|
||||
|
@ -435,8 +442,8 @@ class MetricTester:
|
|||
preds: Tensor,
|
||||
target: Tensor,
|
||||
metric_module: Metric,
|
||||
metric_functional: Callable,
|
||||
metric_args: dict = None,
|
||||
metric_functional: Optional[Callable] = None,
|
||||
metric_args: Optional[dict] = None,
|
||||
**kwargs_update,
|
||||
):
|
||||
"""Test if a metric can be used with half precision tensors on gpu
|
||||
|
@ -459,8 +466,8 @@ class MetricTester:
|
|||
preds: Tensor,
|
||||
target: Tensor,
|
||||
metric_module: Metric,
|
||||
metric_functional: Callable,
|
||||
metric_args: dict = None,
|
||||
metric_functional: Optional[Callable] = None,
|
||||
metric_args: Optional[dict] = None,
|
||||
):
|
||||
"""Test if a metric is differentiable or not.
|
||||
|
||||
|
@ -480,7 +487,7 @@ class MetricTester:
|
|||
# Check if requires_grad matches is_differentiable attribute
|
||||
_assert_requires_grad(metric, out)
|
||||
|
||||
if metric.is_differentiable:
|
||||
if metric.is_differentiable and metric_functional is not None:
|
||||
# check for numerical correctness
|
||||
assert torch.autograd.gradcheck(
|
||||
partial(metric_functional, **metric_args), (preds[0].double(), target[0])
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from lpips import LPIPS as reference_LPIPS
|
||||
from torch import Tensor
|
||||
|
||||
from tests.helpers import seed_all
|
||||
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
|
||||
from torchmetrics.image.lpip_similarity import LPIPS
|
||||
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE
|
||||
|
||||
seed_all(42)
|
||||
|
||||
Input = namedtuple("Input", ["img1", "img2"])
|
||||
|
||||
_inputs = Input(
|
||||
img1=torch.rand(int(NUM_BATCHES * 0.4), int(BATCH_SIZE / 16), 3, 100, 100),
|
||||
img2=torch.rand(int(NUM_BATCHES * 0.4), int(BATCH_SIZE / 16), 3, 100, 100),
|
||||
)
|
||||
|
||||
|
||||
def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, reduction: str = "mean") -> Tensor:
|
||||
"""comparison function for tm implementation."""
|
||||
ref = reference_LPIPS(net=net_type)
|
||||
res = ref(img1, img2).detach().cpu().numpy()
|
||||
if reduction == "mean":
|
||||
return res.mean()
|
||||
return res.sum()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
|
||||
@pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"])
|
||||
class TestLPIPS(MetricTester):
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
def test_lpips(self, net_type, ddp):
|
||||
"""test modular implementation for correctness."""
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
preds=_inputs.img1,
|
||||
target=_inputs.img2,
|
||||
metric_class=LPIPS,
|
||||
sk_metric=partial(_compare_fn, net_type=net_type),
|
||||
dist_sync_on_step=False,
|
||||
check_scriptable=False,
|
||||
metric_args={"net_type": net_type},
|
||||
)
|
||||
|
||||
def test_lpips_differentiability(self, net_type):
|
||||
"""test for differentiability of LPIPS metric."""
|
||||
self.run_differentiability_test(preds=_inputs.img1, target=_inputs.img2, metric_module=LPIPS)
|
||||
|
||||
# LPIPS half + cpu does not work due to missing support in torch.min
|
||||
@pytest.mark.xfail(reason="PearsonCorrcoef metric does not support cpu + half precision")
|
||||
def test_lpips_half_cpu(self, net_type):
|
||||
"""test for half + cpu support."""
|
||||
self.run_precision_test_cpu(_inputs.img1, _inputs.img2, LPIPS)
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
|
||||
def test_lpips_half_gpu(self, net_type):
|
||||
"""test for half + gpu support."""
|
||||
self.run_precision_test_gpu(_inputs.img1, _inputs.img2, LPIPS)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
|
||||
def test_error_on_wrong_init():
|
||||
"""Test class raises the expected errors."""
|
||||
with pytest.raises(ValueError, match="Argument `net_type` must be one .*"):
|
||||
LPIPS(net_type="resnet")
|
||||
|
||||
with pytest.raises(ValueError, match="Argument `reduction` must be one .*"):
|
||||
LPIPS(reduction=None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
|
||||
@pytest.mark.parametrize(
|
||||
"inp1, inp2",
|
||||
[
|
||||
(torch.rand(1, 1, 28, 28), torch.rand(1, 3, 28, 28)), # wrong number of channels
|
||||
(torch.rand(1, 3, 28, 28), torch.rand(1, 1, 28, 28)), # wrong number of channels
|
||||
(torch.randn(1, 3, 28, 28), torch.rand(1, 3, 28, 28)), # non-normalized input
|
||||
(torch.rand(1, 3, 28, 28), torch.randn(1, 3, 28, 28)), # non-normalized input
|
||||
],
|
||||
)
|
||||
def test_error_on_wrong_update(inp1, inp2):
|
||||
"""test error is raised on wrong input to update method."""
|
||||
metric = LPIPS()
|
||||
with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"):
|
||||
metric(inp1, inp2)
|
|
@ -19,11 +19,14 @@ import torch
|
|||
from sklearn.metrics import precision_score, recall_score
|
||||
from torch import Tensor
|
||||
|
||||
from tests.helpers import seed_all
|
||||
from torchmetrics.classification import Precision, Recall
|
||||
from torchmetrics.utilities import apply_to_collection
|
||||
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7
|
||||
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler
|
||||
|
||||
seed_all(42)
|
||||
|
||||
_preds = torch.randint(10, (10, 32))
|
||||
_target = torch.randint(10, (10, 32))
|
||||
|
||||
|
@ -55,10 +58,10 @@ def _sample_checker(old_samples, new_samples, op: operator, threshold: int):
|
|||
@pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"])
|
||||
def test_bootstrap_sampler(sampling_strategy):
|
||||
"""make sure that the bootstrap sampler works as intended."""
|
||||
old_samples = torch.randn(10, 2)
|
||||
old_samples = torch.randn(20, 2)
|
||||
|
||||
# make sure that the new samples are only made up of old samples
|
||||
idx = _bootstrap_sampler(10, sampling_strategy=sampling_strategy)
|
||||
idx = _bootstrap_sampler(20, sampling_strategy=sampling_strategy)
|
||||
new_samples = old_samples[idx]
|
||||
for ns in new_samples:
|
||||
assert ns in old_samples
|
||||
|
|
|
@ -40,7 +40,7 @@ from torchmetrics.classification import ( # noqa: E402
|
|||
StatScores,
|
||||
)
|
||||
from torchmetrics.collections import MetricCollection # noqa: E402
|
||||
from torchmetrics.image import FID, IS, KID, PSNR, SSIM # noqa: E402
|
||||
from torchmetrics.image import FID, IS, KID, LPIPS, PSNR, SSIM # noqa: E402
|
||||
from torchmetrics.metric import Metric # noqa: E402
|
||||
from torchmetrics.regression import ( # noqa: E402
|
||||
CosineSimilarity,
|
||||
|
@ -92,6 +92,7 @@ __all__ = [
|
|||
"IS",
|
||||
"KID",
|
||||
"KLDivergence",
|
||||
"LPIPS",
|
||||
"MatthewsCorrcoef",
|
||||
"MeanAbsoluteError",
|
||||
"MeanAbsolutePercentageError",
|
||||
|
|
|
@ -14,5 +14,6 @@
|
|||
from torchmetrics.image.fid import FID # noqa: F401
|
||||
from torchmetrics.image.inception import IS # noqa: F401
|
||||
from torchmetrics.image.kid import KID # noqa: F401
|
||||
from torchmetrics.image.lpip_similarity import LPIPS # noqa: F401
|
||||
from torchmetrics.image.psnr import PSNR # noqa: F401
|
||||
from torchmetrics.image.ssim import SSIM # noqa: F401
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.metric import Metric
|
||||
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE
|
||||
|
||||
if _LPIPS_AVAILABLE:
|
||||
from lpips import LPIPS as Lpips_backbone
|
||||
else:
|
||||
|
||||
class Lpips_backbone(torch.nn.Module): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class NoTrainLpips(Lpips_backbone):
|
||||
def train(self, mode: bool) -> "NoTrainLpips":
|
||||
"""the network should not be able to be switched away from evaluation mode."""
|
||||
return super().train(False)
|
||||
|
||||
|
||||
def _valid_img(img: Tensor) -> bool:
|
||||
"""check that input is a valid image to the network."""
|
||||
return img.ndim == 4 and img.shape[1] == 3 and img.min() >= -1.0 and img.max() <= 1.0
|
||||
|
||||
|
||||
class LPIPS(Metric):
|
||||
"""The Learned Perceptual Image Patch Similarity (`LPIPS_`) is used to judge the perceptual similarity between
|
||||
two images. LPIPS essentially computes the similarity between the activations of two image patches for some
|
||||
pre-defined network. This measure have been shown to match human perseption well. A low LPIPS score means that
|
||||
image patches are perceptual similar.
|
||||
|
||||
Both input image patches are expected to have shape `[N, 3, H, W]` and be normalized to the [-1,1]
|
||||
range. The minimum size of `H, W` depends on the chosen backbone (see `net_type` arg).
|
||||
|
||||
.. note:: using this metrics requires you to have ``lpips`` package installed. Either install
|
||||
as ``pip install torchmetrics[image]`` or ``pip install lpips``
|
||||
|
||||
.. note:: this metric is not scriptable when using ``torch<1.8``. Please update your pytorch installation
|
||||
if this is a issue.
|
||||
|
||||
Args:
|
||||
net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'`
|
||||
reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'` or `'mean'`.
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
|
||||
dist_sync_on_step:
|
||||
Synchronize metric state across processes at each ``forward()``
|
||||
before returning the value at the step
|
||||
process_group:
|
||||
Specify the process group on which synchronization is called.
|
||||
default: ``None`` (which selects the entire world)
|
||||
dist_sync_fn:
|
||||
Callback that performs the allgather operation on the metric state. When ``None``, DDP
|
||||
will be used to perform the allgather
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If ``lpips`` package is not installed
|
||||
ValueError:
|
||||
If ``net_type`` is not one of ``"vgg"``, ``"alex"`` or ``"squeeze"``
|
||||
ValueError:
|
||||
If ``reduction`` is not one of ``"mean"`` or ``"sum"``
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> _ = torch.manual_seed(123)
|
||||
>>> from torchmetrics import LPIPS
|
||||
>>> lpips = LPIPS(net_type='vgg')
|
||||
>>> img1 = torch.rand(10, 3, 100, 100)
|
||||
>>> img2 = torch.rand(10, 3, 100, 100)
|
||||
>>> lpips(img1, img2)
|
||||
tensor([0.3566], grad_fn=<DivBackward0>)
|
||||
"""
|
||||
|
||||
real_features: List[Tensor]
|
||||
fake_features: List[Tensor]
|
||||
|
||||
# due to the use of named tuple in the backbone the net variable cannot be scriptet
|
||||
__jit_ignored_attributes__ = ["net"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
net_type: str = "alex",
|
||||
reduction: str = "mean",
|
||||
compute_on_step: bool = True,
|
||||
dist_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
dist_sync_fn: Callable[[Tensor], List[Tensor]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
process_group=process_group,
|
||||
dist_sync_fn=dist_sync_fn,
|
||||
)
|
||||
|
||||
if not _LPIPS_AVAILABLE:
|
||||
raise ValueError(
|
||||
"LPIPS metric requires that lpips is installed."
|
||||
"Either install as `pip install torchmetrics[image]` or `pip install lpips`"
|
||||
)
|
||||
|
||||
valid_net_type = ("vgg", "alex", "squeeze")
|
||||
if net_type not in valid_net_type:
|
||||
raise ValueError(f"Argument `net_type` must be one of {valid_net_type}, but got {net_type}.")
|
||||
self.net = NoTrainLpips(net=net_type, verbose=False)
|
||||
|
||||
valid_reduction = ("mean", "sum")
|
||||
if reduction not in valid_reduction:
|
||||
raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
|
||||
self.reduction = reduction
|
||||
|
||||
self.add_state("sum_scores", torch.zeros(1), dist_reduce_fx="sum")
|
||||
self.add_state("total", torch.zeros(1), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, img1: Tensor, img2: Tensor) -> None: # type: ignore
|
||||
"""Update internal states with lpips score.
|
||||
|
||||
Args:
|
||||
img1: tensor with images of shape [N, 3, H, W]
|
||||
img2: tensor with images of shape [N, 3, H, W]
|
||||
"""
|
||||
if not (_valid_img(img1) and _valid_img(img2)):
|
||||
raise ValueError(
|
||||
"Expected both input arguments to be normalized tensors (all values in range [-1,1])"
|
||||
f" and to have shape [N, 3, H, W] but `img1` have shape {img1.shape} with values in"
|
||||
f" range {[img1.min(), img1.max()]} and `img2` have shape {img2.shape} with value"
|
||||
f" in range {[img2.min(), img2.max()]}"
|
||||
)
|
||||
|
||||
loss = self.net(img1, img2).squeeze()
|
||||
self.sum_scores += loss.sum()
|
||||
self.total += img1.shape[0]
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Compute final perceptual similarity metric."""
|
||||
if self.reduction == "mean":
|
||||
return self.sum_scores / self.total
|
||||
elif self.reduction == "sum":
|
||||
return self.sum_scores
|
||||
|
||||
@property
|
||||
def is_differentiable(self) -> bool:
|
||||
return True
|
|
@ -79,3 +79,4 @@ _ROUGE_SCORE_AVAILABLE: bool = _module_available("rouge_score")
|
|||
_BERTSCORE_AVAILABLE: bool = _module_available("bert_score")
|
||||
_SCIPY_AVAILABLE: bool = _module_available("scipy")
|
||||
_TORCH_FIDELITY_AVAILABLE: bool = _module_available("torch_fidelity")
|
||||
_LPIPS_AVAILABLE: bool = _module_available("lpips")
|
||||
|
|
Загрузка…
Ссылка в новой задаче