diff --git a/conf/config.yaml b/conf/config.yaml index 2b6aa4c..bbf1c3c 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -1,4 +1,4 @@ defaults: - data: cartpole-updated.yaml - - model: SVR.yaml + - model: xgboost.yaml - simulator: cartpole-updated-simparam.yaml diff --git a/conf/simulator/cartpole-updated-simparam.yaml b/conf/simulator/cartpole-updated-simparam.yaml index 2a5ed57..12cc056 100644 --- a/conf/simulator/cartpole-updated-simparam.yaml +++ b/conf/simulator/cartpole-updated-simparam.yaml @@ -6,6 +6,16 @@ simulator: # estimate these during training # e.g.,: episode_inits: { "pole_length": 0.4, "pole_mass": 0.055, "cart_mass": 0.31 } + # e.g.,: your simulator may need to know the initial state + # before the first episode. define these here as a dictionary + # you can include these in your Inkling scenarios during brain training + initial_states: + { + "cart_position": 0, + "cart_velocity": 0, + "pole_angle": 0, + "pole_angular_velocity": 0, + } # episode_inits: policy: bonsai logging: enable diff --git a/ddm_predictor.py b/ddm_predictor.py index 5be88e0..829ffa6 100644 --- a/ddm_predictor.py +++ b/ddm_predictor.py @@ -44,6 +44,7 @@ class Simulator(BaseModel): inputs: List[str], outputs: List[str], episode_inits: Dict[str, float], + initial_states: Dict[str, float], diff_state: bool = False, ): @@ -57,14 +58,30 @@ class Simulator(BaseModel): self.state_keys = states self.action_keys = actions self.diff_state = diff_state + self.initial_states = initial_states # TODO: Add logging logger.info(f"DDM features: {self.features}") logger.info(f"DDM outputs: {self.labels}") def episode_start(self, config: Dict[str, Any] = None): + """Initial DDM with initial states. This could include initializations of configs + as well as initial values for actions - initial_state = {k: random.random() for k in self.state_keys} + Parameters + ---------- + config : Dict[str, Any], optional + episode initializations, by default None + """ + + # initialize states based on simulator.yaml + initial_state = self.initial_states + # define initial state from config if available (e.g. when brain training) + # skip if config missing + if config: + initial_state.update( + (k, config[k]) for k in initial_state.keys() & config.keys() + ) initial_action = {k: random.random() for k in self.action_keys} if config: logger.info(f"Initializing episode with provided config: {config}") @@ -82,6 +99,7 @@ class Simulator(BaseModel): # request_continue = input("Are you sure you want to continue with random configs?") self.config = {k: random.random() for k in self.config_keys} self.state = initial_state + logger.info(f"Initial states: {initial_state}") self.action = initial_action # capture all data # TODO: check if we can pick a subset of data yaml, i.e., what happens if @@ -206,6 +224,7 @@ def main(cfg: DictConfig): states = cfg["simulator"]["states"] actions = cfg["simulator"]["actions"] configs = cfg["simulator"]["configs"] + initial_states = cfg["simulator"]["initial_states"] policy = cfg["simulator"]["policy"] logflag = cfg["simulator"]["logging"] # logging not yet implemented @@ -238,6 +257,12 @@ def main(cfg: DictConfig): model.load_model(filename=save_path, scale_data=scale_data) # model.build_model(**cfg["model"]["build_params"]) + if not initial_states: + logger.warn( + "No initial values provided, using randomly initialized states which is probably NOT what you want" + ) + initial_states = {k: random.random() for k in states} + # Grab standardized way to interact with sim API sim = Simulator( model, @@ -247,6 +272,7 @@ def main(cfg: DictConfig): input_cols, output_cols, episode_inits, + initial_states, diff_state, )