OSCD: adding configuration file (#262)

* OSCD: adding configuration file

* add OSCDDataModule

* some changes

* Works, it's training

* some edits

* not required

* Get rid of swap

* Padding in val and test

* don't ignore zeros

* 0.56 IoU

* 0.59 Iou

* 0.6 IoU

* make this the same

* oscd change

* fix tests

* black fix

* change defaults

* Null weights

* null

* use Kornia PadTo

* remove name from config + fix val dataloader

* padto class attribute

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Maciej Kilian 2021-12-20 23:53:20 +01:00 коммит произвёл GitHub
Родитель 25d864d141
Коммит 95a4d89d57
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 65 добавлений и 6 удалений

26
conf/oscd.yaml Normal file
Просмотреть файл

@ -0,0 +1,26 @@
trainer:
gpus: 1
min_epochs: 20
max_epochs: 500
benchmark: True
experiment:
task: "oscd"
module:
loss: "jaccard"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 26
num_classes: 2
num_filters: 128
ignore_zeros: True
datamodule:
train_batch_size: 2
num_workers: 6
val_split_pct: 0.1
bands: "all"
num_patches_per_tile: 128

Просмотреть файл

@ -0,0 +1,21 @@
experiment:
task: "oscd"
module:
loss: "jaccard"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 26
num_classes: 2
num_filters: 256
ignore_zeros: True
datamodule:
root_dir: "tests/data/oscd"
batch_size: 32
num_workers: 0
val_split_pct: 0.1
bands: "all"
num_patches_per_tile: 128

Просмотреть файл

@ -141,7 +141,9 @@ class TestOSCDDataModule:
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
sample = next(iter(datamodule.val_dataloader()))
if datamodule.val_split_pct > 0.0:
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (3, 3)
assert (
sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
if datamodule.bands == "all":
assert sample["image"].shape[1] == 26
@ -150,7 +152,7 @@ class TestOSCDDataModule:
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
sample = next(iter(datamodule.test_dataloader()))
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (3, 3)
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
if datamodule.bands == "all":
assert sample["image"].shape[1] == 26

Просмотреть файл

@ -22,6 +22,7 @@ from .datasets import (
EuroSATDataModule,
LandCoverAIDataModule,
NAIPChesapeakeDataModule,
OSCDDataModule,
RESISC45DataModule,
SEN12MSDataModule,
So2SatDataModule,
@ -54,6 +55,7 @@ _TASK_TO_MODULES_MAPPING: Dict[
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
"landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule),
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
"oscd": (SemanticSegmentationTask, OSCDDataModule),
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
"so2sat": (ClassificationTask, So2SatDataModule),

Просмотреть файл

@ -407,11 +407,12 @@ class OSCDDataModule(pl.LightningDataModule):
self.rcrop = K.AugmentationSequential(
K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True
)
self.padto = K.PadTo((1280, 1280))
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"].float()
sample["mask"] = sample["mask"]
sample["image"] = self.norm(sample["image"])
sample["image"] = torch.flatten( # type: ignore[attr-defined]
sample["image"], 0, 1
@ -434,17 +435,24 @@ class OSCDDataModule(pl.LightningDataModule):
def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]:
images, masks = [], []
for i in range(self.num_patches_per_tile):
mask = repeat(sample["mask"], "h w -> t h w", t=2)
mask = repeat(sample["mask"], "h w -> t h w", t=2).float()
image, mask = self.rcrop(sample["image"], mask)
mask = mask.squeeze()[0]
images.append(image.squeeze())
masks.append(mask)
masks.append(mask.long())
sample["image"] = torch.stack(images)
sample["mask"] = torch.stack(masks)
return sample
def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["image"] = self.padto(sample["image"])[0]
sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0]
return sample
train_transforms = Compose([self.preprocess, n_random_crop])
test_transforms = Compose([self.preprocess])
# for testing and validation we pad all inputs to a fixed size to avoid issues
# with the upsampling paths in encoder-decoder architectures
test_transforms = Compose([self.preprocess, pad_to])
train_dataset = OSCD(
self.root_dir, split="train", bands=self.bands, transforms=train_transforms