Add prefix arg to metric collection (#70)

* prefix arg

* prefix arg

* Apply suggestions from code review

* chlog

* add types

* fix doctest

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
Nicki Skafte 2021-03-14 09:28:45 +01:00 коммит произвёл GitHub
Родитель 0b9553d8a0
Коммит c32d3e5efb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 82 добавлений и 28 удалений

Просмотреть файл

@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
### Changed

Просмотреть файл

@ -195,27 +195,31 @@ Example:
Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule
.. code-block:: python
.. testcode::
def __init__(self):
...
metrics = pl.metrics.MetricCollection(...)
self.train_metrics = metrics.clone()
self.valid_metrics = metrics.clone()
from torchmetrics import Accuracy, MetricCollection, Precision, Recall
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
class MyModule():
def __init__(self):
metrics = MetricCollection(Accuracy(), Precision(), Recall())
self.train_metrics = metrics.clone(prefix='train_')
self.valid_metrics = metrics.clone(prefix='val_')
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
def training_step(self, batch, batch_idx):
logits = self(x)
# ...
output = self.train_metrics(logits, y)
# use log_dict instead of log
# metrics are logged with keys: train_Accuracy, train_Precision and train_Recall
self.log_dict(output)
def validation_step(self, batch, batch_idx):
logits = self(x)
# ...
output = self.valid_metrics(logits, y)
# use log_dict instead of log
# metrics are logged with keys: val_Accuracy, val_Precision and val_Recall
self.log_dict(output)
.. note::

Просмотреть файл

@ -130,3 +130,28 @@ def test_metric_collection_args_kwargs(tmpdir):
_ = metric_collection(x=10, y=20)
assert metric_collection['DummyMetricSum'].x == 10
assert metric_collection['DummyMetricDiff'].x == -20
def test_metric_collection_prefix_arg(tmpdir):
""" Test that the prefix arg alters the keywords in the output"""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()
names = ['DummyMetricSum', 'DummyMetricDiff']
metric_collection = MetricCollection([m1, m2], prefix='prefix_')
# test forward
out = metric_collection(5)
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method'
# test compute
out = metric_collection.compute()
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method'
# test clone
new_metric_collection = metric_collection.clone(prefix='new_prefix_')
out = new_metric_collection(5)
for name in names:
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'

Просмотреть файл

@ -13,7 +13,7 @@
# limitations under the License.
from copy import deepcopy
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from torch import nn
@ -36,6 +36,8 @@ class MetricCollection(nn.ModuleDict):
dict as key for output dict. Use this format if you want to chain
together multiple of the same metric with different parameters.
prefix: a string to append in front of the keys of the output dict
Example (input as list):
>>> import torch
>>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall
@ -58,8 +60,11 @@ class MetricCollection(nn.ModuleDict):
>>> metrics.persistent()
"""
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
def __init__(
self,
metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]],
prefix: Optional[str] = None
):
super().__init__()
if isinstance(metrics, dict):
# Check all values are metrics
@ -84,13 +89,15 @@ class MetricCollection(nn.ModuleDict):
else:
raise ValueError("Unknown input to MetricCollection.")
self.prefix = self._check_prefix_arg(prefix)
def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
def update(self, *args, **kwargs): # pylint: disable=E0202
"""
@ -103,20 +110,36 @@ class MetricCollection(nn.ModuleDict):
m.update(*args, **m_kwargs)
def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.items()}
return {self._set_prefix(k): m.compute() for k, m in self.items()}
def reset(self):
def reset(self) -> None:
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()
def clone(self):
""" Make a copy of the metric collection """
return deepcopy(self)
def clone(self, prefix: Optional[str] = None) -> 'MetricCollection':
""" Make a copy of the metric collection
Args:
prefix: a string to append in front of the metric keys
"""
mc = deepcopy(self)
mc.prefix = self._check_prefix_arg(prefix)
return mc
def persistent(self, mode: bool = True):
def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
m.persistent(mode)
def _set_prefix(self, k: str) -> str:
return k if self.prefix is None else self.prefix + k
def _check_prefix_arg(self, prefix: str) -> Optional[str]:
if prefix is not None:
if isinstance(prefix, str):
return prefix
else:
raise ValueError('Expected input `prefix` to be a string')
return None