Add AverageMeter implementation (#138)
* Add AverageMeter * Fix type annotation to accomodate Python 3.6 bug * Add tests * Update changelog * Add AverageMeter to docs * fixup! Add AverageMeter to docs * Code review comments * Add tests for scalar case * Fix behavior on PyTorch <1.8 * fixup! Add tests for scalar case * fixup! fixup! Add tests for scalar case * Update CHANGELOG.md * Add Pearson correlation coefficient (#157) * init files * rest * pep8 * changelog * clamp * suggestions * rename * format * _sk_pearsonr * inline * fix sync * fix tests * fix docs * Apply suggestions from code review * Update torchmetrics/functional/regression/pearson.py * atol * update * pep8 * pep8 * chlog * . Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> * Spearman correlation coefficient (#158) * ranking * init files * update * nearly working * fix tests * pep8 * add docs * fix doctests * fix docs * pep8 * isort * ghlog * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Added changes for Test Differentiability [1/n] (#154) * added test changes * fix style error * fixed typo * added changes for requires_grad * metrics differentiability testing generalization * Update tests/classification/test_accuracy.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * fix tests * pep8 * changelog * fix docs * fix tests * pep8 * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Binned PR-related metrics (#128) * WIP: Binned PR-related metrics * attempt to fix types * switch to linspace to make old pytorch happy * make flake happy * clean up * Add more testing, move test input generation to the approproate place * bugfixes and more stable and thorough tests * flake8 * Reuse python zip-based implementation as it can't be reproduced with torch.where/max * address comments * isort * Add docs and doctests, make APIs same as non-binned versions * pep8 * isort * doctests likes longer title underlines :O * use numpy's nan_to_num * add atol to bleu tests to make them more stable * atol=1e-2 for bleu * add more docs * pep8 * remove nlp test hack * address comments * pep8 * abc * flake8 * remove typecheck * chlog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> * version + about (#170) * version + about * flake8 * try * . * fix doc * overload sig * fix * Different import style Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> Co-authored-by: Bhadresh Savani <bhadreshpsavani@gmail.com> Co-authored-by: Maxim Grechkin <maximsch2@gmail.com>
This commit is contained in:
Родитель
b26c20d6f8
Коммит
da0017410c
|
@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77),
|
||||
[#135](https://github.com/PyTorchLightning/metrics/pull/135)
|
||||
)
|
||||
- Added `AverageMeter` for ad-hoc averages of values ([#138](https://github.com/PyTorchLightning/metrics/pull/138))
|
||||
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
|
||||
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
|
||||
- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154))
|
||||
|
|
|
@ -12,6 +12,12 @@ metrics.
|
|||
.. autoclass:: torchmetrics.Metric
|
||||
:noindex:
|
||||
|
||||
We also have an ``AverageMeter`` class that is helpful for defining ad-hoc metrics, when creating
|
||||
your own metric type might be too burdensome.
|
||||
|
||||
.. autoclass:: torchmetrics.AverageMeter
|
||||
:noindex:
|
||||
|
||||
**********************
|
||||
Classification Metrics
|
||||
**********************
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
|
||||
from torchmetrics.average import AverageMeter
|
||||
|
||||
|
||||
def average(values, weights):
|
||||
return np.average(values, weights=weights)
|
||||
|
||||
|
||||
def average_ignore_weights(values, weights):
|
||||
return np.average(values)
|
||||
|
||||
|
||||
class DefaultWeightWrapper(AverageMeter):
|
||||
def update(self, values, weights):
|
||||
super().update(values)
|
||||
|
||||
|
||||
class ScalarWrapper(AverageMeter):
|
||||
def update(self, values, weights):
|
||||
# torch.ravel is PyTorch 1.8 only, so use np.ravel instead
|
||||
values = values.cpu().numpy()
|
||||
weights = weights.cpu().numpy()
|
||||
for v, w in zip(np.ravel(values), np.ravel(weights)):
|
||||
super().update(float(v), float(w))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"values, weights",
|
||||
[
|
||||
(torch.rand(NUM_BATCHES, BATCH_SIZE), torch.ones(NUM_BATCHES, BATCH_SIZE)),
|
||||
(torch.rand(NUM_BATCHES, BATCH_SIZE), torch.rand(NUM_BATCHES, BATCH_SIZE) > 0.5),
|
||||
(torch.rand(NUM_BATCHES, BATCH_SIZE, 2), torch.rand(NUM_BATCHES, BATCH_SIZE, 2) > 0.5),
|
||||
],
|
||||
)
|
||||
class TestAverageMeter(MetricTester):
|
||||
@pytest.mark.parametrize("ddp", [False, True])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
|
||||
def test_average_fn(self, ddp, dist_sync_on_step, values, weights):
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_class=AverageMeter,
|
||||
sk_metric=average,
|
||||
# Abuse of names here
|
||||
preds=values,
|
||||
target=weights,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("ddp", [False, True])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
|
||||
def test_average_fn_default(self, ddp, dist_sync_on_step, values, weights):
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_class=DefaultWeightWrapper,
|
||||
sk_metric=average_ignore_weights,
|
||||
# Abuse of names here
|
||||
preds=values,
|
||||
target=weights,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("ddp", [False, True])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
|
||||
def test_average_fn_scalar(self, ddp, dist_sync_on_step, values, weights):
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_class=ScalarWrapper,
|
||||
sk_metric=average,
|
||||
# Abuse of names here
|
||||
preds=values,
|
||||
target=weights,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to")
|
||||
@pytest.mark.parametrize(
|
||||
"weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)]
|
||||
)
|
||||
def test_AverageMeter_broadcasting(weights, expected):
|
||||
values = torch.arange(24).reshape(2, 3, 4)
|
||||
avg = AverageMeter()
|
||||
|
||||
assert avg(values, weights) == expected
|
|
@ -11,6 +11,7 @@ _logger.setLevel(__logging.INFO)
|
|||
_PACKAGE_ROOT = os.path.dirname(__file__)
|
||||
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
|
||||
|
||||
from torchmetrics.average import AverageMeter # noqa: F401 E402
|
||||
from torchmetrics.classification import ( # noqa: F401 E402
|
||||
AUC,
|
||||
AUROC,
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.metric import Metric
|
||||
|
||||
|
||||
class AverageMeter(Metric):
|
||||
"""Computes the average of a stream of values.
|
||||
|
||||
Forward accepts
|
||||
- ``value`` (float tensor): ``(...)``
|
||||
- ``weight`` (float tensor): ``(...)``
|
||||
|
||||
Args:
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and returns None if this is
|
||||
set to False. default: True
|
||||
dist_sync_on_step:
|
||||
Synchronize metric state across processes at each ``forward()``
|
||||
before returning the value at the step.
|
||||
process_group:
|
||||
Specify the process group on which synchronization is called.
|
||||
default: None (which selects the entire world)
|
||||
dist_sync_fn:
|
||||
Callback that performs the allgather operation on the metric state.
|
||||
When `None`, DDP will be used to perform the allgather.
|
||||
|
||||
Example::
|
||||
>>> from torchmetrics import AverageMeter
|
||||
>>> avg = AverageMeter()
|
||||
>>> avg.update(3)
|
||||
>>> avg.update(1)
|
||||
>>> avg.compute()
|
||||
tensor(2.)
|
||||
|
||||
>>> avg = AverageMeter()
|
||||
>>> values = torch.tensor([1., 2., 3.])
|
||||
>>> avg(values)
|
||||
tensor(2.)
|
||||
|
||||
>>> avg = AverageMeter()
|
||||
>>> values = torch.tensor([1., 2.])
|
||||
>>> weights = torch.tensor([3., 1.])
|
||||
>>> avg(values, weights)
|
||||
tensor(1.2500)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compute_on_step: bool = True,
|
||||
dist_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
dist_sync_fn: Callable = None,
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
process_group=process_group,
|
||||
dist_sync_fn=dist_sync_fn,
|
||||
)
|
||||
self.add_state("value", torch.zeros(()), dist_reduce_fx="sum")
|
||||
self.add_state("weight", torch.zeros(()), dist_reduce_fx="sum")
|
||||
|
||||
# TODO: need to be strings because Unions are not pickleable in Python 3.6
|
||||
def update( # type: ignore
|
||||
self,
|
||||
value: "Union[Tensor, float]",
|
||||
weight: "Union[Tensor, float]" = 1.0
|
||||
) -> None:
|
||||
"""Updates the average with.
|
||||
|
||||
Args:
|
||||
value: A tensor of observations (can also be a scalar value)
|
||||
weight: The weight of each observation (automatically broadcasted
|
||||
to fit ``value``)
|
||||
"""
|
||||
if not isinstance(value, Tensor):
|
||||
value = torch.as_tensor(value, dtype=torch.float32, device=self.value.device)
|
||||
if not isinstance(weight, Tensor):
|
||||
weight = torch.as_tensor(weight, dtype=torch.float32, device=self.weight.device)
|
||||
|
||||
# braodcast_to only supported on PyTorch 1.8+
|
||||
if not hasattr(torch, "broadcast_to"):
|
||||
if weight.shape == ():
|
||||
weight = torch.ones_like(value) * weight
|
||||
if weight.shape != value.shape:
|
||||
raise ValueError("Broadcasting not supported on PyTorch <1.8")
|
||||
else:
|
||||
weight = torch.broadcast_to(weight, value.shape)
|
||||
|
||||
self.value += (value * weight).sum()
|
||||
self.weight += weight.sum()
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
return self.value / self.weight
|
Загрузка…
Ссылка в новой задаче