зеркало из https://github.com/microsoft/torchgeo.git
Remov Argmax Computation for torchmetrics in Classification and Segmentation (#1777)
* remove y_hat_hard * argmax vs softmax :)
This commit is contained in:
Родитель
b2369d6bd4
Коммит
f713ba813a
|
@ -161,10 +161,9 @@ class ClassificationTask(BaseTask):
|
|||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
loss: Tensor = self.criterion(y_hat, y)
|
||||
self.log("train_loss", loss)
|
||||
self.train_metrics(y_hat_hard, y)
|
||||
self.train_metrics(y_hat, y)
|
||||
self.log_dict(self.train_metrics)
|
||||
|
||||
return loss
|
||||
|
@ -182,10 +181,9 @@ class ClassificationTask(BaseTask):
|
|||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
loss = self.criterion(y_hat, y)
|
||||
self.log("val_loss", loss)
|
||||
self.val_metrics(y_hat_hard, y)
|
||||
self.val_metrics(y_hat, y)
|
||||
self.log_dict(self.val_metrics)
|
||||
|
||||
if (
|
||||
|
@ -198,7 +196,7 @@ class ClassificationTask(BaseTask):
|
|||
):
|
||||
try:
|
||||
datamodule = self.trainer.datamodule
|
||||
batch["prediction"] = y_hat_hard
|
||||
batch["prediction"] = y_hat.argmax(dim=-1)
|
||||
for key in ["image", "label", "prediction"]:
|
||||
batch[key] = batch[key].cpu()
|
||||
sample = unbind_samples(batch)[0]
|
||||
|
@ -223,10 +221,9 @@ class ClassificationTask(BaseTask):
|
|||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
loss = self.criterion(y_hat, y)
|
||||
self.log("test_loss", loss)
|
||||
self.test_metrics(y_hat_hard, y)
|
||||
self.test_metrics(y_hat, y)
|
||||
self.log_dict(self.test_metrics)
|
||||
|
||||
def predict_step(
|
||||
|
|
|
@ -217,10 +217,9 @@ class SemanticSegmentationTask(BaseTask):
|
|||
x = batch["image"]
|
||||
y = batch["mask"]
|
||||
y_hat = self(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
loss: Tensor = self.criterion(y_hat, y)
|
||||
self.log("train_loss", loss)
|
||||
self.train_metrics(y_hat_hard, y)
|
||||
self.train_metrics(y_hat, y)
|
||||
self.log_dict(self.train_metrics)
|
||||
return loss
|
||||
|
||||
|
@ -237,10 +236,9 @@ class SemanticSegmentationTask(BaseTask):
|
|||
x = batch["image"]
|
||||
y = batch["mask"]
|
||||
y_hat = self(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
loss = self.criterion(y_hat, y)
|
||||
self.log("val_loss", loss)
|
||||
self.val_metrics(y_hat_hard, y)
|
||||
self.val_metrics(y_hat, y)
|
||||
self.log_dict(self.val_metrics)
|
||||
|
||||
if (
|
||||
|
@ -253,7 +251,7 @@ class SemanticSegmentationTask(BaseTask):
|
|||
):
|
||||
try:
|
||||
datamodule = self.trainer.datamodule
|
||||
batch["prediction"] = y_hat_hard
|
||||
batch["prediction"] = y_hat.argmax(dim=1)
|
||||
for key in ["image", "mask", "prediction"]:
|
||||
batch[key] = batch[key].cpu()
|
||||
sample = unbind_samples(batch)[0]
|
||||
|
@ -278,10 +276,9 @@ class SemanticSegmentationTask(BaseTask):
|
|||
x = batch["image"]
|
||||
y = batch["mask"]
|
||||
y_hat = self(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
loss = self.criterion(y_hat, y)
|
||||
self.log("test_loss", loss)
|
||||
self.test_metrics(y_hat_hard, y)
|
||||
self.test_metrics(y_hat, y)
|
||||
self.log_dict(self.test_metrics)
|
||||
|
||||
def predict_step(
|
||||
|
|
Загрузка…
Ссылка в новой задаче