[docs] Drop pl.metrics and define classes with explicit superclass (#231)
* Update docs * Fix doctest * Add docs/source/generated/ to gitignore
This commit is contained in:
Родитель
253dfbc443
Коммит
6e205fa886
|
@ -57,6 +57,7 @@ coverage.xml
|
|||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/source/api/
|
||||
docs/source/generated/
|
||||
docs/source/*.md
|
||||
|
||||
# PyBuilder
|
||||
|
|
|
@ -19,11 +19,13 @@ While TorchMetrics was built to be used with native PyTorch, using TorchMetrics
|
|||
|
||||
The example below shows how to use a metric in your `LightningModule <https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html>`_:
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode:: python
|
||||
|
||||
class MyModel(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
self.accuracy = pl.metrics.Accuracy()
|
||||
self.accuracy = torchmetrics.Accuracy()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
|
@ -56,12 +58,14 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
|
|||
or reduction functions.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode:: python
|
||||
|
||||
class MyModule(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
self.train_acc = pl.metrics.Accuracy()
|
||||
self.valid_acc = pl.metrics.Accuracy()
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
self.valid_acc = torchmetrics.Accuracy()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
|
@ -85,6 +89,8 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
|
|||
|
||||
.. testcode:: python
|
||||
|
||||
class MyModule(LightningModule):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
data, target = batch
|
||||
preds = self(data)
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
.. testsetup:: *
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
########
|
||||
Overview
|
||||
|
@ -91,7 +95,7 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.
|
|||
|
||||
from torchmetrics import Accuracy, MetricCollection
|
||||
|
||||
class MyModule():
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
...
|
||||
# valid ways metrics will be identified as child modules
|
||||
|
@ -238,7 +242,7 @@ inside your LightningModule
|
|||
|
||||
from torchmetrics import Accuracy, MetricCollection, Precision, Recall
|
||||
|
||||
class MyModule():
|
||||
class MyModule(LightningModule):
|
||||
def __init__(self):
|
||||
metrics = MetricCollection([Accuracy(), Precision(), Recall()])
|
||||
self.train_metrics = metrics.clone(prefix='train_')
|
||||
|
|
Загрузка…
Ссылка в новой задаче