MetricCollection (PL^4318)
* docs + precision + recall + f_beta + refactor Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * rebase Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * fixes Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * added missing file * docs * docs * extra import * add metric collection * add docs + integration with log_dict * add test * update * update * more test * more test * pep8 * fix doctest * pep8 * add clone method * add clone method * merge-2 * changelog * kwargs filtering and tests * pep8 * fix test * update docs * Update docs/source/metrics.rst Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> * fix docs * fix tests * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix docs * fix doctest * fix doctest * fix doctest * fix doctest Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai> Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> (cherry picked from commit c371c57641947be0aa22ffa688d437bed6eeea82)
This commit is contained in:
Родитель
dfeaaf0713
Коммит
bcb82d2a0c
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.metrics.metric import Metric # noqa: F401
|
||||
from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
|
||||
|
||||
from pytorch_lightning.metrics.classification import ( # noqa: F401
|
||||
Accuracy,
|
||||
|
|
|
@ -12,10 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -57,6 +58,7 @@ class Metric(nn.Module, ABC):
|
|||
Callback that performs the allgather operation on the metric state. When `None`, DDP
|
||||
will be used to perform the allgather. default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compute_on_step: bool = True,
|
||||
|
@ -72,6 +74,7 @@ class Metric(nn.Module, ABC):
|
|||
self.dist_sync_fn = dist_sync_fn
|
||||
self._to_sync = True
|
||||
|
||||
self._update_signature = inspect.signature(self.update)
|
||||
self.update = self._wrap_update(self.update)
|
||||
self.compute = self._wrap_compute(self.compute)
|
||||
self._computed = None
|
||||
|
@ -120,7 +123,7 @@ class Metric(nn.Module, ABC):
|
|||
"""
|
||||
if (
|
||||
not isinstance(default, torch.Tensor)
|
||||
and not isinstance(default, list) # noqa: W503
|
||||
and not isinstance(default, list) # noqa: W503
|
||||
or (isinstance(default, list) and len(default) != 0) # noqa: W503
|
||||
):
|
||||
raise ValueError(
|
||||
|
@ -208,9 +211,11 @@ class Metric(nn.Module, ABC):
|
|||
return self._computed
|
||||
|
||||
dist_sync_fn = self.dist_sync_fn
|
||||
if (dist_sync_fn is None
|
||||
and torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()):
|
||||
if (
|
||||
dist_sync_fn is None
|
||||
and torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
):
|
||||
# User provided a bool, so we assume DDP if available
|
||||
dist_sync_fn = gather_all_tensors
|
||||
|
||||
|
@ -250,6 +255,10 @@ class Metric(nn.Module, ABC):
|
|||
else:
|
||||
setattr(self, attr, deepcopy(default))
|
||||
|
||||
def clone(self):
|
||||
""" Make a copy of the metric """
|
||||
return deepcopy(self)
|
||||
|
||||
def __getstate__(self):
|
||||
# ignore update and compute functions for pickling
|
||||
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
|
||||
|
@ -292,3 +301,101 @@ class Metric(nn.Module, ABC):
|
|||
current_val = getattr(self, key)
|
||||
state_dict.update({key: current_val})
|
||||
return state_dict
|
||||
|
||||
|
||||
class MetricCollection(nn.ModuleDict):
|
||||
"""
|
||||
MetricCollection class can be used to chain metrics that have the same
|
||||
call pattern into one single class.
|
||||
|
||||
Args:
|
||||
metrics: One of the following
|
||||
|
||||
* list or tuple: if metrics are passed in as a list, will use the
|
||||
metrics class name as key for output dict. Therefore, two metrics
|
||||
of the same class cannot be chained this way.
|
||||
|
||||
* dict: if metrics are passed in as a dict, will use each key in the
|
||||
dict as key for output dict. Use this format if you want to chain
|
||||
together multiple of the same metric with different parameters.
|
||||
|
||||
Example (input as list):
|
||||
|
||||
>>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
|
||||
>>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
|
||||
>>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
|
||||
>>> metrics = MetricCollection([Accuracy(),
|
||||
... Precision(num_classes=3, average='macro'),
|
||||
... Recall(num_classes=3, average='macro')])
|
||||
>>> metrics(preds, target)
|
||||
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
|
||||
|
||||
Example (input as dict):
|
||||
|
||||
>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
|
||||
... 'macro_recall': Recall(num_classes=3, average='macro')})
|
||||
>>> metrics(preds, target)
|
||||
{'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)}
|
||||
|
||||
"""
|
||||
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
|
||||
super().__init__()
|
||||
if isinstance(metrics, dict):
|
||||
# Check all values are metrics
|
||||
for name, metric in metrics.items():
|
||||
if not isinstance(metric, Metric):
|
||||
raise ValueError(f'Value {metric} belonging to key {name}'
|
||||
' is not an instance of `pl.metrics.Metric`')
|
||||
self[name] = metric
|
||||
elif isinstance(metrics, (tuple, list)):
|
||||
for metric in metrics:
|
||||
if not isinstance(metric, Metric):
|
||||
raise ValueError(f'Input {metric} to `MetricCollection` is not a instance'
|
||||
' of `pl.metrics.Metric`')
|
||||
name = metric.__class__.__name__
|
||||
if name in self:
|
||||
raise ValueError(f'Encountered two metrics both named {name}')
|
||||
self[name] = metric
|
||||
else:
|
||||
raise ValueError('Unknown input to MetricCollection.')
|
||||
|
||||
def _filter_kwargs(self, metric: Metric, **kwargs):
|
||||
""" filter kwargs such that they match the update signature of the metric """
|
||||
return {k: v for k, v in kwargs.items() if k in metric._update_signature.parameters.keys()}
|
||||
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
|
||||
"""
|
||||
Iteratively call forward for each metric. Positional arguments (args) will
|
||||
be passed to every metric in the collection, while keyword arguments (kwargs)
|
||||
will be filtered based on the signature of the individual metric.
|
||||
"""
|
||||
return {k: m(*args, **self._filter_kwargs(m, **kwargs)) for k, m in self.items()}
|
||||
|
||||
def update(self, *args, **kwargs): # pylint: disable=E0202
|
||||
"""
|
||||
Iteratively call update for each metric. Positional arguments (args) will
|
||||
be passed to every metric in the collection, while keyword arguments (kwargs)
|
||||
will be filtered based on the signature of the individual metric.
|
||||
"""
|
||||
for _, m in self.items():
|
||||
m_kwargs = self._filter_kwargs(m, **kwargs)
|
||||
m.update(*args, **m_kwargs)
|
||||
|
||||
def compute(self) -> Dict[str, Any]:
|
||||
return {k: m.compute() for k, m in self.items()}
|
||||
|
||||
def reset(self):
|
||||
""" Iteratively call reset for each metric """
|
||||
for _, m in self.items():
|
||||
m.reset()
|
||||
|
||||
def clone(self):
|
||||
""" Make a copy of the metric collection """
|
||||
return deepcopy(self)
|
||||
|
||||
def persistent(self, mode: bool = True):
|
||||
""" Method for post-init to change if metric states should be saved to
|
||||
its state_dict
|
||||
"""
|
||||
for _, m in self.items():
|
||||
m.persistent(mode)
|
||||
|
|
|
@ -6,8 +6,7 @@ import cloudpickle
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.metrics.metric import Metric, MetricCollection
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
@ -17,7 +16,7 @@ class Dummy(Metric):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0), dist_reduce_fx=None)
|
||||
self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None)
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
@ -166,7 +165,7 @@ def test_forward():
|
|||
assert a.compute() == 13
|
||||
|
||||
|
||||
class ToPickle(Dummy):
|
||||
class DummyMetric1(Dummy):
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
|
@ -174,9 +173,17 @@ class ToPickle(Dummy):
|
|||
return self.x
|
||||
|
||||
|
||||
class DummyMetric2(Dummy):
|
||||
def update(self, y):
|
||||
self.x -= y
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
|
||||
|
||||
def test_pickle(tmpdir):
|
||||
# doesn't tests for DDP
|
||||
a = ToPickle()
|
||||
a = DummyMetric1()
|
||||
a.update(1)
|
||||
|
||||
metric_pickled = pickle.dumps(a)
|
||||
|
@ -201,3 +208,130 @@ def test_state_dict(tmpdir):
|
|||
assert metric.state_dict() == OrderedDict(x=0)
|
||||
metric.persistent(False)
|
||||
assert metric.state_dict() == OrderedDict()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
|
||||
def test_device_and_dtype_transfer(tmpdir):
|
||||
metric = DummyMetric1()
|
||||
assert metric.x.is_cuda is False
|
||||
assert metric.x.dtype == torch.float32
|
||||
|
||||
metric = metric.to(device='cuda')
|
||||
assert metric.x.is_cuda
|
||||
|
||||
metric = metric.double()
|
||||
assert metric.x.dtype == torch.float64
|
||||
|
||||
metric = metric.half()
|
||||
assert metric.x.dtype == torch.float16
|
||||
|
||||
|
||||
def test_metric_collection(tmpdir):
|
||||
m1 = DummyMetric1()
|
||||
m2 = DummyMetric2()
|
||||
|
||||
metric_collection = MetricCollection([m1, m2])
|
||||
|
||||
# Test correct dict structure
|
||||
assert len(metric_collection) == 2
|
||||
assert metric_collection['DummyMetric1'] == m1
|
||||
assert metric_collection['DummyMetric2'] == m2
|
||||
|
||||
# Test correct initialization
|
||||
for name, metric in metric_collection.items():
|
||||
assert metric.x == 0, f'Metric {name} not initialized correctly'
|
||||
|
||||
# Test every metric gets updated
|
||||
metric_collection.update(5)
|
||||
for name, metric in metric_collection.items():
|
||||
assert metric.x.abs() == 5, f'Metric {name} not updated correctly'
|
||||
|
||||
# Test compute on each metric
|
||||
metric_collection.update(-5)
|
||||
metric_vals = metric_collection.compute()
|
||||
assert len(metric_vals) == 2
|
||||
for name, metric_val in metric_vals.items():
|
||||
assert metric_val == 0, f'Metric {name}.compute not called correctly'
|
||||
|
||||
# Test that everything is reset
|
||||
for name, metric in metric_collection.items():
|
||||
assert metric.x == 0, f'Metric {name} not reset correctly'
|
||||
|
||||
# Test pickable
|
||||
metric_pickled = pickle.dumps(metric_collection)
|
||||
metric_loaded = pickle.loads(metric_pickled)
|
||||
assert isinstance(metric_loaded, MetricCollection)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
|
||||
def test_device_and_dtype_transfer_metriccollection(tmpdir):
|
||||
m1 = DummyMetric1()
|
||||
m2 = DummyMetric2()
|
||||
|
||||
metric_collection = MetricCollection([m1, m2])
|
||||
for _, metric in metric_collection.items():
|
||||
assert metric.x.is_cuda is False
|
||||
assert metric.x.dtype == torch.float32
|
||||
|
||||
metric_collection = metric_collection.to(device='cuda')
|
||||
for _, metric in metric_collection.items():
|
||||
assert metric.x.is_cuda
|
||||
|
||||
metric_collection = metric_collection.double()
|
||||
for _, metric in metric_collection.items():
|
||||
assert metric.x.dtype == torch.float64
|
||||
|
||||
metric_collection = metric_collection.half()
|
||||
for _, metric in metric_collection.items():
|
||||
assert metric.x.dtype == torch.float16
|
||||
|
||||
|
||||
def test_metric_collection_wrong_input(tmpdir):
|
||||
""" Check that errors are raised on wrong input """
|
||||
m1 = DummyMetric1()
|
||||
|
||||
# Not all input are metrics (list)
|
||||
with pytest.raises(ValueError):
|
||||
_ = MetricCollection([m1, 5])
|
||||
|
||||
# Not all input are metrics (dict)
|
||||
with pytest.raises(ValueError):
|
||||
_ = MetricCollection({'metric1': m1,
|
||||
'metric2': 5})
|
||||
|
||||
# Same metric passed in multiple times
|
||||
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
|
||||
_ = MetricCollection([m1, m1])
|
||||
|
||||
# Not a list or dict passed in
|
||||
with pytest.raises(ValueError, match='Unknown input to MetricCollection.'):
|
||||
_ = MetricCollection(m1)
|
||||
|
||||
|
||||
def test_metric_collection_args_kwargs(tmpdir):
|
||||
""" Check that args and kwargs gets passed correctly in metric collection,
|
||||
Checks both update and forward method
|
||||
"""
|
||||
m1 = DummyMetric1()
|
||||
m2 = DummyMetric2()
|
||||
|
||||
metric_collection = MetricCollection([m1, m2])
|
||||
|
||||
# args gets passed to all metrics
|
||||
metric_collection.update(5)
|
||||
assert metric_collection['DummyMetric1'].x == 5
|
||||
assert metric_collection['DummyMetric2'].x == -5
|
||||
metric_collection.reset()
|
||||
_ = metric_collection(5)
|
||||
assert metric_collection['DummyMetric1'].x == 5
|
||||
assert metric_collection['DummyMetric2'].x == -5
|
||||
metric_collection.reset()
|
||||
|
||||
# kwargs gets only passed to metrics that it matches
|
||||
metric_collection.update(x=10, y=20)
|
||||
assert metric_collection['DummyMetric1'].x == 10
|
||||
assert metric_collection['DummyMetric2'].x == -20
|
||||
metric_collection.reset()
|
||||
_ = metric_collection(x=10, y=20)
|
||||
assert metric_collection['DummyMetric1'].x == 10
|
||||
assert metric_collection['DummyMetric2'].x == -20
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.metrics import Metric
|
||||
from pytorch_lightning.metrics import Metric, MetricCollection
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -17,6 +17,18 @@ class SumMetric(Metric):
|
|||
return self.x
|
||||
|
||||
|
||||
class DiffMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, x):
|
||||
self.x -= x
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
|
||||
|
||||
def test_metric_lightning(tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self):
|
||||
|
@ -125,3 +137,41 @@ def test_scriptable(tmpdir):
|
|||
output = model(rand_input)
|
||||
script_output = script_model(rand_input)
|
||||
assert torch.allclose(output, script_output)
|
||||
|
||||
|
||||
def test_metric_collection_lightning_log(tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = MetricCollection([SumMetric(), DiffMetric()])
|
||||
self.sum = 0.0
|
||||
self.diff = 0.0
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = batch
|
||||
metric_vals = self.metric(x.sum())
|
||||
self.sum += x.sum()
|
||||
self.diff -= x.sum()
|
||||
self.log_dict({f'{k}_step': v for k, v in metric_vals.items()})
|
||||
return self.step(x)
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
metric_vals = self.metric.compute()
|
||||
self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()})
|
||||
|
||||
model = TestModel()
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
log_every_n_steps=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
logged = trainer.logged_metrics
|
||||
assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
|
||||
assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)
|
||||
|
|
Загрузка…
Ссылка в новой задаче