2021-01-12 03:19:44 +03:00
|
|
|
import logging
|
2021-04-17 02:59:55 +03:00
|
|
|
import os
|
2021-04-23 03:08:11 +03:00
|
|
|
import pathlib
|
2021-01-15 11:16:44 +03:00
|
|
|
import hydra
|
2021-04-23 03:08:11 +03:00
|
|
|
import numpy as np
|
|
|
|
from math import floor
|
2021-04-17 02:59:55 +03:00
|
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
|
|
|
2021-04-23 03:08:11 +03:00
|
|
|
logger = logging.getLogger("datamodeler")
|
2021-01-15 11:16:44 +03:00
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
2021-01-13 02:24:38 +03:00
|
|
|
|
2021-01-12 03:19:44 +03:00
|
|
|
|
2021-01-15 11:16:44 +03:00
|
|
|
@hydra.main(config_path="conf", config_name="config")
|
|
|
|
def main(cfg: DictConfig) -> None:
|
2021-01-12 03:19:44 +03:00
|
|
|
|
2021-01-15 11:16:44 +03:00
|
|
|
logger.info("Configuration: ")
|
|
|
|
logger.info(f"\n{OmegaConf.to_yaml(cfg)}")
|
2021-01-12 03:19:44 +03:00
|
|
|
|
2021-04-23 03:08:11 +03:00
|
|
|
# for readability, read common data args into variables
|
2021-01-15 11:16:44 +03:00
|
|
|
input_cols = cfg["data"]["inputs"]
|
|
|
|
output_cols = cfg["data"]["outputs"]
|
|
|
|
augmented_cols = cfg["data"]["augmented_cols"]
|
2021-06-03 02:28:05 +03:00
|
|
|
|
2021-01-15 11:16:44 +03:00
|
|
|
iteration_order = cfg["data"]["iteration_order"]
|
|
|
|
episode_col = cfg["data"]["episode_col"]
|
|
|
|
iteration_col = cfg["data"]["iteration_col"]
|
|
|
|
dataset_path = cfg["data"]["path"]
|
2021-01-15 22:48:54 +03:00
|
|
|
max_rows = cfg["data"]["max_rows"]
|
2021-04-23 03:08:11 +03:00
|
|
|
test_perc = cfg["data"]["test_perc"]
|
|
|
|
|
|
|
|
# common model args
|
2021-03-26 06:39:35 +03:00
|
|
|
save_path = cfg["model"]["saver"]["filename"]
|
2021-01-15 11:16:44 +03:00
|
|
|
model_name = cfg["model"]["name"]
|
2021-04-01 22:06:06 +03:00
|
|
|
delta_state = cfg["data"]["diff_state"]
|
2021-04-17 02:59:55 +03:00
|
|
|
run_sweep = cfg["model"]["sweep"]["run"]
|
2021-04-23 03:08:11 +03:00
|
|
|
split_strategy = cfg["model"]["sweep"]["split_strategy"]
|
|
|
|
results_csv_path = cfg["model"]["sweep"]["results_csv_path"]
|
2021-04-17 02:59:55 +03:00
|
|
|
|
2021-07-12 22:57:37 +03:00
|
|
|
if model_name.lower() == "pytorch":
|
2021-06-09 22:38:33 +03:00
|
|
|
from all_models import available_models
|
|
|
|
else:
|
|
|
|
from model_loader import available_models
|
|
|
|
|
2021-01-15 11:16:44 +03:00
|
|
|
Model = available_models[model_name]
|
2021-01-12 03:19:44 +03:00
|
|
|
|
2021-01-15 11:16:44 +03:00
|
|
|
if cfg["data"]["full_or_relative"] == "relative":
|
|
|
|
dataset_path = os.path.join(dir_path, dataset_path)
|
|
|
|
|
|
|
|
save_path = os.path.join(dir_path, save_path + ".pkl")
|
|
|
|
|
|
|
|
if type(input_cols) == ListConfig:
|
|
|
|
input_cols = list(input_cols)
|
|
|
|
if type(output_cols) == ListConfig:
|
|
|
|
output_cols = list(output_cols)
|
|
|
|
if type(augmented_cols) == ListConfig:
|
|
|
|
augmented_cols = list(augmented_cols)
|
|
|
|
|
|
|
|
model = Model()
|
|
|
|
X, y = model.load_csv(
|
|
|
|
input_cols=input_cols,
|
|
|
|
output_cols=output_cols,
|
|
|
|
augm_cols=augmented_cols,
|
2021-01-12 03:19:44 +03:00
|
|
|
dataset_path=dataset_path,
|
2021-01-15 11:16:44 +03:00
|
|
|
iteration_order=iteration_order,
|
|
|
|
episode_col=episode_col,
|
|
|
|
iteration_col=iteration_col,
|
2021-01-15 22:48:54 +03:00
|
|
|
max_rows=max_rows,
|
2021-04-01 22:06:06 +03:00
|
|
|
diff_state=delta_state,
|
2021-01-12 03:19:44 +03:00
|
|
|
)
|
2021-04-23 03:08:11 +03:00
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Saving last {test_perc * 100}% for test, using first {(1 - test_perc) * 100}% for training/sweeping"
|
|
|
|
)
|
|
|
|
train_id_end = floor(X.shape[0] * (1 - test_perc))
|
2021-04-23 03:46:17 +03:00
|
|
|
X_train, y_train = (
|
2021-07-13 02:55:27 +03:00
|
|
|
X[
|
|
|
|
:train_id_end,
|
|
|
|
],
|
|
|
|
y[
|
|
|
|
:train_id_end,
|
|
|
|
],
|
2021-04-23 03:46:17 +03:00
|
|
|
)
|
|
|
|
X_test, y_test = (
|
2021-07-13 02:55:27 +03:00
|
|
|
X[
|
|
|
|
train_id_end:,
|
|
|
|
],
|
|
|
|
y[
|
|
|
|
train_id_end:,
|
|
|
|
],
|
2021-04-23 03:46:17 +03:00
|
|
|
)
|
2021-04-23 03:08:11 +03:00
|
|
|
|
|
|
|
# save training and test sets
|
|
|
|
save_data_path = os.path.join(os.getcwd(), "data")
|
|
|
|
if not os.path.exists(save_data_path):
|
|
|
|
pathlib.Path(save_data_path).mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"Saving data to {os.path.abspath(save_data_path)}")
|
|
|
|
np.save(os.path.join(save_data_path, "x_train.npy"), X_train)
|
|
|
|
np.save(os.path.join(save_data_path, "y_train.npy"), y_train)
|
|
|
|
np.save(os.path.join(save_data_path, "x_test.npy"), X_test)
|
|
|
|
np.save(os.path.join(save_data_path, "y_test.npy"), y_test)
|
|
|
|
|
2021-01-15 11:16:44 +03:00
|
|
|
logger.info("Building model...")
|
2021-03-26 08:10:13 +03:00
|
|
|
model.build_model(**cfg["model"]["build_params"])
|
2021-04-17 02:59:55 +03:00
|
|
|
|
|
|
|
if run_sweep:
|
|
|
|
params = OmegaConf.to_container(cfg["model"]["sweep"]["params"])
|
|
|
|
logger.info(f"Sweeping with parameters: {params}")
|
2021-04-17 22:11:36 +03:00
|
|
|
|
2021-04-23 03:08:11 +03:00
|
|
|
sweep_df = model.sweep(
|
2021-04-17 02:59:55 +03:00
|
|
|
params=params,
|
2021-04-23 03:08:11 +03:00
|
|
|
X=X_train,
|
|
|
|
y=y_train,
|
2021-04-17 02:59:55 +03:00
|
|
|
search_algorithm=cfg["model"]["sweep"]["search_algorithm"],
|
|
|
|
num_trials=cfg["model"]["sweep"]["num_trials"],
|
|
|
|
scoring_func=cfg["model"]["sweep"]["scoring_func"],
|
2021-04-23 03:08:11 +03:00
|
|
|
results_csv_path=results_csv_path,
|
|
|
|
splitting_criteria=split_strategy,
|
2021-04-17 02:59:55 +03:00
|
|
|
)
|
2021-04-23 03:08:11 +03:00
|
|
|
logger.info(f"Sweep results: {sweep_df}")
|
2021-04-17 02:59:55 +03:00
|
|
|
else:
|
|
|
|
logger.info("Fitting model...")
|
2021-04-23 03:08:11 +03:00
|
|
|
model.fit(X_train, y_train)
|
2021-01-15 11:16:44 +03:00
|
|
|
|
|
|
|
logger.info(f"Saving model to {save_path}")
|
|
|
|
model.save_model(filename=save_path)
|
2021-01-12 03:19:44 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
2021-01-15 11:16:44 +03:00
|
|
|
main()
|