Add AverageMeter implementation (#138)

* Add AverageMeter

* Fix type annotation to accomodate Python 3.6 bug

* Add tests

* Update changelog

* Add AverageMeter to docs

* fixup! Add AverageMeter to docs

* Code review comments

* Add tests for scalar case

* Fix behavior on PyTorch <1.8

* fixup! Add tests for scalar case

* fixup! fixup! Add tests for scalar case

* Update CHANGELOG.md

* Add Pearson correlation coefficient (#157)

* init files

* rest

* pep8

* changelog

* clamp

* suggestions

* rename

* format

* _sk_pearsonr

* inline

* fix sync

* fix tests

* fix docs

* Apply suggestions from code review

* Update torchmetrics/functional/regression/pearson.py

* atol

* update

* pep8

* pep8

* chlog

* .

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>

* Spearman correlation coefficient (#158)

* ranking

* init files

* update

* nearly working

* fix tests

* pep8

* add docs

* fix doctests

* fix docs

* pep8

* isort

* ghlog

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Added changes for Test Differentiability [1/n] (#154)

* added test changes

* fix style error

* fixed typo

* added changes for requires_grad

* metrics differentiability testing generalization

* Update tests/classification/test_accuracy.py

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* fix tests

* pep8

* changelog

* fix docs

* fix tests

* pep8

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Binned PR-related metrics (#128)

* WIP: Binned PR-related metrics

* attempt to fix types

* switch to linspace to make old pytorch happy

* make flake happy

* clean up

* Add more testing, move test input generation to the approproate place

* bugfixes and more stable and thorough tests

* flake8

* Reuse python zip-based implementation as it can't be reproduced with torch.where/max

* address comments

* isort

* Add docs and doctests, make APIs same as non-binned versions

* pep8

* isort

* doctests likes longer title underlines :O

* use numpy's nan_to_num

* add atol to bleu tests to make them more stable

* atol=1e-2 for bleu

* add more docs

* pep8

* remove nlp test hack

* address comments

* pep8

* abc

* flake8

* remove typecheck

* chlog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>

* version + about (#170)

* version + about

* flake8

* try

* .

* fix doc

* overload sig

* fix

* Different import style

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Bhadresh Savani <bhadreshpsavani@gmail.com>
Co-authored-by: Maxim Grechkin <maximsch2@gmail.com>
This commit is contained in:
Alan Du 2021-04-14 14:51:58 -04:00 коммит произвёл GitHub
Родитель b26c20d6f8
Коммит da0017410c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 206 добавлений и 0 удалений

Просмотреть файл

@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77),
[#135](https://github.com/PyTorchLightning/metrics/pull/135)
)
- Added `AverageMeter` for ad-hoc averages of values ([#138](https://github.com/PyTorchLightning/metrics/pull/138))
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154))

Просмотреть файл

@ -12,6 +12,12 @@ metrics.
.. autoclass:: torchmetrics.Metric
:noindex:
We also have an ``AverageMeter`` class that is helpful for defining ad-hoc metrics, when creating
your own metric type might be too burdensome.
.. autoclass:: torchmetrics.AverageMeter
:noindex:
**********************
Classification Metrics
**********************

Просмотреть файл

@ -0,0 +1,88 @@
import numpy as np
import pytest
import torch
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.average import AverageMeter
def average(values, weights):
return np.average(values, weights=weights)
def average_ignore_weights(values, weights):
return np.average(values)
class DefaultWeightWrapper(AverageMeter):
def update(self, values, weights):
super().update(values)
class ScalarWrapper(AverageMeter):
def update(self, values, weights):
# torch.ravel is PyTorch 1.8 only, so use np.ravel instead
values = values.cpu().numpy()
weights = weights.cpu().numpy()
for v, w in zip(np.ravel(values), np.ravel(weights)):
super().update(float(v), float(w))
@pytest.mark.parametrize(
"values, weights",
[
(torch.rand(NUM_BATCHES, BATCH_SIZE), torch.ones(NUM_BATCHES, BATCH_SIZE)),
(torch.rand(NUM_BATCHES, BATCH_SIZE), torch.rand(NUM_BATCHES, BATCH_SIZE) > 0.5),
(torch.rand(NUM_BATCHES, BATCH_SIZE, 2), torch.rand(NUM_BATCHES, BATCH_SIZE, 2) > 0.5),
],
)
class TestAverageMeter(MetricTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_average_fn(self, ddp, dist_sync_on_step, values, weights):
self.run_class_metric_test(
ddp=ddp,
dist_sync_on_step=dist_sync_on_step,
metric_class=AverageMeter,
sk_metric=average,
# Abuse of names here
preds=values,
target=weights,
)
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_average_fn_default(self, ddp, dist_sync_on_step, values, weights):
self.run_class_metric_test(
ddp=ddp,
dist_sync_on_step=dist_sync_on_step,
metric_class=DefaultWeightWrapper,
sk_metric=average_ignore_weights,
# Abuse of names here
preds=values,
target=weights,
)
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_average_fn_scalar(self, ddp, dist_sync_on_step, values, weights):
self.run_class_metric_test(
ddp=ddp,
dist_sync_on_step=dist_sync_on_step,
metric_class=ScalarWrapper,
sk_metric=average,
# Abuse of names here
preds=values,
target=weights,
)
@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to")
@pytest.mark.parametrize(
"weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)]
)
def test_AverageMeter_broadcasting(weights, expected):
values = torch.arange(24).reshape(2, 3, 4)
avg = AverageMeter()
assert avg(values, weights) == expected

Просмотреть файл

@ -11,6 +11,7 @@ _logger.setLevel(__logging.INFO)
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
from torchmetrics.average import AverageMeter # noqa: F401 E402
from torchmetrics.classification import ( # noqa: F401 E402
AUC,
AUROC,

110
torchmetrics/average.py Normal file
Просмотреть файл

@ -0,0 +1,110 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
import torch
from torch import Tensor
from torchmetrics.metric import Metric
class AverageMeter(Metric):
"""Computes the average of a stream of values.
Forward accepts
- ``value`` (float tensor): ``(...)``
- ``weight`` (float tensor): ``(...)``
Args:
compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
Example::
>>> from torchmetrics import AverageMeter
>>> avg = AverageMeter()
>>> avg.update(3)
>>> avg.update(1)
>>> avg.compute()
tensor(2.)
>>> avg = AverageMeter()
>>> values = torch.tensor([1., 2., 3.])
>>> avg(values)
tensor(2.)
>>> avg = AverageMeter()
>>> values = torch.tensor([1., 2.])
>>> weights = torch.tensor([3., 1.])
>>> avg(values, weights)
tensor(1.2500)
"""
def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.add_state("value", torch.zeros(()), dist_reduce_fx="sum")
self.add_state("weight", torch.zeros(()), dist_reduce_fx="sum")
# TODO: need to be strings because Unions are not pickleable in Python 3.6
def update( # type: ignore
self,
value: "Union[Tensor, float]",
weight: "Union[Tensor, float]" = 1.0
) -> None:
"""Updates the average with.
Args:
value: A tensor of observations (can also be a scalar value)
weight: The weight of each observation (automatically broadcasted
to fit ``value``)
"""
if not isinstance(value, Tensor):
value = torch.as_tensor(value, dtype=torch.float32, device=self.value.device)
if not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=torch.float32, device=self.weight.device)
# braodcast_to only supported on PyTorch 1.8+
if not hasattr(torch, "broadcast_to"):
if weight.shape == ():
weight = torch.ones_like(value) * weight
if weight.shape != value.shape:
raise ValueError("Broadcasting not supported on PyTorch <1.8")
else:
weight = torch.broadcast_to(weight, value.shape)
self.value += (value * weight).sum()
self.weight += weight.sum()
def compute(self) -> Tensor:
return self.value / self.weight