TropicalCyclone dataset: use label for sample key

This commit is contained in:
Adam J. Stewart 2021-11-02 16:10:50 +00:00
Родитель 3cc63def02
Коммит 5a7c80fb0f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
5 изменённых файлов: 8 добавлений и 8 удалений

Просмотреть файл

@ -56,7 +56,7 @@ import torchgeo.datasets
dataset = torchgeo.datasets.TropicalCycloneWindEstimation(split="train", download=True)
print(dataset[0]["image"].shape)
print(dataset[0]["target"])
print(dataset[0]["label"])
```
## Contributing

Просмотреть файл

@ -70,7 +70,7 @@ class TestTropicalCycloneWindEstimation:
assert isinstance(x["storm_id"], str)
assert isinstance(x["relative_time"], int)
assert isinstance(x["ocean"], int)
assert isinstance(x["target"], int)
assert isinstance(x["label"], int)
assert x["image"].shape == (dataset.size, dataset.size)
def test_len(self, dataset: TropicalCycloneWindEstimation) -> None:

Просмотреть файл

@ -168,7 +168,7 @@ class TropicalCycloneWindEstimation(VisionDataset):
features["relative_time"] = int(features["relative_time"])
features["ocean"] = int(features["ocean"])
features["target"] = int(features["wind_speed"])
features["label"] = int(features["wind_speed"])
return features

Просмотреть файл

@ -64,8 +64,8 @@ class CycloneDataModule(pl.LightningDataModule):
sample["image"] = (
sample["image"].unsqueeze(0).repeat(3, 1, 1)
) # convert to 3 channel
sample["target"] = torch.as_tensor( # type: ignore[attr-defined]
sample["target"]
sample["label"] = torch.as_tensor( # type: ignore[attr-defined]
sample["label"]
).float()
return sample

Просмотреть файл

@ -430,7 +430,7 @@ class RegressionTask(pl.LightningModule):
training loss
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y = batch["label"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
@ -459,7 +459,7 @@ class RegressionTask(pl.LightningModule):
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y = batch["label"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
@ -485,7 +485,7 @@ class RegressionTask(pl.LightningModule):
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y = batch["label"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)