2020-12-16 00:45:32 +03:00
|
|
|
import logging
|
|
|
|
import os
|
2020-12-18 05:37:28 +03:00
|
|
|
import random
|
2020-12-16 00:45:32 +03:00
|
|
|
import time
|
2020-12-18 05:48:51 +03:00
|
|
|
from typing import Any, Dict, List
|
2021-06-22 05:13:53 +03:00
|
|
|
from omegaconf import ListConfig
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
import numpy as np
|
2021-04-06 23:59:50 +03:00
|
|
|
|
|
|
|
# see reason below for why commented out (UPDATE #comment-out-azure-cli)
|
|
|
|
# from azure.core.exceptions import HttpResponseError
|
2020-12-16 00:45:32 +03:00
|
|
|
from dotenv import load_dotenv, set_key
|
|
|
|
from microsoft_bonsai_api.simulator.client import BonsaiClient, BonsaiClientConfig
|
|
|
|
from microsoft_bonsai_api.simulator.generated.models import (
|
|
|
|
SimulatorInterface,
|
|
|
|
SimulatorSessionResponse,
|
|
|
|
SimulatorState,
|
|
|
|
)
|
2020-12-18 05:37:28 +03:00
|
|
|
|
2020-12-16 00:45:32 +03:00
|
|
|
from base import BaseModel
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
logging.basicConfig()
|
|
|
|
logging.root.setLevel(logging.INFO)
|
2021-05-11 22:07:44 +03:00
|
|
|
for name in logging.Logger.manager.loggerDict.keys():
|
|
|
|
if "azure" in name:
|
|
|
|
logging.getLogger(name).setLevel(logging.WARNING)
|
|
|
|
logging.propagate = True
|
2021-01-16 00:42:14 +03:00
|
|
|
logger = logging.getLogger("datamodeler")
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
import hydra
|
2021-06-09 22:38:33 +03:00
|
|
|
from omegaconf import DictConfig
|
2021-01-04 20:16:23 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
env_name = "DDM"
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
|
|
|
|
class Simulator(BaseModel):
|
2021-03-31 22:44:56 +03:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model,
|
2021-03-31 23:40:29 +03:00
|
|
|
states: List[str],
|
|
|
|
actions: List[str],
|
|
|
|
configs: List[str],
|
2021-06-22 05:13:53 +03:00
|
|
|
inputs: List[str],
|
|
|
|
outputs: List[str],
|
2021-06-10 21:09:30 +03:00
|
|
|
episode_inits: Dict[str, float],
|
2021-04-01 22:06:06 +03:00
|
|
|
diff_state: bool = False,
|
|
|
|
):
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
self.model = model
|
2021-06-22 05:13:53 +03:00
|
|
|
# self.features = states + configs + actions
|
|
|
|
# self.labels = states
|
|
|
|
self.features = inputs
|
|
|
|
self.labels = outputs
|
2021-01-16 00:42:14 +03:00
|
|
|
self.config_keys = configs
|
2021-06-10 21:09:30 +03:00
|
|
|
self.episode_inits = episode_inits
|
2021-01-16 00:42:14 +03:00
|
|
|
self.state_keys = states
|
|
|
|
self.action_keys = actions
|
2021-03-31 22:44:56 +03:00
|
|
|
self.diff_state = diff_state
|
2021-04-01 22:06:06 +03:00
|
|
|
# TODO: Add logging
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-06-11 01:43:33 +03:00
|
|
|
logger.info(f"DDM features: {self.features}")
|
|
|
|
logger.info(f"DDM outputs: {self.labels}")
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
def episode_start(self, config: Dict[str, Any] = None):
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
initial_state = {k: random.random() for k in self.state_keys}
|
2021-06-22 05:13:53 +03:00
|
|
|
initial_action = {k: random.random() for k in self.action_keys}
|
2021-01-16 00:42:14 +03:00
|
|
|
if config:
|
2021-06-10 21:09:30 +03:00
|
|
|
logger.info(f"Initializing episode with provided config: {config}")
|
2021-01-16 00:42:14 +03:00
|
|
|
self.config = config
|
2021-06-10 21:09:30 +03:00
|
|
|
elif not config and self.episode_inits:
|
|
|
|
logger.info(
|
|
|
|
f"No episode initializations provided, using initializations in yaml `episode_inits`"
|
|
|
|
)
|
2021-06-11 01:43:33 +03:00
|
|
|
logger.info(f"Episode config: {self.episode_inits}")
|
2021-06-10 21:09:30 +03:00
|
|
|
self.config = self.episode_inits
|
2021-01-16 00:42:14 +03:00
|
|
|
else:
|
2021-06-10 21:09:30 +03:00
|
|
|
logger.warn(
|
|
|
|
"No config provided, so using random Gaussians. This probably not what you want!"
|
|
|
|
)
|
|
|
|
# request_continue = input("Are you sure you want to continue with random configs?")
|
2021-01-16 00:42:14 +03:00
|
|
|
self.config = {k: random.random() for k in self.config_keys}
|
|
|
|
self.state = initial_state
|
2021-06-22 05:13:53 +03:00
|
|
|
self.action = initial_action
|
|
|
|
# capture all data
|
|
|
|
self.all_data = {**self.state, **self.action, **self.config}
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2020-12-18 05:37:28 +03:00
|
|
|
def episode_step(self, action: Dict[str, int]):
|
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
# load design matrix for self.model.predict
|
|
|
|
# should match the shape of conf.data.inputs
|
|
|
|
# make dict of D={states, actions, configs}
|
|
|
|
# ddm_inputs = filter D \ (conf.data.inputs+conf.data.augmented_cols)
|
|
|
|
# ddm_outputs = filter D \ conf.data.outputs
|
|
|
|
# update(ddm_state) =
|
2021-01-16 00:42:14 +03:00
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
self.all_data.update(action)
|
|
|
|
|
|
|
|
ddm_input = {k: self.all_data[k] for k in self.features}
|
|
|
|
|
|
|
|
# input_list = [
|
|
|
|
# list(self.state.values()),
|
|
|
|
# list(self.config.values()),
|
|
|
|
# list(action.values()),
|
|
|
|
# ]
|
|
|
|
|
|
|
|
# input_array = [item for subl in input_list for item in subl]
|
|
|
|
input_array = list(ddm_input.values())
|
2021-01-16 00:42:14 +03:00
|
|
|
X = np.array(input_array).reshape(1, -1)
|
2021-03-31 22:44:56 +03:00
|
|
|
if self.diff_state:
|
2021-04-01 22:06:06 +03:00
|
|
|
preds = np.array(list(self.state.values())) + self.model.predict(
|
|
|
|
X
|
|
|
|
) # compensating for output being delta state st+1-st
|
2021-03-31 22:44:56 +03:00
|
|
|
# preds = np.array(list(simstate))+self.dd_model.predict(X) # if doing per iteration prediction of delta state st+1-st
|
|
|
|
else:
|
2021-04-01 22:06:06 +03:00
|
|
|
preds = self.model.predict(X) # absolute prediction
|
2021-06-22 05:13:53 +03:00
|
|
|
ddm_output = dict(zip(self.labels, preds.reshape(preds.shape[1]).tolist()))
|
|
|
|
self.all_data.update(ddm_output)
|
|
|
|
self.state = {k: self.all_data[k] for k in self.state_keys}
|
|
|
|
# self.state = dict(zip(self.state_keys, preds.reshape(preds.shape[1]).tolist()))
|
2020-12-18 05:37:28 +03:00
|
|
|
return self.state
|
|
|
|
|
2020-12-16 00:45:32 +03:00
|
|
|
def get_state(self):
|
|
|
|
|
|
|
|
return self.state
|
|
|
|
|
|
|
|
def halted(self):
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def env_setup():
|
|
|
|
"""Helper function to setup connection with Project Bonsai
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
Tuple
|
|
|
|
workspace, and access_key
|
|
|
|
"""
|
|
|
|
|
|
|
|
load_dotenv(verbose=True)
|
|
|
|
workspace = os.getenv("SIM_WORKSPACE")
|
|
|
|
access_key = os.getenv("SIM_ACCESS_KEY")
|
|
|
|
|
2021-06-11 01:43:33 +03:00
|
|
|
env_file_path = os.path.join(dir_path, ".env")
|
|
|
|
env_file_exists = os.path.exists(env_file_path)
|
2020-12-16 00:45:32 +03:00
|
|
|
if not env_file_exists:
|
2021-06-11 01:43:33 +03:00
|
|
|
open(env_file_path, "a").close()
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
if not all([env_file_exists, workspace]):
|
|
|
|
workspace = input("Please enter your workspace id: ")
|
2021-06-11 01:43:33 +03:00
|
|
|
set_key(env_file_path, "SIM_WORKSPACE", workspace)
|
2020-12-16 00:45:32 +03:00
|
|
|
if not all([env_file_exists, access_key]):
|
|
|
|
access_key = input("Please enter your access key: ")
|
2021-06-11 01:43:33 +03:00
|
|
|
set_key(env_file_path, "SIM_ACCESS_KEY", access_key)
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
load_dotenv(verbose=True, override=True)
|
|
|
|
workspace = os.getenv("SIM_WORKSPACE")
|
|
|
|
access_key = os.getenv("SIM_ACCESS_KEY")
|
|
|
|
|
|
|
|
return workspace, access_key
|
|
|
|
|
|
|
|
|
|
|
|
def test_random_policy(
|
2021-06-03 02:28:05 +03:00
|
|
|
num_episodes: int = 500, num_iterations: int = 250, sim: Simulator = None,
|
2020-12-16 00:45:32 +03:00
|
|
|
):
|
|
|
|
"""Test a policy using random actions over a fixed number of episodes
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
num_episodes : int, optional
|
|
|
|
number of iterations to run, by default 10
|
|
|
|
"""
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
def random_action():
|
|
|
|
return {k: random.random() for k in sim.action_keys}
|
|
|
|
|
2020-12-16 00:45:32 +03:00
|
|
|
for episode in range(num_episodes):
|
|
|
|
iteration = 0
|
|
|
|
terminal = False
|
2021-01-16 00:42:14 +03:00
|
|
|
sim.episode_start()
|
|
|
|
sim_state = sim.get_state()
|
2020-12-16 00:45:32 +03:00
|
|
|
while not terminal:
|
|
|
|
action = random_action()
|
|
|
|
sim.episode_step(action)
|
|
|
|
sim_state = sim.get_state()
|
|
|
|
print(f"Running iteration #{iteration} for episode #{episode}")
|
|
|
|
print(f"Observations: {sim_state}")
|
|
|
|
iteration += 1
|
|
|
|
terminal = iteration >= num_iterations
|
|
|
|
|
|
|
|
return sim
|
|
|
|
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
@hydra.main(config_path="conf", config_name="config")
|
|
|
|
def main(cfg: DictConfig):
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-03-26 06:39:35 +03:00
|
|
|
save_path = cfg["model"]["saver"]["filename"]
|
2021-01-16 00:42:14 +03:00
|
|
|
if cfg["data"]["full_or_relative"] == "relative":
|
|
|
|
save_path = os.path.join(dir_path, save_path)
|
|
|
|
model_name = cfg["model"]["name"]
|
|
|
|
states = cfg["simulator"]["states"]
|
|
|
|
actions = cfg["simulator"]["actions"]
|
|
|
|
configs = cfg["simulator"]["configs"]
|
|
|
|
policy = cfg["simulator"]["policy"]
|
2021-03-31 22:44:56 +03:00
|
|
|
logflag = cfg["simulator"]["logging"]
|
2021-04-01 22:06:06 +03:00
|
|
|
# logging not yet implemented
|
2021-04-06 23:16:55 +03:00
|
|
|
scale_data = cfg["model"]["build_params"]["scale_data"]
|
2021-03-31 22:44:56 +03:00
|
|
|
diff_state = cfg["data"]["diff_state"]
|
2021-06-10 21:09:30 +03:00
|
|
|
workspace_setup = cfg["simulator"]["workspace_setup"]
|
|
|
|
episode_inits = cfg["simulator"]["episode_inits"]
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
input_cols = cfg["data"]["inputs"]
|
|
|
|
output_cols = cfg["data"]["outputs"]
|
|
|
|
augmented_cols = cfg["data"]["augmented_cols"]
|
|
|
|
if type(input_cols) == ListConfig:
|
|
|
|
input_cols = list(input_cols)
|
|
|
|
if type(output_cols) == ListConfig:
|
|
|
|
output_cols = list(output_cols)
|
|
|
|
if type(augmented_cols) == ListConfig:
|
|
|
|
augmented_cols = list(augmented_cols)
|
2021-01-16 00:42:14 +03:00
|
|
|
|
2021-06-22 05:13:53 +03:00
|
|
|
input_cols = input_cols + augmented_cols
|
|
|
|
|
|
|
|
logger.info(f"Training with a new {policy} policy")
|
2021-06-09 22:38:33 +03:00
|
|
|
if model_name.lower() == "torch":
|
|
|
|
from all_models import available_models
|
|
|
|
else:
|
|
|
|
from model_loader import available_models
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
Model = available_models[model_name]
|
|
|
|
model = Model()
|
|
|
|
|
|
|
|
model.load_model(filename=save_path, scale_data=scale_data)
|
2021-06-03 02:28:05 +03:00
|
|
|
# model.build_model(**cfg["model"]["build_params"])
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
# Grab standardized way to interact with sim API
|
2021-06-22 05:13:53 +03:00
|
|
|
sim = Simulator(
|
|
|
|
model,
|
|
|
|
states,
|
|
|
|
actions,
|
|
|
|
configs,
|
|
|
|
input_cols,
|
|
|
|
output_cols,
|
|
|
|
episode_inits,
|
|
|
|
diff_state,
|
|
|
|
)
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
# do a random action to get initial state
|
|
|
|
sim.episode_start()
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
if policy == "random":
|
2021-01-16 01:03:31 +03:00
|
|
|
test_random_policy(1000, 250, sim)
|
2021-01-16 00:42:14 +03:00
|
|
|
elif policy == "bonsai":
|
2021-06-10 21:09:30 +03:00
|
|
|
if workspace_setup:
|
2021-06-10 00:51:56 +03:00
|
|
|
logger.info(f"Loading workspace information form .env")
|
|
|
|
env_setup()
|
|
|
|
load_dotenv(verbose=True, override=True)
|
2021-01-16 00:42:14 +03:00
|
|
|
# Configure client to interact with Bonsai service
|
|
|
|
config_client = BonsaiClientConfig()
|
|
|
|
client = BonsaiClient(config_client)
|
|
|
|
|
|
|
|
# Create simulator session and init sequence id
|
|
|
|
registration_info = SimulatorInterface(
|
|
|
|
name=env_name,
|
|
|
|
timeout=60,
|
|
|
|
simulator_context=config_client.simulator_context,
|
|
|
|
)
|
2020-12-16 00:45:32 +03:00
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
def CreateSession(
|
|
|
|
registration_info: SimulatorInterface, config_client: BonsaiClientConfig
|
|
|
|
):
|
2021-04-23 03:46:17 +03:00
|
|
|
"""Creates a new Simulator Session and returns new session, sequenceId"""
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
try:
|
|
|
|
print(
|
2021-01-16 00:42:14 +03:00
|
|
|
"config: {}, {}".format(
|
|
|
|
config_client.server, config_client.workspace
|
|
|
|
)
|
2020-12-16 00:45:32 +03:00
|
|
|
)
|
2021-01-16 00:42:14 +03:00
|
|
|
registered_session: SimulatorSessionResponse = client.session.create(
|
|
|
|
workspace_name=config_client.workspace, body=registration_info
|
|
|
|
)
|
|
|
|
print("Registered simulator. {}".format(registered_session.session_id))
|
|
|
|
|
|
|
|
return registered_session, 1
|
2021-04-06 23:59:50 +03:00
|
|
|
# except HttpResponseError as ex:
|
|
|
|
# print(
|
|
|
|
# "HttpResponseError in Registering session: StatusCode: {}, Error: {}, Exception: {}".format(
|
|
|
|
# ex.status_code, ex.error.message, ex
|
|
|
|
# )
|
|
|
|
# )
|
|
|
|
# raise ex
|
2021-01-16 00:42:14 +03:00
|
|
|
except Exception as ex:
|
|
|
|
print(
|
|
|
|
"UnExpected error: {}, Most likely, it's some network connectivity issue, make sure you are able to reach bonsai platform from your network.".format(
|
|
|
|
ex
|
|
|
|
)
|
2020-12-16 00:45:32 +03:00
|
|
|
)
|
2021-01-16 00:42:14 +03:00
|
|
|
raise ex
|
|
|
|
|
|
|
|
registered_session, sequence_id = CreateSession(
|
|
|
|
registration_info, config_client
|
2020-12-16 00:45:32 +03:00
|
|
|
)
|
2021-01-16 00:42:14 +03:00
|
|
|
episode = 0
|
|
|
|
iteration = 0
|
|
|
|
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
# Advance by the new state depending on the event type
|
|
|
|
sim_state = SimulatorState(
|
2021-06-03 02:28:05 +03:00
|
|
|
sequence_id=sequence_id, state=sim.get_state(), halted=sim.halted(),
|
2021-01-16 00:42:14 +03:00
|
|
|
)
|
|
|
|
try:
|
|
|
|
event = client.session.advance(
|
|
|
|
workspace_name=config_client.workspace,
|
|
|
|
session_id=registered_session.session_id,
|
|
|
|
body=sim_state,
|
|
|
|
)
|
|
|
|
sequence_id = event.sequence_id
|
|
|
|
print(
|
|
|
|
"[{}] Last Event: {}".format(
|
|
|
|
time.strftime("%H:%M:%S"), event.type
|
|
|
|
)
|
|
|
|
)
|
2021-04-06 23:59:50 +03:00
|
|
|
# UPDATE #comment-out-azure-cli:
|
|
|
|
# - commented out the HttpResponseError since it relies on azure-cli-core which has
|
|
|
|
# - conflicting dependencies with microsoft-bonsai-api
|
|
|
|
# - the catch-all exception below should still re-connect on disconnects
|
|
|
|
# except HttpResponseError as ex:
|
|
|
|
# print(
|
|
|
|
# "HttpResponseError in Advance: StatusCode: {}, Error: {}, Exception: {}".format(
|
|
|
|
# ex.status_code, ex.error.message, ex
|
|
|
|
# )
|
|
|
|
# )
|
|
|
|
# # This can happen in network connectivity issue, though SDK has retry logic, but even after that request may fail,
|
|
|
|
# # if your network has some issue, or sim session at platform is going away..
|
|
|
|
# # So let's re-register sim-session and get a new session and continue iterating. :-)
|
|
|
|
# registered_session, sequence_id = CreateSession(
|
|
|
|
# registration_info, config_client
|
|
|
|
# )
|
|
|
|
# continue
|
2021-01-16 00:42:14 +03:00
|
|
|
except Exception as err:
|
|
|
|
print("Unexpected error in Advance: {}".format(err))
|
|
|
|
# Ideally this shouldn't happen, but for very long-running sims It can happen with various reasons, let's re-register sim & Move on.
|
|
|
|
# If possible try to notify Bonsai team to see, if this is platform issue and can be fixed.
|
|
|
|
registered_session, sequence_id = CreateSession(
|
|
|
|
registration_info, config_client
|
|
|
|
)
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Event loop
|
|
|
|
if event.type == "Idle":
|
|
|
|
time.sleep(event.idle.callback_time)
|
|
|
|
print("Idling...")
|
|
|
|
elif event.type == "EpisodeStart":
|
|
|
|
print(event.episode_start.config)
|
|
|
|
sim.episode_start(event.episode_start.config)
|
|
|
|
episode += 1
|
|
|
|
elif event.type == "EpisodeStep":
|
|
|
|
iteration += 1
|
|
|
|
sim.episode_step(event.episode_step.action)
|
|
|
|
elif event.type == "EpisodeFinish":
|
|
|
|
print("Episode Finishing...")
|
|
|
|
iteration = 0
|
|
|
|
elif event.type == "Unregister":
|
|
|
|
print(
|
|
|
|
"Simulator Session unregistered by platform because '{}', Registering again!".format(
|
|
|
|
event.unregister.details
|
|
|
|
)
|
|
|
|
)
|
|
|
|
registered_session, sequence_id = CreateSession(
|
|
|
|
registration_info, config_client
|
|
|
|
)
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
pass
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
# Gracefully unregister with keyboard interrupt
|
|
|
|
client.session.delete(
|
|
|
|
workspace_name=config_client.workspace,
|
|
|
|
session_id=registered_session.session_id,
|
|
|
|
)
|
|
|
|
print("Unregistered simulator.")
|
|
|
|
except Exception as err:
|
|
|
|
# Gracefully unregister for any other exceptions
|
|
|
|
client.session.delete(
|
|
|
|
workspace_name=config_client.workspace,
|
|
|
|
session_id=registered_session.session_id,
|
|
|
|
)
|
|
|
|
print("Unregistered simulator because: {}".format(err))
|
2020-12-16 00:45:32 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
2021-01-16 00:42:14 +03:00
|
|
|
main()
|