diff --git a/conf/simulator/cartpole-updated-simparam.yaml b/conf/simulator/cartpole-updated-simparam.yaml index 9b9832e..a1ff24e 100644 --- a/conf/simulator/cartpole-updated-simparam.yaml +++ b/conf/simulator/cartpole-updated-simparam.yaml @@ -16,20 +16,20 @@ simulator: initial_states: cart_position: inkling_name: initial_cart_position - min: 0 - max: 1 + min: -0.05 + max: 0.05 cart_velocity: inkling_name: initial_cart_velocity - min: 0 - max: 1 + min: -0.05 + max: 0.05 pole_angle: inkling_name: initial_pole_angle - min: 0 - max: 1 + min: -0.05 + max: 0.05 pole_angular_velocity: inkling_name: initial_angular_velocity - min: 0 - max: 1 + min: -0.05 + max: 0.05 policy: bonsai logging: enable workspace_setup: True diff --git a/ddm_predictor.py b/ddm_predictor.py index ae3ca9e..8f4e358 100644 --- a/ddm_predictor.py +++ b/ddm_predictor.py @@ -4,6 +4,8 @@ import random import time from typing import Any, Dict, List from omegaconf import ListConfig +from functools import partial +from policies import random_policy, brain_policy import numpy as np @@ -224,11 +226,12 @@ def env_setup(): return workspace, access_key -def test_random_policy( +def test_policy( num_episodes: int = 5, num_iterations: int = 5, sim: Simulator = None, config: Dict[str, float] = None, + policy=random_policy, ): """Test a policy using random actions over a fixed number of episodes @@ -238,9 +241,6 @@ def test_random_policy( number of iterations to run, by default 10 """ - def random_action(): - return {k: random.random() for k in sim.action_keys} - def _config_clean(in_config: Dict): new_config = {} @@ -259,7 +259,7 @@ def test_random_policy( sim.episode_start(new_config) sim_state = sim.get_state() while not terminal: - action = random_action() + action = policy(sim_state) sim.episode_step(action) sim_state = sim.get_state() logger.info(f"Running iteration #{iteration} for episode #{episode}") @@ -337,7 +337,15 @@ def main(cfg: DictConfig): sim.episode_start() if policy == "random": - test_random_policy(sim=sim, config={**episode_inits, **initial_states}) + random_policy_from_keys = partial(random_policy, action_keys=sim.action_keys) + test_policy(sim=sim, config={**episode_inits, **initial_states}, policy=random_policy_from_keys) + elif isinstance(policy, int): + # If docker PORT provided, set as exported brain PORT + port = policy + url = f"http://localhost:{port}" + print(f"Connecting to exported brain running at {url}...") + trained_brain_policy = partial(brain_policy, exported_brain_url=url) + test_policy(sim=sim, config={**episode_inits, **initial_states}, policy=trained_brain_policy) elif policy == "bonsai": if workspace_setup: logger.info(f"Loading workspace information form .env") diff --git a/policies.py b/policies.py new file mode 100644 index 0000000..5ff1347 --- /dev/null +++ b/policies.py @@ -0,0 +1,26 @@ +""" +Fixed policies to test our sim integration with. These are intended to take +Brain states and return Brain actions. +""" + +import random +from typing import Dict +import requests + + +def random_policy(state, action_keys): + """ + Ignore the state, move randomly. + """ + action = {k: random.random() for k in action_keys} + return action + + +def brain_policy( + state: Dict[str, float], exported_brain_url: str = "http://localhost:5000" +): + + prediction_endpoint = f"{exported_brain_url}/v1/prediction" + response = requests.get(prediction_endpoint, json=state) + + return response.json() \ No newline at end of file