UPDATE: if no augm columns then gracefully skip
This commit is contained in:
Родитель
aa81e93157
Коммит
e133e814f1
9
base.py
9
base.py
|
@ -85,7 +85,9 @@ class BaseModel(abc.ABC):
|
|||
raise TypeError(
|
||||
f"input_cols expected type List[str] or str but received type {type(input_cols)}"
|
||||
)
|
||||
if type(augm_cols) == str:
|
||||
if not augm_cols:
|
||||
logging.debug(f"No augmented columns...")
|
||||
elif type(augm_cols) == str:
|
||||
augm_features = [str(col) for col in df if col.startswith(augm_cols)]
|
||||
elif type(augm_cols) == list:
|
||||
augm_features = augm_cols
|
||||
|
@ -94,7 +96,10 @@ class BaseModel(abc.ABC):
|
|||
f"augm_cols expected type List[str] or str but received type {type(augm_cols)}"
|
||||
)
|
||||
|
||||
features = base_features + augm_features
|
||||
if augm_cols:
|
||||
features = base_features + augm_features
|
||||
else:
|
||||
features = base_features
|
||||
self.features = features
|
||||
logging.info(f"Using {features} as the features for modeling DDM")
|
||||
|
||||
|
|
|
@ -155,7 +155,6 @@ def main(cfg: DictConfig):
|
|||
model.load_model(filename=save_path, scale_data=scale_data)
|
||||
|
||||
# Grab standardized way to interact with sim API
|
||||
# sc1_path = os.path.join(os.getcwd(), "models/sc1-small.pkl")
|
||||
sim = Simulator(model, states, actions, configs)
|
||||
|
||||
# do a random action to get initial state
|
||||
|
@ -301,33 +300,3 @@ def main(cfg: DictConfig):
|
|||
if __name__ == "__main__":
|
||||
|
||||
main()
|
||||
|
||||
# import argparse
|
||||
|
||||
# parser = argparse.ArgumentParser(description="Bonsai and Simulator Integration...")
|
||||
# parser.add_argument(
|
||||
# "--log-iterations",
|
||||
# type=lambda x: bool(strtobool(x)),
|
||||
# default=False,
|
||||
# help="Log iterations during training",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--config-setup",
|
||||
# type=lambda x: bool(strtobool(x)),
|
||||
# default=False,
|
||||
# help="Use a local environment file to setup access keys and workspace ids",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--test-local",
|
||||
# type=lambda x: bool(strtobool(x)),
|
||||
# default=True,
|
||||
# help="Run simulator locally without connecting to platform",
|
||||
# )
|
||||
|
||||
# args = parser.parse_args()
|
||||
|
||||
# if args.test_local:
|
||||
# test_random_policy(num_episodes=100, num_iterations=1)
|
||||
# else:
|
||||
# main(config_setup=args.config_setup,)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче