UPDATE: if no augm columns then gracefully skip

This commit is contained in:
Ali Zaidi 2021-01-15 15:02:14 -08:00
Родитель aa81e93157
Коммит e133e814f1
2 изменённых файлов: 7 добавлений и 33 удалений

Просмотреть файл

@ -85,7 +85,9 @@ class BaseModel(abc.ABC):
raise TypeError(
f"input_cols expected type List[str] or str but received type {type(input_cols)}"
)
if type(augm_cols) == str:
if not augm_cols:
logging.debug(f"No augmented columns...")
elif type(augm_cols) == str:
augm_features = [str(col) for col in df if col.startswith(augm_cols)]
elif type(augm_cols) == list:
augm_features = augm_cols
@ -94,7 +96,10 @@ class BaseModel(abc.ABC):
f"augm_cols expected type List[str] or str but received type {type(augm_cols)}"
)
features = base_features + augm_features
if augm_cols:
features = base_features + augm_features
else:
features = base_features
self.features = features
logging.info(f"Using {features} as the features for modeling DDM")

Просмотреть файл

@ -155,7 +155,6 @@ def main(cfg: DictConfig):
model.load_model(filename=save_path, scale_data=scale_data)
# Grab standardized way to interact with sim API
# sc1_path = os.path.join(os.getcwd(), "models/sc1-small.pkl")
sim = Simulator(model, states, actions, configs)
# do a random action to get initial state
@ -301,33 +300,3 @@ def main(cfg: DictConfig):
if __name__ == "__main__":
main()
# import argparse
# parser = argparse.ArgumentParser(description="Bonsai and Simulator Integration...")
# parser.add_argument(
# "--log-iterations",
# type=lambda x: bool(strtobool(x)),
# default=False,
# help="Log iterations during training",
# )
# parser.add_argument(
# "--config-setup",
# type=lambda x: bool(strtobool(x)),
# default=False,
# help="Use a local environment file to setup access keys and workspace ids",
# )
# parser.add_argument(
# "--test-local",
# type=lambda x: bool(strtobool(x)),
# default=True,
# help="Run simulator locally without connecting to platform",
# )
# args = parser.parse_args()
# if args.test_local:
# test_random_policy(num_episodes=100, num_iterations=1)
# else:
# main(config_setup=args.config_setup,)