generalized test_policy. Now has exported brain and random. changed min/max in sample

This commit is contained in:
Journey McDowell 2021-08-13 13:46:12 -07:00
Родитель acedf1e98d
Коммит 268e1827cd
3 изменённых файлов: 48 добавлений и 14 удалений

Просмотреть файл

@ -16,20 +16,20 @@ simulator:
initial_states: initial_states:
cart_position: cart_position:
inkling_name: initial_cart_position inkling_name: initial_cart_position
min: 0 min: -0.05
max: 1 max: 0.05
cart_velocity: cart_velocity:
inkling_name: initial_cart_velocity inkling_name: initial_cart_velocity
min: 0 min: -0.05
max: 1 max: 0.05
pole_angle: pole_angle:
inkling_name: initial_pole_angle inkling_name: initial_pole_angle
min: 0 min: -0.05
max: 1 max: 0.05
pole_angular_velocity: pole_angular_velocity:
inkling_name: initial_angular_velocity inkling_name: initial_angular_velocity
min: 0 min: -0.05
max: 1 max: 0.05
policy: bonsai policy: bonsai
logging: enable logging: enable
workspace_setup: True workspace_setup: True

Просмотреть файл

@ -4,6 +4,8 @@ import random
import time import time
from typing import Any, Dict, List from typing import Any, Dict, List
from omegaconf import ListConfig from omegaconf import ListConfig
from functools import partial
from policies import random_policy, brain_policy
import numpy as np import numpy as np
@ -224,11 +226,12 @@ def env_setup():
return workspace, access_key return workspace, access_key
def test_random_policy( def test_policy(
num_episodes: int = 5, num_episodes: int = 5,
num_iterations: int = 5, num_iterations: int = 5,
sim: Simulator = None, sim: Simulator = None,
config: Dict[str, float] = None, config: Dict[str, float] = None,
policy=random_policy,
): ):
"""Test a policy using random actions over a fixed number of episodes """Test a policy using random actions over a fixed number of episodes
@ -238,9 +241,6 @@ def test_random_policy(
number of iterations to run, by default 10 number of iterations to run, by default 10
""" """
def random_action():
return {k: random.random() for k in sim.action_keys}
def _config_clean(in_config: Dict): def _config_clean(in_config: Dict):
new_config = {} new_config = {}
@ -259,7 +259,7 @@ def test_random_policy(
sim.episode_start(new_config) sim.episode_start(new_config)
sim_state = sim.get_state() sim_state = sim.get_state()
while not terminal: while not terminal:
action = random_action() action = policy(sim_state)
sim.episode_step(action) sim.episode_step(action)
sim_state = sim.get_state() sim_state = sim.get_state()
logger.info(f"Running iteration #{iteration} for episode #{episode}") logger.info(f"Running iteration #{iteration} for episode #{episode}")
@ -337,7 +337,15 @@ def main(cfg: DictConfig):
sim.episode_start() sim.episode_start()
if policy == "random": if policy == "random":
test_random_policy(sim=sim, config={**episode_inits, **initial_states}) random_policy_from_keys = partial(random_policy, action_keys=sim.action_keys)
test_policy(sim=sim, config={**episode_inits, **initial_states}, policy=random_policy_from_keys)
elif isinstance(policy, int):
# If docker PORT provided, set as exported brain PORT
port = policy
url = f"http://localhost:{port}"
print(f"Connecting to exported brain running at {url}...")
trained_brain_policy = partial(brain_policy, exported_brain_url=url)
test_policy(sim=sim, config={**episode_inits, **initial_states}, policy=trained_brain_policy)
elif policy == "bonsai": elif policy == "bonsai":
if workspace_setup: if workspace_setup:
logger.info(f"Loading workspace information form .env") logger.info(f"Loading workspace information form .env")

26
policies.py Normal file
Просмотреть файл

@ -0,0 +1,26 @@
"""
Fixed policies to test our sim integration with. These are intended to take
Brain states and return Brain actions.
"""
import random
from typing import Dict
import requests
def random_policy(state, action_keys):
"""
Ignore the state, move randomly.
"""
action = {k: random.random() for k in action_keys}
return action
def brain_policy(
state: Dict[str, float], exported_brain_url: str = "http://localhost:5000"
):
prediction_endpoint = f"{exported_brain_url}/v1/prediction"
response = requests.get(prediction_endpoint, json=state)
return response.json()