[docs] Drop pl.metrics and define classes with explicit superclass (#231)

* Update docs

* Fix doctest

* Add docs/source/generated/ to gitignore
This commit is contained in:
Akihiro Nitta 2021-05-07 23:05:38 +09:00 коммит произвёл GitHub
Родитель 253dfbc443
Коммит 6e205fa886
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 52 добавлений и 41 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -57,6 +57,7 @@ coverage.xml
# Sphinx documentation
docs/_build/
docs/source/api/
docs/source/generated/
docs/source/*.md
# PyBuilder

Просмотреть файл

@ -19,23 +19,25 @@ While TorchMetrics was built to be used with native PyTorch, using TorchMetrics
The example below shows how to use a metric in your `LightningModule <https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html>`_:
.. code-block:: python
.. testcode:: python
def __init__(self):
...
self.accuracy = pl.metrics.Accuracy()
class MyModel(LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
# log step metric
self.log('train_acc_step', self.accuracy(preds, y))
...
def __init__(self):
...
self.accuracy = torchmetrics.Accuracy()
def training_epoch_end(self, outs):
# log epoch metric
self.log('train_acc_epoch', self.accuracy.compute())
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
# log step metric
self.log('train_acc_step', self.accuracy(preds, y))
...
def training_epoch_end(self, outs):
# log epoch metric
self.log('train_acc_epoch', self.accuracy.compute())
********************
Logging TorchMetrics
@ -56,25 +58,27 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
or reduction functions.
.. code-block:: python
.. testcode:: python
def __init__(self):
...
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()
class MyModule(LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
self.train_acc(preds, y)
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
def __init__(self):
...
self.train_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy()
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
self.train_acc(preds, y)
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
.. note::
@ -85,15 +89,17 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
.. testcode:: python
def training_step(self, batch, batch_idx):
data, target = batch
preds = self(data)
# ...
return {'loss' : loss, 'preds' : preds, 'target' : target}
class MyModule(LightningModule):
def training_step_end(self, outputs):
#update and log
self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric)
def training_step(self, batch, batch_idx):
data, target = batch
preds = self(data)
# ...
return {'loss' : loss, 'preds' : preds, 'target' : target}
def training_step_end(self, outputs):
#update and log
self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric)
For more details see `Lightning Docs <https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#logging-from-a-lightningmodule>`_

Просмотреть файл

@ -1,3 +1,7 @@
.. testsetup:: *
import torch
from pytorch_lightning.core.lightning import LightningModule
########
Overview
@ -91,7 +95,7 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.
from torchmetrics import Accuracy, MetricCollection
class MyModule():
class MyModule(torch.nn.Module):
def __init__(self):
...
# valid ways metrics will be identified as child modules
@ -238,7 +242,7 @@ inside your LightningModule
from torchmetrics import Accuracy, MetricCollection, Precision, Recall
class MyModule():
class MyModule(LightningModule):
def __init__(self):
metrics = MetricCollection([Accuracy(), Precision(), Recall()])
self.train_metrics = metrics.clone(prefix='train_')