This commit is contained in:
Caleb Robinson 2023-05-30 05:14:42 +00:00
Родитель 0ae66f9937
Коммит 6934dc5592
4 изменённых файлов: 203 добавлений и 3 удалений

37
conf/l8biometile.yaml Normal file
Просмотреть файл

@ -0,0 +1,37 @@
module:
_target_: torchgeo.trainers.SemanticSegmentationTask
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: true
learning_rate: 1e-4
learning_rate_schedule_patience: 6
in_channels: 11
num_classes: 5
num_filters: 64
ignore_index: 0
weight_decay: 0
datamodule:
_target_: torchgeo.datamodules.L8BiomeTileDataModule
root: "/home/calebrobinson/ssdprivate/data/L8BiomeSimple/"
batch_size: 32
patch_size: 256
train_batches_per_epoch: 2000
val_batches_per_epoch: 200
num_workers: 6
trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices:
- 3
min_epochs: 15
max_epochs: 100
program:
seed: 0
output_dir: output/l8biome/
log_dir: logs/l8biome/
overwrite: True
experiment_name: unet_imagenet_lr1e-4_wd0

Просмотреть файл

@ -0,0 +1,81 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Runs the train script with a grid of hyperparameters."""
import itertools
import os
import subprocess
from multiprocessing import Process, Queue
# list of GPU IDs that we want to use, one job will be started for every ID in the list
GPUS = [0, 1, 2, 3]
DRY_RUN = False # if False then print out the commands to be run, if True then run
# Hyperparameter options
model_options = ["fcn"]
backbone_options = ["resnet18"]
lr_options = [0.0001]
loss_options = ["ce"]
wd_options = [0]
weight_options = [True]
seed_options = [1,2,3,4]
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"""Process for each ID in GPUS."""
while not work.empty():
experiment = work.get()
experiment = experiment.replace("GPU", str(gpu_idx))
print(experiment)
if not DRY_RUN:
subprocess.call(experiment.split(" "))
return True
if __name__ == "__main__":
work: "Queue[str]" = Queue()
for model, backbone, lr, loss, wd, weights, seed in itertools.product(
model_options,
backbone_options,
lr_options,
loss_options,
wd_options,
weight_options,
seed_options,
):
if model == "fcn" and not weights:
continue
if model != "unet":
experiment_name = f"{model}_{backbone}_{lr}_{loss}_{wd}_{weights}_{seed}"
else:
experiment_name = f"{model}_{lr}_{loss}_{wd}_{weights}_{seed}"
config_file = os.path.join("conf", "l8biometile.yaml")
command = (
"python train.py"
+ f" config_file={config_file}"
+ f" module.model={model}"
+ f" module.backbone={backbone}"
+ f" module.learning_rate={lr}"
+ f" module.loss={loss}"
+ f" module.weight_decay={wd}"
+ f" module.weights={weights}"
+ f" program.seed={seed}"
+ f" program.experiment_name={experiment_name}"
+ " trainer.devices=[GPU]"
)
command = command.strip()
work.put(command)
processes = []
for gpu_idx in GPUS:
p = Process(target=do_work, args=(work, gpu_idx))
processes.append(p)
p.start()
for p in processes:
p.join()

Просмотреть файл

@ -16,7 +16,7 @@ from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .l7irish import L7IrishDataModule, L7IrishTileDataModule
from .l8biome import L8BiomeDataModule
from .l8biome import L8BiomeDataModule, L8BiomeTileDataModule
from .landcoverai import LandCoverAIDataModule
from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule
@ -43,6 +43,7 @@ __all__ = (
"L7IrishDataModule",
"L7IrishTileDataModule",
"L8BiomeDataModule",
"L8BiomeTileDataModule",
"NAIPChesapeakeDataModule",
# NonGeoDataset
"BigEarthNetDataModule",

Просмотреть файл

@ -5,10 +5,12 @@
from typing import Any, Optional, Union
from lightning.pytorch import LightningDataModule
import torch
from torch.utils.data import DataLoader
from ..datasets import L8Biome, random_bbox_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..datasets import L8Biome, random_bbox_assignment, TileDataset
from ..samplers import GridGeoSampler, RandomBatchGeoSampler, RandomTileGeoSampler, GridTileGeoSampler
from .geo import GeoDataModule
@ -74,3 +76,82 @@ class L8BiomeDataModule(GeoDataModule):
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)
class L8BiomeTileDataModule(LightningDataModule):
@staticmethod
def preprocess(sample):
sample["image"] = sample["image"] / 255.0
mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4}
if "mask" in sample:
mask = sample["mask"].squeeze()
for k, v in mask_mapping.items():
mask[mask == k] = v
sample["mask"] = mask
return sample
def _get_all_the_fns(self, root):
import os
areas = L8Biome.filenames_to_md5.keys()
image_fns = []
mask_fns = []
for area in areas:
for scene_idx in os.listdir(os.path.join(root,area)):
image_fns.append(os.path.join(root,area,scene_idx,f"{scene_idx}.TIF"))
mask_fns.append(os.path.join(root,area,scene_idx,f"{scene_idx}_fixedmask.TIF"))
return image_fns, mask_fns
def __init__(self, root, batch_size=1, patch_size=32, train_batches_per_epoch=None, val_batches_per_epoch=None, num_workers=0, seed=0):
super().__init__()
self.image_fns, self.mask_fns = self._get_all_the_fns(root)
self.batch_size = batch_size
self.patch_size = patch_size
self.train_batches_per_epoch = train_batches_per_epoch
self.val_batches_per_epoch = val_batches_per_epoch
self.num_workers = num_workers
generator = torch.Generator().manual_seed(seed)
idxs = torch.randperm(len(self.image_fns), generator=generator)
train_idxs = idxs[:int(len(idxs)*0.6)]
val_idxs = idxs[int(len(idxs)*0.6):int(len(idxs)*0.8)]
test_idxs = idxs[int(len(idxs)*0.8):]
self.train_image_fns = [self.image_fns[i] for i in train_idxs]
self.train_mask_fns = [self.mask_fns[i] for i in train_idxs]
self.val_image_fns = [self.image_fns[i] for i in val_idxs]
self.val_mask_fns = [self.mask_fns[i] for i in val_idxs]
self.test_image_fns = [self.image_fns[i] for i in test_idxs]
self.test_mask_fns = [self.mask_fns[i] for i in test_idxs]
def setup(self, stage):
self.train_dataset = TileDataset(self.train_image_fns, self.train_mask_fns, transforms=L8BiomeTileDataModule.preprocess)
self.val_dataset = TileDataset(self.val_image_fns, self.val_mask_fns, transforms=L8BiomeTileDataModule.preprocess)
self.test_dataset = TileDataset(self.test_image_fns, self.test_mask_fns, transforms=L8BiomeTileDataModule.preprocess)
# def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
# return super().on_after_batch_transfer(batch, dataloader_idx)
def train_dataloader(self):
sampler = RandomTileGeoSampler(self.train_dataset, self.patch_size, self.batch_size * self.train_batches_per_epoch)
return DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)
def val_dataloader(self):
sampler = RandomTileGeoSampler(self.val_dataset, self.patch_size, self.batch_size * self.val_batches_per_epoch)
return DataLoader(self.val_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)
def test_dataloader(self):
sampler = GridTileGeoSampler(self.test_dataset, self.patch_size, self.patch_size)
return DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)
def plot(self, sample):
import matplotlib.pyplot as plt
image = sample["image"].permute(1,2,0).numpy()
mask = sample["mask"].numpy().squeeze()
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(image[:,:,[2,1,0]])
axs[0].axis("off")
axs[1].imshow(mask, vmin=0, vmax=4)
axs[1].axis("off")
return fig