This commit is contained in:
ysqyang 2020-01-13 18:27:35 +08:00
Родитель 4c6e11229f
Коммит 915311587a
3 изменённых файлов: 22 добавлений и 13 удалений

Просмотреть файл

@ -173,7 +173,7 @@ class EnvRunner(Runner):
Load policy net parameters for the given agent's algorithm
"""
if msg.body[MsgKey.POLICY_NET_PARAMETERS] != None:
if msg.body[MsgKey.POLICY_NET_PARAMETERS] is not None:
self._agent_dict[msg.body[MsgKey.AGENT_ID]].load_policy_net_parameters(
msg.body[MsgKey.POLICY_NET_PARAMETERS])
@ -182,15 +182,19 @@ class EnvRunner(Runner):
Waiting for all agents have the updated policy net parameters, and message may
contain the policy net parameters.
"""
print('force syncing...')
pending_updated_agents = len(self._agent_idx_list)
for msg in self._proxy.receive():
print(f'received a {msg.type} message from {msg.source}')
if msg.type == MsgType.UPDATED_PARAMETERS:
self.on_updated_parameters(msg)
pending_updated_agents -= 1
elif msg.type == MsgType.NOT_READY_FOR_TRAINING:
pending_updated_agents -= 1
else:
raise Exception(f'Unrecognized message type: {msg.type}')
if not pending_updated_agents:
if pending_updated_agents == 0:
break

Просмотреть файл

@ -60,7 +60,7 @@ proxy = Proxy(group_name=os.environ['GROUP'],
redis_address=(config.redis.host, config.redis.port),
logger=logger)
pending_envs = set(proxy.peers) # environments the learner expects experiences from
pending_envs = set(proxy.peers) # environments the learner expects experiences from, required for forced sync
if DASHBOARD_ENABLE:
dashboard = DashboardECR(config.experiment_name, LOG_FOLDER)
@ -78,15 +78,19 @@ def on_new_experience(local_instance, proxy, message):
if message.source in pending_envs:
pending_envs.remove(message.source)
if len(pending_envs) == 0 and local_instance.experience_pool.size['info'] > MIN_TRAIN_EXP_NUM:
local_instance.train(message.body[MsgKey.EPISODE], message.body[MsgKey.AGENT_NAME])
policy_net_parameters = local_instance.algorithm.policy_net.state_dict()
# send updated policy net parameters to the target environment runner
message = Message(type=MsgType.UPDATED_PARAMETERS, source=proxy.name,
destination=message.source,
body={MsgKey.AGENT_ID: message.body[MsgKey.AGENT_ID],
MsgKey.POLICY_NET_PARAMETERS: policy_net_parameters})
proxy.send(message)
if len(pending_envs) == 0:
if local_instance.experience_pool.size['info'] > 0:
local_instance.train(message.body[MsgKey.EPISODE], message.body[MsgKey.AGENT_NAME])
policy_net_parameters = local_instance.algorithm.policy_net.state_dict()
# send updated policy net parameters to the target environment runner
for env in proxy.peers:
proxy.send(Message(type=MsgType.UPDATED_PARAMETERS, source=proxy.name, destination=env,
body={MsgKey.AGENT_ID: message.body[MsgKey.AGENT_ID],
MsgKey.POLICY_NET_PARAMETERS: policy_net_parameters}))
else:
for env in proxy.peers:
proxy.send(Message(type=MsgType.NOT_READY_FOR_TRAINING, source=proxy.name, destination=env))
pending_envs.update(proxy.peers) # reset pending environments to the full list
else:
logger.info(f'Pending experiences from {pending_envs}')

Просмотреть файл

@ -9,7 +9,8 @@ class MsgType(Enum):
STORE_EXPERIENCE = 0 # message contains actual experience data
INITIAL_PARAMETERS = 1 # message contains model's parameter
UPDATED_PARAMETERS = 2 # message notify the learner is ready for training
ENV_CHECKOUT = 3 # message notify the environment is finish and checkout
NOT_READY_FOR_TRAINING = 3 # message to indicate that the learner has not collected enough experiences for training
ENV_CHECKOUT = 4 # message notify the environment is finish and checkout
class MsgKey(Enum):