From 95a4d89d574b675821382b080abaf5d38eb0fcf4 Mon Sep 17 00:00:00 2001 From: Maciej Kilian <61431446+iejMac@users.noreply.github.com> Date: Mon, 20 Dec 2021 23:53:20 +0100 Subject: [PATCH] OSCD: adding configuration file (#262) * OSCD: adding configuration file * add OSCDDataModule * some changes * Works, it's training * some edits * not required * Get rid of swap * Padding in val and test * don't ignore zeros * 0.56 IoU * 0.59 Iou * 0.6 IoU * make this the same * oscd change * fix tests * black fix * change defaults * Null weights * null * use Kornia PadTo * remove name from config + fix val dataloader * padto class attribute Co-authored-by: Caleb Robinson --- conf/oscd.yaml | 26 ++++++++++++++++++++++++++ conf/task_defaults/oscd.yaml | 21 +++++++++++++++++++++ tests/datasets/test_oscd.py | 6 ++++-- torchgeo/__init__.py | 2 ++ torchgeo/datasets/oscd.py | 16 ++++++++++++---- 5 files changed, 65 insertions(+), 6 deletions(-) create mode 100644 conf/oscd.yaml create mode 100644 conf/task_defaults/oscd.yaml diff --git a/conf/oscd.yaml b/conf/oscd.yaml new file mode 100644 index 000000000..d7d182225 --- /dev/null +++ b/conf/oscd.yaml @@ -0,0 +1,26 @@ +trainer: + gpus: 1 + min_epochs: 20 + max_epochs: 500 + benchmark: True + +experiment: + task: "oscd" + module: + loss: "jaccard" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 26 + num_classes: 2 + num_filters: 128 + ignore_zeros: True + datamodule: + train_batch_size: 2 + num_workers: 6 + val_split_pct: 0.1 + bands: "all" + num_patches_per_tile: 128 diff --git a/conf/task_defaults/oscd.yaml b/conf/task_defaults/oscd.yaml new file mode 100644 index 000000000..5ae3fdccb --- /dev/null +++ b/conf/task_defaults/oscd.yaml @@ -0,0 +1,21 @@ +experiment: + task: "oscd" + module: + loss: "jaccard" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 26 + num_classes: 2 + num_filters: 256 + ignore_zeros: True + datamodule: + root_dir: "tests/data/oscd" + batch_size: 32 + num_workers: 0 + val_split_pct: 0.1 + bands: "all" + num_patches_per_tile: 128 diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index c8f302ed4..2bfaf25d5 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -141,7 +141,9 @@ class TestOSCDDataModule: def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: sample = next(iter(datamodule.val_dataloader())) if datamodule.val_split_pct > 0.0: - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (3, 3) + assert ( + sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) + ) assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 if datamodule.bands == "all": assert sample["image"].shape[1] == 26 @@ -150,7 +152,7 @@ class TestOSCDDataModule: def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: sample = next(iter(datamodule.test_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (3, 3) + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 if datamodule.bands == "all": assert sample["image"].shape[1] == 26 diff --git a/torchgeo/__init__.py b/torchgeo/__init__.py index 28922adda..cc1b2abb9 100644 --- a/torchgeo/__init__.py +++ b/torchgeo/__init__.py @@ -22,6 +22,7 @@ from .datasets import ( EuroSATDataModule, LandCoverAIDataModule, NAIPChesapeakeDataModule, + OSCDDataModule, RESISC45DataModule, SEN12MSDataModule, So2SatDataModule, @@ -54,6 +55,7 @@ _TASK_TO_MODULES_MAPPING: Dict[ "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), "landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule), "naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule), + "oscd": (SemanticSegmentationTask, OSCDDataModule), "resisc45": (RESISC45ClassificationTask, RESISC45DataModule), "sen12ms": (SemanticSegmentationTask, SEN12MSDataModule), "so2sat": (ClassificationTask, So2SatDataModule), diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index ed51c27ed..c2f807b49 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -407,11 +407,12 @@ class OSCDDataModule(pl.LightningDataModule): self.rcrop = K.AugmentationSequential( K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True ) + self.padto = K.PadTo((1280, 1280)) def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset.""" sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"].float() + sample["mask"] = sample["mask"] sample["image"] = self.norm(sample["image"]) sample["image"] = torch.flatten( # type: ignore[attr-defined] sample["image"], 0, 1 @@ -434,17 +435,24 @@ class OSCDDataModule(pl.LightningDataModule): def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: images, masks = [], [] for i in range(self.num_patches_per_tile): - mask = repeat(sample["mask"], "h w -> t h w", t=2) + mask = repeat(sample["mask"], "h w -> t h w", t=2).float() image, mask = self.rcrop(sample["image"], mask) mask = mask.squeeze()[0] images.append(image.squeeze()) - masks.append(mask) + masks.append(mask.long()) sample["image"] = torch.stack(images) sample["mask"] = torch.stack(masks) return sample + def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: + sample["image"] = self.padto(sample["image"])[0] + sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0] + return sample + train_transforms = Compose([self.preprocess, n_random_crop]) - test_transforms = Compose([self.preprocess]) + # for testing and validation we pad all inputs to a fixed size to avoid issues + # with the upsampling paths in encoder-decoder architectures + test_transforms = Compose([self.preprocess, pad_to]) train_dataset = OSCD( self.root_dir, split="train", bands=self.bands, transforms=train_transforms