зеркало из https://github.com/microsoft/torchgeo.git
OSCD: adding configuration file (#262)
* OSCD: adding configuration file * add OSCDDataModule * some changes * Works, it's training * some edits * not required * Get rid of swap * Padding in val and test * don't ignore zeros * 0.56 IoU * 0.59 Iou * 0.6 IoU * make this the same * oscd change * fix tests * black fix * change defaults * Null weights * null * use Kornia PadTo * remove name from config + fix val dataloader * padto class attribute Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Родитель
25d864d141
Коммит
95a4d89d57
|
@ -0,0 +1,26 @@
|
|||
trainer:
|
||||
gpus: 1
|
||||
min_epochs: 20
|
||||
max_epochs: 500
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "oscd"
|
||||
module:
|
||||
loss: "jaccard"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: null
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
verbose: false
|
||||
in_channels: 26
|
||||
num_classes: 2
|
||||
num_filters: 128
|
||||
ignore_zeros: True
|
||||
datamodule:
|
||||
train_batch_size: 2
|
||||
num_workers: 6
|
||||
val_split_pct: 0.1
|
||||
bands: "all"
|
||||
num_patches_per_tile: 128
|
|
@ -0,0 +1,21 @@
|
|||
experiment:
|
||||
task: "oscd"
|
||||
module:
|
||||
loss: "jaccard"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: null
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
verbose: false
|
||||
in_channels: 26
|
||||
num_classes: 2
|
||||
num_filters: 256
|
||||
ignore_zeros: True
|
||||
datamodule:
|
||||
root_dir: "tests/data/oscd"
|
||||
batch_size: 32
|
||||
num_workers: 0
|
||||
val_split_pct: 0.1
|
||||
bands: "all"
|
||||
num_patches_per_tile: 128
|
|
@ -141,7 +141,9 @@ class TestOSCDDataModule:
|
|||
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
sample = next(iter(datamodule.val_dataloader()))
|
||||
if datamodule.val_split_pct > 0.0:
|
||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (3, 3)
|
||||
assert (
|
||||
sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
|
||||
)
|
||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
||||
if datamodule.bands == "all":
|
||||
assert sample["image"].shape[1] == 26
|
||||
|
@ -150,7 +152,7 @@ class TestOSCDDataModule:
|
|||
|
||||
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
sample = next(iter(datamodule.test_dataloader()))
|
||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (3, 3)
|
||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
|
||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
||||
if datamodule.bands == "all":
|
||||
assert sample["image"].shape[1] == 26
|
||||
|
|
|
@ -22,6 +22,7 @@ from .datasets import (
|
|||
EuroSATDataModule,
|
||||
LandCoverAIDataModule,
|
||||
NAIPChesapeakeDataModule,
|
||||
OSCDDataModule,
|
||||
RESISC45DataModule,
|
||||
SEN12MSDataModule,
|
||||
So2SatDataModule,
|
||||
|
@ -54,6 +55,7 @@ _TASK_TO_MODULES_MAPPING: Dict[
|
|||
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
|
||||
"landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule),
|
||||
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
|
||||
"oscd": (SemanticSegmentationTask, OSCDDataModule),
|
||||
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
|
||||
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
|
||||
"so2sat": (ClassificationTask, So2SatDataModule),
|
||||
|
|
|
@ -407,11 +407,12 @@ class OSCDDataModule(pl.LightningDataModule):
|
|||
self.rcrop = K.AugmentationSequential(
|
||||
K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True
|
||||
)
|
||||
self.padto = K.PadTo((1280, 1280))
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["mask"] = sample["mask"].float()
|
||||
sample["mask"] = sample["mask"]
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
sample["image"] = torch.flatten( # type: ignore[attr-defined]
|
||||
sample["image"], 0, 1
|
||||
|
@ -434,17 +435,24 @@ class OSCDDataModule(pl.LightningDataModule):
|
|||
def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
images, masks = [], []
|
||||
for i in range(self.num_patches_per_tile):
|
||||
mask = repeat(sample["mask"], "h w -> t h w", t=2)
|
||||
mask = repeat(sample["mask"], "h w -> t h w", t=2).float()
|
||||
image, mask = self.rcrop(sample["image"], mask)
|
||||
mask = mask.squeeze()[0]
|
||||
images.append(image.squeeze())
|
||||
masks.append(mask)
|
||||
masks.append(mask.long())
|
||||
sample["image"] = torch.stack(images)
|
||||
sample["mask"] = torch.stack(masks)
|
||||
return sample
|
||||
|
||||
def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample["image"] = self.padto(sample["image"])[0]
|
||||
sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0]
|
||||
return sample
|
||||
|
||||
train_transforms = Compose([self.preprocess, n_random_crop])
|
||||
test_transforms = Compose([self.preprocess])
|
||||
# for testing and validation we pad all inputs to a fixed size to avoid issues
|
||||
# with the upsampling paths in encoder-decoder architectures
|
||||
test_transforms = Compose([self.preprocess, pad_to])
|
||||
|
||||
train_dataset = OSCD(
|
||||
self.root_dir, split="train", bands=self.bands, transforms=train_transforms
|
||||
|
|
Загрузка…
Ссылка в новой задаче