зеркало из https://github.com/microsoft/torchgeo.git
Updated test_train to match new way of passing arguments. Fixed mypy problem.
This commit is contained in:
Родитель
34769f9f6c
Коммит
5d2dc8f96e
|
@ -7,16 +7,11 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
|
||||
def test_help() -> None:
|
||||
args = [sys.executable, "train.py", "--help"]
|
||||
subprocess.run(args, check=True)
|
||||
|
||||
|
||||
def test_required_args() -> None:
|
||||
args = [sys.executable, "train.py"]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
assert b"error: the following arguments are required:" in ps.stderr
|
||||
assert b"Missing mandatory value" in ps.stderr
|
||||
|
||||
|
||||
def test_output_file(tmp_path: Path) -> None:
|
||||
|
@ -25,10 +20,9 @@ def test_output_file(tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"--experiment_name",
|
||||
"test",
|
||||
"--output_dir",
|
||||
str(output_file),
|
||||
"program.experiment_name=test",
|
||||
f"program.output_dir={str(output_file)}",
|
||||
"task.name=test",
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
|
@ -44,10 +38,9 @@ def test_experiment_dir_not_empty(tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"--experiment_name",
|
||||
"test",
|
||||
"--output_dir",
|
||||
str(output_dir),
|
||||
"program.experiment_name=test",
|
||||
f"program.output_dir={str(output_dir)}",
|
||||
"task.name=test",
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
|
@ -66,17 +59,13 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"--experiment_name",
|
||||
experiment_name,
|
||||
"--output_dir",
|
||||
str(output_dir),
|
||||
"--data_dir",
|
||||
data_dir,
|
||||
"--log_dir",
|
||||
str(log_dir),
|
||||
"--overwrite",
|
||||
"--fast_dev_run",
|
||||
"1",
|
||||
"program.experiment_name=test",
|
||||
f"program.output_dir={str(output_dir)}",
|
||||
f"program.data_dir={data_dir}",
|
||||
f"program.log_dir={str(log_dir)}",
|
||||
"task.name=cyclone",
|
||||
"program.overwrite=True",
|
||||
"trainer.fast_dev_run=1",
|
||||
]
|
||||
ps = subprocess.run(
|
||||
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||
|
@ -89,24 +78,18 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
|
|||
|
||||
@pytest.mark.parametrize("task", ["cyclone", "sen12ms"])
|
||||
def test_tasks(task: str, tmp_path: Path) -> None:
|
||||
experiment_name = "test"
|
||||
output_dir = tmp_path / "output"
|
||||
data_dir = os.path.join("tests", "data")
|
||||
log_dir = tmp_path / "logs"
|
||||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"--experiment_name",
|
||||
experiment_name,
|
||||
"--output_dir",
|
||||
str(output_dir),
|
||||
"--data_dir",
|
||||
data_dir,
|
||||
"--log_dir",
|
||||
str(log_dir),
|
||||
"--fast_dev_run",
|
||||
"1",
|
||||
"--task",
|
||||
task,
|
||||
"program.experiment_name=test",
|
||||
f"program.output_dir={str(output_dir)}",
|
||||
f"program.data_dir={data_dir}",
|
||||
f"program.log_dir={str(log_dir)}",
|
||||
"trainer.fast_dev_run=1",
|
||||
f"task.name={task}",
|
||||
"program.overwrite=True",
|
||||
]
|
||||
subprocess.run(args, check=True)
|
||||
|
|
|
@ -90,16 +90,14 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
|
|||
"""Initialize the optimizer and learning rate scheduler."""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"], # type: ignore[index]
|
||||
lr=self.hparams["learning_rate"],
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": ReduceLROnPlateau(
|
||||
optimizer,
|
||||
patience=self.hparams[ # type: ignore[index]
|
||||
"learning_rate_schedule_patience"
|
||||
],
|
||||
patience=self.hparams["learning_rate_schedule_patience"],
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
|
|
|
@ -112,16 +112,14 @@ class SEN12MSSegmentationTask(pl.LightningModule):
|
|||
"""Initialize the optimizer and learning rate scheduler."""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"], # type: ignore[index]
|
||||
lr=self.hparams["learning_rate"],
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": ReduceLROnPlateau(
|
||||
optimizer,
|
||||
patience=self.hparams[ # type: ignore[index]
|
||||
"learning_rate_schedule_patience"
|
||||
],
|
||||
patience=self.hparams["learning_rate_schedule_patience"],
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
|
|
2
train.py
2
train.py
|
@ -62,6 +62,8 @@ def set_up_omegaconf() -> DictConfig:
|
|||
task_conf = OmegaConf.load("conf/task_defaults/cyclone.yaml")
|
||||
elif conf.task.name == "sen12ms":
|
||||
task_conf = OmegaConf.load("conf/task_defaults/sen12ms.yaml")
|
||||
elif conf.task.name == "test":
|
||||
task_conf = OmegaConf.create()
|
||||
else:
|
||||
raise ValueError(f"task.name={conf.task.name} is not recognized as a validtask")
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче