зеркало из https://github.com/microsoft/torchgeo.git
TropicalCyclone dataset: use label for sample key
This commit is contained in:
Родитель
3cc63def02
Коммит
5a7c80fb0f
|
@ -56,7 +56,7 @@ import torchgeo.datasets
|
|||
|
||||
dataset = torchgeo.datasets.TropicalCycloneWindEstimation(split="train", download=True)
|
||||
print(dataset[0]["image"].shape)
|
||||
print(dataset[0]["target"])
|
||||
print(dataset[0]["label"])
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
|
|
@ -70,7 +70,7 @@ class TestTropicalCycloneWindEstimation:
|
|||
assert isinstance(x["storm_id"], str)
|
||||
assert isinstance(x["relative_time"], int)
|
||||
assert isinstance(x["ocean"], int)
|
||||
assert isinstance(x["target"], int)
|
||||
assert isinstance(x["label"], int)
|
||||
assert x["image"].shape == (dataset.size, dataset.size)
|
||||
|
||||
def test_len(self, dataset: TropicalCycloneWindEstimation) -> None:
|
||||
|
|
|
@ -168,7 +168,7 @@ class TropicalCycloneWindEstimation(VisionDataset):
|
|||
|
||||
features["relative_time"] = int(features["relative_time"])
|
||||
features["ocean"] = int(features["ocean"])
|
||||
features["target"] = int(features["wind_speed"])
|
||||
features["label"] = int(features["wind_speed"])
|
||||
|
||||
return features
|
||||
|
||||
|
|
|
@ -64,8 +64,8 @@ class CycloneDataModule(pl.LightningDataModule):
|
|||
sample["image"] = (
|
||||
sample["image"].unsqueeze(0).repeat(3, 1, 1)
|
||||
) # convert to 3 channel
|
||||
sample["target"] = torch.as_tensor( # type: ignore[attr-defined]
|
||||
sample["target"]
|
||||
sample["label"] = torch.as_tensor( # type: ignore[attr-defined]
|
||||
sample["label"]
|
||||
).float()
|
||||
|
||||
return sample
|
||||
|
|
|
@ -430,7 +430,7 @@ class RegressionTask(pl.LightningModule):
|
|||
training loss
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y = batch["label"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
|
@ -459,7 +459,7 @@ class RegressionTask(pl.LightningModule):
|
|||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y = batch["label"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
|
@ -485,7 +485,7 @@ class RegressionTask(pl.LightningModule):
|
|||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y = batch["label"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
|
|
Загрузка…
Ссылка в новой задаче