зеркало из https://github.com/microsoft/torchgeo.git
adding BYOL for landcover AI, Extending LandcoverAIDataModule to support SSL
This commit is contained in:
Родитель
b7e3b61e39
Коммит
aa0c1dc794
|
@ -5,7 +5,7 @@ trainer:
|
|||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "ssl"
|
||||
task: "byol"
|
||||
name: "test_byol"
|
||||
module:
|
||||
model: "byol"
|
||||
|
@ -15,10 +15,4 @@ experiment:
|
|||
imagenet_pretraining: True
|
||||
datamodule:
|
||||
batch_size: 64
|
||||
num_workers: 6
|
||||
train_splits:
|
||||
- "de-train"
|
||||
val_splits:
|
||||
- "de-val"
|
||||
test_splits:
|
||||
- "de-test"
|
||||
num_workers: 6
|
|
@ -0,0 +1,16 @@
|
|||
trainer:
|
||||
gpus: 1
|
||||
min_epochs: 20
|
||||
max_epochs: 50
|
||||
experiment:
|
||||
task: "byol_landcoverai"
|
||||
module:
|
||||
model: "byol"
|
||||
encoder: "resnet50"
|
||||
learning_rate: 1e-2
|
||||
input_channels: 3
|
||||
imagenet_pretraining: True
|
||||
datamodule:
|
||||
batch_size: 96
|
||||
num_workers: 6
|
||||
unsupervised_mode: True
|
|
@ -1,5 +1,5 @@
|
|||
experiment:
|
||||
task: "ssl"
|
||||
task: "byol"
|
||||
name: "test_byol"
|
||||
module:
|
||||
model: "byol"
|
||||
|
@ -11,9 +11,3 @@ experiment:
|
|||
datamodule:
|
||||
batch_size: 64
|
||||
num_workers: 6
|
||||
train_splits:
|
||||
- "de-train"
|
||||
val_splits:
|
||||
- "de-val"
|
||||
test_splits:
|
||||
- "de-test"
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
experiment:
|
||||
task: "byol_landcoverai"
|
||||
module:
|
||||
model: "byol"
|
||||
encoder: "resnet50"
|
||||
learning_rate: 1e-2
|
||||
input_channels: 3
|
||||
imagenet_pretraining: True
|
||||
learning_rate_schedule_patience: 6
|
||||
datamodule:
|
||||
batch_size: 96
|
||||
num_workers: 6
|
||||
unsupervised_mode: True
|
|
@ -272,6 +272,7 @@ class LandcoverAIDataModule(pl.LightningDataModule):
|
|||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
unsupervised_mode: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for Landcover.AI based DataLoaders.
|
||||
|
@ -280,11 +281,14 @@ class LandcoverAIDataModule(pl.LightningDataModule):
|
|||
root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.unsupervised_mode = unsupervised_mode
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
|
@ -314,23 +318,50 @@ class LandcoverAIDataModule(pl.LightningDataModule):
|
|||
train_transforms = self.preprocess
|
||||
val_test_transforms = self.preprocess
|
||||
|
||||
self.train_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
transforms=train_transforms,
|
||||
)
|
||||
if not self.unsupervised_mode:
|
||||
|
||||
self.train_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
transforms=train_transforms,
|
||||
)
|
||||
|
||||
self.val_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="val",
|
||||
transforms=val_test_transforms,
|
||||
)
|
||||
|
||||
self.test_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
transforms=val_test_transforms,
|
||||
)
|
||||
|
||||
else:
|
||||
temp_train_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
transforms=train_transforms,
|
||||
)
|
||||
|
||||
self.val_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="val",
|
||||
transforms=val_test_transforms,
|
||||
)
|
||||
|
||||
self.test_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
transforms=val_test_transforms,
|
||||
)
|
||||
|
||||
self.train_dataset = cast(
|
||||
LandCoverAI, temp_train_dataset + self.val_dataset + self.test_dataset
|
||||
)
|
||||
|
||||
self.val_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="val",
|
||||
transforms=val_test_transforms,
|
||||
)
|
||||
|
||||
self.test_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
transforms=val_test_transforms,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
|
|
1
train.py
1
train.py
|
@ -35,6 +35,7 @@ TASK_TO_MODULES_MAPPING: Dict[
|
|||
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
|
||||
] = {
|
||||
"byol": (BYOLTask, So2SatDataModule),
|
||||
"byol_landcoverai": (BYOLTask, LandcoverAIDataModule),
|
||||
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
|
||||
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
|
||||
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
|
||||
|
|
Загрузка…
Ссылка в новой задаче