UPDATE: bug fixes when missing config variables
This commit is contained in:
Родитель
a3fb08063a
Коммит
9d46161353
|
@ -1,14 +1,14 @@
|
|||
defaults:
|
||||
# - data: house_energy.yaml
|
||||
- data: house_energy.yaml
|
||||
# - data: cartpole_mixed.yaml
|
||||
- data: hvac_b1.yaml
|
||||
# - data: hvac_b1.yaml
|
||||
# - data: quanser_rand.yaml
|
||||
- model: xgboost.yaml
|
||||
# - model: lightgbm.yaml
|
||||
# - model: SVR.yaml
|
||||
# - model: torch.yaml
|
||||
# - simulator: house_energy_simparam.yaml
|
||||
- simulator: house_energy_simparam.yaml
|
||||
# - simulator: cartpole_mixed_simparam.yaml
|
||||
# - simulator: quanser-log.yaml
|
||||
# - simulator: quanser_simparam.yaml
|
||||
- simulator: hvac_b1_simparam.yaml
|
||||
# - simulator: hvac_b1_simparam.yaml
|
||||
|
|
|
@ -60,6 +60,13 @@ def type_conversion(obj, type, minimum, maximum):
|
|||
return obj
|
||||
|
||||
|
||||
# helper function that return None if element is not present in config
|
||||
def hydra_read_config_var(cfg: DictConfig, level: str, key_name: str):
|
||||
"""Reads the config file and returns the config as a dictionary"""
|
||||
|
||||
return cfg[level][key_name] if key_name in cfg[level] else None
|
||||
|
||||
|
||||
class Simulator(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -194,7 +201,8 @@ class Simulator(BaseModel):
|
|||
|
||||
# if you are using both initial values and exogeneous variables, then
|
||||
# make sure to sample a single episode from each and play it through
|
||||
if self.initial_values_df is not None:
|
||||
if hasattr(self, "initial_values_df"):
|
||||
# if self.initial_values_df is not None:
|
||||
initial_values_episode = (
|
||||
self.initial_values_df["episode"].sample(1).values[0]
|
||||
)
|
||||
|
@ -212,8 +220,8 @@ class Simulator(BaseModel):
|
|||
|
||||
# if using exogeneous variables
|
||||
# sample from exog df and play it through the episode
|
||||
if self.exog_df is not None:
|
||||
if self.initial_values_df is not None:
|
||||
if hasattr(self, "exog_df"):
|
||||
if hasattr(self, "initial_values_df"):
|
||||
logger.info(f"Using sampled episode from initial values dataset")
|
||||
exog_episode = initial_values_episode
|
||||
else:
|
||||
|
@ -410,7 +418,7 @@ class Simulator(BaseModel):
|
|||
f"Iteration used as a feature. Iteration #: {self.iteration_counter}"
|
||||
)
|
||||
|
||||
if self.exogeneous_variables:
|
||||
if hasattr(self, "exogeneous_variables"):
|
||||
logger.info(
|
||||
f"Updating {self.exogeneous_variables} using next iteration from episode #: {self.exog_ep['episode'].values[0]}"
|
||||
)
|
||||
|
@ -632,13 +640,17 @@ def main(cfg: DictConfig):
|
|||
else:
|
||||
scale_data = cfg["model"]["build_params"]["scale_data"]
|
||||
# scale_data = cfg["data"]["scale_data"]
|
||||
diff_state = cfg["data"]["diff_state"]
|
||||
concatenated_steps = cfg["data"]["concatenated_steps"]
|
||||
concatenated_zero_padding = cfg["data"]["concatenated_zero_padding"]
|
||||
concatenate_var_length = cfg["data"]["concatenate_length"]
|
||||
exogeneous_variables = cfg["data"]["exogeneous_variables"]
|
||||
exogeneous_path = cfg["data"]["exogeneous_save_path"]
|
||||
initial_values_save_path = cfg["data"]["initial_values_save_path"]
|
||||
diff_state = hydra_read_config_var(cfg, "data", "diff_state")
|
||||
concatenated_steps = hydra_read_config_var(cfg, "data", "concatenated_steps")
|
||||
concatenated_zero_padding = hydra_read_config_var(
|
||||
cfg, "data", "concatenated_zero_padding"
|
||||
)
|
||||
concatenate_var_length = hydra_read_config_var(cfg, "data", "concatenate_length")
|
||||
exogeneous_variables = hydra_read_config_var(cfg, "data", "exogeneous_variables")
|
||||
exogeneous_save_path = hydra_read_config_var(cfg, "data", "exogeneous_save_path")
|
||||
initial_values_save_path = hydra_read_config_var(
|
||||
cfg, "data", "initial_values_save_path"
|
||||
)
|
||||
|
||||
workspace_setup = cfg["simulator"]["workspace_setup"]
|
||||
episode_inits = cfg["simulator"]["episode_inits"]
|
||||
|
@ -646,8 +658,8 @@ def main(cfg: DictConfig):
|
|||
input_cols = cfg["data"]["inputs"]
|
||||
output_cols = cfg["data"]["outputs"]
|
||||
augmented_cols = cfg["data"]["augmented_cols"]
|
||||
prep_pipeline = cfg["data"]["preprocess"]
|
||||
iteration_col = cfg["data"]["iteration_col"]
|
||||
preprocess = hydra_read_config_var(cfg, "data", "preprocess")
|
||||
iteration_col = hydra_read_config_var(cfg, "data", "iteration_col")
|
||||
iteration_col = iteration_col if iteration_col in input_cols else None
|
||||
if type(input_cols) == ListConfig:
|
||||
input_cols = list(input_cols)
|
||||
|
@ -703,10 +715,10 @@ def main(cfg: DictConfig):
|
|||
concatenated_steps,
|
||||
concatenated_zero_padding,
|
||||
concatenate_var_length,
|
||||
prep_pipeline=prep_pipeline,
|
||||
prep_pipeline=preprocess,
|
||||
iteration_col=iteration_col,
|
||||
exogeneous_variables=exogeneous_variables,
|
||||
exogeneous_save_path=exogeneous_path,
|
||||
exogeneous_save_path=exogeneous_save_path,
|
||||
initial_values_save_path=initial_values_save_path,
|
||||
)
|
||||
|
||||
|
|
|
@ -15,31 +15,43 @@ logger = logging.getLogger("datamodeler")
|
|||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
# helper function that return None if element is not present in config
|
||||
def hydra_read_config_var(cfg: DictConfig, level: str, key_name: str):
|
||||
"""Reads the config file and returns the config as a dictionary"""
|
||||
|
||||
return cfg[level][key_name] if key_name in cfg[level] else None
|
||||
|
||||
|
||||
@hydra.main(config_path="conf", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
logger.info("Configuration: ")
|
||||
logger.info(f"\n{OmegaConf.to_yaml(cfg)}")
|
||||
|
||||
# for readability, read common data args into variables
|
||||
input_cols = cfg["data"]["inputs"]
|
||||
output_cols = cfg["data"]["outputs"]
|
||||
augmented_cols = cfg["data"]["augmented_cols"]
|
||||
input_cols = hydra_read_config_var(cfg, "data", "inputs")
|
||||
output_cols = hydra_read_config_var(cfg, "data", "outputs")
|
||||
augmented_cols = hydra_read_config_var(cfg, "data", "augmented_cols")
|
||||
|
||||
iteration_order = cfg["data"]["iteration_order"]
|
||||
episode_col = cfg["data"]["episode_col"]
|
||||
iteration_col = cfg["data"]["iteration_col"]
|
||||
dataset_path = cfg["data"]["path"]
|
||||
max_rows = cfg["data"]["max_rows"]
|
||||
test_perc = cfg["data"]["test_perc"]
|
||||
delta_state = cfg["data"]["diff_state"]
|
||||
concatenated_steps = cfg["data"]["concatenated_steps"]
|
||||
concatenated_zero_padding = cfg["data"]["concatenated_zero_padding"]
|
||||
concatenate_var_length = cfg["data"]["concatenate_length"]
|
||||
pipeline = cfg["data"]["preprocess"]
|
||||
var_rename = cfg["data"]["var_rename"]
|
||||
exogeneous_variables = cfg["data"]["exogeneous_variables"]
|
||||
exogeneous_path = cfg["data"]["exogeneous_save_path"]
|
||||
initial_values_save_path = cfg["data"]["initial_values_save_path"]
|
||||
iteration_order = hydra_read_config_var(cfg, "data", "iteration_order")
|
||||
episode_col = hydra_read_config_var(cfg, "data", "episode_col")
|
||||
iteration_col = hydra_read_config_var(cfg, "data", "iteration_col")
|
||||
dataset_path = hydra_read_config_var(cfg, "data", "path")
|
||||
max_rows = hydra_read_config_var(cfg, "data", "max_rows")
|
||||
test_perc = hydra_read_config_var(cfg, "data", "test_perc")
|
||||
|
||||
diff_state = hydra_read_config_var(cfg, "data", "diff_state")
|
||||
concatenated_steps = hydra_read_config_var(cfg, "data", "concatenated_steps")
|
||||
concatenated_zero_padding = hydra_read_config_var(
|
||||
cfg, "data", "concatenated_zero_padding"
|
||||
)
|
||||
concatenate_var_length = hydra_read_config_var(cfg, "data", "concatenate_length")
|
||||
preprocess = hydra_read_config_var(cfg, "data", "preprocess")
|
||||
var_rename = hydra_read_config_var(cfg, "data", "var_rename")
|
||||
exogeneous_variables = hydra_read_config_var(cfg, "data", "exogeneous_variables")
|
||||
exogeneous_save_path = hydra_read_config_var(cfg, "data", "exogeneous_save_path")
|
||||
initial_values_save_path = hydra_read_config_var(
|
||||
cfg, "data", "initial_values_save_path"
|
||||
)
|
||||
|
||||
# common model args
|
||||
save_path = cfg["model"]["saver"]["filename"]
|
||||
|
@ -99,7 +111,7 @@ def main(cfg: DictConfig) -> None:
|
|||
return_ts=False,
|
||||
var_rename=var_rename,
|
||||
exogeneous_variables=exogeneous_variables,
|
||||
exogeneous_path=exogeneous_path,
|
||||
exogeneous_path=exogeneous_save_path,
|
||||
)
|
||||
else:
|
||||
X_train, y_train, X_test, y_test = model.load_csv(
|
||||
|
@ -113,14 +125,14 @@ def main(cfg: DictConfig) -> None:
|
|||
# drop_nulls: bool = True,
|
||||
max_rows=max_rows,
|
||||
test_perc=test_perc,
|
||||
diff_state=delta_state,
|
||||
prep_pipeline=pipeline,
|
||||
diff_state=diff_state,
|
||||
prep_pipeline=preprocess,
|
||||
var_rename=var_rename,
|
||||
concatenated_steps=concatenated_steps,
|
||||
concatenated_zero_padding=concatenated_zero_padding,
|
||||
concatenate_var_length=concatenate_var_length,
|
||||
exogeneous_variables=exogeneous_variables,
|
||||
exogeneous_save_path=exogeneous_path,
|
||||
exogeneous_save_path=exogeneous_save_path,
|
||||
initial_values_save_path=initial_values_save_path,
|
||||
)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче