adding BYOL for landcover AI, Extending LandcoverAIDataModule to support SSL

This commit is contained in:
Anthony 2021-10-02 21:20:41 +00:00 коммит произвёл Anthony Ortiz
Родитель b7e3b61e39
Коммит aa0c1dc794
6 изменённых файлов: 79 добавлений и 30 удалений

Просмотреть файл

@ -5,7 +5,7 @@ trainer:
benchmark: True
experiment:
task: "ssl"
task: "byol"
name: "test_byol"
module:
model: "byol"
@ -15,10 +15,4 @@ experiment:
imagenet_pretraining: True
datamodule:
batch_size: 64
num_workers: 6
train_splits:
- "de-train"
val_splits:
- "de-val"
test_splits:
- "de-test"
num_workers: 6

Просмотреть файл

@ -0,0 +1,16 @@
trainer:
gpus: 1
min_epochs: 20
max_epochs: 50
experiment:
task: "byol_landcoverai"
module:
model: "byol"
encoder: "resnet50"
learning_rate: 1e-2
input_channels: 3
imagenet_pretraining: True
datamodule:
batch_size: 96
num_workers: 6
unsupervised_mode: True

Просмотреть файл

@ -1,5 +1,5 @@
experiment:
task: "ssl"
task: "byol"
name: "test_byol"
module:
model: "byol"
@ -11,9 +11,3 @@ experiment:
datamodule:
batch_size: 64
num_workers: 6
train_splits:
- "de-train"
val_splits:
- "de-val"
test_splits:
- "de-test"

Просмотреть файл

@ -0,0 +1,13 @@
experiment:
task: "byol_landcoverai"
module:
model: "byol"
encoder: "resnet50"
learning_rate: 1e-2
input_channels: 3
imagenet_pretraining: True
learning_rate_schedule_patience: 6
datamodule:
batch_size: 96
num_workers: 6
unsupervised_mode: True

Просмотреть файл

@ -272,6 +272,7 @@ class LandcoverAIDataModule(pl.LightningDataModule):
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
unsupervised_mode: bool = False,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Landcover.AI based DataLoaders.
@ -280,11 +281,14 @@ class LandcoverAIDataModule(pl.LightningDataModule):
root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.unsupervised_mode = unsupervised_mode
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
@ -314,23 +318,50 @@ class LandcoverAIDataModule(pl.LightningDataModule):
train_transforms = self.preprocess
val_test_transforms = self.preprocess
self.train_dataset = LandCoverAI(
self.root_dir,
split="train",
transforms=train_transforms,
)
if not self.unsupervised_mode:
self.train_dataset = LandCoverAI(
self.root_dir,
split="train",
transforms=train_transforms,
)
self.val_dataset = LandCoverAI(
self.root_dir,
split="val",
transforms=val_test_transforms,
)
self.test_dataset = LandCoverAI(
self.root_dir,
split="test",
transforms=val_test_transforms,
)
else:
temp_train_dataset = LandCoverAI(
self.root_dir,
split="train",
transforms=train_transforms,
)
self.val_dataset = LandCoverAI(
self.root_dir,
split="val",
transforms=val_test_transforms,
)
self.test_dataset = LandCoverAI(
self.root_dir,
split="test",
transforms=val_test_transforms,
)
self.train_dataset = cast(
LandCoverAI, temp_train_dataset + self.val_dataset + self.test_dataset
)
self.val_dataset = LandCoverAI(
self.root_dir,
split="val",
transforms=val_test_transforms,
)
self.test_dataset = LandCoverAI(
self.root_dir,
split="test",
transforms=val_test_transforms,
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""

Просмотреть файл

@ -35,6 +35,7 @@ TASK_TO_MODULES_MAPPING: Dict[
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
] = {
"byol": (BYOLTask, So2SatDataModule),
"byol_landcoverai": (BYOLTask, LandcoverAIDataModule),
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),