зеркало из https://github.com/microsoft/torchgeo.git
Add COWC counting experiment (#706)
* Add COWC counting experiment * Fix tests * Search event logs for all versions * Add seed experiments
This commit is contained in:
Родитель
90458ab608
Коммит
9dee4b166c
|
@ -12,5 +12,5 @@ experiment:
|
|||
datamodule:
|
||||
root_dir: "data/cowc_counting"
|
||||
seed: 0
|
||||
batch_size: 32
|
||||
batch_size: 64
|
||||
num_workers: 4
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Find the optimal set of hyperparameters given experiment checkpoints."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict
|
||||
|
||||
from tbparse import SummaryReader
|
||||
|
||||
OUTPUT_DIR = ""
|
||||
|
||||
|
||||
# mypy does not yet support recursive type hints
|
||||
def nested_dict() -> DefaultDict[str, defaultdict]: # type: ignore[type-arg]
|
||||
"""Recursive defaultdict.
|
||||
|
||||
Returns:
|
||||
a nested dictionary
|
||||
"""
|
||||
return defaultdict(nested_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
metrics = nested_dict()
|
||||
|
||||
logs = os.path.join(OUTPUT_DIR, "logs", "*", "version_*", "events*")
|
||||
for log in glob.iglob(logs):
|
||||
hyperparams = log.split(os.sep)[-3]
|
||||
reader = SummaryReader(log)
|
||||
df = reader.scalars
|
||||
|
||||
# Some event logs are for train/val, others are for test
|
||||
for split in ["train", "val", "test"]:
|
||||
rmse = df.loc[df["tag"] == f"{split}_RMSE"]
|
||||
mae = df.loc[df["tag"] == f"{split}_MAE"]
|
||||
if len(rmse):
|
||||
metrics[hyperparams][split]["RMSE"] = rmse.iloc[-1]["value"]
|
||||
if len(mae):
|
||||
metrics[hyperparams][split]["MAE"] = mae.iloc[-1]["value"]
|
||||
|
||||
print(json.dumps(metrics, sort_keys=True, indent=4))
|
|
@ -0,0 +1,70 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Runs the train script with a grid of hyperparameters."""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
# list of GPU IDs that we want to use, one job will be started for every ID in the list
|
||||
GPUS = range(0)
|
||||
DRY_RUN = True # if True then print out the commands to be run, if False then run
|
||||
DATA_DIR = "" # path to the COWC data directory
|
||||
|
||||
# Hyperparameter options
|
||||
model_options = ["resnet18", "resnet50"]
|
||||
pretrained_options = [True, False]
|
||||
lr_options = [1e-2, 1e-3, 1e-4]
|
||||
|
||||
|
||||
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
||||
"""Process for each ID in GPUS."""
|
||||
while not work.empty():
|
||||
experiment = work.get()
|
||||
experiment = experiment.replace("GPU", str(gpu_idx))
|
||||
print(experiment)
|
||||
if not DRY_RUN:
|
||||
subprocess.call(experiment.split(" "))
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
work: "Queue[str]" = Queue()
|
||||
|
||||
for (model, lr, pretrained) in itertools.product(
|
||||
model_options, lr_options, pretrained_options
|
||||
):
|
||||
experiment_name = f"{model}_{lr}_{pretrained}"
|
||||
|
||||
output_dir = os.path.join("output", "cowc_experiments")
|
||||
log_dir = os.path.join(output_dir, "logs")
|
||||
config_file = os.path.join("conf", "cowc_counting.yaml")
|
||||
|
||||
if not os.path.exists(os.path.join(output_dir, experiment_name)):
|
||||
command = (
|
||||
"python train.py"
|
||||
+ f" config_file={config_file}"
|
||||
+ f" experiment.name={experiment_name}"
|
||||
+ f" experiment.module.model={model}"
|
||||
+ f" experiment.module.learning_rate={lr}"
|
||||
+ f" experiment.module.pretrained={pretrained}"
|
||||
+ f" program.output_dir={output_dir}"
|
||||
+ f" program.log_dir={log_dir}"
|
||||
+ f" program.data_dir={DATA_DIR}"
|
||||
+ " trainer.gpus=[GPU]"
|
||||
)
|
||||
command = command.strip()
|
||||
|
||||
work.put(command)
|
||||
|
||||
processes = []
|
||||
for gpu_idx in GPUS:
|
||||
p = Process(target=do_work, args=(work, gpu_idx))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
|
@ -0,0 +1,72 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Runs the train script with a grid of hyperparameters."""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
# list of GPU IDs that we want to use, one job will be started for every ID in the list
|
||||
GPUS = range(1)
|
||||
DRY_RUN = True # if True then print out the commands to be run, if False then run
|
||||
DATA_DIR = "" # path to the COWC data directory
|
||||
|
||||
# Hyperparameter options
|
||||
model_options = ["resnet18", "resnet50"]
|
||||
pretrained_options = [True]
|
||||
lr_options = [1e-4]
|
||||
seeds = range(10)
|
||||
|
||||
|
||||
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
||||
"""Process for each ID in GPUS."""
|
||||
while not work.empty():
|
||||
experiment = work.get()
|
||||
experiment = experiment.replace("GPU", str(gpu_idx))
|
||||
print(experiment)
|
||||
if not DRY_RUN:
|
||||
subprocess.call(experiment.split(" "))
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
work: "Queue[str]" = Queue()
|
||||
|
||||
for (model, lr, pretrained, seed) in itertools.product(
|
||||
model_options, lr_options, pretrained_options, seeds
|
||||
):
|
||||
experiment_name = f"{model}_{lr}_{pretrained}_{seed}"
|
||||
|
||||
output_dir = os.path.join("output", "cowc_seed_experiments")
|
||||
log_dir = os.path.join(output_dir, "logs")
|
||||
config_file = os.path.join("conf", "cowc_counting.yaml")
|
||||
|
||||
if not os.path.exists(os.path.join(output_dir, experiment_name)):
|
||||
command = (
|
||||
"python train.py"
|
||||
+ f" config_file={config_file}"
|
||||
+ f" experiment.name={experiment_name}"
|
||||
+ f" experiment.module.model={model}"
|
||||
+ f" experiment.module.learning_rate={lr}"
|
||||
+ f" experiment.module.pretrained={pretrained}"
|
||||
+ f" program.output_dir={output_dir}"
|
||||
+ f" program.log_dir={log_dir}"
|
||||
+ f" program.data_dir={DATA_DIR}"
|
||||
+ f" program.seed={seed}"
|
||||
+ " trainer.gpus=[GPU]"
|
||||
)
|
||||
command = command.strip()
|
||||
|
||||
work.put(command)
|
||||
|
||||
processes = []
|
||||
for gpu_idx in GPUS:
|
||||
p = Process(target=do_work, args=(work, gpu_idx))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
|
@ -59,6 +59,6 @@ class TestRegressionTask:
|
|||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
def test_invalid_model(self) -> None:
|
||||
match = "Model type 'invalid_model' is not valid."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
RegressionTask(model="invalid_model")
|
||||
match = "module 'torchvision.models' has no attribute 'invalid_model'"
|
||||
with pytest.raises(AttributeError, match=match):
|
||||
RegressionTask(model="invalid_model", pretrained=False)
|
||||
|
|
|
@ -15,7 +15,6 @@ from torch import Tensor
|
|||
from torch.nn.modules import Conv2d, Linear
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
|
||||
from torchvision import models
|
||||
|
||||
from ..datasets.utils import unbind_samples
|
||||
|
||||
|
@ -30,20 +29,24 @@ class RegressionTask(pl.LightningModule):
|
|||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters."""
|
||||
if self.hyperparams["model"] == "resnet18":
|
||||
pretrained = self.hyperparams["pretrained"]
|
||||
if parse(torchvision.__version__) >= parse("0.12"):
|
||||
if pretrained:
|
||||
kwargs = {"weights": models.ResNet18_Weights.DEFAULT}
|
||||
else:
|
||||
kwargs = {"weights": None}
|
||||
model = self.hyperparams["model"]
|
||||
pretrained = self.hyperparams["pretrained"]
|
||||
|
||||
if parse(torchvision.__version__) >= parse("0.12"):
|
||||
if pretrained:
|
||||
kwargs = {
|
||||
"weights": getattr(
|
||||
torchvision.models, f"ResNet{model[6:]}_Weights"
|
||||
).DEFAULT
|
||||
}
|
||||
else:
|
||||
kwargs = {"pretrained": pretrained}
|
||||
self.model = models.resnet18(**kwargs)
|
||||
in_features = self.model.fc.in_features
|
||||
self.model.fc = nn.Linear(in_features, out_features=1)
|
||||
kwargs = {"weights": None}
|
||||
else:
|
||||
raise ValueError(f"Model type '{self.hyperparams['model']}' is not valid.")
|
||||
kwargs = {"pretrained": pretrained}
|
||||
|
||||
self.model = getattr(torchvision.models, model)(**kwargs)
|
||||
in_features = self.model.fc.in_features
|
||||
self.model.fc = nn.Linear(in_features, out_features=1)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize a new LightningModule for training simple regression models.
|
||||
|
|
Загрузка…
Ссылка в новой задаче