175 строки
5.2 KiB
Python
175 строки
5.2 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
'''
|
||
|
Script for visualizing learned policies in the gym Atari environment.
|
||
|
'''
|
||
|
|
||
|
import argparse
|
||
|
from collections import defaultdict
|
||
|
import gym
|
||
|
import importlib
|
||
|
import numpy as np
|
||
|
import os
|
||
|
import pickle
|
||
|
import time
|
||
|
|
||
|
import ray
|
||
|
from ray.tune.registry import get_trainable_cls
|
||
|
from supersuit import frame_stack_v1, color_reduction_v0, resize_v0, normalize_obs_v0, dtype_v0
|
||
|
|
||
|
import algorithms
|
||
|
import environments
|
||
|
|
||
|
|
||
|
class RLLibPolicy:
|
||
|
""" Represents a single policy contained in an RLLib Trainer """
|
||
|
|
||
|
def __init__(self, trainer, policy_id):
|
||
|
self._trainer = trainer
|
||
|
self._policy_id = policy_id
|
||
|
|
||
|
# Get the local policy object for the given ID
|
||
|
policy = trainer.get_policy(policy_id)
|
||
|
|
||
|
# Sample an action from the action space for this policy to act as the previous action for the first step
|
||
|
self._initial_action = 0
|
||
|
|
||
|
# Get the initial state for a recurrent policy if needed
|
||
|
initial_rnn_state = policy.get_initial_state()
|
||
|
|
||
|
if initial_rnn_state is not None and len(initial_rnn_state) > 0:
|
||
|
self._initial_rnn_state = initial_rnn_state
|
||
|
else:
|
||
|
self._initial_rnn_state = None
|
||
|
|
||
|
# Initialize the policy - only affects the wrapper, not the underlying policy
|
||
|
self.reset()
|
||
|
|
||
|
def reset(self):
|
||
|
self._prev_action = self._initial_action
|
||
|
self._prev_rnn_state = self._initial_rnn_state
|
||
|
|
||
|
def action(self, obs, prev_reward=0.0):
|
||
|
if self._initial_rnn_state is not None:
|
||
|
self._prev_action, self._prev_state, _ = self._trainer.compute_action(obs,
|
||
|
state=self._prev_rnn_state,
|
||
|
prev_action=self._prev_action,
|
||
|
prev_reward=prev_reward,
|
||
|
policy_id=self._policy_id)
|
||
|
else:
|
||
|
self._prev_action = self._trainer.compute_action(obs,
|
||
|
prev_action=self._prev_action,
|
||
|
prev_reward=prev_reward,
|
||
|
policy_id=self._policy_id)
|
||
|
|
||
|
return self._prev_action
|
||
|
|
||
|
|
||
|
def load_last_checkpoint(run, trainer_cls):
|
||
|
|
||
|
# Build trainable with appropriate configuration
|
||
|
with open(os.path.join(run, "params.pkl"), "rb") as f:
|
||
|
config = pickle.load(f)
|
||
|
|
||
|
config["num_workers"] = 0
|
||
|
config["num_gpus"] = 0
|
||
|
|
||
|
# Because RLLib is stupid, a log directory is required even when we are using a NoopLogger
|
||
|
trainer = trainer_cls(config=config)
|
||
|
|
||
|
# Get checkpoint IDs
|
||
|
checkpoint_ids = []
|
||
|
|
||
|
for obj in os.listdir(run):
|
||
|
if obj.startswith("checkpoint_"):
|
||
|
checkpoint_ids.append(int(obj[11:]))
|
||
|
|
||
|
checkpoint_ids.sort()
|
||
|
|
||
|
# Load final checkpoint
|
||
|
checkpoint = str(checkpoint_ids[-1])
|
||
|
|
||
|
# Don't restore, see if this affects the trainer config
|
||
|
checkpoint = os.path.join(run, "checkpoint_" + checkpoint, "checkpoint-" + checkpoint)
|
||
|
trainer.restore(checkpoint)
|
||
|
|
||
|
return trainer, config
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser("Visualizes a set of trained policies in the Gym Atari environments")
|
||
|
|
||
|
parser.add_argument("trial", type=str,
|
||
|
help="path to the training run to visualize")
|
||
|
parser.add_argument("-a", "--alg", type=str, default="PPO",
|
||
|
help="name of the Trainable class from which the checkpoints were generated")
|
||
|
parser.add_argument("-e", "--episodes", type=int, default=20,
|
||
|
help="the number of episodes to roll out")
|
||
|
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
|
||
|
# Initialize ray
|
||
|
ray.init(num_cpus=1)
|
||
|
|
||
|
# Get trainable class
|
||
|
trainer_cls = get_trainable_cls(args.alg)
|
||
|
trainer, config = load_last_checkpoint(args.trial, trainer_cls)
|
||
|
|
||
|
# Get environment config
|
||
|
env_config = config["env_config"].copy()
|
||
|
env_name = env_config.pop("game")
|
||
|
frame_stack = env_config.pop("frame_stack", 4)
|
||
|
color_reduction = env_config.pop("color_reduction", True)
|
||
|
|
||
|
agent_id = env_config.pop("agent_id", "first_0")
|
||
|
|
||
|
# Build Atari environment
|
||
|
env = gym.make(env_name)
|
||
|
|
||
|
# Wrap environment in preprocessors
|
||
|
wrapped_env = env
|
||
|
|
||
|
if color_reduction:
|
||
|
wrapped_env = color_reduction_v0(wrapped_env, mode='full')
|
||
|
|
||
|
wrapped_env = resize_v0(wrapped_env, 84, 84)
|
||
|
wrapped_env = dtype_v0(wrapped_env, dtype=np.float32)
|
||
|
wrapped_env = frame_stack_v1(wrapped_env, frame_stack)
|
||
|
wrapped_env = normalize_obs_v0(wrapped_env, env_min=0, env_max=1)
|
||
|
|
||
|
# Reset environment
|
||
|
obs = wrapped_env.reset()
|
||
|
env.render()
|
||
|
|
||
|
# Initialize policies
|
||
|
policy = RLLibPolicy(trainer, f"policy_{agent_id}")
|
||
|
|
||
|
# Roll-out and visualize policies
|
||
|
step = 0
|
||
|
episode = 0
|
||
|
|
||
|
while episode < args.episodes:
|
||
|
|
||
|
print(f"episode {episode}, step {step}")
|
||
|
step += 1
|
||
|
|
||
|
action = policy.action(obs)
|
||
|
obs, reward, done, info = wrapped_env.step(action)
|
||
|
|
||
|
# Render environment
|
||
|
env.render()
|
||
|
time.sleep(0.01)
|
||
|
|
||
|
# Reset if necessary
|
||
|
if done:
|
||
|
obs = wrapped_env.reset()
|
||
|
episode += 1
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = parse_args()
|
||
|
main(args)
|