From f713ba813a172a35dfbccef1e90ec2d53798ffdb Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:55:15 +0100 Subject: [PATCH] Remov Argmax Computation for torchmetrics in Classification and Segmentation (#1777) * remove y_hat_hard * argmax vs softmax :) --- torchgeo/trainers/classification.py | 11 ++++------- torchgeo/trainers/segmentation.py | 11 ++++------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 76bea8211..8b46b856c 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -161,10 +161,9 @@ class ClassificationTask(BaseTask): x = batch["image"] y = batch["label"] y_hat = self(x) - y_hat_hard = y_hat.argmax(dim=1) loss: Tensor = self.criterion(y_hat, y) self.log("train_loss", loss) - self.train_metrics(y_hat_hard, y) + self.train_metrics(y_hat, y) self.log_dict(self.train_metrics) return loss @@ -182,10 +181,9 @@ class ClassificationTask(BaseTask): x = batch["image"] y = batch["label"] y_hat = self(x) - y_hat_hard = y_hat.argmax(dim=1) loss = self.criterion(y_hat, y) self.log("val_loss", loss) - self.val_metrics(y_hat_hard, y) + self.val_metrics(y_hat, y) self.log_dict(self.val_metrics) if ( @@ -198,7 +196,7 @@ class ClassificationTask(BaseTask): ): try: datamodule = self.trainer.datamodule - batch["prediction"] = y_hat_hard + batch["prediction"] = y_hat.argmax(dim=-1) for key in ["image", "label", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] @@ -223,10 +221,9 @@ class ClassificationTask(BaseTask): x = batch["image"] y = batch["label"] y_hat = self(x) - y_hat_hard = y_hat.argmax(dim=1) loss = self.criterion(y_hat, y) self.log("test_loss", loss) - self.test_metrics(y_hat_hard, y) + self.test_metrics(y_hat, y) self.log_dict(self.test_metrics) def predict_step( diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 9ee51d826..ad71621f2 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -217,10 +217,9 @@ class SemanticSegmentationTask(BaseTask): x = batch["image"] y = batch["mask"] y_hat = self(x) - y_hat_hard = y_hat.argmax(dim=1) loss: Tensor = self.criterion(y_hat, y) self.log("train_loss", loss) - self.train_metrics(y_hat_hard, y) + self.train_metrics(y_hat, y) self.log_dict(self.train_metrics) return loss @@ -237,10 +236,9 @@ class SemanticSegmentationTask(BaseTask): x = batch["image"] y = batch["mask"] y_hat = self(x) - y_hat_hard = y_hat.argmax(dim=1) loss = self.criterion(y_hat, y) self.log("val_loss", loss) - self.val_metrics(y_hat_hard, y) + self.val_metrics(y_hat, y) self.log_dict(self.val_metrics) if ( @@ -253,7 +251,7 @@ class SemanticSegmentationTask(BaseTask): ): try: datamodule = self.trainer.datamodule - batch["prediction"] = y_hat_hard + batch["prediction"] = y_hat.argmax(dim=1) for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] @@ -278,10 +276,9 @@ class SemanticSegmentationTask(BaseTask): x = batch["image"] y = batch["mask"] y_hat = self(x) - y_hat_hard = y_hat.argmax(dim=1) loss = self.criterion(y_hat, y) self.log("test_loss", loss) - self.test_metrics(y_hat_hard, y) + self.test_metrics(y_hat, y) self.log_dict(self.test_metrics) def predict_step(