Docs: fix trainer metric definitions (#1924)

* Docs: fix trainer metric definitions

* Link to torchmetrics docs

* Teach Sphinx where docs live
This commit is contained in:
Adam J. Stewart 2024-03-03 19:45:08 +01:00 коммит произвёл isaaccorley
Родитель dd38fddcbd
Коммит b9653beb2f
5 изменённых файлов: 34 добавлений и 35 удалений

Просмотреть файл

@ -119,6 +119,7 @@ intersphinx_mapping = {
"sklearn": ("https://scikit-learn.org/stable/", None),
"timm": ("https://huggingface.co/docs/timm/main/en/", None),
"torch": ("https://pytorch.org/docs/stable", None),
"torchmetrics": ("https://lightning.ai/docs/torchmetrics/stable/", None),
"torchvision": ("https://pytorch.org/vision/stable", None),
}

Просмотреть файл

@ -97,14 +97,15 @@ class ClassificationTask(BaseTask):
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
Uses 'macro' averaging. Higher values are better.
* Multiclass Jaccard Index (IoU): Per-class overlap between predicted and
actual classes. Uses 'macro' averaging. Higher valuers are better.
* Multiclass F1 Score: The harmonic mean of precision and recall.
Uses 'micro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassAccuracy`: The number of
true positives divided by the dataset size. Both overall accuracy (OA)
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
are reported. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
over union (IoU). Uses 'macro' averaging. Higher valuers are better.
* :class:`~torchmetrics.classification.MulticlassFBetaScore`: F1 score.
The harmonic mean of precision and recall. Uses 'micro' averaging.
Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
@ -266,12 +267,13 @@ class MultiLabelClassificationTask(ClassificationTask):
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
Uses 'macro' averaging. Higher values are better.
* Multiclass F1 Score: The harmonic mean of precision and recall.
Uses 'micro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MultilabelAccuracy`: The number of
true positives divided by the dataset size. Both overall accuracy (OA)
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
are reported. Higher values are better.
* :class:`~torchmetrics.classification.MultilabelFBetaScore`: F1 score.
The harmonic mean of precision and recall. Uses 'micro' averaging.
Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not

Просмотреть файл

@ -205,10 +205,11 @@ class ObjectDetectionTask(BaseTask):
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Mean Average Precision (mAP): Computes the Mean-Average-Precision (mAP) and
Mean-Average-Recall (mAR) for object detection. Prediction is based on the
intersection over union (IoU) between the predicted bounding boxes and the
ground truth bounding boxes. Uses 'macro' averaging. Higher values are better.
* :class:`~torchmetrics.detection.mean_ap.MeanAveragePrecision`: Mean average
precision (mAP) and mean average recall (mAR). Precision is the number of
true positives divided by the number of true positives + false positives.
Recall is the number of true positives divived by the number of true positives
+ false negatives. Uses 'macro' averaging. Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not
@ -216,7 +217,7 @@ class ObjectDetectionTask(BaseTask):
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection([MeanAveragePrecision()])
metrics = MetricCollection([MeanAveragePrecision(average="macro")])
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")

Просмотреть файл

@ -98,18 +98,12 @@ class RegressionTask(BaseTask):
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Root Mean Squared Error (RMSE): The square root of the average of the squared
differences between the predicted and actual values. Lower values are better.
* Mean Squared Error (MSE): The average of the squared differences between the
predicted and actual values. Lower values are better.
* Mean Absolute Error (MAE): The average of the absolute differences between the
predicted and actual values. Lower values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
* :class:`~torchmetrics.MeanSquaredError`: The average of the squared
differences between the predicted and actual values (MSE) and its
square root (RMSE). Lower values are better.
* :class:`~torchmetrics.MeanAbsoluteError`: The average of the absolute
differences between the predicted and actual values (MAE).
Lower values are better.
"""
metrics = MetricCollection(
{

Просмотреть файл

@ -126,10 +126,11 @@ class SemanticSegmentationTask(BaseTask):
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Multiclass Pixel Accuracy: Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Jaccard Index (IoU): Per-pixel overlap between predicted and
actual segments. Uses 'macro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy
(OA) using 'micro' averaging. The number of true positives divided by the
dataset size. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
over union (IoU). Uses 'micro' averaging. Higher valuers are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect