Add prefix arg to metric collection (#70)
* prefix arg * prefix arg * Apply suggestions from code review * chlog * add types * fix doctest Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
Родитель
0b9553d8a0
Коммит
c32d3e5efb
|
@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
### Added
|
||||
|
||||
- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -195,27 +195,31 @@ Example:
|
|||
Similarly it can also reduce the amount of code required to log multiple metrics
|
||||
inside your LightningModule
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
metrics = pl.metrics.MetricCollection(...)
|
||||
self.train_metrics = metrics.clone()
|
||||
self.valid_metrics = metrics.clone()
|
||||
from torchmetrics import Accuracy, MetricCollection, Precision, Recall
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
self.train_metrics(logits, y)
|
||||
# use log_dict instead of log
|
||||
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
|
||||
class MyModule():
|
||||
def __init__(self):
|
||||
metrics = MetricCollection(Accuracy(), Precision(), Recall())
|
||||
self.train_metrics = metrics.clone(prefix='train_')
|
||||
self.valid_metrics = metrics.clone(prefix='val_')
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
self.valid_metrics(logits, y)
|
||||
# use log_dict instead of log
|
||||
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
|
||||
def training_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
# ...
|
||||
output = self.train_metrics(logits, y)
|
||||
# use log_dict instead of log
|
||||
# metrics are logged with keys: train_Accuracy, train_Precision and train_Recall
|
||||
self.log_dict(output)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
# ...
|
||||
output = self.valid_metrics(logits, y)
|
||||
# use log_dict instead of log
|
||||
# metrics are logged with keys: val_Accuracy, val_Precision and val_Recall
|
||||
self.log_dict(output)
|
||||
|
||||
.. note::
|
||||
|
||||
|
|
|
@ -130,3 +130,28 @@ def test_metric_collection_args_kwargs(tmpdir):
|
|||
_ = metric_collection(x=10, y=20)
|
||||
assert metric_collection['DummyMetricSum'].x == 10
|
||||
assert metric_collection['DummyMetricDiff'].x == -20
|
||||
|
||||
|
||||
def test_metric_collection_prefix_arg(tmpdir):
|
||||
""" Test that the prefix arg alters the keywords in the output"""
|
||||
m1 = DummyMetricSum()
|
||||
m2 = DummyMetricDiff()
|
||||
names = ['DummyMetricSum', 'DummyMetricDiff']
|
||||
|
||||
metric_collection = MetricCollection([m1, m2], prefix='prefix_')
|
||||
|
||||
# test forward
|
||||
out = metric_collection(5)
|
||||
for name in names:
|
||||
assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method'
|
||||
|
||||
# test compute
|
||||
out = metric_collection.compute()
|
||||
for name in names:
|
||||
assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method'
|
||||
|
||||
# test clone
|
||||
new_metric_collection = metric_collection.clone(prefix='new_prefix_')
|
||||
out = new_metric_collection(5)
|
||||
for name in names:
|
||||
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
@ -36,6 +36,8 @@ class MetricCollection(nn.ModuleDict):
|
|||
dict as key for output dict. Use this format if you want to chain
|
||||
together multiple of the same metric with different parameters.
|
||||
|
||||
prefix: a string to append in front of the keys of the output dict
|
||||
|
||||
Example (input as list):
|
||||
>>> import torch
|
||||
>>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall
|
||||
|
@ -58,8 +60,11 @@ class MetricCollection(nn.ModuleDict):
|
|||
>>> metrics.persistent()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
|
||||
def __init__(
|
||||
self,
|
||||
metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]],
|
||||
prefix: Optional[str] = None
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(metrics, dict):
|
||||
# Check all values are metrics
|
||||
|
@ -84,13 +89,15 @@ class MetricCollection(nn.ModuleDict):
|
|||
else:
|
||||
raise ValueError("Unknown input to MetricCollection.")
|
||||
|
||||
self.prefix = self._check_prefix_arg(prefix)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
|
||||
"""
|
||||
Iteratively call forward for each metric. Positional arguments (args) will
|
||||
be passed to every metric in the collection, while keyword arguments (kwargs)
|
||||
will be filtered based on the signature of the individual metric.
|
||||
"""
|
||||
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
|
||||
return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
|
||||
|
||||
def update(self, *args, **kwargs): # pylint: disable=E0202
|
||||
"""
|
||||
|
@ -103,20 +110,36 @@ class MetricCollection(nn.ModuleDict):
|
|||
m.update(*args, **m_kwargs)
|
||||
|
||||
def compute(self) -> Dict[str, Any]:
|
||||
return {k: m.compute() for k, m in self.items()}
|
||||
return {self._set_prefix(k): m.compute() for k, m in self.items()}
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
""" Iteratively call reset for each metric """
|
||||
for _, m in self.items():
|
||||
m.reset()
|
||||
|
||||
def clone(self):
|
||||
""" Make a copy of the metric collection """
|
||||
return deepcopy(self)
|
||||
def clone(self, prefix: Optional[str] = None) -> 'MetricCollection':
|
||||
""" Make a copy of the metric collection
|
||||
Args:
|
||||
prefix: a string to append in front of the metric keys
|
||||
"""
|
||||
mc = deepcopy(self)
|
||||
mc.prefix = self._check_prefix_arg(prefix)
|
||||
return mc
|
||||
|
||||
def persistent(self, mode: bool = True):
|
||||
def persistent(self, mode: bool = True) -> None:
|
||||
"""Method for post-init to change if metric states should be saved to
|
||||
its state_dict
|
||||
"""
|
||||
for _, m in self.items():
|
||||
m.persistent(mode)
|
||||
|
||||
def _set_prefix(self, k: str) -> str:
|
||||
return k if self.prefix is None else self.prefix + k
|
||||
|
||||
def _check_prefix_arg(self, prefix: str) -> Optional[str]:
|
||||
if prefix is not None:
|
||||
if isinstance(prefix, str):
|
||||
return prefix
|
||||
else:
|
||||
raise ValueError('Expected input `prefix` to be a string')
|
||||
return None
|
||||
|
|
Загрузка…
Ссылка в новой задаче