UPDATE: initial states in episode_start and in cartpole-updated-simparam.yaml
initial states should be defined before episode_step. Previously these were randomly initialized as standard gaussians in episode_start using the keys defined by simulator.states. Now a new value is provdied in simulator.yaml called initial_states which takes a dictionary of initial_values. Moreover, episode_start looks in the `config` Dict (for instance when using scenario parameters in Inkling from lessons) and updates initial_states if provided. TODO: update documentation and all simulator.yaml files
This commit is contained in:
Родитель
335f018b71
Коммит
4041ff825e
|
@ -1,4 +1,4 @@
|
||||||
defaults:
|
defaults:
|
||||||
- data: cartpole-updated.yaml
|
- data: cartpole-updated.yaml
|
||||||
- model: SVR.yaml
|
- model: xgboost.yaml
|
||||||
- simulator: cartpole-updated-simparam.yaml
|
- simulator: cartpole-updated-simparam.yaml
|
||||||
|
|
|
@ -6,6 +6,16 @@ simulator:
|
||||||
# estimate these during training
|
# estimate these during training
|
||||||
# e.g.,:
|
# e.g.,:
|
||||||
episode_inits: { "pole_length": 0.4, "pole_mass": 0.055, "cart_mass": 0.31 }
|
episode_inits: { "pole_length": 0.4, "pole_mass": 0.055, "cart_mass": 0.31 }
|
||||||
|
# e.g.,: your simulator may need to know the initial state
|
||||||
|
# before the first episode. define these here as a dictionary
|
||||||
|
# you can include these in your Inkling scenarios during brain training
|
||||||
|
initial_states:
|
||||||
|
{
|
||||||
|
"cart_position": 0,
|
||||||
|
"cart_velocity": 0,
|
||||||
|
"pole_angle": 0,
|
||||||
|
"pole_angular_velocity": 0,
|
||||||
|
}
|
||||||
# episode_inits:
|
# episode_inits:
|
||||||
policy: bonsai
|
policy: bonsai
|
||||||
logging: enable
|
logging: enable
|
||||||
|
|
|
@ -44,6 +44,7 @@ class Simulator(BaseModel):
|
||||||
inputs: List[str],
|
inputs: List[str],
|
||||||
outputs: List[str],
|
outputs: List[str],
|
||||||
episode_inits: Dict[str, float],
|
episode_inits: Dict[str, float],
|
||||||
|
initial_states: Dict[str, float],
|
||||||
diff_state: bool = False,
|
diff_state: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
@ -57,14 +58,30 @@ class Simulator(BaseModel):
|
||||||
self.state_keys = states
|
self.state_keys = states
|
||||||
self.action_keys = actions
|
self.action_keys = actions
|
||||||
self.diff_state = diff_state
|
self.diff_state = diff_state
|
||||||
|
self.initial_states = initial_states
|
||||||
# TODO: Add logging
|
# TODO: Add logging
|
||||||
|
|
||||||
logger.info(f"DDM features: {self.features}")
|
logger.info(f"DDM features: {self.features}")
|
||||||
logger.info(f"DDM outputs: {self.labels}")
|
logger.info(f"DDM outputs: {self.labels}")
|
||||||
|
|
||||||
def episode_start(self, config: Dict[str, Any] = None):
|
def episode_start(self, config: Dict[str, Any] = None):
|
||||||
|
"""Initial DDM with initial states. This could include initializations of configs
|
||||||
|
as well as initial values for actions
|
||||||
|
|
||||||
initial_state = {k: random.random() for k in self.state_keys}
|
Parameters
|
||||||
|
----------
|
||||||
|
config : Dict[str, Any], optional
|
||||||
|
episode initializations, by default None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# initialize states based on simulator.yaml
|
||||||
|
initial_state = self.initial_states
|
||||||
|
# define initial state from config if available (e.g. when brain training)
|
||||||
|
# skip if config missing
|
||||||
|
if config:
|
||||||
|
initial_state.update(
|
||||||
|
(k, config[k]) for k in initial_state.keys() & config.keys()
|
||||||
|
)
|
||||||
initial_action = {k: random.random() for k in self.action_keys}
|
initial_action = {k: random.random() for k in self.action_keys}
|
||||||
if config:
|
if config:
|
||||||
logger.info(f"Initializing episode with provided config: {config}")
|
logger.info(f"Initializing episode with provided config: {config}")
|
||||||
|
@ -82,6 +99,7 @@ class Simulator(BaseModel):
|
||||||
# request_continue = input("Are you sure you want to continue with random configs?")
|
# request_continue = input("Are you sure you want to continue with random configs?")
|
||||||
self.config = {k: random.random() for k in self.config_keys}
|
self.config = {k: random.random() for k in self.config_keys}
|
||||||
self.state = initial_state
|
self.state = initial_state
|
||||||
|
logger.info(f"Initial states: {initial_state}")
|
||||||
self.action = initial_action
|
self.action = initial_action
|
||||||
# capture all data
|
# capture all data
|
||||||
# TODO: check if we can pick a subset of data yaml, i.e., what happens if
|
# TODO: check if we can pick a subset of data yaml, i.e., what happens if
|
||||||
|
@ -206,6 +224,7 @@ def main(cfg: DictConfig):
|
||||||
states = cfg["simulator"]["states"]
|
states = cfg["simulator"]["states"]
|
||||||
actions = cfg["simulator"]["actions"]
|
actions = cfg["simulator"]["actions"]
|
||||||
configs = cfg["simulator"]["configs"]
|
configs = cfg["simulator"]["configs"]
|
||||||
|
initial_states = cfg["simulator"]["initial_states"]
|
||||||
policy = cfg["simulator"]["policy"]
|
policy = cfg["simulator"]["policy"]
|
||||||
logflag = cfg["simulator"]["logging"]
|
logflag = cfg["simulator"]["logging"]
|
||||||
# logging not yet implemented
|
# logging not yet implemented
|
||||||
|
@ -238,6 +257,12 @@ def main(cfg: DictConfig):
|
||||||
model.load_model(filename=save_path, scale_data=scale_data)
|
model.load_model(filename=save_path, scale_data=scale_data)
|
||||||
# model.build_model(**cfg["model"]["build_params"])
|
# model.build_model(**cfg["model"]["build_params"])
|
||||||
|
|
||||||
|
if not initial_states:
|
||||||
|
logger.warn(
|
||||||
|
"No initial values provided, using randomly initialized states which is probably NOT what you want"
|
||||||
|
)
|
||||||
|
initial_states = {k: random.random() for k in states}
|
||||||
|
|
||||||
# Grab standardized way to interact with sim API
|
# Grab standardized way to interact with sim API
|
||||||
sim = Simulator(
|
sim = Simulator(
|
||||||
model,
|
model,
|
||||||
|
@ -247,6 +272,7 @@ def main(cfg: DictConfig):
|
||||||
input_cols,
|
input_cols,
|
||||||
output_cols,
|
output_cols,
|
||||||
episode_inits,
|
episode_inits,
|
||||||
|
initial_states,
|
||||||
diff_state,
|
diff_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче