From 6e205fa88670d9cca23541363481232f569b3724 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 7 May 2021 23:05:38 +0900 Subject: [PATCH] [docs] Drop pl.metrics and define classes with explicit superclass (#231) * Update docs * Fix doctest * Add docs/source/generated/ to gitignore --- .gitignore | 1 + docs/source/pages/lightning.rst | 84 ++++++++++++++++++--------------- docs/source/pages/overview.rst | 8 +++- 3 files changed, 52 insertions(+), 41 deletions(-) diff --git a/.gitignore b/.gitignore index e3ac9e1..9b09b3c 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ coverage.xml # Sphinx documentation docs/_build/ docs/source/api/ +docs/source/generated/ docs/source/*.md # PyBuilder diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index 2e53619..cac9bb7 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -19,23 +19,25 @@ While TorchMetrics was built to be used with native PyTorch, using TorchMetrics The example below shows how to use a metric in your `LightningModule `_: -.. code-block:: python +.. testcode:: python - def __init__(self): - ... - self.accuracy = pl.metrics.Accuracy() + class MyModel(LightningModule): - def training_step(self, batch, batch_idx): - x, y = batch - preds = self(x) - ... - # log step metric - self.log('train_acc_step', self.accuracy(preds, y)) - ... + def __init__(self): + ... + self.accuracy = torchmetrics.Accuracy() - def training_epoch_end(self, outs): - # log epoch metric - self.log('train_acc_epoch', self.accuracy.compute()) + def training_step(self, batch, batch_idx): + x, y = batch + preds = self(x) + ... + # log step metric + self.log('train_acc_step', self.accuracy(preds, y)) + ... + + def training_epoch_end(self, outs): + # log epoch metric + self.log('train_acc_epoch', self.accuracy.compute()) ******************** Logging TorchMetrics @@ -56,25 +58,27 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v or reduction functions. -.. code-block:: python +.. testcode:: python - def __init__(self): - ... - self.train_acc = pl.metrics.Accuracy() - self.valid_acc = pl.metrics.Accuracy() + class MyModule(LightningModule): - def training_step(self, batch, batch_idx): - x, y = batch - preds = self(x) - ... - self.train_acc(preds, y) - self.log('train_acc', self.train_acc, on_step=True, on_epoch=False) + def __init__(self): + ... + self.train_acc = torchmetrics.Accuracy() + self.valid_acc = torchmetrics.Accuracy() - def validation_step(self, batch, batch_idx): - logits = self(x) - ... - self.valid_acc(logits, y) - self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) + def training_step(self, batch, batch_idx): + x, y = batch + preds = self(x) + ... + self.train_acc(preds, y) + self.log('train_acc', self.train_acc, on_step=True, on_epoch=False) + + def validation_step(self, batch, batch_idx): + logits = self(x) + ... + self.valid_acc(logits, y) + self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) .. note:: @@ -85,15 +89,17 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v .. testcode:: python - def training_step(self, batch, batch_idx): - data, target = batch - preds = self(data) - # ... - return {'loss' : loss, 'preds' : preds, 'target' : target} + class MyModule(LightningModule): - def training_step_end(self, outputs): - #update and log - self.metric(outputs['preds'], outputs['target']) - self.log('metric', self.metric) + def training_step(self, batch, batch_idx): + data, target = batch + preds = self(data) + # ... + return {'loss' : loss, 'preds' : preds, 'target' : target} + + def training_step_end(self, outputs): + #update and log + self.metric(outputs['preds'], outputs['target']) + self.log('metric', self.metric) For more details see `Lightning Docs `_ diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 7b2fa89..a93020b 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -1,3 +1,7 @@ +.. testsetup:: * + + import torch + from pytorch_lightning.core.lightning import LightningModule ######## Overview @@ -91,7 +95,7 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics. from torchmetrics import Accuracy, MetricCollection - class MyModule(): + class MyModule(torch.nn.Module): def __init__(self): ... # valid ways metrics will be identified as child modules @@ -238,7 +242,7 @@ inside your LightningModule from torchmetrics import Accuracy, MetricCollection, Precision, Recall - class MyModule(): + class MyModule(LightningModule): def __init__(self): metrics = MetricCollection([Accuracy(), Precision(), Recall()]) self.train_metrics = metrics.clone(prefix='train_')