зеркало из https://github.com/microsoft/torchgeo.git
Reorganize configuration files (#352)
* Reorganize configuration files * Undo changes to tests/conf files
This commit is contained in:
Родитель
f69bdaf5f8
Коммит
5136d819c0
|
@ -1,9 +1,8 @@
|
|||
trainer:
|
||||
gpus: 1 # single GPU training
|
||||
gpus: 1
|
||||
min_epochs: 10
|
||||
max_epochs: 40
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "bigearthnet"
|
||||
module:
|
||||
|
@ -11,9 +10,12 @@ experiment:
|
|||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
in_channels: 14
|
||||
datamodule:
|
||||
num_classes: 19
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
datamodule:
|
||||
root_dir: "data/bigearthnet"
|
||||
bands: "all"
|
||||
num_classes: ${experiment.module.num_classes}
|
||||
batch_size: 128
|
||||
num_workers: 4
|
||||
|
|
|
@ -3,7 +3,6 @@ trainer:
|
|||
min_epochs: 20
|
||||
max_epochs: 100
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "ssl"
|
||||
name: "test_byol"
|
||||
|
@ -12,12 +11,15 @@ experiment:
|
|||
encoder: "resnet18"
|
||||
input_channels: 4
|
||||
imagenet_pretraining: True
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
datamodule:
|
||||
batch_size: 64
|
||||
num_workers: 6
|
||||
root_dir: "data/chesapeake/cvpr"
|
||||
train_splits:
|
||||
- "de-train"
|
||||
val_splits:
|
||||
- "de-val"
|
||||
test_splits:
|
||||
- "de-test"
|
||||
batch_size: 64
|
||||
num_workers: 4
|
||||
|
|
|
@ -1,29 +1,35 @@
|
|||
trainer:
|
||||
gpus: 1 # single GPU training
|
||||
gpus: 1
|
||||
min_epochs: 20
|
||||
max_epochs: 100
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "chesapeake_cvpr"
|
||||
name: "chesapeake_cvpr_example"
|
||||
module:
|
||||
loss: "ce" # cross entropy loss
|
||||
loss: "ce"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: null # use ImageNet weight initialization
|
||||
encoder_weights: null
|
||||
encoder_output_stride: 16
|
||||
learning_rate: 1e-2
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 4
|
||||
num_classes: 7
|
||||
num_filters: 256
|
||||
ignore_zeros: False
|
||||
imagenet_pretraining: True
|
||||
datamodule:
|
||||
batch_size: 64
|
||||
num_workers: 6
|
||||
root_dir: "data/chesapeake/cvpr"
|
||||
train_splits:
|
||||
- "de-train"
|
||||
val_splits:
|
||||
- "de-val"
|
||||
test_splits:
|
||||
- "de-test"
|
||||
patches_per_tile: 200
|
||||
patch_size: 256
|
||||
batch_size: 64
|
||||
num_workers: 4
|
||||
class_set: ${experiment.module.num_classes}
|
||||
use_prior_labels: False
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
program: # These are the arguments that define how the train.py script works
|
||||
seed: 1337
|
||||
overwrite: True
|
||||
|
||||
trainer:
|
||||
gpus: 1
|
||||
min_epochs: 15
|
||||
experiment:
|
||||
task: cowc_counting
|
||||
name: cowc_counting_test
|
||||
|
@ -9,10 +8,9 @@ experiment:
|
|||
model: resnet18
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
pretrained: True
|
||||
datamodule:
|
||||
root_dir: "data/cowc_counting"
|
||||
seed: 0
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
|
||||
trainer:
|
||||
min_epochs: 15
|
||||
gpus: 1
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
program: # These are the arguments that define how the train.py script works
|
||||
seed: 1337
|
||||
overwrite: True
|
||||
|
||||
trainer:
|
||||
gpus: 1
|
||||
min_epochs: 15
|
||||
experiment:
|
||||
task: "cyclone"
|
||||
name: cyclone_test
|
||||
name: "cyclone_test"
|
||||
module:
|
||||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
pretrained: True
|
||||
datamodule:
|
||||
root_dir: "data/cyclone"
|
||||
seed: 0
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
|
||||
trainer:
|
||||
min_epochs: 15
|
||||
gpus: 1
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
experiment:
|
||||
task: "etci2021"
|
||||
module:
|
||||
loss: "ce"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: "imagenet"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 7
|
||||
num_classes: 2
|
||||
ignore_zeros: True
|
||||
datamodule:
|
||||
root_dir: "data/etci2021"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
|
@ -0,0 +1,14 @@
|
|||
experiment:
|
||||
task: "eurosat"
|
||||
module:
|
||||
loss: "ce"
|
||||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
in_channels: 13
|
||||
num_classes: 10
|
||||
datamodule:
|
||||
root_dir: "data/eurosat"
|
||||
batch_size: 128
|
||||
num_workers: 4
|
|
@ -1,22 +1,22 @@
|
|||
trainer:
|
||||
gpus: 1 # single GPU training
|
||||
gpus: 1
|
||||
min_epochs: 20
|
||||
max_epochs: 100
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "landcoverai"
|
||||
module:
|
||||
loss: "ce"
|
||||
segmentation_model: "deeplabv3+"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: null
|
||||
encoder_output_stride: 16
|
||||
encoder_weights: "imagenet"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 3
|
||||
num_classes: 6
|
||||
num_filters: 256
|
||||
ignore_zeros: False
|
||||
datamodule:
|
||||
root_dir: "data/landcoverai"
|
||||
batch_size: 32
|
||||
num_workers: 6
|
||||
num_workers: 4
|
||||
|
|
|
@ -1,13 +1,25 @@
|
|||
program: # These are experiment level arguments
|
||||
experiment_name: naip_chesapeake_test
|
||||
program:
|
||||
experiment_name: "naip_chesapeake_test"
|
||||
overwrite: True
|
||||
naip_data_dir: data/naip
|
||||
chesapeake_data_dir: data/chesapeake
|
||||
|
||||
trainer: # These are all the arguments that will be passed to the pl.Trainer
|
||||
trainer:
|
||||
min_epochs: 15
|
||||
|
||||
task: # These are all the arguments that will be used to create an appropriate task
|
||||
name: naipchesapeake
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
experiment:
|
||||
task: "naipchesapeake"
|
||||
module:
|
||||
loss: "ce"
|
||||
segmentation_model: "deeplabv3+"
|
||||
encoder_name: "resnet34"
|
||||
encoder_weights: "imagenet"
|
||||
encoder_output_stride: 16
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
in_channels: 4
|
||||
num_classes: 13
|
||||
num_filters: 64
|
||||
ignore_zeros: False
|
||||
datamodule:
|
||||
naip_root_dir: "data/naip"
|
||||
chesapeake_root_dir: "data/chesapeake/BAYWIDE"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
patch_size: 32
|
||||
|
|
|
@ -3,7 +3,6 @@ trainer:
|
|||
min_epochs: 20
|
||||
max_epochs: 500
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "oscd"
|
||||
module:
|
||||
|
@ -16,11 +15,15 @@ experiment:
|
|||
verbose: false
|
||||
in_channels: 26
|
||||
num_classes: 2
|
||||
num_filters: 128
|
||||
num_filters: 256
|
||||
ignore_zeros: True
|
||||
datamodule:
|
||||
train_batch_size: 2
|
||||
num_workers: 6
|
||||
root_dir: "data/oscd"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
val_split_pct: 0.1
|
||||
bands: "all"
|
||||
pad_size:
|
||||
- 1028
|
||||
- 1028
|
||||
num_patches_per_tile: 128
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
trainer:
|
||||
gpus: 1 # single GPU training
|
||||
gpus: 1
|
||||
min_epochs: 10
|
||||
max_epochs: 40
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "resisc45"
|
||||
module:
|
||||
|
@ -12,6 +11,9 @@ experiment:
|
|||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
in_channels: 3
|
||||
num_classes: 45
|
||||
datamodule:
|
||||
root_dir: "data/resisc45"
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
num_workers: 4
|
||||
|
|
|
@ -1,11 +1,24 @@
|
|||
program: # These are experiment level arguments
|
||||
program:
|
||||
experiment_name: sen12ms_test
|
||||
overwrite: True
|
||||
|
||||
trainer: # These are all the arguments that will be passed to the pl.Trainer
|
||||
trainer:
|
||||
min_epochs: 15
|
||||
|
||||
task: # These are all the arguments that will be used to create an appropriate task
|
||||
name: sen12ms
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
experiment:
|
||||
task: "sen12ms"
|
||||
module:
|
||||
loss: "ce"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: null
|
||||
encoder_output_stride: 16
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
in_channels: 15
|
||||
num_classes: 11
|
||||
ignore_zeros: False
|
||||
datamodule:
|
||||
root_dir: "data/sen12ms"
|
||||
band_set: "all"
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
seed: 0
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
trainer:
|
||||
gpus: 1 # single GPU training
|
||||
gpus: 1
|
||||
min_epochs: 10
|
||||
max_epochs: 40
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "so2sat"
|
||||
module:
|
||||
|
@ -11,8 +10,12 @@ experiment:
|
|||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
in_channels: 3
|
||||
num_classes: 17
|
||||
datamodule:
|
||||
root_dir: "data/so2sat"
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
num_workers: 4
|
||||
bands: "rgb"
|
||||
unsupervised_mode: False
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
experiment:
|
||||
task: "ucmerced"
|
||||
module:
|
||||
loss: "ce"
|
||||
classification_model: "resnet18"
|
||||
weights: null
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 3
|
||||
num_classes: 21
|
||||
datamodule:
|
||||
root_dir: "data/ucmerced"
|
||||
batch_size: 128
|
||||
num_workers: 4
|
|
@ -14,9 +14,7 @@ from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
|||
class TestChesapeakeCVPRDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> ChesapeakeCVPRDataModule:
|
||||
conf = OmegaConf.load(
|
||||
os.path.join("conf", "task_defaults", "chesapeake_cvpr_5.yaml")
|
||||
)
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", "chesapeake_cvpr_5.yaml"))
|
||||
kwargs = OmegaConf.to_object(conf.experiment.datamodule)
|
||||
kwargs = cast(Dict[str, Any], kwargs)
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class TestBYOLTask:
|
|||
],
|
||||
)
|
||||
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
|
||||
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class TestClassificationTask:
|
|||
if name.startswith("so2sat"):
|
||||
pytest.importorskip("h5py")
|
||||
|
||||
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
@ -68,7 +68,7 @@ class TestClassificationTask:
|
|||
trainer.test(model=model, datamodule=datamodule)
|
||||
|
||||
def test_no_logger(self) -> None:
|
||||
conf = OmegaConf.load(os.path.join("conf", "task_defaults", "ucmerced.yaml"))
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
@ -142,7 +142,7 @@ class TestMultiLabelClassificationTask:
|
|||
name: str,
|
||||
classname: Type[LightningDataModule],
|
||||
) -> None:
|
||||
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
@ -163,9 +163,7 @@ class TestMultiLabelClassificationTask:
|
|||
trainer.test(model=model, datamodule=datamodule)
|
||||
|
||||
def test_no_logger(self) -> None:
|
||||
conf = OmegaConf.load(
|
||||
os.path.join("conf", "task_defaults", "bigearthnet_s1.yaml")
|
||||
)
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class TestRegressionTask:
|
|||
[("cowc_counting", COWCCountingDataModule), ("cyclone", CycloneDataModule)],
|
||||
)
|
||||
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
|
||||
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
@ -40,7 +40,7 @@ class TestRegressionTask:
|
|||
trainer.test(model=model, datamodule=datamodule)
|
||||
|
||||
def test_no_logger(self) -> None:
|
||||
conf = OmegaConf.load(os.path.join("conf", "task_defaults", "cyclone.yaml"))
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ class TestSemanticSegmentationTask:
|
|||
name: str,
|
||||
classname: Type[LightningDataModule],
|
||||
) -> None:
|
||||
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
@ -72,7 +72,7 @@ class TestSemanticSegmentationTask:
|
|||
trainer.test(model=model, datamodule=datamodule)
|
||||
|
||||
def test_no_logger(self) -> None:
|
||||
conf = OmegaConf.load(os.path.join("conf", "task_defaults", "landcoverai.yaml"))
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", "landcoverai.yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
|
||||
|
|
21
train.py
21
train.py
|
@ -39,28 +39,19 @@ from torchgeo.trainers import (
|
|||
TASK_TO_MODULES_MAPPING: Dict[
|
||||
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
|
||||
] = {
|
||||
"bigearthnet_all": (MultiLabelClassificationTask, BigEarthNetDataModule),
|
||||
"bigearthnet_s1": (MultiLabelClassificationTask, BigEarthNetDataModule),
|
||||
"bigearthnet_s2": (MultiLabelClassificationTask, BigEarthNetDataModule),
|
||||
"bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule),
|
||||
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
|
||||
"chesapeake_cvpr_5": (SemanticSegmentationTask, ChesapeakeCVPRDataModule),
|
||||
"chesapeake_cvpr_7": (SemanticSegmentationTask, ChesapeakeCVPRDataModule),
|
||||
"chesapeake_cvpr_prior": (SemanticSegmentationTask, ChesapeakeCVPRDataModule),
|
||||
"chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule),
|
||||
"cowc_counting": (RegressionTask, COWCCountingDataModule),
|
||||
"cyclone": (RegressionTask, CycloneDataModule),
|
||||
"eurosat": (ClassificationTask, EuroSATDataModule),
|
||||
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
|
||||
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule),
|
||||
"naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule),
|
||||
"oscd_all": (SemanticSegmentationTask, OSCDDataModule),
|
||||
"oscd_rgb": (SemanticSegmentationTask, OSCDDataModule),
|
||||
"oscd": (SemanticSegmentationTask, OSCDDataModule),
|
||||
"resisc45": (ClassificationTask, RESISC45DataModule),
|
||||
"sen12ms_all": (SemanticSegmentationTask, SEN12MSDataModule),
|
||||
"sen12ms_s1": (SemanticSegmentationTask, SEN12MSDataModule),
|
||||
"sen12ms_s2_all": (SemanticSegmentationTask, SEN12MSDataModule),
|
||||
"sen12ms_s2_reduced": (SemanticSegmentationTask, SEN12MSDataModule),
|
||||
"so2sat_supervised": (ClassificationTask, So2SatDataModule),
|
||||
"so2sat_unsupervised": (ClassificationTask, So2SatDataModule),
|
||||
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
|
||||
"so2sat": (ClassificationTask, So2SatDataModule),
|
||||
"ucmerced": (ClassificationTask, UCMercedDataModule),
|
||||
}
|
||||
|
||||
|
@ -104,7 +95,7 @@ def set_up_omegaconf() -> DictConfig:
|
|||
# These OmegaConf structured configs enforce a schema at runtime, see:
|
||||
# https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs
|
||||
task_name = conf.experiment.task
|
||||
task_config_fn = os.path.join("conf", "task_defaults", f"{task_name}.yaml")
|
||||
task_config_fn = os.path.join("conf", f"{task_name}.yaml")
|
||||
if task_name == "test":
|
||||
task_conf = OmegaConf.create()
|
||||
elif os.path.exists(task_config_fn):
|
||||
|
|
Загрузка…
Ссылка в новой задаче