Remov Argmax Computation for torchmetrics in Classification and Segmentation (#1777)

* remove y_hat_hard

* argmax vs softmax :)
This commit is contained in:
Nils Lehmann 2023-12-15 17:55:15 +01:00 коммит произвёл isaaccorley
Родитель b2369d6bd4
Коммит f713ba813a
2 изменённых файлов: 8 добавлений и 14 удалений

Просмотреть файл

@ -161,10 +161,9 @@ class ClassificationTask(BaseTask):
x = batch["image"] x = batch["image"]
y = batch["label"] y = batch["label"]
y_hat = self(x) y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss: Tensor = self.criterion(y_hat, y) loss: Tensor = self.criterion(y_hat, y)
self.log("train_loss", loss) self.log("train_loss", loss)
self.train_metrics(y_hat_hard, y) self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics) self.log_dict(self.train_metrics)
return loss return loss
@ -182,10 +181,9 @@ class ClassificationTask(BaseTask):
x = batch["image"] x = batch["image"]
y = batch["label"] y = batch["label"]
y_hat = self(x) y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.criterion(y_hat, y) loss = self.criterion(y_hat, y)
self.log("val_loss", loss) self.log("val_loss", loss)
self.val_metrics(y_hat_hard, y) self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics) self.log_dict(self.val_metrics)
if ( if (
@ -198,7 +196,7 @@ class ClassificationTask(BaseTask):
): ):
try: try:
datamodule = self.trainer.datamodule datamodule = self.trainer.datamodule
batch["prediction"] = y_hat_hard batch["prediction"] = y_hat.argmax(dim=-1)
for key in ["image", "label", "prediction"]: for key in ["image", "label", "prediction"]:
batch[key] = batch[key].cpu() batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0] sample = unbind_samples(batch)[0]
@ -223,10 +221,9 @@ class ClassificationTask(BaseTask):
x = batch["image"] x = batch["image"]
y = batch["label"] y = batch["label"]
y_hat = self(x) y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.criterion(y_hat, y) loss = self.criterion(y_hat, y)
self.log("test_loss", loss) self.log("test_loss", loss)
self.test_metrics(y_hat_hard, y) self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics) self.log_dict(self.test_metrics)
def predict_step( def predict_step(

Просмотреть файл

@ -217,10 +217,9 @@ class SemanticSegmentationTask(BaseTask):
x = batch["image"] x = batch["image"]
y = batch["mask"] y = batch["mask"]
y_hat = self(x) y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss: Tensor = self.criterion(y_hat, y) loss: Tensor = self.criterion(y_hat, y)
self.log("train_loss", loss) self.log("train_loss", loss)
self.train_metrics(y_hat_hard, y) self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics) self.log_dict(self.train_metrics)
return loss return loss
@ -237,10 +236,9 @@ class SemanticSegmentationTask(BaseTask):
x = batch["image"] x = batch["image"]
y = batch["mask"] y = batch["mask"]
y_hat = self(x) y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.criterion(y_hat, y) loss = self.criterion(y_hat, y)
self.log("val_loss", loss) self.log("val_loss", loss)
self.val_metrics(y_hat_hard, y) self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics) self.log_dict(self.val_metrics)
if ( if (
@ -253,7 +251,7 @@ class SemanticSegmentationTask(BaseTask):
): ):
try: try:
datamodule = self.trainer.datamodule datamodule = self.trainer.datamodule
batch["prediction"] = y_hat_hard batch["prediction"] = y_hat.argmax(dim=1)
for key in ["image", "mask", "prediction"]: for key in ["image", "mask", "prediction"]:
batch[key] = batch[key].cpu() batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0] sample = unbind_samples(batch)[0]
@ -278,10 +276,9 @@ class SemanticSegmentationTask(BaseTask):
x = batch["image"] x = batch["image"]
y = batch["mask"] y = batch["mask"]
y_hat = self(x) y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.criterion(y_hat, y) loss = self.criterion(y_hat, y)
self.log("test_loss", loss) self.log("test_loss", loss)
self.test_metrics(y_hat_hard, y) self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics) self.log_dict(self.test_metrics)
def predict_step( def predict_step(