зеркало из https://github.com/microsoft/torchgeo.git
Add tests for train.py
This commit is contained in:
Родитель
829297f386
Коммит
0bc26a667d
|
@ -11,7 +11,7 @@ def test_required_args() -> None:
|
|||
args = [sys.executable, "train.py"]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
assert b"Missing mandatory value" in ps.stderr
|
||||
assert b"MissingMandatoryValue" in ps.stderr
|
||||
|
||||
|
||||
def test_output_file(tmp_path: Path) -> None:
|
||||
|
@ -21,7 +21,7 @@ def test_output_file(tmp_path: Path) -> None:
|
|||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
f"program.output_dir={str(output_file)}",
|
||||
"program.output_dir=" + str(output_file),
|
||||
"task.name=test",
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
@ -39,7 +39,7 @@ def test_experiment_dir_not_empty(tmp_path: Path) -> None:
|
|||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
f"program.output_dir={str(output_dir)}",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"task.name=test",
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
@ -60,9 +60,9 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
|
|||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
f"program.output_dir={str(output_dir)}",
|
||||
f"program.data_dir={data_dir}",
|
||||
f"program.log_dir={str(log_dir)}",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"program.data_dir=" + data_dir,
|
||||
"program.log_dir=" + str(log_dir),
|
||||
"task.name=cyclone",
|
||||
"program.overwrite=True",
|
||||
"trainer.fast_dev_run=1",
|
||||
|
@ -76,6 +76,57 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["test", "foo"])
|
||||
def test_invalid_task(task: str, tmp_path: Path) -> None:
|
||||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
"task.name=" + task,
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
assert b"ValueError" in ps.stderr
|
||||
|
||||
|
||||
def test_missing_config_file(tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "config.yaml"
|
||||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
"task.name=test",
|
||||
"config_file=" + str(config_file),
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
assert b"FileNotFoundError" in ps.stderr
|
||||
|
||||
|
||||
def test_config_file(tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
data_dir = os.path.join("tests", "data")
|
||||
log_dir = tmp_path / "logs"
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(f"""
|
||||
program:
|
||||
experiment_name: test
|
||||
output_dir: {output_dir}
|
||||
data_dir: {data_dir}
|
||||
log_dir: {log_dir}
|
||||
task:
|
||||
name: cyclone
|
||||
trainer:
|
||||
fast_dev_run: true
|
||||
""")
|
||||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"config_file=" + str(config_file),
|
||||
]
|
||||
ps = subprocess.run(args, check=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["cyclone", "sen12ms"])
|
||||
def test_tasks(task: str, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
|
@ -85,11 +136,11 @@ def test_tasks(task: str, tmp_path: Path) -> None:
|
|||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
f"program.output_dir={str(output_dir)}",
|
||||
f"program.data_dir={data_dir}",
|
||||
f"program.log_dir={str(log_dir)}",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"program.data_dir=" + data_dir,
|
||||
"program.log_dir=" + str(log_dir),
|
||||
"trainer.fast_dev_run=1",
|
||||
f"task.name={task}",
|
||||
"task.name=" + task,
|
||||
"program.overwrite=True",
|
||||
]
|
||||
subprocess.run(args, check=True)
|
||||
|
|
12
train.py
12
train.py
|
@ -39,7 +39,9 @@ def set_up_omegaconf() -> DictConfig:
|
|||
Returns:
|
||||
an OmegaConf DictConfig containing all the validated program arguments
|
||||
|
||||
# TODO: Raises
|
||||
Raises:
|
||||
FileNotFoundError: when ``config_file`` does not exist
|
||||
ValueError: when ``task.name`` is not a valid task
|
||||
"""
|
||||
conf = OmegaConf.load("conf/defaults.yaml")
|
||||
command_line_conf = OmegaConf.from_cli()
|
||||
|
@ -50,7 +52,7 @@ def set_up_omegaconf() -> DictConfig:
|
|||
user_conf = OmegaConf.load(config_fn)
|
||||
conf = OmegaConf.merge(conf, user_conf)
|
||||
else:
|
||||
raise IOError(f"config_file={config_fn} is not a valid file")
|
||||
raise FileNotFoundError(f"config_file={config_fn} is not a valid file")
|
||||
|
||||
conf = OmegaConf.merge( # Merge in any arguments passed via the command line
|
||||
conf, command_line_conf
|
||||
|
@ -65,7 +67,7 @@ def set_up_omegaconf() -> DictConfig:
|
|||
elif conf.task.name == "test":
|
||||
task_conf = OmegaConf.create()
|
||||
else:
|
||||
raise ValueError(f"task.name={conf.task.name} is not recognized as a validtask")
|
||||
raise ValueError(f"task.name={conf.task.name} is not recognized as a valid task")
|
||||
|
||||
conf = OmegaConf.merge(task_conf, conf)
|
||||
conf = cast(DictConfig, conf) # convince mypy that everything is alright
|
||||
|
@ -80,7 +82,7 @@ def main(conf: DictConfig) -> None:
|
|||
######################################
|
||||
|
||||
if os.path.isfile(conf.program.output_dir):
|
||||
raise NotADirectoryError("`--output_dir` must be a directory")
|
||||
raise NotADirectoryError("`program.output_dir` must be a directory")
|
||||
os.makedirs(conf.program.output_dir, exist_ok=True)
|
||||
|
||||
experiment_dir = os.path.join(conf.program.output_dir, conf.program.experiment_name)
|
||||
|
@ -135,6 +137,8 @@ def main(conf: DictConfig) -> None:
|
|||
)
|
||||
loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
|
||||
task = SEN12MSSegmentationTask(model, loss, **task_args)
|
||||
else:
|
||||
raise ValueError(f"task.name={conf.task.name} is not recognized as a valid task")
|
||||
|
||||
######################################
|
||||
# Setup trainer
|
||||
|
|
Загрузка…
Ссылка в новой задаче