From d91b3913f5685bb8b4dfa317a6ffd8502878b27c Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 30 May 2024 18:11:48 +0100 Subject: [PATCH] Update num_classes parameter description in ObjectDetectionTask class (#2101) * Update num_classes parameter description in ObjectDetectionTask class * Update num_classes parameter description in SemanticSegmentationTask class --- torchgeo/trainers/detection.py | 2 +- torchgeo/trainers/segmentation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 24a7b4e8b..d13d84dcf 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -82,7 +82,7 @@ class ObjectDetectionTask(BaseTask): weights: Initial model weights. True for ImageNet weights, False or None for random weights. in_channels: Number of input channels to model. - num_classes: Number of prediction classes. + num_classes: Number of prediction classes (including the background). trainable_layers: Number of trainable layers. lr: Learning rate for optimizer. patience: Patience for learning rate scheduler. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index dad26635c..afd715210 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -54,7 +54,7 @@ class SemanticSegmentationTask(BaseTask): model does not support pretrained weights. Pretrained ViT weight enums are not supported yet. in_channels: Number of input channels to model. - num_classes: Number of prediction classes. + num_classes: Number of prediction classes (including the background). num_filters: Number of filters. Only applicable when model='fcn'. loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss.