зеркало из https://github.com/microsoft/torchgeo.git
Remov Argmax Computation for torchmetrics in Classification and Segmentation (#1777)
* remove y_hat_hard * argmax vs softmax :)
This commit is contained in:
Родитель
b2369d6bd4
Коммит
f713ba813a
|
@ -161,10 +161,9 @@ class ClassificationTask(BaseTask):
|
||||||
x = batch["image"]
|
x = batch["image"]
|
||||||
y = batch["label"]
|
y = batch["label"]
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
y_hat_hard = y_hat.argmax(dim=1)
|
|
||||||
loss: Tensor = self.criterion(y_hat, y)
|
loss: Tensor = self.criterion(y_hat, y)
|
||||||
self.log("train_loss", loss)
|
self.log("train_loss", loss)
|
||||||
self.train_metrics(y_hat_hard, y)
|
self.train_metrics(y_hat, y)
|
||||||
self.log_dict(self.train_metrics)
|
self.log_dict(self.train_metrics)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -182,10 +181,9 @@ class ClassificationTask(BaseTask):
|
||||||
x = batch["image"]
|
x = batch["image"]
|
||||||
y = batch["label"]
|
y = batch["label"]
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
y_hat_hard = y_hat.argmax(dim=1)
|
|
||||||
loss = self.criterion(y_hat, y)
|
loss = self.criterion(y_hat, y)
|
||||||
self.log("val_loss", loss)
|
self.log("val_loss", loss)
|
||||||
self.val_metrics(y_hat_hard, y)
|
self.val_metrics(y_hat, y)
|
||||||
self.log_dict(self.val_metrics)
|
self.log_dict(self.val_metrics)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -198,7 +196,7 @@ class ClassificationTask(BaseTask):
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
datamodule = self.trainer.datamodule
|
datamodule = self.trainer.datamodule
|
||||||
batch["prediction"] = y_hat_hard
|
batch["prediction"] = y_hat.argmax(dim=-1)
|
||||||
for key in ["image", "label", "prediction"]:
|
for key in ["image", "label", "prediction"]:
|
||||||
batch[key] = batch[key].cpu()
|
batch[key] = batch[key].cpu()
|
||||||
sample = unbind_samples(batch)[0]
|
sample = unbind_samples(batch)[0]
|
||||||
|
@ -223,10 +221,9 @@ class ClassificationTask(BaseTask):
|
||||||
x = batch["image"]
|
x = batch["image"]
|
||||||
y = batch["label"]
|
y = batch["label"]
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
y_hat_hard = y_hat.argmax(dim=1)
|
|
||||||
loss = self.criterion(y_hat, y)
|
loss = self.criterion(y_hat, y)
|
||||||
self.log("test_loss", loss)
|
self.log("test_loss", loss)
|
||||||
self.test_metrics(y_hat_hard, y)
|
self.test_metrics(y_hat, y)
|
||||||
self.log_dict(self.test_metrics)
|
self.log_dict(self.test_metrics)
|
||||||
|
|
||||||
def predict_step(
|
def predict_step(
|
||||||
|
|
|
@ -217,10 +217,9 @@ class SemanticSegmentationTask(BaseTask):
|
||||||
x = batch["image"]
|
x = batch["image"]
|
||||||
y = batch["mask"]
|
y = batch["mask"]
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
y_hat_hard = y_hat.argmax(dim=1)
|
|
||||||
loss: Tensor = self.criterion(y_hat, y)
|
loss: Tensor = self.criterion(y_hat, y)
|
||||||
self.log("train_loss", loss)
|
self.log("train_loss", loss)
|
||||||
self.train_metrics(y_hat_hard, y)
|
self.train_metrics(y_hat, y)
|
||||||
self.log_dict(self.train_metrics)
|
self.log_dict(self.train_metrics)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -237,10 +236,9 @@ class SemanticSegmentationTask(BaseTask):
|
||||||
x = batch["image"]
|
x = batch["image"]
|
||||||
y = batch["mask"]
|
y = batch["mask"]
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
y_hat_hard = y_hat.argmax(dim=1)
|
|
||||||
loss = self.criterion(y_hat, y)
|
loss = self.criterion(y_hat, y)
|
||||||
self.log("val_loss", loss)
|
self.log("val_loss", loss)
|
||||||
self.val_metrics(y_hat_hard, y)
|
self.val_metrics(y_hat, y)
|
||||||
self.log_dict(self.val_metrics)
|
self.log_dict(self.val_metrics)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -253,7 +251,7 @@ class SemanticSegmentationTask(BaseTask):
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
datamodule = self.trainer.datamodule
|
datamodule = self.trainer.datamodule
|
||||||
batch["prediction"] = y_hat_hard
|
batch["prediction"] = y_hat.argmax(dim=1)
|
||||||
for key in ["image", "mask", "prediction"]:
|
for key in ["image", "mask", "prediction"]:
|
||||||
batch[key] = batch[key].cpu()
|
batch[key] = batch[key].cpu()
|
||||||
sample = unbind_samples(batch)[0]
|
sample = unbind_samples(batch)[0]
|
||||||
|
@ -278,10 +276,9 @@ class SemanticSegmentationTask(BaseTask):
|
||||||
x = batch["image"]
|
x = batch["image"]
|
||||||
y = batch["mask"]
|
y = batch["mask"]
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
y_hat_hard = y_hat.argmax(dim=1)
|
|
||||||
loss = self.criterion(y_hat, y)
|
loss = self.criterion(y_hat, y)
|
||||||
self.log("test_loss", loss)
|
self.log("test_loss", loss)
|
||||||
self.test_metrics(y_hat_hard, y)
|
self.test_metrics(y_hat, y)
|
||||||
self.log_dict(self.test_metrics)
|
self.log_dict(self.test_metrics)
|
||||||
|
|
||||||
def predict_step(
|
def predict_step(
|
||||||
|
|
Загрузка…
Ссылка в новой задаче