2020-12-16 00:45:32 +03:00
|
|
|
import logging
|
|
|
|
import os
|
2020-12-18 05:37:28 +03:00
|
|
|
import random
|
2020-12-16 00:45:32 +03:00
|
|
|
import time
|
2022-12-07 01:54:57 +03:00
|
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
2021-06-22 05:13:53 +03:00
|
|
|
from omegaconf import ListConfig
|
2021-08-13 23:46:12 +03:00
|
|
|
from functools import partial
|
2023-01-27 02:21:01 +03:00
|
|
|
|
|
|
|
import pandas as pd
|
2021-08-13 23:46:12 +03:00
|
|
|
from policies import random_policy, brain_policy
|
2022-08-02 08:02:15 +03:00
|
|
|
from signal_builder import SignalBuilder
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
import numpy as np
|
2021-04-06 23:59:50 +03:00
|
|
|
|
|
|
|
# see reason below for why commented out (UPDATE #comment-out-azure-cli)
|
|
|
|
# from azure.core.exceptions import HttpResponseError
|
2020-12-16 00:45:32 +03:00
|
|
|
from dotenv import load_dotenv, set_key
|
|
|
|
from microsoft_bonsai_api.simulator.client import BonsaiClient, BonsaiClientConfig
|
|
|
|
from microsoft_bonsai_api.simulator.generated.models import (
|
|
|
|
SimulatorInterface,
|
|
|
|
SimulatorSessionResponse,
|
|
|
|
SimulatorState,
|
|
|
|
)
|
2020-12-18 05:37:28 +03:00
|
|
|
|
2020-12-16 00:45:32 +03:00
|
|
|
from base import BaseModel
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
logging.basicConfig()
|
|
|
|
logging.root.setLevel(logging.INFO)
|
2021-05-11 22:07:44 +03:00
|
|
|
for name in logging.Logger.manager.loggerDict.keys():
|
|
|
|
if "azure" in name:
|
|
|
|
logging.getLogger(name).setLevel(logging.WARNING)
|
|
|
|
logging.propagate = True
|
2021-01-16 00:42:14 +03:00
|
|
|
logger = logging.getLogger("datamodeler")
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
import hydra
|
2021-06-09 22:38:33 +03:00
|
|
|
from omegaconf import DictConfig
|
2021-01-04 20:16:23 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
env_name = "DDM"
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
|
2022-12-07 01:54:57 +03:00
|
|
|
def type_conversion(obj, type, minimum, maximum):
|
|
|
|
if type == "str":
|
|
|
|
return str(obj)
|
|
|
|
elif type == "int":
|
|
|
|
if obj <= minimum:
|
|
|
|
return int(minimum)
|
|
|
|
elif obj >= maximum:
|
|
|
|
return int(maximum)
|
|
|
|
else:
|
|
|
|
return int(obj)
|
|
|
|
elif type == "float":
|
|
|
|
if obj <= minimum:
|
|
|
|
return float(minimum)
|
|
|
|
elif obj >= maximum:
|
|
|
|
return float(maximum)
|
|
|
|
else:
|
|
|
|
return float(obj)
|
2023-01-27 02:21:01 +03:00
|
|
|
elif type == "bool":
|
|
|
|
return obj
|
2022-12-07 01:54:57 +03:00
|
|
|
|
2023-02-25 05:03:50 +03:00
|
|
|
|
2023-02-25 07:00:09 +03:00
|
|
|
# helper function that return None if element is not present in config
|
|
|
|
def hydra_read_config_var(cfg: DictConfig, level: str, key_name: str):
|
|
|
|
"""Reads the config file and returns the config as a dictionary"""
|
|
|
|
|
|
|
|
return cfg[level][key_name] if key_name in cfg[level] else None
|
|
|
|
|
|
|
|
|
2020-12-16 00:45:32 +03:00
|
|
|
class Simulator(BaseModel):
|
2021-03-31 22:44:56 +03:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model,
|
2021-03-31 23:40:29 +03:00
|
|
|
states: List[str],
|
|
|
|
actions: List[str],
|
|
|
|
configs: List[str],
|
2021-06-22 05:13:53 +03:00
|
|
|
inputs: List[str],
|
2022-12-07 01:54:57 +03:00
|
|
|
outputs: Union[List[str], Dict[str, str]],
|
2021-06-10 21:09:30 +03:00
|
|
|
episode_inits: Dict[str, float],
|
2021-07-14 03:07:21 +03:00
|
|
|
initial_states: Dict[str, float],
|
2022-08-02 08:02:15 +03:00
|
|
|
signal_builder: Dict[str, float],
|
2021-04-01 22:06:06 +03:00
|
|
|
diff_state: bool = False,
|
2022-06-02 16:55:12 +03:00
|
|
|
lagged_inputs: int = 1,
|
|
|
|
lagged_padding: bool = False,
|
2022-08-05 08:14:44 +03:00
|
|
|
concatenate_var_length: Optional[Dict[str, int]] = None,
|
2022-12-05 21:19:47 +03:00
|
|
|
prep_pipeline: Optional[Callable] = None,
|
2022-12-14 03:34:49 +03:00
|
|
|
iteration_col: Optional[str] = None,
|
2023-01-11 18:56:22 +03:00
|
|
|
exogeneous_variables: Optional[List[str]] = None,
|
2023-01-27 02:21:01 +03:00
|
|
|
exogeneous_save_path: Optional[str] = None,
|
|
|
|
initial_values_save_path: Optional[str] = None,
|
2021-04-01 22:06:06 +03:00
|
|
|
):
|
2021-01-16 00:42:14 +03:00
|
|
|
self.model = model
|
2021-06-22 05:13:53 +03:00
|
|
|
# self.features = states + configs + actions
|
|
|
|
# self.labels = states
|
|
|
|
self.features = inputs
|
2022-12-07 01:54:57 +03:00
|
|
|
if type(outputs) == ListConfig:
|
|
|
|
outputs = list(outputs)
|
|
|
|
self.label_types = None
|
|
|
|
elif type(outputs) == DictConfig:
|
|
|
|
output_types = outputs
|
|
|
|
outputs = list(outputs.keys())
|
|
|
|
self.label_types = output_types
|
2023-01-27 02:21:01 +03:00
|
|
|
|
2023-01-11 18:56:22 +03:00
|
|
|
# if you're using exogeneous variables these will be looked up
|
|
|
|
# from a saved dataset and appended during episode_step
|
2023-01-27 02:21:01 +03:00
|
|
|
if exogeneous_variables and exogeneous_save_path:
|
|
|
|
if os.path.dirname(exogeneous_save_path) == "":
|
|
|
|
exogeneous_save_path = os.path.join(dir_path, exogeneous_save_path)
|
|
|
|
if not os.path.exists(exogeneous_save_path):
|
|
|
|
raise ValueError(
|
|
|
|
f"Exogeneous variables not found at {exogeneous_save_path}"
|
|
|
|
)
|
|
|
|
logger.info(f"Reading exogeneous variables from {exogeneous_save_path}")
|
|
|
|
exogeneous_vars_df = pd.read_csv(exogeneous_save_path)
|
|
|
|
self.exogeneous_variables = exogeneous_variables
|
|
|
|
self.exog_df = exogeneous_vars_df
|
|
|
|
|
|
|
|
if initial_values_save_path:
|
|
|
|
if os.path.dirname(initial_values_save_path) == "":
|
|
|
|
initial_values_save_path = os.path.join(
|
|
|
|
dir_path, initial_values_save_path
|
|
|
|
)
|
|
|
|
if not os.path.exists(initial_values_save_path):
|
|
|
|
raise ValueError(
|
|
|
|
f"Initial values not found at {initial_values_save_path}"
|
2023-01-11 18:56:22 +03:00
|
|
|
)
|
2023-01-27 02:21:01 +03:00
|
|
|
logger.info(f"Reading initial values from {initial_values_save_path}")
|
|
|
|
initial_values_df = pd.read_csv(initial_values_save_path)
|
|
|
|
self.initial_values_df = initial_values_df
|
2022-12-07 01:54:57 +03:00
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
self.labels = outputs
|
2021-01-16 00:42:14 +03:00
|
|
|
self.config_keys = configs
|
2021-06-10 21:09:30 +03:00
|
|
|
self.episode_inits = episode_inits
|
2021-01-16 00:42:14 +03:00
|
|
|
self.state_keys = states
|
|
|
|
self.action_keys = actions
|
2022-08-02 08:02:15 +03:00
|
|
|
self.signal_builder = signal_builder
|
2021-03-31 22:44:56 +03:00
|
|
|
self.diff_state = diff_state
|
2022-06-02 16:55:12 +03:00
|
|
|
self.lagged_inputs = lagged_inputs
|
|
|
|
self.lagged_padding = lagged_padding
|
2022-08-05 08:14:44 +03:00
|
|
|
self.concatenate_var_length = concatenate_var_length
|
2022-12-05 21:19:47 +03:00
|
|
|
self.prep_pipeline = prep_pipeline
|
2022-12-14 03:34:49 +03:00
|
|
|
self.iteration_col = iteration_col
|
2022-06-02 16:55:12 +03:00
|
|
|
|
2022-08-05 08:14:44 +03:00
|
|
|
if self.concatenate_var_length:
|
|
|
|
logger.info(f"Using variable length lags: {self.concatenate_var_length}")
|
|
|
|
self.lagged_feature_cols = [
|
|
|
|
feat + f"_{i}"
|
|
|
|
for feat in list(self.concatenate_var_length.keys())
|
|
|
|
for i in range(1, self.concatenate_var_length[feat] + 1)
|
|
|
|
]
|
|
|
|
self.non_lagged_feature_cols = list(
|
|
|
|
set(self.features) - set(list(self.concatenate_var_length.keys()))
|
|
|
|
)
|
2022-08-05 17:46:16 +03:00
|
|
|
# need to verify order here
|
|
|
|
# this matches dataclass when concatenating inputs
|
2022-08-05 08:14:44 +03:00
|
|
|
self.features = self.non_lagged_feature_cols + self.lagged_feature_cols
|
|
|
|
elif self.lagged_inputs > 1:
|
2022-06-02 16:55:12 +03:00
|
|
|
logger.info(f"Using {self.lagged_inputs} lagged inputs as features")
|
|
|
|
self.lagged_feature_cols = [
|
|
|
|
feat + f"_{i}"
|
|
|
|
for i in range(1, self.lagged_inputs + 1)
|
|
|
|
for feat in self.features
|
|
|
|
]
|
|
|
|
self.features = self.lagged_feature_cols
|
|
|
|
else:
|
|
|
|
self.lagged_feature_cols = []
|
2021-08-13 02:31:39 +03:00
|
|
|
|
|
|
|
# create a dictionary containing initial_states
|
|
|
|
# with some initial values
|
|
|
|
# these should be coming from the simulator.yaml
|
|
|
|
# the initial values aren't important
|
|
|
|
# these will be updated in self.episode_start
|
|
|
|
|
|
|
|
# create a mapper that maps config values to
|
|
|
|
# initial state values
|
|
|
|
# these will be used when mapping scenario keys
|
|
|
|
# to self.initial_states values during episode_start
|
|
|
|
initial_states_mapper = {}
|
|
|
|
if type(list(initial_states.values())[0]) == DictConfig:
|
|
|
|
self.initial_states = {k: v["min"] for k, v in initial_states.items()}
|
|
|
|
for k, v in initial_states.items():
|
|
|
|
initial_states_mapper[v["inkling_name"]] = k
|
|
|
|
else:
|
|
|
|
self.initial_states = initial_states
|
|
|
|
self.initial_states_mapper = initial_states_mapper
|
|
|
|
|
2021-06-11 01:43:33 +03:00
|
|
|
logger.info(f"DDM features: {self.features}")
|
|
|
|
logger.info(f"DDM outputs: {self.labels}")
|
|
|
|
|
2022-12-05 21:19:47 +03:00
|
|
|
def episode_start(self, config: Optional[Dict[str, Any]] = None):
|
2021-07-14 03:17:49 +03:00
|
|
|
"""Initialize DDM. This could include initializations of configs
|
|
|
|
as well as initial values for states.
|
2021-07-14 03:07:21 +03:00
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
config : Dict[str, Any], optional
|
|
|
|
episode initializations, by default None
|
|
|
|
"""
|
|
|
|
|
2022-06-02 16:55:12 +03:00
|
|
|
self.iteration_counter = 0
|
|
|
|
|
2023-01-27 02:21:01 +03:00
|
|
|
# if you are using both initial values and exogeneous variables, then
|
|
|
|
# make sure to sample a single episode from each and play it through
|
2023-02-25 07:00:09 +03:00
|
|
|
if hasattr(self, "initial_values_df"):
|
|
|
|
# if self.initial_values_df is not None:
|
2023-01-27 02:21:01 +03:00
|
|
|
initial_values_episode = (
|
|
|
|
self.initial_values_df["episode"].sample(1).values[0]
|
|
|
|
)
|
|
|
|
initial_values_data = self.initial_values_df[
|
|
|
|
self.initial_values_df["episode"] == initial_values_episode
|
|
|
|
]
|
|
|
|
for i in list(self.initial_states.keys()):
|
|
|
|
# terminals are not assumed to be in the lookup dataset
|
|
|
|
# however, we will need to terminate the episdoe when we reach
|
|
|
|
# the end of the dataset so we need a terminal variable in the MDP
|
|
|
|
if i == "terminal":
|
|
|
|
self.initial_states[i] = False
|
|
|
|
else:
|
|
|
|
self.initial_states[i] = initial_values_data[i].values[0]
|
|
|
|
|
|
|
|
# if using exogeneous variables
|
|
|
|
# sample from exog df and play it through the episode
|
2023-02-25 07:00:09 +03:00
|
|
|
if hasattr(self, "exog_df"):
|
|
|
|
if hasattr(self, "initial_values_df"):
|
2023-01-27 02:21:01 +03:00
|
|
|
logger.info(f"Using sampled episode from initial values dataset")
|
|
|
|
exog_episode = initial_values_episode
|
|
|
|
else:
|
|
|
|
exog_episode = self.exog_df["episode"].sample(1).values[0]
|
|
|
|
exog_data = self.exog_df[self.exog_df["episode"] == exog_episode]
|
|
|
|
self.exog_ep = exog_data
|
|
|
|
for i in self.exogeneous_variables:
|
|
|
|
self.initial_states[i] = self.exog_ep[i].values.tolist()[0]
|
|
|
|
|
2021-07-14 03:07:21 +03:00
|
|
|
# initialize states based on simulator.yaml
|
2021-08-13 02:31:39 +03:00
|
|
|
# we have defined the initial dict in our
|
|
|
|
# constructor
|
2021-07-14 03:07:21 +03:00
|
|
|
initial_state = self.initial_states
|
2021-08-13 02:31:39 +03:00
|
|
|
|
|
|
|
# if initial state from config if available (e.g. when brain training)
|
2021-07-14 03:07:21 +03:00
|
|
|
# skip if config missing
|
2021-08-13 02:31:39 +03:00
|
|
|
# check if any keys from config exit in mapper
|
|
|
|
# if so update self.initial_states with config
|
|
|
|
# create new config to update self.all_data
|
2021-07-14 03:07:21 +03:00
|
|
|
if config:
|
2021-08-13 02:31:39 +03:00
|
|
|
new_config = {}
|
|
|
|
for k, v in config.items():
|
|
|
|
if k in self.initial_states_mapper.keys():
|
|
|
|
initial_state[self.initial_states_mapper[k]] = v
|
|
|
|
else:
|
|
|
|
new_config[k] = v
|
|
|
|
logger.info(f"Initial states: {initial_state}")
|
|
|
|
else:
|
|
|
|
new_config = None
|
|
|
|
|
|
|
|
# if config:
|
|
|
|
# initial_state.update(
|
|
|
|
# (k, config[k]) for k in initial_state.keys() & config.keys()
|
|
|
|
# )
|
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
initial_action = {k: random.random() for k in self.action_keys}
|
2021-08-13 02:31:39 +03:00
|
|
|
if new_config:
|
|
|
|
logger.info(f"Initializing episode with provided config: {new_config}")
|
|
|
|
self.config = new_config
|
|
|
|
elif not new_config and self.episode_inits:
|
2021-06-10 21:09:30 +03:00
|
|
|
logger.info(
|
|
|
|
f"No episode initializations provided, using initializations in yaml `episode_inits`"
|
|
|
|
)
|
2021-06-11 01:43:33 +03:00
|
|
|
logger.info(f"Episode config: {self.episode_inits}")
|
2021-06-10 21:09:30 +03:00
|
|
|
self.config = self.episode_inits
|
2021-01-16 00:42:14 +03:00
|
|
|
else:
|
2021-06-10 21:09:30 +03:00
|
|
|
logger.warn(
|
|
|
|
"No config provided, so using random Gaussians. This probably not what you want!"
|
|
|
|
)
|
2021-07-14 23:12:55 +03:00
|
|
|
# TODO: during ddm_trainer save the ranges of configs (and maybe states too for initial conditions)
|
|
|
|
# to a file so we can sample from that range instead of random Gaussians
|
2021-06-10 21:09:30 +03:00
|
|
|
# request_continue = input("Are you sure you want to continue with random configs?")
|
2022-11-15 08:22:08 +03:00
|
|
|
if self.config_keys:
|
|
|
|
self.config = {k: random.random() for k in self.config_keys}
|
|
|
|
else:
|
|
|
|
self.config = None
|
2021-08-13 02:31:39 +03:00
|
|
|
|
|
|
|
# update state with initial_state values if
|
|
|
|
# provided by config
|
|
|
|
# otherwise default is used
|
2021-01-16 00:42:14 +03:00
|
|
|
self.state = initial_state
|
2021-06-22 05:13:53 +03:00
|
|
|
self.action = initial_action
|
2022-08-02 08:02:15 +03:00
|
|
|
|
|
|
|
# Grab signal params pertaining to specific format of key_parameter from Inkling
|
|
|
|
self.config_signals = {}
|
2022-11-15 08:22:08 +03:00
|
|
|
if new_config and self.signal_builder:
|
2022-08-02 08:02:15 +03:00
|
|
|
for k, v in self.signal_builder["signal_params"].items():
|
|
|
|
for key, value in new_config.items():
|
|
|
|
if k in key:
|
|
|
|
self.config_signals.update({key: value})
|
|
|
|
|
|
|
|
if self.config_signals:
|
|
|
|
# If signal params from Inkling, use those for building signals
|
|
|
|
self.signals = {}
|
|
|
|
for key, val in self.signal_builder["signal_types"].items():
|
|
|
|
self.signals.update(
|
|
|
|
{
|
|
|
|
key: SignalBuilder(
|
|
|
|
val,
|
|
|
|
new_config["horizon"],
|
|
|
|
{
|
|
|
|
k.split("_")[1]: v
|
|
|
|
for k, v in self.config_signals.items()
|
|
|
|
if key in k
|
|
|
|
},
|
|
|
|
)
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
self.current_signals = {}
|
|
|
|
for key, val in self.signals.items():
|
|
|
|
self.current_signals.update(
|
|
|
|
{key: float(self.signals[key].get_current_signal())}
|
|
|
|
)
|
|
|
|
elif self.signal_builder:
|
|
|
|
# Otherwise use signal builder from simulator/conf
|
|
|
|
self.signals = {}
|
|
|
|
for key, val in self.signal_builder["signal_types"].items():
|
|
|
|
self.signals.update(
|
|
|
|
{
|
|
|
|
key: SignalBuilder(
|
|
|
|
val,
|
|
|
|
self.signal_builder["horizon"],
|
|
|
|
self.signal_builder["signal_params"][key],
|
|
|
|
)
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
self.current_signals = {}
|
|
|
|
for key, val in self.signals.items():
|
|
|
|
self.current_signals.update(
|
|
|
|
{key: float(self.signals[key].get_current_signal())}
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
print("No signal builder used")
|
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
# capture all data
|
2021-07-13 19:31:50 +03:00
|
|
|
# TODO: check if we can pick a subset of data yaml, i.e., what happens if
|
|
|
|
# {simulator.state, simulator.action, simulator.config} is a strict subset {data.inputs + data.augmented_cols, self.outputs}
|
2022-11-15 08:22:08 +03:00
|
|
|
if self.config:
|
|
|
|
self.all_data = {**self.state, **self.action, **self.config}
|
|
|
|
else:
|
|
|
|
self.all_data = {**self.state, **self.action}
|
2022-12-05 21:19:47 +03:00
|
|
|
if self.prep_pipeline:
|
|
|
|
from preprocess import pipeline
|
|
|
|
|
|
|
|
self.all_data = pipeline(self.all_data)
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2022-12-14 03:34:49 +03:00
|
|
|
if self.iteration_col:
|
|
|
|
self.all_data[self.iteration_col] = self.iteration_counter
|
|
|
|
logger.info(
|
|
|
|
f"Iteration used as a feature. Iteration #: {self.iteration_counter}"
|
|
|
|
)
|
|
|
|
|
2022-08-05 08:14:44 +03:00
|
|
|
## if you're using lagged_features, we need to initialize them
|
2022-11-15 08:22:08 +03:00
|
|
|
## will initially be set to the same value, which is either 0
|
|
|
|
## or the initial value of the state depending on zero_padding
|
|
|
|
## and gets updated during each episode step
|
|
|
|
if self.lagged_inputs > 1 or self.concatenate_var_length:
|
2022-06-02 16:55:12 +03:00
|
|
|
self.lagged_all_data = {
|
2022-08-05 08:14:44 +03:00
|
|
|
k: self.all_data["_".join(k.split("_")[:-1])]
|
|
|
|
if not self.lagged_padding
|
|
|
|
else 0
|
|
|
|
for k in self.lagged_feature_cols
|
2022-06-02 16:55:12 +03:00
|
|
|
}
|
2022-11-15 08:22:08 +03:00
|
|
|
self.all_data = {**self.all_data, **self.lagged_all_data}
|
2023-01-27 02:21:01 +03:00
|
|
|
self.all_data["terminal"] = False
|
2022-08-05 08:14:44 +03:00
|
|
|
|
|
|
|
# if self.concatenate_var_length:
|
|
|
|
# all_data = {
|
|
|
|
# k: self.all_data[k] for k in self.features
|
|
|
|
# if k not in self.lagged_feature_cols
|
|
|
|
# else
|
|
|
|
# }
|
2022-06-02 16:55:12 +03:00
|
|
|
|
2021-07-21 21:34:30 +03:00
|
|
|
def episode_step(self, action: Dict[str, int]) -> Dict:
|
2021-06-22 05:13:53 +03:00
|
|
|
# load design matrix for self.model.predict
|
|
|
|
# should match the shape of conf.data.inputs
|
|
|
|
# make dict of D={states, actions, configs}
|
|
|
|
# ddm_inputs = filter D \ (conf.data.inputs+conf.data.augmented_cols)
|
|
|
|
# ddm_outputs = filter D \ conf.data.outputs
|
2021-01-16 00:42:14 +03:00
|
|
|
|
2022-08-05 08:14:44 +03:00
|
|
|
# initialize matrix of all actions
|
|
|
|
# set current action to action_1
|
|
|
|
# all other actions get pushed back one value to action_{i+1}
|
|
|
|
if self.concatenate_var_length:
|
2022-11-15 08:22:08 +03:00
|
|
|
# only create lagged action if they were provided in
|
|
|
|
# concatenate_var_length
|
|
|
|
actions_to_lag = list(
|
|
|
|
set(list(self.concatenate_var_length.keys())) & set(list(action.keys()))
|
|
|
|
)
|
|
|
|
if actions_to_lag:
|
|
|
|
lagged_action = {
|
|
|
|
f"{k}_{i}": action[k] if i == 1 else self.all_data[f"{k}_{i-1}"]
|
|
|
|
for k in actions_to_lag
|
|
|
|
for i in range(1, self.concatenate_var_length[k] + 1)
|
|
|
|
}
|
|
|
|
action = lagged_action
|
2022-08-05 08:14:44 +03:00
|
|
|
elif self.lagged_inputs > 1:
|
2022-06-02 16:55:12 +03:00
|
|
|
lagged_action = {
|
|
|
|
f"{k}_{i}": v if i == 1 else self.all_data[f"{k}_{i-1}"]
|
|
|
|
for k, v in action.items()
|
|
|
|
for i in range(1, self.lagged_inputs + 1)
|
|
|
|
}
|
|
|
|
action = lagged_action
|
2021-06-22 05:13:53 +03:00
|
|
|
self.all_data.update(action)
|
2022-12-05 21:19:47 +03:00
|
|
|
if self.prep_pipeline:
|
|
|
|
from preprocess import pipeline
|
|
|
|
|
|
|
|
self.all_data = pipeline(self.all_data)
|
2022-06-02 16:55:12 +03:00
|
|
|
self.iteration_counter += 1
|
2022-12-14 03:34:49 +03:00
|
|
|
if self.iteration_col:
|
|
|
|
logger.info(
|
|
|
|
f"Iteration used as a feature. Iteration #: {self.iteration_counter}"
|
|
|
|
)
|
2021-06-22 05:13:53 +03:00
|
|
|
|
2023-02-25 07:00:09 +03:00
|
|
|
if hasattr(self, "exogeneous_variables"):
|
2023-01-27 02:21:01 +03:00
|
|
|
logger.info(
|
|
|
|
f"Updating {self.exogeneous_variables} using next iteration from episode #: {self.exog_ep['episode'].values[0]}"
|
|
|
|
)
|
|
|
|
next_iteration = self.exog_ep[
|
|
|
|
self.exog_ep["iteration"] == self.iteration_counter + 1
|
|
|
|
]
|
|
|
|
self.all_data.update(
|
|
|
|
next_iteration.reset_index()[self.exogeneous_variables].loc[0].to_dict()
|
|
|
|
)
|
|
|
|
# set terminal to true if at the last iteration
|
|
|
|
if self.iteration_counter == self.exog_ep["iteration"].max() - 1:
|
|
|
|
self.all_data["terminal"] = True
|
|
|
|
|
2022-08-02 08:02:15 +03:00
|
|
|
# Use the signal builder's value as input to DDM if specified
|
|
|
|
if self.signal_builder:
|
|
|
|
for key in self.features:
|
2022-08-02 08:05:42 +03:00
|
|
|
if key in self.signals:
|
|
|
|
self.all_data.update({key: self.current_signals[key]})
|
2021-09-23 20:30:14 +03:00
|
|
|
|
2021-09-25 01:35:36 +03:00
|
|
|
# Use the signal builder's value as input to DDM if specified
|
|
|
|
if self.signal_builder:
|
|
|
|
for key in self.features:
|
2022-08-02 08:02:15 +03:00
|
|
|
if key in self.signals:
|
|
|
|
self.all_data.update({key: self.current_signals[key]})
|
|
|
|
|
2022-08-05 17:46:16 +03:00
|
|
|
# MAKE SURE THIS IS SORTED ACCORDING TO THE ORDER USED IN TRAINING
|
2021-06-22 05:13:53 +03:00
|
|
|
ddm_input = {k: self.all_data[k] for k in self.features}
|
|
|
|
|
|
|
|
# input_list = [
|
|
|
|
# list(self.state.values()),
|
|
|
|
# list(self.config.values()),
|
|
|
|
# list(action.values()),
|
|
|
|
# ]
|
|
|
|
|
|
|
|
# input_array = [item for subl in input_list for item in subl]
|
|
|
|
input_array = list(ddm_input.values())
|
2021-01-16 00:42:14 +03:00
|
|
|
X = np.array(input_array).reshape(1, -1)
|
2021-03-31 22:44:56 +03:00
|
|
|
if self.diff_state:
|
2021-04-01 22:06:06 +03:00
|
|
|
preds = np.array(list(self.state.values())) + self.model.predict(
|
|
|
|
X
|
|
|
|
) # compensating for output being delta state st+1-st
|
2021-03-31 22:44:56 +03:00
|
|
|
# preds = np.array(list(simstate))+self.dd_model.predict(X) # if doing per iteration prediction of delta state st+1-st
|
|
|
|
else:
|
2021-04-01 22:06:06 +03:00
|
|
|
preds = self.model.predict(X) # absolute prediction
|
2021-06-22 05:13:53 +03:00
|
|
|
ddm_output = dict(zip(self.labels, preds.reshape(preds.shape[1]).tolist()))
|
2022-08-02 08:02:15 +03:00
|
|
|
|
|
|
|
# update lagged values in ddm_output -> which updates self.all_data
|
|
|
|
# current predictions become the new t1, everything else is pushed back by 1
|
2022-08-05 08:14:44 +03:00
|
|
|
if self.concatenate_var_length:
|
|
|
|
lagged_ddm_output = {
|
|
|
|
f"{k}_{i}": v if i == 1 else self.all_data[f"{k}_{i-1}"]
|
|
|
|
for k, v in ddm_output.items()
|
|
|
|
for i in range(1, self.concatenate_var_length[k] + 1)
|
|
|
|
}
|
|
|
|
ddm_output = lagged_ddm_output
|
|
|
|
elif self.lagged_inputs > 1:
|
2022-06-02 16:55:12 +03:00
|
|
|
lagged_ddm_output = {
|
|
|
|
f"{k}_{i}": v if i == 1 else self.all_data[f"{k}_{i-1}"]
|
|
|
|
for k, v in ddm_output.items()
|
|
|
|
for i in range(1, self.lagged_inputs + 1)
|
|
|
|
}
|
|
|
|
ddm_output = lagged_ddm_output
|
2021-06-22 05:13:53 +03:00
|
|
|
self.all_data.update(ddm_output)
|
2022-12-14 03:34:49 +03:00
|
|
|
if self.iteration_col:
|
|
|
|
self.all_data[self.iteration_col] = self.iteration_counter
|
2022-06-02 16:55:12 +03:00
|
|
|
|
2022-08-02 08:02:15 +03:00
|
|
|
# current state is just the first value
|
2022-11-15 08:22:08 +03:00
|
|
|
states_lagged = list(
|
|
|
|
set(list(self.concatenate_var_length.keys())) & set(self.state_keys)
|
|
|
|
)
|
|
|
|
if self.lagged_inputs > 1 and not self.concatenate_var_length:
|
2022-06-02 16:55:12 +03:00
|
|
|
self.state = {k: self.all_data[f"{k}_1"] for k in self.state_keys}
|
2022-11-15 08:22:08 +03:00
|
|
|
elif self.concatenate_var_length:
|
|
|
|
self.state = {
|
|
|
|
k: self.all_data[f"{k}_1"] if k in states_lagged else self.all_data[k]
|
|
|
|
for k in self.state_keys
|
|
|
|
}
|
2022-06-02 16:55:12 +03:00
|
|
|
else:
|
|
|
|
self.state = {k: self.all_data[k] for k in self.state_keys}
|
2021-06-22 05:13:53 +03:00
|
|
|
# self.state = dict(zip(self.state_keys, preds.reshape(preds.shape[1]).tolist()))
|
2022-08-02 08:02:15 +03:00
|
|
|
|
|
|
|
if self.signal_builder:
|
|
|
|
self.current_signals = {}
|
|
|
|
for key, val in self.signals.items():
|
|
|
|
self.current_signals.update(
|
|
|
|
{key: float(self.signals[key].get_current_signal())}
|
|
|
|
)
|
|
|
|
|
2021-07-21 21:34:30 +03:00
|
|
|
return dict(self.state)
|
2020-12-18 05:37:28 +03:00
|
|
|
|
2021-07-21 21:34:30 +03:00
|
|
|
def get_state(self) -> Dict:
|
2022-12-07 01:54:57 +03:00
|
|
|
if hasattr(self, "label_types"):
|
|
|
|
for key, val_type in self.label_types.items():
|
|
|
|
state_val = self.state[key]
|
|
|
|
val_type = val_type.split(" ")
|
|
|
|
if len(val_type) < 2:
|
|
|
|
bottom = state_val - 10
|
|
|
|
top = state_val + 10
|
|
|
|
val_type = val_type[0]
|
|
|
|
elif len(val_type) == 2:
|
|
|
|
# val_type, val_range = val_type.split(" ")
|
|
|
|
val_range = val_type[1]
|
|
|
|
val_type = val_type[0]
|
|
|
|
val_range = val_range.split(",")
|
|
|
|
bottom = float(val_range[0])
|
|
|
|
top = float(val_range[1])
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Invalid label type provided: {type(val_type)}")
|
|
|
|
state_val = type_conversion(state_val, val_type, bottom, top)
|
|
|
|
self.state[key] = state_val
|
|
|
|
|
2022-08-02 08:02:15 +03:00
|
|
|
if self.signal_builder:
|
|
|
|
state_plus_signals = {**self.state, **self.current_signals}
|
|
|
|
logger.info(f"Current state with signals: {state_plus_signals}")
|
|
|
|
return state_plus_signals
|
|
|
|
else:
|
|
|
|
logger.info(f"Current state: {self.state}")
|
|
|
|
return dict(self.state)
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
def halted(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def env_setup():
|
|
|
|
"""Helper function to setup connection with Project Bonsai
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
Tuple
|
|
|
|
workspace, and access_key
|
|
|
|
"""
|
|
|
|
|
|
|
|
load_dotenv(verbose=True)
|
|
|
|
workspace = os.getenv("SIM_WORKSPACE")
|
|
|
|
access_key = os.getenv("SIM_ACCESS_KEY")
|
|
|
|
|
2021-06-11 01:43:33 +03:00
|
|
|
env_file_path = os.path.join(dir_path, ".env")
|
|
|
|
env_file_exists = os.path.exists(env_file_path)
|
2020-12-16 00:45:32 +03:00
|
|
|
if not env_file_exists:
|
2021-06-11 01:43:33 +03:00
|
|
|
open(env_file_path, "a").close()
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
if not all([env_file_exists, workspace]):
|
|
|
|
workspace = input("Please enter your workspace id: ")
|
2021-06-11 01:43:33 +03:00
|
|
|
set_key(env_file_path, "SIM_WORKSPACE", workspace)
|
2020-12-16 00:45:32 +03:00
|
|
|
if not all([env_file_exists, access_key]):
|
|
|
|
access_key = input("Please enter your access key: ")
|
2021-06-11 01:43:33 +03:00
|
|
|
set_key(env_file_path, "SIM_ACCESS_KEY", access_key)
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
load_dotenv(verbose=True, override=True)
|
|
|
|
workspace = os.getenv("SIM_WORKSPACE")
|
|
|
|
access_key = os.getenv("SIM_ACCESS_KEY")
|
|
|
|
|
|
|
|
return workspace, access_key
|
|
|
|
|
|
|
|
|
2021-08-13 23:46:12 +03:00
|
|
|
def test_policy(
|
2021-07-14 21:56:14 +03:00
|
|
|
num_episodes: int = 5,
|
|
|
|
num_iterations: int = 5,
|
2023-01-27 02:21:01 +03:00
|
|
|
sim: Optional[Simulator] = None,
|
|
|
|
config: Optional[Dict[str, float]] = None,
|
2021-08-13 23:46:12 +03:00
|
|
|
policy=random_policy,
|
2020-12-16 00:45:32 +03:00
|
|
|
):
|
|
|
|
"""Test a policy using random actions over a fixed number of episodes
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
num_episodes : int, optional
|
|
|
|
number of iterations to run, by default 10
|
|
|
|
"""
|
|
|
|
|
2023-01-27 02:21:01 +03:00
|
|
|
def _config_clean(in_config):
|
2021-08-13 02:31:39 +03:00
|
|
|
new_config = {}
|
|
|
|
for k, v in in_config.items():
|
|
|
|
if type(v) in [DictConfig, dict]:
|
|
|
|
v = random.uniform(in_config[k]["min"], in_config[k]["max"])
|
|
|
|
k = in_config[k]["inkling_name"]
|
|
|
|
new_config[k] = v
|
|
|
|
return new_config
|
|
|
|
|
2020-12-16 00:45:32 +03:00
|
|
|
for episode in range(num_episodes):
|
|
|
|
iteration = 0
|
|
|
|
terminal = False
|
2023-01-27 02:21:01 +03:00
|
|
|
if config:
|
|
|
|
new_config = _config_clean(config)
|
|
|
|
logger.info(f"Configuration: {new_config}")
|
|
|
|
sim.episode_start(new_config)
|
|
|
|
else:
|
|
|
|
sim.episode_start()
|
2021-01-16 00:42:14 +03:00
|
|
|
sim_state = sim.get_state()
|
2020-12-16 00:45:32 +03:00
|
|
|
while not terminal:
|
2021-08-13 23:46:12 +03:00
|
|
|
action = policy(sim_state)
|
2020-12-16 00:45:32 +03:00
|
|
|
sim.episode_step(action)
|
|
|
|
sim_state = sim.get_state()
|
2021-06-22 20:20:32 +03:00
|
|
|
logger.info(f"Running iteration #{iteration} for episode #{episode}")
|
2021-07-13 19:31:50 +03:00
|
|
|
logger.info(f"Action: {action}")
|
2021-06-22 20:20:32 +03:00
|
|
|
logger.info(f"Observations: {sim_state}")
|
2020-12-16 00:45:32 +03:00
|
|
|
iteration += 1
|
|
|
|
terminal = iteration >= num_iterations
|
|
|
|
|
|
|
|
return sim
|
|
|
|
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
@hydra.main(config_path="conf", config_name="config")
|
|
|
|
def main(cfg: DictConfig):
|
2021-03-26 06:39:35 +03:00
|
|
|
save_path = cfg["model"]["saver"]["filename"]
|
2022-11-15 08:22:08 +03:00
|
|
|
save_path = os.path.join(dir_path, save_path)
|
2021-01-16 00:42:14 +03:00
|
|
|
model_name = cfg["model"]["name"]
|
|
|
|
states = cfg["simulator"]["states"]
|
|
|
|
actions = cfg["simulator"]["actions"]
|
|
|
|
configs = cfg["simulator"]["configs"]
|
2021-07-14 03:07:21 +03:00
|
|
|
initial_states = cfg["simulator"]["initial_states"]
|
2021-01-16 00:42:14 +03:00
|
|
|
policy = cfg["simulator"]["policy"]
|
2021-07-14 21:56:14 +03:00
|
|
|
# logflag = cfg["simulator"]["logging"]
|
2021-04-01 22:06:06 +03:00
|
|
|
# logging not yet implemented
|
2022-12-05 21:19:47 +03:00
|
|
|
|
|
|
|
ts_model = model_name.lower() in ["nhits", "tftmodel", "varima", "ets", "sfarima"]
|
|
|
|
if ts_model:
|
|
|
|
scale_data = cfg["model"]["scale_data"]
|
|
|
|
else:
|
|
|
|
scale_data = cfg["model"]["build_params"]["scale_data"]
|
|
|
|
# scale_data = cfg["data"]["scale_data"]
|
2023-02-25 07:00:09 +03:00
|
|
|
diff_state = hydra_read_config_var(cfg, "data", "diff_state")
|
|
|
|
concatenated_steps = hydra_read_config_var(cfg, "data", "concatenated_steps")
|
|
|
|
concatenated_zero_padding = hydra_read_config_var(
|
|
|
|
cfg, "data", "concatenated_zero_padding"
|
|
|
|
)
|
|
|
|
concatenate_var_length = hydra_read_config_var(cfg, "data", "concatenate_length")
|
|
|
|
exogeneous_variables = hydra_read_config_var(cfg, "data", "exogeneous_variables")
|
|
|
|
exogeneous_save_path = hydra_read_config_var(cfg, "data", "exogeneous_save_path")
|
|
|
|
initial_values_save_path = hydra_read_config_var(
|
|
|
|
cfg, "data", "initial_values_save_path"
|
|
|
|
)
|
2022-06-02 16:55:12 +03:00
|
|
|
|
2021-06-10 21:09:30 +03:00
|
|
|
workspace_setup = cfg["simulator"]["workspace_setup"]
|
|
|
|
episode_inits = cfg["simulator"]["episode_inits"]
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
input_cols = cfg["data"]["inputs"]
|
|
|
|
output_cols = cfg["data"]["outputs"]
|
|
|
|
augmented_cols = cfg["data"]["augmented_cols"]
|
2023-02-25 07:00:09 +03:00
|
|
|
preprocess = hydra_read_config_var(cfg, "data", "preprocess")
|
|
|
|
iteration_col = hydra_read_config_var(cfg, "data", "iteration_col")
|
2022-12-14 03:34:49 +03:00
|
|
|
iteration_col = iteration_col if iteration_col in input_cols else None
|
2021-06-22 05:13:53 +03:00
|
|
|
if type(input_cols) == ListConfig:
|
|
|
|
input_cols = list(input_cols)
|
|
|
|
if type(output_cols) == ListConfig:
|
|
|
|
output_cols = list(output_cols)
|
|
|
|
if type(augmented_cols) == ListConfig:
|
|
|
|
augmented_cols = list(augmented_cols)
|
2021-01-16 00:42:14 +03:00
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
input_cols = input_cols + augmented_cols
|
|
|
|
|
2022-09-29 18:06:22 +03:00
|
|
|
ts_model = False
|
2021-08-13 02:37:54 +03:00
|
|
|
logger.info(f"Using DDM with {policy} policy")
|
2021-07-13 19:31:50 +03:00
|
|
|
if model_name.lower() == "pytorch":
|
2021-06-09 22:38:33 +03:00
|
|
|
from all_models import available_models
|
2022-09-29 18:06:22 +03:00
|
|
|
elif model_name.lower() in ["nhits", "tftmodel", "varima", "ets", "sfarima"]:
|
|
|
|
from timeseriesclass import darts_models as available_models
|
|
|
|
|
|
|
|
ts_model = True
|
2021-06-09 22:38:33 +03:00
|
|
|
else:
|
|
|
|
from model_loader import available_models
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
Model = available_models[model_name]
|
2022-09-29 18:06:22 +03:00
|
|
|
if not ts_model:
|
|
|
|
model = Model()
|
|
|
|
else:
|
|
|
|
model = Model()
|
|
|
|
model.build_model()
|
2021-01-16 00:42:14 +03:00
|
|
|
|
|
|
|
model.load_model(filename=save_path, scale_data=scale_data)
|
2021-06-03 02:28:05 +03:00
|
|
|
# model.build_model(**cfg["model"]["build_params"])
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-07-14 03:07:21 +03:00
|
|
|
if not initial_states:
|
2023-01-27 02:21:01 +03:00
|
|
|
if not initial_values_save_path:
|
|
|
|
logger.warn(
|
|
|
|
"No initial values provided, using randomly initialized states which is probably NOT what you want"
|
|
|
|
)
|
2021-07-14 03:07:21 +03:00
|
|
|
initial_states = {k: random.random() for k in states}
|
|
|
|
|
2022-08-02 08:02:15 +03:00
|
|
|
signal_builder = cfg["simulator"]["signal_builder"]
|
|
|
|
|
2020-12-16 00:45:32 +03:00
|
|
|
# Grab standardized way to interact with sim API
|
2021-06-22 05:13:53 +03:00
|
|
|
sim = Simulator(
|
|
|
|
model,
|
|
|
|
states,
|
|
|
|
actions,
|
|
|
|
configs,
|
|
|
|
input_cols,
|
|
|
|
output_cols,
|
|
|
|
episode_inits,
|
2021-07-14 03:07:21 +03:00
|
|
|
initial_states,
|
2022-08-02 08:02:15 +03:00
|
|
|
signal_builder,
|
2021-06-22 05:13:53 +03:00
|
|
|
diff_state,
|
2022-06-02 16:55:12 +03:00
|
|
|
concatenated_steps,
|
|
|
|
concatenated_zero_padding,
|
2022-08-05 08:14:44 +03:00
|
|
|
concatenate_var_length,
|
2023-02-25 07:00:09 +03:00
|
|
|
prep_pipeline=preprocess,
|
2022-12-14 03:34:49 +03:00
|
|
|
iteration_col=iteration_col,
|
2023-01-11 18:56:22 +03:00
|
|
|
exogeneous_variables=exogeneous_variables,
|
2023-02-25 07:00:09 +03:00
|
|
|
exogeneous_save_path=exogeneous_save_path,
|
2023-01-27 02:21:01 +03:00
|
|
|
initial_values_save_path=initial_values_save_path,
|
2021-06-22 05:13:53 +03:00
|
|
|
)
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
if policy == "random":
|
2021-08-13 23:46:12 +03:00
|
|
|
random_policy_from_keys = partial(random_policy, action_keys=sim.action_keys)
|
2021-08-13 23:47:02 +03:00
|
|
|
test_policy(
|
2022-11-15 08:22:08 +03:00
|
|
|
sim=sim,
|
2023-01-27 02:21:01 +03:00
|
|
|
config=None,
|
2022-11-15 08:22:08 +03:00
|
|
|
policy=random_policy_from_keys,
|
2021-08-13 23:47:02 +03:00
|
|
|
)
|
2021-08-13 23:46:12 +03:00
|
|
|
elif isinstance(policy, int):
|
|
|
|
# If docker PORT provided, set as exported brain PORT
|
|
|
|
port = policy
|
|
|
|
url = f"http://localhost:{port}"
|
|
|
|
print(f"Connecting to exported brain running at {url}...")
|
|
|
|
trained_brain_policy = partial(brain_policy, exported_brain_url=url)
|
2021-08-13 23:47:02 +03:00
|
|
|
test_policy(
|
2022-11-15 08:22:08 +03:00
|
|
|
sim=sim,
|
|
|
|
config={**initial_states},
|
|
|
|
policy=trained_brain_policy,
|
2021-08-13 23:47:02 +03:00
|
|
|
)
|
2021-01-16 00:42:14 +03:00
|
|
|
elif policy == "bonsai":
|
2021-06-10 21:09:30 +03:00
|
|
|
if workspace_setup:
|
2021-06-10 00:51:56 +03:00
|
|
|
logger.info(f"Loading workspace information form .env")
|
|
|
|
env_setup()
|
|
|
|
load_dotenv(verbose=True, override=True)
|
2021-01-16 00:42:14 +03:00
|
|
|
# Configure client to interact with Bonsai service
|
|
|
|
config_client = BonsaiClientConfig()
|
|
|
|
client = BonsaiClient(config_client)
|
|
|
|
|
2023-01-27 02:21:01 +03:00
|
|
|
# SimulatorInterface needs to be initialized with
|
|
|
|
# existin state attribute
|
|
|
|
# TODO: see if we can move this into constructor method
|
|
|
|
sim.episode_start()
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
# Create simulator session and init sequence id
|
|
|
|
registration_info = SimulatorInterface(
|
|
|
|
name=env_name,
|
|
|
|
timeout=60,
|
|
|
|
simulator_context=config_client.simulator_context,
|
|
|
|
)
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
def CreateSession(
|
|
|
|
registration_info: SimulatorInterface, config_client: BonsaiClientConfig
|
|
|
|
):
|
2021-04-23 03:46:17 +03:00
|
|
|
"""Creates a new Simulator Session and returns new session, sequenceId"""
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
try:
|
|
|
|
print(
|
2021-01-16 00:42:14 +03:00
|
|
|
"config: {}, {}".format(
|
|
|
|
config_client.server, config_client.workspace
|
|
|
|
)
|
2020-12-16 00:45:32 +03:00
|
|
|
)
|
2021-01-16 00:42:14 +03:00
|
|
|
registered_session: SimulatorSessionResponse = client.session.create(
|
|
|
|
workspace_name=config_client.workspace, body=registration_info
|
|
|
|
)
|
|
|
|
print("Registered simulator. {}".format(registered_session.session_id))
|
|
|
|
|
|
|
|
return registered_session, 1
|
2021-04-06 23:59:50 +03:00
|
|
|
# except HttpResponseError as ex:
|
|
|
|
# print(
|
|
|
|
# "HttpResponseError in Registering session: StatusCode: {}, Error: {}, Exception: {}".format(
|
|
|
|
# ex.status_code, ex.error.message, ex
|
|
|
|
# )
|
|
|
|
# )
|
|
|
|
# raise ex
|
2021-01-16 00:42:14 +03:00
|
|
|
except Exception as ex:
|
|
|
|
print(
|
|
|
|
"UnExpected error: {}, Most likely, it's some network connectivity issue, make sure you are able to reach bonsai platform from your network.".format(
|
|
|
|
ex
|
|
|
|
)
|
2020-12-16 00:45:32 +03:00
|
|
|
)
|
2021-01-16 00:42:14 +03:00
|
|
|
raise ex
|
|
|
|
|
|
|
|
registered_session, sequence_id = CreateSession(
|
|
|
|
registration_info, config_client
|
2020-12-16 00:45:32 +03:00
|
|
|
)
|
2021-01-16 00:42:14 +03:00
|
|
|
episode = 0
|
|
|
|
iteration = 0
|
|
|
|
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
# Advance by the new state depending on the event type
|
|
|
|
sim_state = SimulatorState(
|
2022-11-15 08:22:08 +03:00
|
|
|
sequence_id=sequence_id,
|
|
|
|
state=sim.get_state(),
|
|
|
|
halted=sim.halted(),
|
2021-01-16 00:42:14 +03:00
|
|
|
)
|
|
|
|
try:
|
|
|
|
event = client.session.advance(
|
|
|
|
workspace_name=config_client.workspace,
|
|
|
|
session_id=registered_session.session_id,
|
|
|
|
body=sim_state,
|
|
|
|
)
|
|
|
|
sequence_id = event.sequence_id
|
|
|
|
print(
|
|
|
|
"[{}] Last Event: {}".format(
|
|
|
|
time.strftime("%H:%M:%S"), event.type
|
|
|
|
)
|
|
|
|
)
|
2021-04-06 23:59:50 +03:00
|
|
|
# UPDATE #comment-out-azure-cli:
|
|
|
|
# - commented out the HttpResponseError since it relies on azure-cli-core which has
|
|
|
|
# - conflicting dependencies with microsoft-bonsai-api
|
|
|
|
# - the catch-all exception below should still re-connect on disconnects
|
|
|
|
# except HttpResponseError as ex:
|
|
|
|
# print(
|
|
|
|
# "HttpResponseError in Advance: StatusCode: {}, Error: {}, Exception: {}".format(
|
|
|
|
# ex.status_code, ex.error.message, ex
|
|
|
|
# )
|
|
|
|
# )
|
|
|
|
# # This can happen in network connectivity issue, though SDK has retry logic, but even after that request may fail,
|
|
|
|
# # if your network has some issue, or sim session at platform is going away..
|
|
|
|
# # So let's re-register sim-session and get a new session and continue iterating. :-)
|
|
|
|
# registered_session, sequence_id = CreateSession(
|
|
|
|
# registration_info, config_client
|
|
|
|
# )
|
|
|
|
# continue
|
2021-01-16 00:42:14 +03:00
|
|
|
except Exception as err:
|
|
|
|
print("Unexpected error in Advance: {}".format(err))
|
|
|
|
# Ideally this shouldn't happen, but for very long-running sims It can happen with various reasons, let's re-register sim & Move on.
|
|
|
|
# If possible try to notify Bonsai team to see, if this is platform issue and can be fixed.
|
|
|
|
registered_session, sequence_id = CreateSession(
|
|
|
|
registration_info, config_client
|
|
|
|
)
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Event loop
|
|
|
|
if event.type == "Idle":
|
|
|
|
time.sleep(event.idle.callback_time)
|
|
|
|
print("Idling...")
|
|
|
|
elif event.type == "EpisodeStart":
|
|
|
|
print(event.episode_start.config)
|
|
|
|
sim.episode_start(event.episode_start.config)
|
|
|
|
episode += 1
|
|
|
|
elif event.type == "EpisodeStep":
|
|
|
|
iteration += 1
|
|
|
|
sim.episode_step(event.episode_step.action)
|
|
|
|
elif event.type == "EpisodeFinish":
|
|
|
|
print("Episode Finishing...")
|
|
|
|
iteration = 0
|
|
|
|
elif event.type == "Unregister":
|
|
|
|
print(
|
|
|
|
"Simulator Session unregistered by platform because '{}', Registering again!".format(
|
|
|
|
event.unregister.details
|
|
|
|
)
|
|
|
|
)
|
|
|
|
registered_session, sequence_id = CreateSession(
|
|
|
|
registration_info, config_client
|
|
|
|
)
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
pass
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
# Gracefully unregister with keyboard interrupt
|
|
|
|
client.session.delete(
|
|
|
|
workspace_name=config_client.workspace,
|
|
|
|
session_id=registered_session.session_id,
|
|
|
|
)
|
|
|
|
print("Unregistered simulator.")
|
|
|
|
except Exception as err:
|
|
|
|
# Gracefully unregister for any other exceptions
|
|
|
|
client.session.delete(
|
|
|
|
workspace_name=config_client.workspace,
|
|
|
|
session_id=registered_session.session_id,
|
|
|
|
)
|
|
|
|
print("Unregistered simulator because: {}".format(err))
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2021-01-16 00:42:14 +03:00
|
|
|
main()
|