datadrivenmodel/train_bonsai_main.py

291 строка
10 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
MSFT Bonsai SDK3 Template for Simulator Integration using Python
Copyright 2020 Microsoft
Usage:
For registering simulator with the Bonsai service for training:
python main.py --api-host https://api.bons.ai \
--workspace <workspace_id> \
--accesskey="<access_key> \
Then connect your registered simulator to a Brain via UI
Alternatively, one can set the SIM_ACCESS_KEY and SIM_WORKSPACE as
environment variables.
"""
import time
import os
from dotenv import load_dotenv, set_key
from typing import Dict, Any, Optional
from microsoft_bonsai_api.simulator.client import BonsaiClientConfig, BonsaiClient
from microsoft_bonsai_api.simulator.generated.models import (
SimulatorState,
SimulatorInterface,
)
from azure.core.exceptions import HttpResponseError
import numpy as np
import yaml
from predictor import ModelPredictor
class TemplateSimulatorSession:
def __init__(self, limit_checks: bool = False):
with open("config/config_model.yml") as cmfile:
self.model_config = yaml.full_load(cmfile)
# Obtain model limitations
self.limit_checks = limit_checks
if self.limit_checks:
with open("config/model_limits.yml") as mlimfile:
self.model_limits = yaml.full_load(mlimfile)
state_space_dim = 18
action_space_dim = 7
# self.state = []
# self.action = []
self.state = [0] * state_space_dim
self.action = [0] * action_space_dim
for key, value in self.model_config["IO"]["feature_name"].items():
if value == "state":
state_space_dim += 1
# Setting values to initialize, to be replaced by episode_start
# self.state.append(self.model_limits[key]["mean"])
elif value == "action":
action_space_dim += 1
# Setting values to initialize, to be replaced by episode_start
# self.action.append(self.model_limits[key]["mean"])
else:
print("Please fix config_model.yml to specify either state or action")
exit()
# TODO: fix this to input_shape and output_shape
self.predictor = ModelPredictor(
modeltype=self.model_config["MODEL"]["type"],
noise_percentage=0,
state_space_dim=state_space_dim,
action_space_dim=action_space_dim,
markovian_order=self.model_config["LSTM"]["markovian_order"],
)
def get_state(self) -> Dict[str, Any]:
"""Called to retreive the current state of the simulator.
Returns
-------
dictionary
Dictionary of sim_state elements at current iteration
"""
state_dict = {}
i = 0
for key, value in self.model_config["IO"]["feature_name"].items():
if value == "state":
state_dict[key] = float(self.state[i])
i += 1
return state_dict
def episode_start(self, config: Dict[str, Any]):
"""Method invoked at the start of each episode with a given
episode configuration.
Parameters
----------
config : Dict[str, Any]
SimConfig parameters for the current episode defined in Inkling
"""
self.state = self.predictor.reset_state(config)
try:
self.predictor.noise_percentage = config["noise_percentage"]
except:
pass
def episode_step(self, action: Dict[str, Any]):
"""Called for each step of the episode
Parameters
----------
action : Dict[str, Any]
BrainAction chosen from the Bonsai Service, prediction or exploration
"""
self.action = []
for key, value in self.model_config["IO"]["feature_name"].items():
if value == "action":
self.action.append(action[key])
self.state = self.predictor.predict(
state=self.state,
action=self.action,
)
def halted(self) -> bool:
""" Should return True if the simulator cannot continue"""
features = np.concatenate([self.state, self.action])
num_tripped = self.predictor.warn_limitation(features)
if num_tripped > 0:
return True
return False
def env_setup():
"""Helper function to setup connection with Project Bonsai
Returns
-------
Tuple
workspace, and access_key
"""
load_dotenv(verbose=True)
workspace = os.getenv("SIM_WORKSPACE")
access_key = os.getenv("SIM_ACCESS_KEY")
env_file_exists = os.path.exists(".env")
if not env_file_exists:
open(".env", "a").close()
if not all([env_file_exists, workspace]):
workspace = input("Please enter your workspace id: ")
set_key(".env", "SIM_WORKSPACE", workspace)
if not all([env_file_exists, access_key]):
access_key = input("Please enter your access key: ")
set_key(".env", "SIM_ACCESS_KEY", access_key)
load_dotenv(verbose=True, override=True)
workspace = os.getenv("SIM_WORKSPACE")
access_key = os.getenv("SIM_ACCESS_KEY")
return workspace, access_key
def main():
# Grab standardized way to interact with sim API
env_setup()
load_dotenv(verbose=True, override=True)
sim = TemplateSimulatorSession()
# Configure client to interact with Bonsai service
config_client = BonsaiClientConfig()
client = BonsaiClient(config_client)
# Create simulator session and init sequence id
registration_info = SimulatorInterface(
name="YourSimDDM",
timeout=60,
simulator_context=config_client.simulator_context,
)
def CreateSession(
registration_info: SimulatorInterface, config_client: BonsaiClientConfig
):
"""Creates a new Simulator Session and returns new session, sequenceId"""
try:
print(
"config: {}, {}".format(config_client.server, config_client.workspace)
)
registered_session: SimulatorSessionResponse = client.session.create(
workspace_name=config_client.workspace, body=registration_info
)
print("Registered simulator. {}".format(registered_session.session_id))
return registered_session, 1
except HttpResponseError as ex:
print(
"HttpResponseError in Registering session: StatusCode: {}, Error: {}, Exception: {}".format(
ex.status_code, ex.error.message, ex
)
)
raise ex
except Exception as ex:
print(
"UnExpected error: {}, Most likely, It's some network connectivity issue, make sure, you are able to reach bonsai platform from your PC.".format(
ex
)
)
raise ex
registered_session, sequence_id = CreateSession(registration_info, config_client)
sequence_id = 1
try:
while True:
# Advance by the new state depending on the event type
sim_state = SimulatorState(
sequence_id=sequence_id, state=sim.get_state(), halted=sim.halted()
)
try:
event = client.session.advance(
workspace_name=config_client.workspace,
session_id=registered_session.session_id,
body=sim_state,
)
sequence_id = event.sequence_id
print(
"[{}] Last Event: {}".format(time.strftime("%H:%M:%S"), event.type)
)
except HttpResponseError as ex:
print(
"HttpResponseError in Advance: StatusCode: {}, Error: {}, Exception: {}".format(
ex.status_code, ex.error.message, ex
)
)
# This can happen in network connectivity issue, though SDK has retry logic, but even after that request may fail,
# if your network has some issue, or sim session at platform is going away..
# So let's re-register sim-session and get a new session and continue iterating. :-)
registered_session, sequence_id = CreateSession(
registration_info, config_client
)
continue
except Exception as err:
print("Unexpected error in Advance: {}".format(err))
# Ideally this shouldn't happen, but for very long-running sims It can happen with various reasons, let's re-register sim & Move on.
# If possible try to notify Bonsai team to see, if this is platform issue and can be fixed.
registered_session, sequence_id = CreateSession(
registration_info, config_client
)
continue
# Event loop
if event.type == "Idle":
time.sleep(event.idle.callback_time)
print("Idling...")
elif event.type == "EpisodeStart":
sim.episode_start(event.episode_start.config)
elif event.type == "EpisodeStep":
sim.episode_step(event.episode_step.action)
elif event.type == "EpisodeFinish":
print("Episode Finishing...")
elif event.type == "Unregister":
print("Simulator Session unregistered by platform because '{}', Registering again!".format(event.unregister.details))
registered_session, sequence_id = CreateSession(
registration_info, config_client
)
continue
else:
pass
except KeyboardInterrupt:
# Gracefully unregister with keyboard interrupt
client.session.delete(
workspace_name=config_client.workspace,
session_id=registered_session.session_id,
)
print("Unregistered simulator.")
except Exception as err:
# Gracefully unregister for any other exceptions
client.session.delete(
workspace_name=config_client.workspace,
session_id=registered_session.session_id,
)
print("Unregistered simulator because: {}".format(err))
if __name__ == "__main__":
main()