зеркало из https://github.com/microsoft/torchgeo.git
OSCD: adding configuration file (#262)
* OSCD: adding configuration file * add OSCDDataModule * some changes * Works, it's training * some edits * not required * Get rid of swap * Padding in val and test * don't ignore zeros * 0.56 IoU * 0.59 Iou * 0.6 IoU * make this the same * oscd change * fix tests * black fix * change defaults * Null weights * null * use Kornia PadTo * remove name from config + fix val dataloader * padto class attribute Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Родитель
25d864d141
Коммит
95a4d89d57
|
@ -0,0 +1,26 @@
|
||||||
|
trainer:
|
||||||
|
gpus: 1
|
||||||
|
min_epochs: 20
|
||||||
|
max_epochs: 500
|
||||||
|
benchmark: True
|
||||||
|
|
||||||
|
experiment:
|
||||||
|
task: "oscd"
|
||||||
|
module:
|
||||||
|
loss: "jaccard"
|
||||||
|
segmentation_model: "unet"
|
||||||
|
encoder_name: "resnet18"
|
||||||
|
encoder_weights: null
|
||||||
|
learning_rate: 1e-3
|
||||||
|
learning_rate_schedule_patience: 6
|
||||||
|
verbose: false
|
||||||
|
in_channels: 26
|
||||||
|
num_classes: 2
|
||||||
|
num_filters: 128
|
||||||
|
ignore_zeros: True
|
||||||
|
datamodule:
|
||||||
|
train_batch_size: 2
|
||||||
|
num_workers: 6
|
||||||
|
val_split_pct: 0.1
|
||||||
|
bands: "all"
|
||||||
|
num_patches_per_tile: 128
|
|
@ -0,0 +1,21 @@
|
||||||
|
experiment:
|
||||||
|
task: "oscd"
|
||||||
|
module:
|
||||||
|
loss: "jaccard"
|
||||||
|
segmentation_model: "unet"
|
||||||
|
encoder_name: "resnet18"
|
||||||
|
encoder_weights: null
|
||||||
|
learning_rate: 1e-3
|
||||||
|
learning_rate_schedule_patience: 6
|
||||||
|
verbose: false
|
||||||
|
in_channels: 26
|
||||||
|
num_classes: 2
|
||||||
|
num_filters: 256
|
||||||
|
ignore_zeros: True
|
||||||
|
datamodule:
|
||||||
|
root_dir: "tests/data/oscd"
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 0
|
||||||
|
val_split_pct: 0.1
|
||||||
|
bands: "all"
|
||||||
|
num_patches_per_tile: 128
|
|
@ -141,7 +141,9 @@ class TestOSCDDataModule:
|
||||||
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
|
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||||
sample = next(iter(datamodule.val_dataloader()))
|
sample = next(iter(datamodule.val_dataloader()))
|
||||||
if datamodule.val_split_pct > 0.0:
|
if datamodule.val_split_pct > 0.0:
|
||||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (3, 3)
|
assert (
|
||||||
|
sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
|
||||||
|
)
|
||||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
||||||
if datamodule.bands == "all":
|
if datamodule.bands == "all":
|
||||||
assert sample["image"].shape[1] == 26
|
assert sample["image"].shape[1] == 26
|
||||||
|
@ -150,7 +152,7 @@ class TestOSCDDataModule:
|
||||||
|
|
||||||
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
|
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||||
sample = next(iter(datamodule.test_dataloader()))
|
sample = next(iter(datamodule.test_dataloader()))
|
||||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (3, 3)
|
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
|
||||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
||||||
if datamodule.bands == "all":
|
if datamodule.bands == "all":
|
||||||
assert sample["image"].shape[1] == 26
|
assert sample["image"].shape[1] == 26
|
||||||
|
|
|
@ -22,6 +22,7 @@ from .datasets import (
|
||||||
EuroSATDataModule,
|
EuroSATDataModule,
|
||||||
LandCoverAIDataModule,
|
LandCoverAIDataModule,
|
||||||
NAIPChesapeakeDataModule,
|
NAIPChesapeakeDataModule,
|
||||||
|
OSCDDataModule,
|
||||||
RESISC45DataModule,
|
RESISC45DataModule,
|
||||||
SEN12MSDataModule,
|
SEN12MSDataModule,
|
||||||
So2SatDataModule,
|
So2SatDataModule,
|
||||||
|
@ -54,6 +55,7 @@ _TASK_TO_MODULES_MAPPING: Dict[
|
||||||
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
|
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
|
||||||
"landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule),
|
"landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule),
|
||||||
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
|
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
|
||||||
|
"oscd": (SemanticSegmentationTask, OSCDDataModule),
|
||||||
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
|
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
|
||||||
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
|
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
|
||||||
"so2sat": (ClassificationTask, So2SatDataModule),
|
"so2sat": (ClassificationTask, So2SatDataModule),
|
||||||
|
|
|
@ -407,11 +407,12 @@ class OSCDDataModule(pl.LightningDataModule):
|
||||||
self.rcrop = K.AugmentationSequential(
|
self.rcrop = K.AugmentationSequential(
|
||||||
K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True
|
K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True
|
||||||
)
|
)
|
||||||
|
self.padto = K.PadTo((1280, 1280))
|
||||||
|
|
||||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Transform a single sample from the Dataset."""
|
"""Transform a single sample from the Dataset."""
|
||||||
sample["image"] = sample["image"].float()
|
sample["image"] = sample["image"].float()
|
||||||
sample["mask"] = sample["mask"].float()
|
sample["mask"] = sample["mask"]
|
||||||
sample["image"] = self.norm(sample["image"])
|
sample["image"] = self.norm(sample["image"])
|
||||||
sample["image"] = torch.flatten( # type: ignore[attr-defined]
|
sample["image"] = torch.flatten( # type: ignore[attr-defined]
|
||||||
sample["image"], 0, 1
|
sample["image"], 0, 1
|
||||||
|
@ -434,17 +435,24 @@ class OSCDDataModule(pl.LightningDataModule):
|
||||||
def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]:
|
def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
images, masks = [], []
|
images, masks = [], []
|
||||||
for i in range(self.num_patches_per_tile):
|
for i in range(self.num_patches_per_tile):
|
||||||
mask = repeat(sample["mask"], "h w -> t h w", t=2)
|
mask = repeat(sample["mask"], "h w -> t h w", t=2).float()
|
||||||
image, mask = self.rcrop(sample["image"], mask)
|
image, mask = self.rcrop(sample["image"], mask)
|
||||||
mask = mask.squeeze()[0]
|
mask = mask.squeeze()[0]
|
||||||
images.append(image.squeeze())
|
images.append(image.squeeze())
|
||||||
masks.append(mask)
|
masks.append(mask.long())
|
||||||
sample["image"] = torch.stack(images)
|
sample["image"] = torch.stack(images)
|
||||||
sample["mask"] = torch.stack(masks)
|
sample["mask"] = torch.stack(masks)
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
sample["image"] = self.padto(sample["image"])[0]
|
||||||
|
sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0]
|
||||||
|
return sample
|
||||||
|
|
||||||
train_transforms = Compose([self.preprocess, n_random_crop])
|
train_transforms = Compose([self.preprocess, n_random_crop])
|
||||||
test_transforms = Compose([self.preprocess])
|
# for testing and validation we pad all inputs to a fixed size to avoid issues
|
||||||
|
# with the upsampling paths in encoder-decoder architectures
|
||||||
|
test_transforms = Compose([self.preprocess, pad_to])
|
||||||
|
|
||||||
train_dataset = OSCD(
|
train_dataset = OSCD(
|
||||||
self.root_dir, split="train", bands=self.bands, transforms=train_transforms
|
self.root_dir, split="train", bands=self.bands, transforms=train_transforms
|
||||||
|
|
Загрузка…
Ссылка в новой задаче