[docs] Drop pl.metrics and define classes with explicit superclass (#231)
* Update docs * Fix doctest * Add docs/source/generated/ to gitignore
This commit is contained in:
Родитель
253dfbc443
Коммит
6e205fa886
|
@ -57,6 +57,7 @@ coverage.xml
|
|||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/source/api/
|
||||
docs/source/generated/
|
||||
docs/source/*.md
|
||||
|
||||
# PyBuilder
|
||||
|
|
|
@ -19,23 +19,25 @@ While TorchMetrics was built to be used with native PyTorch, using TorchMetrics
|
|||
|
||||
The example below shows how to use a metric in your `LightningModule <https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html>`_:
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode:: python
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
self.accuracy = pl.metrics.Accuracy()
|
||||
class MyModel(LightningModule):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
preds = self(x)
|
||||
...
|
||||
# log step metric
|
||||
self.log('train_acc_step', self.accuracy(preds, y))
|
||||
...
|
||||
def __init__(self):
|
||||
...
|
||||
self.accuracy = torchmetrics.Accuracy()
|
||||
|
||||
def training_epoch_end(self, outs):
|
||||
# log epoch metric
|
||||
self.log('train_acc_epoch', self.accuracy.compute())
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
preds = self(x)
|
||||
...
|
||||
# log step metric
|
||||
self.log('train_acc_step', self.accuracy(preds, y))
|
||||
...
|
||||
|
||||
def training_epoch_end(self, outs):
|
||||
# log epoch metric
|
||||
self.log('train_acc_epoch', self.accuracy.compute())
|
||||
|
||||
********************
|
||||
Logging TorchMetrics
|
||||
|
@ -56,25 +58,27 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
|
|||
or reduction functions.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode:: python
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
self.train_acc = pl.metrics.Accuracy()
|
||||
self.valid_acc = pl.metrics.Accuracy()
|
||||
class MyModule(LightningModule):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
preds = self(x)
|
||||
...
|
||||
self.train_acc(preds, y)
|
||||
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
|
||||
def __init__(self):
|
||||
...
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
self.valid_acc = torchmetrics.Accuracy()
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
self.valid_acc(logits, y)
|
||||
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
preds = self(x)
|
||||
...
|
||||
self.train_acc(preds, y)
|
||||
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
self.valid_acc(logits, y)
|
||||
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -85,15 +89,17 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
|
|||
|
||||
.. testcode:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
data, target = batch
|
||||
preds = self(data)
|
||||
# ...
|
||||
return {'loss' : loss, 'preds' : preds, 'target' : target}
|
||||
class MyModule(LightningModule):
|
||||
|
||||
def training_step_end(self, outputs):
|
||||
#update and log
|
||||
self.metric(outputs['preds'], outputs['target'])
|
||||
self.log('metric', self.metric)
|
||||
def training_step(self, batch, batch_idx):
|
||||
data, target = batch
|
||||
preds = self(data)
|
||||
# ...
|
||||
return {'loss' : loss, 'preds' : preds, 'target' : target}
|
||||
|
||||
def training_step_end(self, outputs):
|
||||
#update and log
|
||||
self.metric(outputs['preds'], outputs['target'])
|
||||
self.log('metric', self.metric)
|
||||
|
||||
For more details see `Lightning Docs <https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#logging-from-a-lightningmodule>`_
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
.. testsetup:: *
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
########
|
||||
Overview
|
||||
|
@ -91,7 +95,7 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.
|
|||
|
||||
from torchmetrics import Accuracy, MetricCollection
|
||||
|
||||
class MyModule():
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
...
|
||||
# valid ways metrics will be identified as child modules
|
||||
|
@ -238,7 +242,7 @@ inside your LightningModule
|
|||
|
||||
from torchmetrics import Accuracy, MetricCollection, Precision, Recall
|
||||
|
||||
class MyModule():
|
||||
class MyModule(LightningModule):
|
||||
def __init__(self):
|
||||
metrics = MetricCollection([Accuracy(), Precision(), Recall()])
|
||||
self.train_metrics = metrics.clone(prefix='train_')
|
||||
|
|
Загрузка…
Ссылка в новой задаче