Updated test_train to match new way of passing arguments. Fixed mypy problem.

This commit is contained in:
Caleb Robinson 2021-07-22 23:52:45 +00:00 коммит произвёл Adam J. Stewart
Родитель 34769f9f6c
Коммит 5d2dc8f96e
4 изменённых файлов: 27 добавлений и 46 удалений

Просмотреть файл

@ -7,16 +7,11 @@ from pathlib import Path
import pytest
def test_help() -> None:
args = [sys.executable, "train.py", "--help"]
subprocess.run(args, check=True)
def test_required_args() -> None:
args = [sys.executable, "train.py"]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
assert b"error: the following arguments are required:" in ps.stderr
assert b"Missing mandatory value" in ps.stderr
def test_output_file(tmp_path: Path) -> None:
@ -25,10 +20,9 @@ def test_output_file(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"--experiment_name",
"test",
"--output_dir",
str(output_file),
"program.experiment_name=test",
f"program.output_dir={str(output_file)}",
"task.name=test",
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
@ -44,10 +38,9 @@ def test_experiment_dir_not_empty(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"--experiment_name",
"test",
"--output_dir",
str(output_dir),
"program.experiment_name=test",
f"program.output_dir={str(output_dir)}",
"task.name=test",
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
@ -66,17 +59,13 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"--experiment_name",
experiment_name,
"--output_dir",
str(output_dir),
"--data_dir",
data_dir,
"--log_dir",
str(log_dir),
"--overwrite",
"--fast_dev_run",
"1",
"program.experiment_name=test",
f"program.output_dir={str(output_dir)}",
f"program.data_dir={data_dir}",
f"program.log_dir={str(log_dir)}",
"task.name=cyclone",
"program.overwrite=True",
"trainer.fast_dev_run=1",
]
ps = subprocess.run(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
@ -89,24 +78,18 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
@pytest.mark.parametrize("task", ["cyclone", "sen12ms"])
def test_tasks(task: str, tmp_path: Path) -> None:
experiment_name = "test"
output_dir = tmp_path / "output"
data_dir = os.path.join("tests", "data")
log_dir = tmp_path / "logs"
args = [
sys.executable,
"train.py",
"--experiment_name",
experiment_name,
"--output_dir",
str(output_dir),
"--data_dir",
data_dir,
"--log_dir",
str(log_dir),
"--fast_dev_run",
"1",
"--task",
task,
"program.experiment_name=test",
f"program.output_dir={str(output_dir)}",
f"program.data_dir={data_dir}",
f"program.log_dir={str(log_dir)}",
"trainer.fast_dev_run=1",
f"task.name={task}",
"program.overwrite=True",
]
subprocess.run(args, check=True)

Просмотреть файл

@ -90,16 +90,14 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
"""Initialize the optimizer and learning rate scheduler."""
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.hparams["learning_rate"], # type: ignore[index]
lr=self.hparams["learning_rate"],
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer,
patience=self.hparams[ # type: ignore[index]
"learning_rate_schedule_patience"
],
patience=self.hparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
},

Просмотреть файл

@ -112,16 +112,14 @@ class SEN12MSSegmentationTask(pl.LightningModule):
"""Initialize the optimizer and learning rate scheduler."""
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.hparams["learning_rate"], # type: ignore[index]
lr=self.hparams["learning_rate"],
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer,
patience=self.hparams[ # type: ignore[index]
"learning_rate_schedule_patience"
],
patience=self.hparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
},

Просмотреть файл

@ -62,6 +62,8 @@ def set_up_omegaconf() -> DictConfig:
task_conf = OmegaConf.load("conf/task_defaults/cyclone.yaml")
elif conf.task.name == "sen12ms":
task_conf = OmegaConf.load("conf/task_defaults/sen12ms.yaml")
elif conf.task.name == "test":
task_conf = OmegaConf.create()
else:
raise ValueError(f"task.name={conf.task.name} is not recognized as a validtask")