metrics/torchmetrics/metric.py

801 строка
32 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import operator as op
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Callable, Dict, Generator, List, Optional, Union
import torch
from torch import Tensor
from torch.nn import Module
from torchmetrics.utilities import apply_to_collection, rank_zero_warn
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version
def jit_distributed_available() -> bool:
return torch.distributed.is_available() and torch.distributed.is_initialized()
class Metric(Module, ABC):
"""Base class for all metrics present in the Metrics API.
Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to
handle distributed synchronization and per-step metric computation.
Override ``update()`` and ``compute()`` functions to implement your own metric. Use
``add_state()`` to register metric state variables which keep track of state on each
call of ``update()`` and are synchronized across processes when ``compute()`` is called.
Note:
Metric state variables can either be ``torch.Tensors`` or an empty list which can we used
to store `torch.Tensors``.
Note:
Different metrics only override ``update()`` and not ``forward()``. A call to ``update()``
is valid, but it won't return the metric value at the current step. A call to ``forward()``
automatically calls ``update()`` and also returns the metric value at the current step.
Args:
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.
"""
__jit_ignored_attributes__ = ["device"]
__jit_unused_properties__ = ["is_differentiable"]
def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
) -> None:
super().__init__()
# see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
# torch/nn/modules/module.py#L227)
torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}")
self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", op.ge, "1.3.0")
self._device = torch.device("cpu")
self.dist_sync_on_step = dist_sync_on_step
self.compute_on_step = compute_on_step
self.process_group = process_group
self.dist_sync_fn = dist_sync_fn
self._to_sync = True
self._should_unsync = True
self._update_signature = inspect.signature(self.update)
self.update: Callable = self._wrap_update(self.update) # type: ignore
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore
self._computed = None
self._forward_cache = None
self._update_called = False
# initialize state
self._defaults: Dict[str, Union[List, Tensor]] = {}
self._persistent: Dict[str, bool] = {}
self._reductions: Dict[str, Union[str, Callable[[Union[List[Tensor], Tensor]], Tensor], None]] = {}
# state management
self._is_synced = False
self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None
def add_state(
self,
name: str,
default: Union[list, Tensor],
dist_reduce_fx: Optional[Union[str, Callable]] = None,
persistent: bool = False,
) -> None:
"""Adds metric state variable. Only used by subclasses.
Args:
name: The name of the state variable. The variable will then be accessible at ``self.name``.
default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state across multiple processes in distributed mode.
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.
Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
However, there won't be any reduction function applied to the synchronized metric state.
The metric states would be synced as follows
- If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across
the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric
state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``.
- If the metric state is a ``list``, the synced value will be a ``list`` containing the
combined elements from all processes.
Note:
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.
Raises:
ValueError:
If ``default`` is not a ``tensor`` or an ``empty list``.
ValueError:
If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``None``.
"""
if not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default):
raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")
if dist_reduce_fx == "sum":
dist_reduce_fx = dim_zero_sum
elif dist_reduce_fx == "mean":
dist_reduce_fx = dim_zero_mean
elif dist_reduce_fx == "cat":
dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not callable(dist_reduce_fx):
raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]")
if isinstance(default, Tensor):
default = default.contiguous()
setattr(self, name, default)
self._defaults[name] = deepcopy(default)
self._persistent[name] = persistent
self._reductions[name] = dist_reduce_fx
@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Automatically calls ``update()``.
Returns the metric value over inputs if ``compute_on_step`` is True.
"""
# add current step
if self._is_synced:
raise TorchMetricsUserError(
"The Metric shouldn't be synced when performing ``update``. "
"HINT: Did you forget to call ``unsync`` ?."
)
with torch.no_grad():
self.update(*args, **kwargs)
if self.compute_on_step:
self._to_sync = self.dist_sync_on_step
# skip restore cache operation from compute as cache is stored below.
self._should_unsync = False
# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
self._forward_cache = self.compute()
# restore context
for attr, val in cache.items():
setattr(self, attr, val)
self._is_synced = False
self._should_unsync = True
self._to_sync = True
self._computed = None
return self._forward_cache
def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
input_dict = {attr: getattr(self, attr) for attr in self._reductions}
for attr, reduction_fn in self._reductions.items():
# pre-concatenate metric states that are lists to reduce number of all_gather operations
if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
input_dict[attr] = [dim_zero_cat(input_dict[attr])]
output_dict = apply_to_collection(
input_dict,
Tensor,
dist_sync_fn,
group=process_group or self.process_group,
)
for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)
if isinstance(output_dict[attr][0], Tensor):
output_dict[attr] = torch.stack(output_dict[attr])
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])
if not (callable(reduction_fn) or reduction_fn is None):
raise TypeError("reduction_fn must be callable or None")
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
setattr(self, attr, reduced)
def _wrap_update(self, update: Callable) -> Callable:
@functools.wraps(update)
def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]:
self._computed = None
self._update_called = True
return update(*args, **kwargs)
return wrapped_func
def sync(
self,
dist_sync_fn: Optional[Callable] = None,
process_group: Optional[Any] = None,
should_sync: bool = True,
distributed_available: Optional[Callable] = jit_distributed_available,
) -> None:
"""Sync function for manually controlling when metrics states should be synced across processes.
Args:
dist_sync_fn: Function to be used to perform states synchronization
process_group:
Specify the process group on which synchronization is called.
default: None (which selects the entire world)
should_sync: Whether to apply to state synchronization. This will have an impact
only when running in a distributed setting.
distributed_available: Function to determine if we are running inside a distributed setting
"""
if self._is_synced and should_sync:
raise TorchMetricsUserError("The Metric has already been synced.")
is_distributed = distributed_available() if callable(distributed_available) else None
if not should_sync or not is_distributed:
return
if dist_sync_fn is None:
dist_sync_fn = gather_all_tensors
# cache prior to syncing
self._cache = {attr: getattr(self, attr) for attr in self._defaults}
# sync
self._sync_dist(dist_sync_fn, process_group=process_group)
self._is_synced = True
def unsync(self, should_unsync: bool = True) -> None:
"""Unsync function for manually controlling when metrics states should be reverted back to their local
states.
Args:
should_unsync: Whether to perform unsync
"""
if not should_unsync:
return
if not self._is_synced:
raise TorchMetricsUserError("The Metric has already been un-synced.")
if self._cache is None:
raise TorchMetricsUserError("The internal cache should exist to unsync the Metric.")
# if we synced, restore to cache so that we can continue to accumulate un-synced state
for attr, val in self._cache.items():
setattr(self, attr, val)
self._is_synced = False
self._cache = None
@contextmanager
def sync_context(
self,
dist_sync_fn: Optional[Callable] = None,
process_group: Optional[Any] = None,
should_sync: bool = True,
should_unsync: bool = True,
distributed_available: Optional[Callable] = jit_distributed_available,
) -> Generator:
"""Context manager to synchronize the states between processes when running in a distributed setting and
restore the local cache states after yielding.
Args:
dist_sync_fn: Function to be used to perform states synchronization
process_group:
Specify the process group on which synchronization is called.
default: None (which selects the entire world)
should_sync: Whether to apply to state synchronization. This will have an impact
only when running in a distributed setting.
should_unsync: Whether to restore the cache state so that the metrics can
continue to be accumulated.
distributed_available: Function to determine if we are running inside a distributed setting
"""
self.sync(
dist_sync_fn=dist_sync_fn,
process_group=process_group,
should_sync=should_sync,
distributed_available=distributed_available,
)
yield
self.unsync(should_unsync=self._is_synced and should_unsync)
def _wrap_compute(self, compute: Callable) -> Callable:
@functools.wraps(compute)
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
if not self._update_called:
rank_zero_warn(
f"The ``compute`` method of metric {self.__class__.__name__}"
" was called before the ``update`` method which may lead to errors,"
" as metric states have not yet been updated.",
UserWarning,
)
# return cached value
if self._computed is not None:
return self._computed
# compute relies on the sync context manager to gather the states across processes and apply reduction
# if synchronization happened, the current rank accumulated states will be restored to keep
# accumulation going if ``should_unsync=True``,
with self.sync_context(
dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, should_unsync=self._should_unsync
):
self._computed = compute(*args, **kwargs)
return self._computed
return wrapped_func
@abstractmethod
def update(self, *_: Any, **__: Any) -> None:
"""Override this method to update the state variables of your metric class."""
@abstractmethod
def compute(self) -> Any:
"""Override this method to compute the final metric value from state variables synchronized across the
distributed backend."""
def reset(self) -> None:
"""This method automatically resets the metric state variables to their default value."""
self._update_called = False
self._forward_cache = None
# lower lightning versions requires this implicitly to log metric objects correctly in self.log
if not _LIGHTNING_AVAILABLE or self._LIGHTNING_GREATER_EQUAL_1_3:
self._computed = None
for attr, default in self._defaults.items():
current_val = getattr(self, attr)
if isinstance(default, Tensor):
setattr(self, attr, default.detach().clone().to(current_val.device))
else:
setattr(self, attr, [])
# reset internal states
self._cache = None
self._is_synced = False
def clone(self) -> "Metric":
"""Make a copy of the metric."""
return deepcopy(self)
def __getstate__(self) -> Dict[str, Any]:
# ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute", "_update_signature"]}
def __setstate__(self, state: Dict[str, Any]) -> None:
# manually restore update and compute functions for pickling
self.__dict__.update(state)
self._update_signature = inspect.signature(self.update)
self.update: Callable = self._wrap_update(self.update) # type: ignore
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore
@property
def device(self) -> "torch.device":
"""Return the device of the metric."""
return self._device
def to(self, *args: Any, **kwargs: Any) -> "Metric":
"""Moves the parameters and buffers.
Normal dtype casting is not supported by this method instead use the `set_dtype` method instead.
"""
out = torch._C._nn._parse_to(*args, **kwargs)
if len(out) == 4: # pytorch 1.5 and higher
device, dtype, non_blocking, convert_to_format = out
else: # pytorch 1.4 and lower
device, dtype, non_blocking = out
convert_to_format = None
dtype = None # prevent dtype being casted
def convert(t: Tensor) -> Tensor:
if convert_to_format is not None and t.dim() in (4, 5):
return t.to(
device,
dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking,
memory_format=convert_to_format,
)
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
self._device = device
return self._apply(convert)
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "Metric":
"""Moves all model parameters and buffers to the GPU.
Arguments:
device: if specified, all parameters will be copied to that device
"""
if device is None or isinstance(device, int):
device = torch.device("cuda", index=device)
self._device = device
return super().cuda(device=device)
def cpu(self) -> "Metric":
"""Moves all model parameters and buffers to the CPU."""
self._device = torch.device("cpu")
return super().cpu()
def type(self, dst_type: Union[str, torch.dtype]) -> "Metric":
"""Method override default and prevent dtype casting.
Please use `metric.set_dtype(dtype)` instead.
"""
return self
def float(self) -> "Metric":
"""Method override default and prevent dtype casting.
Please use `metric.set_dtype(dtype)` instead.
"""
return self
def double(self) -> "Metric":
"""Method override default and prevent dtype casting.
Please use `metric.set_dtype(dtype)` instead.
"""
return self
def half(self) -> "Metric":
"""Method override default and prevent dtype casting.
Please use `metric.set_dtype(dtype)` instead.
"""
return self
def set_dtype(self, dst_type: Union[str, torch.dtype]) -> None:
"""Special version of `type` for transferring all metric states to specific dtype
Arguments:
dst_type (type or string): the desired type
"""
return super().type(dst_type)
def _apply(self, fn: Callable) -> Module:
"""Overwrite _apply function such that we can also move metric states to the correct device when `.to`,
`.cuda`, etc methods are called."""
this = super()._apply(fn)
# Also apply fn to metric states and defaults
for key, value in this._defaults.items():
if isinstance(value, Tensor):
this._defaults[key] = fn(value)
elif isinstance(value, Sequence):
this._defaults[key] = [fn(v) for v in value]
current_val = getattr(this, key)
if isinstance(current_val, Tensor):
setattr(this, key, fn(current_val))
elif isinstance(current_val, Sequence):
setattr(this, key, [fn(cur_v) for cur_v in current_val])
else:
raise TypeError(
"Expected metric state to be either a Tensor" f"or a list of Tensor, but encountered {current_val}"
)
# Additional apply to forward cache and computed attributes (may be nested)
if this._computed is not None:
this._computed = apply_to_collection(this._computed, Tensor, fn)
if this._forward_cache is not None:
this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn)
return this
def persistent(self, mode: bool = False) -> None:
"""Method for post-init to change if metric states should be saved to its state_dict."""
for key in self._persistent:
self._persistent[key] = mode
def state_dict(
self,
destination: Dict[str, Any] = None,
prefix: str = "",
keep_vars: bool = False,
) -> Optional[Dict[str, Any]]:
destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# Register metric states to be part of the state_dict
for key in self._defaults:
if not self._persistent[key]:
continue
current_val = getattr(self, key)
if not keep_vars:
if isinstance(current_val, Tensor):
current_val = current_val.detach()
elif isinstance(current_val, list):
current_val = [cur_v.detach() if isinstance(cur_v, Tensor) else cur_v for cur_v in current_val]
destination[prefix + key] = deepcopy(current_val) # type: ignore
return destination
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""Loads metric states from state_dict."""
for key in self._defaults:
name = prefix + key
if name in state_dict:
setattr(self, key, state_dict.pop(name))
super()._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
"""filter kwargs such that they match the update signature of the metric."""
# filter all parameters based on update signature except those of
# type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs)
_params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
_sign_params = self._update_signature.parameters
filtered_kwargs = {
k: v for k, v in kwargs.items() if (k in _sign_params.keys() and _sign_params[k].kind not in _params)
}
# if no kwargs filtered, return al kwargs as default
if not filtered_kwargs:
filtered_kwargs = kwargs
return filtered_kwargs
def __hash__(self) -> int:
# we need to add the id here, since PyTorch requires a module hash to be unique.
# Internally, PyTorch nn.Module relies on that for children discovery
# (see https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1544)
# For metrics that include tensors it is not a problem,
# since their hash is unique based on the memory location but we cannot rely on that for every metric.
hash_vals = [self.__class__.__name__, id(self)]
for key in self._defaults:
val = getattr(self, key)
# Special case: allow list values, so long
# as their elements are hashable
if hasattr(val, "__iter__") and not isinstance(val, Tensor):
hash_vals.extend(val)
else:
hash_vals.append(val)
return hash(tuple(hash_vals))
def __add__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.add, self, other)
def __and__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.bitwise_and, self, other)
# Fixme: this shall return bool instead of Metric
def __eq__(self, other: "Metric") -> "Metric": # type: ignore
return CompositionalMetric(torch.eq, self, other)
def __floordiv__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.floor_divide, self, other)
def __ge__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.ge, self, other)
def __gt__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.gt, self, other)
def __le__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.le, self, other)
def __lt__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.lt, self, other)
def __matmul__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.matmul, self, other)
def __mod__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.fmod, self, other)
def __mul__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.mul, self, other)
# Fixme: this shall return bool instead of Metric
def __ne__(self, other: "Metric") -> "Metric": # type: ignore
return CompositionalMetric(torch.ne, self, other)
def __or__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.bitwise_or, self, other)
def __pow__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.pow, self, other)
def __radd__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.add, other, self)
def __rand__(self, other: "Metric") -> "Metric":
# swap them since bitwise_and only supports that way and it's commutative
return CompositionalMetric(torch.bitwise_and, self, other)
def __rfloordiv__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.floor_divide, other, self)
def __rmatmul__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.matmul, other, self)
def __rmod__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.fmod, other, self)
def __rmul__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.mul, other, self)
def __ror__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.bitwise_or, other, self)
def __rpow__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.pow, other, self)
def __rsub__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.sub, other, self)
def __rtruediv__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.true_divide, other, self)
def __rxor__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.bitwise_xor, other, self)
def __sub__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.sub, self, other)
def __truediv__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.true_divide, self, other)
def __xor__(self, other: "Metric") -> "Metric":
return CompositionalMetric(torch.bitwise_xor, self, other)
def __abs__(self) -> "Metric":
return CompositionalMetric(torch.abs, self, None)
def __inv__(self) -> "Metric":
return CompositionalMetric(torch.bitwise_not, self, None)
def __invert__(self) -> "Metric":
return self.__inv__()
def __neg__(self) -> "Metric":
return CompositionalMetric(_neg, self, None)
def __pos__(self) -> "Metric":
return CompositionalMetric(torch.abs, self, None)
def __getitem__(self, idx: int) -> "Metric":
return CompositionalMetric(lambda x: x[idx], self, None)
@property
def is_differentiable(self) -> Optional[bool]:
# There is a bug in PyTorch that leads to properties being executed during scripting
# To make the metric scriptable, we add property to ignore list and switch to return None here
return None
def _neg(x: Tensor) -> Tensor:
return -torch.abs(x)
class CompositionalMetric(Metric):
"""Composition of two metrics with a specific operator which will be executed upon metrics compute."""
def __init__(
self,
operator: Callable,
metric_a: Union[Metric, int, float, Tensor],
metric_b: Union[Metric, int, float, Tensor, None],
) -> None:
"""
Args:
operator: the operator taking in one (if metric_b is None)
or two arguments. Will be applied to outputs of metric_a.compute()
and (optionally if metric_b is not None) metric_b.compute()
metric_a: first metric whose compute() result is the first argument of operator
metric_b: second metric whose compute() result is the second argument of operator.
For operators taking in only one input, this should be None
"""
super().__init__()
self.op = operator
if isinstance(metric_a, Tensor):
self.register_buffer("metric_a", metric_a)
else:
self.metric_a = metric_a
if isinstance(metric_b, Tensor):
self.register_buffer("metric_b", metric_b)
else:
self.metric_b = metric_b
def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None:
# No syncing required here. syncing will be done in metric_a and metric_b
pass
def update(self, *args: Any, **kwargs: Any) -> None:
if isinstance(self.metric_a, Metric):
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))
if isinstance(self.metric_b, Metric):
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))
def compute(self) -> Any:
# also some parsing for kwargs?
if isinstance(self.metric_a, Metric):
val_a = self.metric_a.compute()
else:
val_a = self.metric_a
if isinstance(self.metric_b, Metric):
val_b = self.metric_b.compute()
else:
val_b = self.metric_b
if val_b is None:
return self.op(val_a)
return self.op(val_a, val_b)
def reset(self) -> None:
if isinstance(self.metric_a, Metric):
self.metric_a.reset()
if isinstance(self.metric_b, Metric):
self.metric_b.reset()
def persistent(self, mode: bool = False) -> None:
if isinstance(self.metric_a, Metric):
self.metric_a.persistent(mode=mode)
if isinstance(self.metric_b, Metric):
self.metric_b.persistent(mode=mode)
def __repr__(self) -> str:
_op_metrics = f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
repr_str = self.__class__.__name__ + _op_metrics
return repr_str