зеркало из https://github.com/microsoft/torchgeo.git
Docs: fix trainer metric definitions (#1924)
* Docs: fix trainer metric definitions * Link to torchmetrics docs * Teach Sphinx where docs live
This commit is contained in:
Родитель
2749d5b5b1
Коммит
e40c78d7f2
|
@ -121,6 +121,7 @@ intersphinx_mapping = {
|
|||
"sklearn": ("https://scikit-learn.org/stable/", None),
|
||||
"timm": ("https://huggingface.co/docs/timm/main/en/", None),
|
||||
"torch": ("https://pytorch.org/docs/stable", None),
|
||||
"torchmetrics": ("https://lightning.ai/docs/torchmetrics/stable/", None),
|
||||
"torchvision": ("https://pytorch.org/vision/stable", None),
|
||||
}
|
||||
|
||||
|
|
|
@ -98,14 +98,15 @@ class ClassificationTask(BaseTask):
|
|||
def configure_metrics(self) -> None:
|
||||
"""Initialize the performance metrics.
|
||||
|
||||
* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
|
||||
Uses 'micro' averaging. Higher values are better.
|
||||
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
|
||||
Uses 'macro' averaging. Higher values are better.
|
||||
* Multiclass Jaccard Index (IoU): Per-class overlap between predicted and
|
||||
actual classes. Uses 'macro' averaging. Higher valuers are better.
|
||||
* Multiclass F1 Score: The harmonic mean of precision and recall.
|
||||
Uses 'micro' averaging. Higher values are better.
|
||||
* :class:`~torchmetrics.classification.MulticlassAccuracy`: The number of
|
||||
true positives divided by the dataset size. Both overall accuracy (OA)
|
||||
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
|
||||
are reported. Higher values are better.
|
||||
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
|
||||
over union (IoU). Uses 'macro' averaging. Higher valuers are better.
|
||||
* :class:`~torchmetrics.classification.MulticlassFBetaScore`: F1 score.
|
||||
The harmonic mean of precision and recall. Uses 'micro' averaging.
|
||||
Higher values are better.
|
||||
|
||||
.. note::
|
||||
* 'Micro' averaging suits overall performance evaluation but may not reflect
|
||||
|
@ -270,12 +271,13 @@ class MultiLabelClassificationTask(ClassificationTask):
|
|||
def configure_metrics(self) -> None:
|
||||
"""Initialize the performance metrics.
|
||||
|
||||
* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
|
||||
Uses 'micro' averaging. Higher values are better.
|
||||
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
|
||||
Uses 'macro' averaging. Higher values are better.
|
||||
* Multiclass F1 Score: The harmonic mean of precision and recall.
|
||||
Uses 'micro' averaging. Higher values are better.
|
||||
* :class:`~torchmetrics.classification.MultilabelAccuracy`: The number of
|
||||
true positives divided by the dataset size. Both overall accuracy (OA)
|
||||
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
|
||||
are reported. Higher values are better.
|
||||
* :class:`~torchmetrics.classification.MultilabelFBetaScore`: F1 score.
|
||||
The harmonic mean of precision and recall. Uses 'micro' averaging.
|
||||
Higher values are better.
|
||||
|
||||
.. note::
|
||||
* 'Micro' averaging suits overall performance evaluation but may not
|
||||
|
|
|
@ -206,10 +206,11 @@ class ObjectDetectionTask(BaseTask):
|
|||
def configure_metrics(self) -> None:
|
||||
"""Initialize the performance metrics.
|
||||
|
||||
* Mean Average Precision (mAP): Computes the Mean-Average-Precision (mAP) and
|
||||
Mean-Average-Recall (mAR) for object detection. Prediction is based on the
|
||||
intersection over union (IoU) between the predicted bounding boxes and the
|
||||
ground truth bounding boxes. Uses 'macro' averaging. Higher values are better.
|
||||
* :class:`~torchmetrics.detection.mean_ap.MeanAveragePrecision`: Mean average
|
||||
precision (mAP) and mean average recall (mAR). Precision is the number of
|
||||
true positives divided by the number of true positives + false positives.
|
||||
Recall is the number of true positives divived by the number of true positives
|
||||
+ false negatives. Uses 'macro' averaging. Higher values are better.
|
||||
|
||||
.. note::
|
||||
* 'Micro' averaging suits overall performance evaluation but may not
|
||||
|
@ -217,7 +218,7 @@ class ObjectDetectionTask(BaseTask):
|
|||
* 'Macro' averaging gives equal weight to each class, and is useful for
|
||||
balanced performance assessment across imbalanced classes.
|
||||
"""
|
||||
metrics = MetricCollection([MeanAveragePrecision()])
|
||||
metrics = MetricCollection([MeanAveragePrecision(average="macro")])
|
||||
self.val_metrics = metrics.clone(prefix="val_")
|
||||
self.test_metrics = metrics.clone(prefix="test_")
|
||||
|
||||
|
|
|
@ -99,18 +99,12 @@ class RegressionTask(BaseTask):
|
|||
def configure_metrics(self) -> None:
|
||||
"""Initialize the performance metrics.
|
||||
|
||||
* Root Mean Squared Error (RMSE): The square root of the average of the squared
|
||||
differences between the predicted and actual values. Lower values are better.
|
||||
* Mean Squared Error (MSE): The average of the squared differences between the
|
||||
predicted and actual values. Lower values are better.
|
||||
* Mean Absolute Error (MAE): The average of the absolute differences between the
|
||||
predicted and actual values. Lower values are better.
|
||||
|
||||
.. note::
|
||||
* 'Micro' averaging suits overall performance evaluation but may not reflect
|
||||
minority class accuracy.
|
||||
* 'Macro' averaging gives equal weight to each class, and is useful for
|
||||
balanced performance assessment across imbalanced classes.
|
||||
* :class:`~torchmetrics.MeanSquaredError`: The average of the squared
|
||||
differences between the predicted and actual values (MSE) and its
|
||||
square root (RMSE). Lower values are better.
|
||||
* :class:`~torchmetrics.MeanAbsoluteError`: The average of the absolute
|
||||
differences between the predicted and actual values (MAE).
|
||||
Lower values are better.
|
||||
"""
|
||||
metrics = MetricCollection(
|
||||
{
|
||||
|
|
|
@ -124,10 +124,11 @@ class SemanticSegmentationTask(BaseTask):
|
|||
def configure_metrics(self) -> None:
|
||||
"""Initialize the performance metrics.
|
||||
|
||||
* Multiclass Pixel Accuracy: Ratio of correctly classified pixels.
|
||||
Uses 'micro' averaging. Higher values are better.
|
||||
* Multiclass Jaccard Index (IoU): Per-pixel overlap between predicted and
|
||||
actual segments. Uses 'macro' averaging. Higher values are better.
|
||||
* :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy
|
||||
(OA) using 'micro' averaging. The number of true positives divided by the
|
||||
dataset size. Higher values are better.
|
||||
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
|
||||
over union (IoU). Uses 'micro' averaging. Higher valuers are better.
|
||||
|
||||
.. note::
|
||||
* 'Micro' averaging suits overall performance evaluation but may not reflect
|
||||
|
|
Загрузка…
Ссылка в новой задаче