diff --git a/CHANGELOG.md b/CHANGELOG.md index a6c5885..1efedc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for negative targets in `nDCG` metric ([#378](https://github.com/PyTorchLightning/metrics/pull/378)) +- Added `MetricTracker` wrapper metric for keeping track of the same metric over multiple epochs ([#238](https://github.com/PyTorchLightning/metrics/pull/238)) + + ### Changed - Moved `psnr` and `ssim` from `functional.regression.*` to `functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382)) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 9fced28..ea315ee 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -556,3 +556,9 @@ BootStrapper .. autoclass:: torchmetrics.BootStrapper :noindex: + +MetricTracker +~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.MetricTracker + :noindex: diff --git a/tests/wrappers/test_tracker.py b/tests/wrappers/test_tracker.py new file mode 100644 index 0000000..ce3f977 --- /dev/null +++ b/tests/wrappers/test_tracker.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import pytest +import torch + +from tests.helpers import seed_all +from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, Precision, Recall +from torchmetrics.wrappers import MetricTracker + +seed_all(42) + + +def test_raises_error_on_wrong_input(): + with pytest.raises(TypeError, match="metric arg need to be an instance of a torchmetrics metric .*"): + MetricTracker([1, 2, 3]) + + +@pytest.mark.parametrize( + "method, method_input", + [ + ("update", (torch.randint(10, (50,)), torch.randint(10, (50,)))), + ("forward", (torch.randint(10, (50,)), torch.randint(10, (50,)))), + ("compute", None), + ], +) +def test_raises_error_if_increment_not_called(method, method_input): + tracker = MetricTracker(Accuracy(num_classes=10)) + with pytest.raises(ValueError, match=f"`{method}` cannot be called before .*"): + if method_input is not None: + getattr(tracker, method)(*method_input) + else: + getattr(tracker, method)() + + +@pytest.mark.parametrize( + "base_metric, metric_input, maximize", + [ + (partial(Accuracy, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (partial(Precision, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (partial(Recall, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (MeanSquaredError, (torch.randn(50), torch.randn(50)), False), + (MeanAbsoluteError, (torch.randn(50), torch.randn(50)), False), + ], +) +def test_tracker(base_metric, metric_input, maximize): + tracker = MetricTracker(base_metric(), maximize=maximize) + for i in range(5): + tracker.increment() + # check both update and forward works + for _ in range(5): + tracker.update(*metric_input) + for _ in range(5): + tracker(*metric_input) + + val = tracker.compute() + assert val != 0.0 + assert tracker.n_steps == i + 1 + + assert tracker.n_steps == 5 + assert tracker.compute_all().shape[0] == 5 + val, idx = tracker.best_metric(return_step=True) + assert val != 0.0 + assert idx in list(range(5)) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index aeb12e6..09af128 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -63,7 +63,7 @@ from torchmetrics.retrieval import ( # noqa: E402 RetrievalRecall, ) from torchmetrics.text import WER, BLEUScore, ROUGEScore # noqa: E402 -from torchmetrics.wrappers import BootStrapper # noqa: E402 +from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402 __all__ = [ "functional", @@ -98,6 +98,7 @@ __all__ = [ "MeanSquaredLogError", "Metric", "MetricCollection", + "MetricTracker", "PearsonCorrcoef", "PIT", "Precision", diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 4f506ea..1655a0b 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 +from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 diff --git a/torchmetrics/wrappers/tracker.py b/torchmetrics/wrappers/tracker.py new file mode 100644 index 0000000..90be923 --- /dev/null +++ b/torchmetrics/wrappers/tracker.py @@ -0,0 +1,127 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +from typing import Any, Tuple, Union + +import torch +from torch import Tensor, nn + +from torchmetrics.metric import Metric + + +class MetricTracker(nn.ModuleList): + """A wrapper class that can help keeping track of a metric over time and implement useful methods. The wrapper + implements the standard `update`, `compute`, `reset` methods that just calls corresponding method of the + currently tracked metric. However, the following additional methods are provided: + + -``MetricTracker.n_steps``: number of metrics being tracked + + -``MetricTracker.increment()``: initialize a new metric for being tracked + + -``MetricTracker.compute_all()``: get the metric value for all steps + + -``MetricTracker.best_metric()``: returns the best value + + Args: + metric: instance of a torchmetric modular to keep track of at each timestep. + maximize: bool indicating if higher metric values are better (`True`) or lower + is better (`False`) + + Example: + + >>> from torchmetrics import Accuracy, MetricTracker + >>> _ = torch.manual_seed(42) + >>> tracker = MetricTracker(Accuracy(num_classes=10)) + >>> for epoch in range(5): + ... tracker.increment() + ... for batch_idx in range(5): + ... preds, target = torch.randint(10, (100,)), torch.randint(10, (100,)) + ... tracker.update(preds, target) + ... print(f"current acc={tracker.compute()}") # doctest: +NORMALIZE_WHITESPACE + current acc=0.1120000034570694 + current acc=0.08799999952316284 + current acc=0.12600000202655792 + current acc=0.07999999821186066 + current acc=0.10199999809265137 + >>> best_acc, which_epoch = tracker.best_metric(return_step=True) + >>> tracker.compute_all() + tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020]) + """ + + def __init__(self, metric: Metric, maximize: bool = True) -> None: + super().__init__() + if not isinstance(metric, Metric): + raise TypeError("metric arg need to be an instance of a torchmetrics metric" f" but got {metric}") + self._base_metric = metric + self.maximize = maximize + + self._increment_called = False + + @property + def n_steps(self) -> int: + """Returns the number of times the tracker has been incremented.""" + return len(self) - 1 # subtract the base metric + + def increment(self) -> None: + """Creates a new instace of the input metric that will be updated next.""" + self._increment_called = True + self.append(deepcopy(self._base_metric)) + + def forward(self, *args, **kwargs) -> None: # type: ignore + """Calls forward of the current metric being tracked.""" + self._check_for_increment("forward") + return self[-1](*args, **kwargs) + + def update(self, *args, **kwargs) -> None: # type: ignore + """Updates the current metric being tracked.""" + self._check_for_increment("update") + self[-1].update(*args, **kwargs) + + def compute(self) -> Any: + """Call compute of the current metric being tracked.""" + self._check_for_increment("compute") + return self[-1].compute() + + def compute_all(self) -> Tensor: + """Compute the metric value for all tracked metrics.""" + self._check_for_increment("compute_all") + return torch.stack([metric.compute() for i, metric in enumerate(self) if i != 0], dim=0) + + def reset(self) -> None: + """Resets the current metric being tracked.""" + self[-1].reset() + + def reset_all(self) -> None: + """Resets all metrics being tracked.""" + for metric in self: + metric.reset() + + def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, float]]: + """Returns the highest metric out of all tracked. + + Args: + return_step: If `True` will also return the step with the highest metric value. + + Returns: + The best metric value, and optionally the timestep. + """ + fn = torch.max if self.maximize else torch.min + idx, max = fn(self.compute_all(), 0) + if return_step: + return idx.item(), max.item() + return max.item() + + def _check_for_increment(self, method: str) -> None: + if not self._increment_called: + raise ValueError(f"`{method}` cannot be called before `.increment()` has been called")