зеркало из https://github.com/microsoft/torchgeo.git
Add L8BiomeTile
This commit is contained in:
Родитель
0ae66f9937
Коммит
6934dc5592
|
@ -0,0 +1,37 @@
|
|||
module:
|
||||
_target_: torchgeo.trainers.SemanticSegmentationTask
|
||||
loss: "ce"
|
||||
model: "unet"
|
||||
backbone: "resnet18"
|
||||
weights: true
|
||||
learning_rate: 1e-4
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 11
|
||||
num_classes: 5
|
||||
num_filters: 64
|
||||
ignore_index: 0
|
||||
weight_decay: 0
|
||||
|
||||
datamodule:
|
||||
_target_: torchgeo.datamodules.L8BiomeTileDataModule
|
||||
root: "/home/calebrobinson/ssdprivate/data/L8BiomeSimple/"
|
||||
batch_size: 32
|
||||
patch_size: 256
|
||||
train_batches_per_epoch: 2000
|
||||
val_batches_per_epoch: 200
|
||||
num_workers: 6
|
||||
|
||||
trainer:
|
||||
_target_: lightning.pytorch.Trainer
|
||||
accelerator: gpu
|
||||
devices:
|
||||
- 3
|
||||
min_epochs: 15
|
||||
max_epochs: 100
|
||||
|
||||
program:
|
||||
seed: 0
|
||||
output_dir: output/l8biome/
|
||||
log_dir: logs/l8biome/
|
||||
overwrite: True
|
||||
experiment_name: unet_imagenet_lr1e-4_wd0
|
|
@ -0,0 +1,81 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Runs the train script with a grid of hyperparameters."""
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
# list of GPU IDs that we want to use, one job will be started for every ID in the list
|
||||
GPUS = [0, 1, 2, 3]
|
||||
DRY_RUN = False # if False then print out the commands to be run, if True then run
|
||||
|
||||
# Hyperparameter options
|
||||
model_options = ["fcn"]
|
||||
backbone_options = ["resnet18"]
|
||||
lr_options = [0.0001]
|
||||
loss_options = ["ce"]
|
||||
wd_options = [0]
|
||||
weight_options = [True]
|
||||
seed_options = [1,2,3,4]
|
||||
|
||||
|
||||
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
||||
"""Process for each ID in GPUS."""
|
||||
while not work.empty():
|
||||
experiment = work.get()
|
||||
experiment = experiment.replace("GPU", str(gpu_idx))
|
||||
print(experiment)
|
||||
if not DRY_RUN:
|
||||
subprocess.call(experiment.split(" "))
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
work: "Queue[str]" = Queue()
|
||||
|
||||
for model, backbone, lr, loss, wd, weights, seed in itertools.product(
|
||||
model_options,
|
||||
backbone_options,
|
||||
lr_options,
|
||||
loss_options,
|
||||
wd_options,
|
||||
weight_options,
|
||||
seed_options,
|
||||
):
|
||||
if model == "fcn" and not weights:
|
||||
continue
|
||||
|
||||
if model != "unet":
|
||||
experiment_name = f"{model}_{backbone}_{lr}_{loss}_{wd}_{weights}_{seed}"
|
||||
else:
|
||||
experiment_name = f"{model}_{lr}_{loss}_{wd}_{weights}_{seed}"
|
||||
|
||||
config_file = os.path.join("conf", "l8biometile.yaml")
|
||||
|
||||
command = (
|
||||
"python train.py"
|
||||
+ f" config_file={config_file}"
|
||||
+ f" module.model={model}"
|
||||
+ f" module.backbone={backbone}"
|
||||
+ f" module.learning_rate={lr}"
|
||||
+ f" module.loss={loss}"
|
||||
+ f" module.weight_decay={wd}"
|
||||
+ f" module.weights={weights}"
|
||||
+ f" program.seed={seed}"
|
||||
+ f" program.experiment_name={experiment_name}"
|
||||
+ " trainer.devices=[GPU]"
|
||||
)
|
||||
command = command.strip()
|
||||
|
||||
work.put(command)
|
||||
|
||||
processes = []
|
||||
for gpu_idx in GPUS:
|
||||
p = Process(target=do_work, args=(work, gpu_idx))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
|
@ -16,7 +16,7 @@ from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
|
|||
from .gid15 import GID15DataModule
|
||||
from .inria import InriaAerialImageLabelingDataModule
|
||||
from .l7irish import L7IrishDataModule, L7IrishTileDataModule
|
||||
from .l8biome import L8BiomeDataModule
|
||||
from .l8biome import L8BiomeDataModule, L8BiomeTileDataModule
|
||||
from .landcoverai import LandCoverAIDataModule
|
||||
from .loveda import LoveDADataModule
|
||||
from .naip import NAIPChesapeakeDataModule
|
||||
|
@ -43,6 +43,7 @@ __all__ = (
|
|||
"L7IrishDataModule",
|
||||
"L7IrishTileDataModule",
|
||||
"L8BiomeDataModule",
|
||||
"L8BiomeTileDataModule",
|
||||
"NAIPChesapeakeDataModule",
|
||||
# NonGeoDataset
|
||||
"BigEarthNetDataModule",
|
||||
|
|
|
@ -5,10 +5,12 @@
|
|||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from lightning.pytorch import LightningDataModule
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ..datasets import L8Biome, random_bbox_assignment
|
||||
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
|
||||
from ..datasets import L8Biome, random_bbox_assignment, TileDataset
|
||||
from ..samplers import GridGeoSampler, RandomBatchGeoSampler, RandomTileGeoSampler, GridTileGeoSampler
|
||||
from .geo import GeoDataModule
|
||||
|
||||
|
||||
|
@ -74,3 +76,82 @@ class L8BiomeDataModule(GeoDataModule):
|
|||
self.test_sampler = GridGeoSampler(
|
||||
self.test_dataset, self.patch_size, self.patch_size
|
||||
)
|
||||
|
||||
class L8BiomeTileDataModule(LightningDataModule):
|
||||
|
||||
@staticmethod
|
||||
def preprocess(sample):
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
|
||||
mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4}
|
||||
if "mask" in sample:
|
||||
mask = sample["mask"].squeeze()
|
||||
for k, v in mask_mapping.items():
|
||||
mask[mask == k] = v
|
||||
sample["mask"] = mask
|
||||
return sample
|
||||
|
||||
def _get_all_the_fns(self, root):
|
||||
import os
|
||||
areas = L8Biome.filenames_to_md5.keys()
|
||||
image_fns = []
|
||||
mask_fns = []
|
||||
for area in areas:
|
||||
for scene_idx in os.listdir(os.path.join(root,area)):
|
||||
image_fns.append(os.path.join(root,area,scene_idx,f"{scene_idx}.TIF"))
|
||||
mask_fns.append(os.path.join(root,area,scene_idx,f"{scene_idx}_fixedmask.TIF"))
|
||||
return image_fns, mask_fns
|
||||
|
||||
def __init__(self, root, batch_size=1, patch_size=32, train_batches_per_epoch=None, val_batches_per_epoch=None, num_workers=0, seed=0):
|
||||
super().__init__()
|
||||
self.image_fns, self.mask_fns = self._get_all_the_fns(root)
|
||||
self.batch_size = batch_size
|
||||
self.patch_size = patch_size
|
||||
self.train_batches_per_epoch = train_batches_per_epoch
|
||||
self.val_batches_per_epoch = val_batches_per_epoch
|
||||
self.num_workers = num_workers
|
||||
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
|
||||
idxs = torch.randperm(len(self.image_fns), generator=generator)
|
||||
train_idxs = idxs[:int(len(idxs)*0.6)]
|
||||
val_idxs = idxs[int(len(idxs)*0.6):int(len(idxs)*0.8)]
|
||||
test_idxs = idxs[int(len(idxs)*0.8):]
|
||||
|
||||
self.train_image_fns = [self.image_fns[i] for i in train_idxs]
|
||||
self.train_mask_fns = [self.mask_fns[i] for i in train_idxs]
|
||||
self.val_image_fns = [self.image_fns[i] for i in val_idxs]
|
||||
self.val_mask_fns = [self.mask_fns[i] for i in val_idxs]
|
||||
self.test_image_fns = [self.image_fns[i] for i in test_idxs]
|
||||
self.test_mask_fns = [self.mask_fns[i] for i in test_idxs]
|
||||
|
||||
def setup(self, stage):
|
||||
self.train_dataset = TileDataset(self.train_image_fns, self.train_mask_fns, transforms=L8BiomeTileDataModule.preprocess)
|
||||
self.val_dataset = TileDataset(self.val_image_fns, self.val_mask_fns, transforms=L8BiomeTileDataModule.preprocess)
|
||||
self.test_dataset = TileDataset(self.test_image_fns, self.test_mask_fns, transforms=L8BiomeTileDataModule.preprocess)
|
||||
|
||||
# def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
||||
# return super().on_after_batch_transfer(batch, dataloader_idx)
|
||||
|
||||
def train_dataloader(self):
|
||||
sampler = RandomTileGeoSampler(self.train_dataset, self.patch_size, self.batch_size * self.train_batches_per_epoch)
|
||||
return DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)
|
||||
|
||||
def val_dataloader(self):
|
||||
sampler = RandomTileGeoSampler(self.val_dataset, self.patch_size, self.batch_size * self.val_batches_per_epoch)
|
||||
return DataLoader(self.val_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)
|
||||
|
||||
def test_dataloader(self):
|
||||
sampler = GridTileGeoSampler(self.test_dataset, self.patch_size, self.patch_size)
|
||||
return DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)
|
||||
|
||||
def plot(self, sample):
|
||||
import matplotlib.pyplot as plt
|
||||
image = sample["image"].permute(1,2,0).numpy()
|
||||
mask = sample["mask"].numpy().squeeze()
|
||||
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
|
||||
axs[0].imshow(image[:,:,[2,1,0]])
|
||||
axs[0].axis("off")
|
||||
axs[1].imshow(mask, vmin=0, vmax=4)
|
||||
axs[1].axis("off")
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче