diff --git a/conf/l8biometile.yaml b/conf/l8biometile.yaml new file mode 100644 index 000000000..1f3a9cbe1 --- /dev/null +++ b/conf/l8biometile.yaml @@ -0,0 +1,37 @@ +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: true + learning_rate: 1e-4 + learning_rate_schedule_patience: 6 + in_channels: 11 + num_classes: 5 + num_filters: 64 + ignore_index: 0 + weight_decay: 0 + +datamodule: + _target_: torchgeo.datamodules.L8BiomeTileDataModule + root: "/home/calebrobinson/ssdprivate/data/L8BiomeSimple/" + batch_size: 32 + patch_size: 256 + train_batches_per_epoch: 2000 + val_batches_per_epoch: 200 + num_workers: 6 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: + - 3 + min_epochs: 15 + max_epochs: 100 + +program: + seed: 0 + output_dir: output/l8biome/ + log_dir: logs/l8biome/ + overwrite: True + experiment_name: unet_imagenet_lr1e-4_wd0 \ No newline at end of file diff --git a/experiments/ssl4eo/run_l8biome.py b/experiments/ssl4eo/run_l8biome.py new file mode 100644 index 000000000..5196f1b97 --- /dev/null +++ b/experiments/ssl4eo/run_l8biome.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Runs the train script with a grid of hyperparameters.""" +import itertools +import os +import subprocess +from multiprocessing import Process, Queue + +# list of GPU IDs that we want to use, one job will be started for every ID in the list +GPUS = [0, 1, 2, 3] +DRY_RUN = False # if False then print out the commands to be run, if True then run + +# Hyperparameter options +model_options = ["fcn"] +backbone_options = ["resnet18"] +lr_options = [0.0001] +loss_options = ["ce"] +wd_options = [0] +weight_options = [True] +seed_options = [1,2,3,4] + + +def do_work(work: "Queue[str]", gpu_idx: int) -> bool: + """Process for each ID in GPUS.""" + while not work.empty(): + experiment = work.get() + experiment = experiment.replace("GPU", str(gpu_idx)) + print(experiment) + if not DRY_RUN: + subprocess.call(experiment.split(" ")) + return True + + +if __name__ == "__main__": + work: "Queue[str]" = Queue() + + for model, backbone, lr, loss, wd, weights, seed in itertools.product( + model_options, + backbone_options, + lr_options, + loss_options, + wd_options, + weight_options, + seed_options, + ): + if model == "fcn" and not weights: + continue + + if model != "unet": + experiment_name = f"{model}_{backbone}_{lr}_{loss}_{wd}_{weights}_{seed}" + else: + experiment_name = f"{model}_{lr}_{loss}_{wd}_{weights}_{seed}" + + config_file = os.path.join("conf", "l8biometile.yaml") + + command = ( + "python train.py" + + f" config_file={config_file}" + + f" module.model={model}" + + f" module.backbone={backbone}" + + f" module.learning_rate={lr}" + + f" module.loss={loss}" + + f" module.weight_decay={wd}" + + f" module.weights={weights}" + + f" program.seed={seed}" + + f" program.experiment_name={experiment_name}" + + " trainer.devices=[GPU]" + ) + command = command.strip() + + work.put(command) + + processes = [] + for gpu_idx in GPUS: + p = Process(target=do_work, args=(work, gpu_idx)) + processes.append(p) + p.start() + for p in processes: + p.join() diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 917643164..d769c96f0 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -16,7 +16,7 @@ from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule from .gid15 import GID15DataModule from .inria import InriaAerialImageLabelingDataModule from .l7irish import L7IrishDataModule, L7IrishTileDataModule -from .l8biome import L8BiomeDataModule +from .l8biome import L8BiomeDataModule, L8BiomeTileDataModule from .landcoverai import LandCoverAIDataModule from .loveda import LoveDADataModule from .naip import NAIPChesapeakeDataModule @@ -43,6 +43,7 @@ __all__ = ( "L7IrishDataModule", "L7IrishTileDataModule", "L8BiomeDataModule", + "L8BiomeTileDataModule", "NAIPChesapeakeDataModule", # NonGeoDataset "BigEarthNetDataModule", diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index 4b9dc4b15..94abd136f 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -5,10 +5,12 @@ from typing import Any, Optional, Union +from lightning.pytorch import LightningDataModule import torch +from torch.utils.data import DataLoader -from ..datasets import L8Biome, random_bbox_assignment -from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..datasets import L8Biome, random_bbox_assignment, TileDataset +from ..samplers import GridGeoSampler, RandomBatchGeoSampler, RandomTileGeoSampler, GridTileGeoSampler from .geo import GeoDataModule @@ -74,3 +76,82 @@ class L8BiomeDataModule(GeoDataModule): self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size ) + +class L8BiomeTileDataModule(LightningDataModule): + + @staticmethod + def preprocess(sample): + sample["image"] = sample["image"] / 255.0 + + mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4} + if "mask" in sample: + mask = sample["mask"].squeeze() + for k, v in mask_mapping.items(): + mask[mask == k] = v + sample["mask"] = mask + return sample + + def _get_all_the_fns(self, root): + import os + areas = L8Biome.filenames_to_md5.keys() + image_fns = [] + mask_fns = [] + for area in areas: + for scene_idx in os.listdir(os.path.join(root,area)): + image_fns.append(os.path.join(root,area,scene_idx,f"{scene_idx}.TIF")) + mask_fns.append(os.path.join(root,area,scene_idx,f"{scene_idx}_fixedmask.TIF")) + return image_fns, mask_fns + + def __init__(self, root, batch_size=1, patch_size=32, train_batches_per_epoch=None, val_batches_per_epoch=None, num_workers=0, seed=0): + super().__init__() + self.image_fns, self.mask_fns = self._get_all_the_fns(root) + self.batch_size = batch_size + self.patch_size = patch_size + self.train_batches_per_epoch = train_batches_per_epoch + self.val_batches_per_epoch = val_batches_per_epoch + self.num_workers = num_workers + + generator = torch.Generator().manual_seed(seed) + + idxs = torch.randperm(len(self.image_fns), generator=generator) + train_idxs = idxs[:int(len(idxs)*0.6)] + val_idxs = idxs[int(len(idxs)*0.6):int(len(idxs)*0.8)] + test_idxs = idxs[int(len(idxs)*0.8):] + + self.train_image_fns = [self.image_fns[i] for i in train_idxs] + self.train_mask_fns = [self.mask_fns[i] for i in train_idxs] + self.val_image_fns = [self.image_fns[i] for i in val_idxs] + self.val_mask_fns = [self.mask_fns[i] for i in val_idxs] + self.test_image_fns = [self.image_fns[i] for i in test_idxs] + self.test_mask_fns = [self.mask_fns[i] for i in test_idxs] + + def setup(self, stage): + self.train_dataset = TileDataset(self.train_image_fns, self.train_mask_fns, transforms=L8BiomeTileDataModule.preprocess) + self.val_dataset = TileDataset(self.val_image_fns, self.val_mask_fns, transforms=L8BiomeTileDataModule.preprocess) + self.test_dataset = TileDataset(self.test_image_fns, self.test_mask_fns, transforms=L8BiomeTileDataModule.preprocess) + + # def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + # return super().on_after_batch_transfer(batch, dataloader_idx) + + def train_dataloader(self): + sampler = RandomTileGeoSampler(self.train_dataset, self.patch_size, self.batch_size * self.train_batches_per_epoch) + return DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers) + + def val_dataloader(self): + sampler = RandomTileGeoSampler(self.val_dataset, self.patch_size, self.batch_size * self.val_batches_per_epoch) + return DataLoader(self.val_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers) + + def test_dataloader(self): + sampler = GridTileGeoSampler(self.test_dataset, self.patch_size, self.patch_size) + return DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers) + + def plot(self, sample): + import matplotlib.pyplot as plt + image = sample["image"].permute(1,2,0).numpy() + mask = sample["mask"].numpy().squeeze() + fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + axs[0].imshow(image[:,:,[2,1,0]]) + axs[0].axis("off") + axs[1].imshow(mask, vmin=0, vmax=4) + axs[1].axis("off") + return fig \ No newline at end of file