This commit is contained in:
Adam J. Stewart 2021-07-25 18:02:30 -04:00
Родитель 829297f386
Коммит 0bc26a667d
2 изменённых файлов: 69 добавлений и 14 удалений

Просмотреть файл

@ -11,7 +11,7 @@ def test_required_args() -> None:
args = [sys.executable, "train.py"]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
assert b"Missing mandatory value" in ps.stderr
assert b"MissingMandatoryValue" in ps.stderr
def test_output_file(tmp_path: Path) -> None:
@ -21,7 +21,7 @@ def test_output_file(tmp_path: Path) -> None:
sys.executable,
"train.py",
"program.experiment_name=test",
f"program.output_dir={str(output_file)}",
"program.output_dir=" + str(output_file),
"task.name=test",
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@ -39,7 +39,7 @@ def test_experiment_dir_not_empty(tmp_path: Path) -> None:
sys.executable,
"train.py",
"program.experiment_name=test",
f"program.output_dir={str(output_dir)}",
"program.output_dir=" + str(output_dir),
"task.name=test",
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@ -60,9 +60,9 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
sys.executable,
"train.py",
"program.experiment_name=test",
f"program.output_dir={str(output_dir)}",
f"program.data_dir={data_dir}",
f"program.log_dir={str(log_dir)}",
"program.output_dir=" + str(output_dir),
"program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"task.name=cyclone",
"program.overwrite=True",
"trainer.fast_dev_run=1",
@ -76,6 +76,57 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
)
@pytest.mark.parametrize("task", ["test", "foo"])
def test_invalid_task(task: str, tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"task.name=" + task,
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
assert b"ValueError" in ps.stderr
def test_missing_config_file(tmp_path: Path) -> None:
config_file = tmp_path / "config.yaml"
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"task.name=test",
"config_file=" + str(config_file),
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
assert b"FileNotFoundError" in ps.stderr
def test_config_file(tmp_path: Path) -> None:
output_dir = tmp_path / "output"
data_dir = os.path.join("tests", "data")
log_dir = tmp_path / "logs"
config_file = tmp_path / "config.yaml"
config_file.write_text(f"""
program:
experiment_name: test
output_dir: {output_dir}
data_dir: {data_dir}
log_dir: {log_dir}
task:
name: cyclone
trainer:
fast_dev_run: true
""")
args = [
sys.executable,
"train.py",
"config_file=" + str(config_file),
]
ps = subprocess.run(args, check=True)
@pytest.mark.parametrize("task", ["cyclone", "sen12ms"])
def test_tasks(task: str, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
@ -85,11 +136,11 @@ def test_tasks(task: str, tmp_path: Path) -> None:
sys.executable,
"train.py",
"program.experiment_name=test",
f"program.output_dir={str(output_dir)}",
f"program.data_dir={data_dir}",
f"program.log_dir={str(log_dir)}",
"program.output_dir=" + str(output_dir),
"program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"trainer.fast_dev_run=1",
f"task.name={task}",
"task.name=" + task,
"program.overwrite=True",
]
subprocess.run(args, check=True)

Просмотреть файл

@ -39,7 +39,9 @@ def set_up_omegaconf() -> DictConfig:
Returns:
an OmegaConf DictConfig containing all the validated program arguments
# TODO: Raises
Raises:
FileNotFoundError: when ``config_file`` does not exist
ValueError: when ``task.name`` is not a valid task
"""
conf = OmegaConf.load("conf/defaults.yaml")
command_line_conf = OmegaConf.from_cli()
@ -50,7 +52,7 @@ def set_up_omegaconf() -> DictConfig:
user_conf = OmegaConf.load(config_fn)
conf = OmegaConf.merge(conf, user_conf)
else:
raise IOError(f"config_file={config_fn} is not a valid file")
raise FileNotFoundError(f"config_file={config_fn} is not a valid file")
conf = OmegaConf.merge( # Merge in any arguments passed via the command line
conf, command_line_conf
@ -65,7 +67,7 @@ def set_up_omegaconf() -> DictConfig:
elif conf.task.name == "test":
task_conf = OmegaConf.create()
else:
raise ValueError(f"task.name={conf.task.name} is not recognized as a validtask")
raise ValueError(f"task.name={conf.task.name} is not recognized as a valid task")
conf = OmegaConf.merge(task_conf, conf)
conf = cast(DictConfig, conf) # convince mypy that everything is alright
@ -80,7 +82,7 @@ def main(conf: DictConfig) -> None:
######################################
if os.path.isfile(conf.program.output_dir):
raise NotADirectoryError("`--output_dir` must be a directory")
raise NotADirectoryError("`program.output_dir` must be a directory")
os.makedirs(conf.program.output_dir, exist_ok=True)
experiment_dir = os.path.join(conf.program.output_dir, conf.program.experiment_name)
@ -135,6 +137,8 @@ def main(conf: DictConfig) -> None:
)
loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
task = SEN12MSSegmentationTask(model, loss, **task_args)
else:
raise ValueError(f"task.name={conf.task.name} is not recognized as a valid task")
######################################
# Setup trainer