* docs + precision + recall + f_beta + refactor

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* rebase

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* fixes

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* added missing file

* docs

* docs

* extra import

* add metric collection

* add docs + integration with log_dict

* add test

* update

* update

* more test

* more test

* pep8

* fix doctest

* pep8

* add clone method

* add clone method

* merge-2

* changelog

* kwargs filtering and tests

* pep8

* fix test

* update docs

* Update docs/source/metrics.rst

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>

* fix docs

* fix tests

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix docs

* fix doctest

* fix doctest

* fix doctest

* fix doctest

Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai>
Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
(cherry picked from commit c371c57641947be0aa22ffa688d437bed6eeea82)
This commit is contained in:
Nicki Skafte 2021-01-08 11:09:07 +01:00 коммит произвёл Jirka Borovec
Родитель dfeaaf0713
Коммит bcb82d2a0c
4 изменённых файлов: 303 добавлений и 12 удалений

Просмотреть файл

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.metric import Metric # noqa: F401
from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
from pytorch_lightning.metrics.classification import ( # noqa: F401
Accuracy,

Просмотреть файл

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
@ -57,6 +58,7 @@ class Metric(nn.Module, ABC):
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
"""
def __init__(
self,
compute_on_step: bool = True,
@ -72,6 +74,7 @@ class Metric(nn.Module, ABC):
self.dist_sync_fn = dist_sync_fn
self._to_sync = True
self._update_signature = inspect.signature(self.update)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)
self._computed = None
@ -120,7 +123,7 @@ class Metric(nn.Module, ABC):
"""
if (
not isinstance(default, torch.Tensor)
and not isinstance(default, list) # noqa: W503
and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
raise ValueError(
@ -208,9 +211,11 @@ class Metric(nn.Module, ABC):
return self._computed
dist_sync_fn = self.dist_sync_fn
if (dist_sync_fn is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()):
if (
dist_sync_fn is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()
):
# User provided a bool, so we assume DDP if available
dist_sync_fn = gather_all_tensors
@ -250,6 +255,10 @@ class Metric(nn.Module, ABC):
else:
setattr(self, attr, deepcopy(default))
def clone(self):
""" Make a copy of the metric """
return deepcopy(self)
def __getstate__(self):
# ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
@ -292,3 +301,101 @@ class Metric(nn.Module, ABC):
current_val = getattr(self, key)
state_dict.update({key: current_val})
return state_dict
class MetricCollection(nn.ModuleDict):
"""
MetricCollection class can be used to chain metrics that have the same
call pattern into one single class.
Args:
metrics: One of the following
* list or tuple: if metrics are passed in as a list, will use the
metrics class name as key for output dict. Therefore, two metrics
of the same class cannot be chained this way.
* dict: if metrics are passed in as a dict, will use each key in the
dict as key for output dict. Use this format if you want to chain
together multiple of the same metric with different parameters.
Example (input as list):
>>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
>>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
>>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
>>> metrics = MetricCollection([Accuracy(),
... Precision(num_classes=3, average='macro'),
... Recall(num_classes=3, average='macro')])
>>> metrics(preds, target)
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
Example (input as dict):
>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
... 'macro_recall': Recall(num_classes=3, average='macro')})
>>> metrics(preds, target)
{'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)}
"""
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
super().__init__()
if isinstance(metrics, dict):
# Check all values are metrics
for name, metric in metrics.items():
if not isinstance(metric, Metric):
raise ValueError(f'Value {metric} belonging to key {name}'
' is not an instance of `pl.metrics.Metric`')
self[name] = metric
elif isinstance(metrics, (tuple, list)):
for metric in metrics:
if not isinstance(metric, Metric):
raise ValueError(f'Input {metric} to `MetricCollection` is not a instance'
' of `pl.metrics.Metric`')
name = metric.__class__.__name__
if name in self:
raise ValueError(f'Encountered two metrics both named {name}')
self[name] = metric
else:
raise ValueError('Unknown input to MetricCollection.')
def _filter_kwargs(self, metric: Metric, **kwargs):
""" filter kwargs such that they match the update signature of the metric """
return {k: v for k, v in kwargs.items() if k in metric._update_signature.parameters.keys()}
def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **self._filter_kwargs(m, **kwargs)) for k, m in self.items()}
def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Iteratively call update for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
for _, m in self.items():
m_kwargs = self._filter_kwargs(m, **kwargs)
m.update(*args, **m_kwargs)
def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.items()}
def reset(self):
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()
def clone(self):
""" Make a copy of the metric collection """
return deepcopy(self)
def persistent(self, mode: bool = True):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
m.persistent(mode)

Просмотреть файл

@ -6,8 +6,7 @@ import cloudpickle
import numpy as np
import pytest
import torch
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.metric import Metric, MetricCollection
torch.manual_seed(42)
@ -17,7 +16,7 @@ class Dummy(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx=None)
self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None)
def update(self):
pass
@ -166,7 +165,7 @@ def test_forward():
assert a.compute() == 13
class ToPickle(Dummy):
class DummyMetric1(Dummy):
def update(self, x):
self.x += x
@ -174,9 +173,17 @@ class ToPickle(Dummy):
return self.x
class DummyMetric2(Dummy):
def update(self, y):
self.x -= y
def compute(self):
return self.x
def test_pickle(tmpdir):
# doesn't tests for DDP
a = ToPickle()
a = DummyMetric1()
a.update(1)
metric_pickled = pickle.dumps(a)
@ -201,3 +208,130 @@ def test_state_dict(tmpdir):
assert metric.state_dict() == OrderedDict(x=0)
metric.persistent(False)
assert metric.state_dict() == OrderedDict()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_device_and_dtype_transfer(tmpdir):
metric = DummyMetric1()
assert metric.x.is_cuda is False
assert metric.x.dtype == torch.float32
metric = metric.to(device='cuda')
assert metric.x.is_cuda
metric = metric.double()
assert metric.x.dtype == torch.float64
metric = metric.half()
assert metric.x.dtype == torch.float16
def test_metric_collection(tmpdir):
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
# Test correct dict structure
assert len(metric_collection) == 2
assert metric_collection['DummyMetric1'] == m1
assert metric_collection['DummyMetric2'] == m2
# Test correct initialization
for name, metric in metric_collection.items():
assert metric.x == 0, f'Metric {name} not initialized correctly'
# Test every metric gets updated
metric_collection.update(5)
for name, metric in metric_collection.items():
assert metric.x.abs() == 5, f'Metric {name} not updated correctly'
# Test compute on each metric
metric_collection.update(-5)
metric_vals = metric_collection.compute()
assert len(metric_vals) == 2
for name, metric_val in metric_vals.items():
assert metric_val == 0, f'Metric {name}.compute not called correctly'
# Test that everything is reset
for name, metric in metric_collection.items():
assert metric.x == 0, f'Metric {name} not reset correctly'
# Test pickable
metric_pickled = pickle.dumps(metric_collection)
metric_loaded = pickle.loads(metric_pickled)
assert isinstance(metric_loaded, MetricCollection)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_device_and_dtype_transfer_metriccollection(tmpdir):
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
for _, metric in metric_collection.items():
assert metric.x.is_cuda is False
assert metric.x.dtype == torch.float32
metric_collection = metric_collection.to(device='cuda')
for _, metric in metric_collection.items():
assert metric.x.is_cuda
metric_collection = metric_collection.double()
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float64
metric_collection = metric_collection.half()
for _, metric in metric_collection.items():
assert metric.x.dtype == torch.float16
def test_metric_collection_wrong_input(tmpdir):
""" Check that errors are raised on wrong input """
m1 = DummyMetric1()
# Not all input are metrics (list)
with pytest.raises(ValueError):
_ = MetricCollection([m1, 5])
# Not all input are metrics (dict)
with pytest.raises(ValueError):
_ = MetricCollection({'metric1': m1,
'metric2': 5})
# Same metric passed in multiple times
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
_ = MetricCollection([m1, m1])
# Not a list or dict passed in
with pytest.raises(ValueError, match='Unknown input to MetricCollection.'):
_ = MetricCollection(m1)
def test_metric_collection_args_kwargs(tmpdir):
""" Check that args and kwargs gets passed correctly in metric collection,
Checks both update and forward method
"""
m1 = DummyMetric1()
m2 = DummyMetric2()
metric_collection = MetricCollection([m1, m2])
# args gets passed to all metrics
metric_collection.update(5)
assert metric_collection['DummyMetric1'].x == 5
assert metric_collection['DummyMetric2'].x == -5
metric_collection.reset()
_ = metric_collection(5)
assert metric_collection['DummyMetric1'].x == 5
assert metric_collection['DummyMetric2'].x == -5
metric_collection.reset()
# kwargs gets only passed to metrics that it matches
metric_collection.update(x=10, y=20)
assert metric_collection['DummyMetric1'].x == 10
assert metric_collection['DummyMetric2'].x == -20
metric_collection.reset()
_ = metric_collection(x=10, y=20)
assert metric_collection['DummyMetric1'].x == 10
assert metric_collection['DummyMetric2'].x == -20

Просмотреть файл

@ -1,7 +1,7 @@
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric
from pytorch_lightning.metrics import Metric, MetricCollection
from tests.base.boring_model import BoringModel
@ -17,6 +17,18 @@ class SumMetric(Metric):
return self.x
class DiffMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, x):
self.x -= x
def compute(self):
return self.x
def test_metric_lightning(tmpdir):
class TestModel(BoringModel):
def __init__(self):
@ -125,3 +137,41 @@ def test_scriptable(tmpdir):
output = model(rand_input)
script_output = script_model(rand_input)
assert torch.allclose(output, script_output)
def test_metric_collection_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = MetricCollection([SumMetric(), DiffMetric()])
self.sum = 0.0
self.diff = 0.0
def training_step(self, batch, batch_idx):
x = batch
metric_vals = self.metric(x.sum())
self.sum += x.sum()
self.diff -= x.sum()
self.log_dict({f'{k}_step': v for k, v in metric_vals.items()})
return self.step(x)
def training_epoch_end(self, outputs):
metric_vals = self.metric.compute()
self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()})
model = TestModel()
model.val_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)