diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index c94549d7a..57b1a7fbf 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -192,6 +192,7 @@ class ClassificationTask(pl.LightningModule): and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule @@ -376,6 +377,7 @@ class MultiLabelClassificationTask(ClassificationTask): and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 939ae761b..1eafdf1b4 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -230,6 +230,7 @@ class ObjectDetectionTask(pl.LightningModule): and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 7e6137f22..0d32e3436 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -149,6 +149,7 @@ class RegressionTask(pl.LightningModule): and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 7459fa8a9..8913589b9 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -202,6 +202,7 @@ class SemanticSegmentationTask(pl.LightningModule): and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") + and hasattr(self.logger.experiment, "add_figure") ): try: datamodule = self.trainer.datamodule