revised forced sync logic
This commit is contained in:
Родитель
4c6e11229f
Коммит
915311587a
|
@ -173,7 +173,7 @@ class EnvRunner(Runner):
|
|||
|
||||
Load policy net parameters for the given agent's algorithm
|
||||
"""
|
||||
if msg.body[MsgKey.POLICY_NET_PARAMETERS] != None:
|
||||
if msg.body[MsgKey.POLICY_NET_PARAMETERS] is not None:
|
||||
self._agent_dict[msg.body[MsgKey.AGENT_ID]].load_policy_net_parameters(
|
||||
msg.body[MsgKey.POLICY_NET_PARAMETERS])
|
||||
|
||||
|
@ -182,15 +182,19 @@ class EnvRunner(Runner):
|
|||
Waiting for all agents have the updated policy net parameters, and message may
|
||||
contain the policy net parameters.
|
||||
"""
|
||||
print('force syncing...')
|
||||
pending_updated_agents = len(self._agent_idx_list)
|
||||
for msg in self._proxy.receive():
|
||||
print(f'received a {msg.type} message from {msg.source}')
|
||||
if msg.type == MsgType.UPDATED_PARAMETERS:
|
||||
self.on_updated_parameters(msg)
|
||||
pending_updated_agents -= 1
|
||||
elif msg.type == MsgType.NOT_READY_FOR_TRAINING:
|
||||
pending_updated_agents -= 1
|
||||
else:
|
||||
raise Exception(f'Unrecognized message type: {msg.type}')
|
||||
|
||||
if not pending_updated_agents:
|
||||
if pending_updated_agents == 0:
|
||||
break
|
||||
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ proxy = Proxy(group_name=os.environ['GROUP'],
|
|||
redis_address=(config.redis.host, config.redis.port),
|
||||
logger=logger)
|
||||
|
||||
pending_envs = set(proxy.peers) # environments the learner expects experiences from
|
||||
pending_envs = set(proxy.peers) # environments the learner expects experiences from, required for forced sync
|
||||
|
||||
if DASHBOARD_ENABLE:
|
||||
dashboard = DashboardECR(config.experiment_name, LOG_FOLDER)
|
||||
|
@ -78,15 +78,19 @@ def on_new_experience(local_instance, proxy, message):
|
|||
|
||||
if message.source in pending_envs:
|
||||
pending_envs.remove(message.source)
|
||||
if len(pending_envs) == 0 and local_instance.experience_pool.size['info'] > MIN_TRAIN_EXP_NUM:
|
||||
local_instance.train(message.body[MsgKey.EPISODE], message.body[MsgKey.AGENT_NAME])
|
||||
policy_net_parameters = local_instance.algorithm.policy_net.state_dict()
|
||||
# send updated policy net parameters to the target environment runner
|
||||
message = Message(type=MsgType.UPDATED_PARAMETERS, source=proxy.name,
|
||||
destination=message.source,
|
||||
body={MsgKey.AGENT_ID: message.body[MsgKey.AGENT_ID],
|
||||
MsgKey.POLICY_NET_PARAMETERS: policy_net_parameters})
|
||||
proxy.send(message)
|
||||
if len(pending_envs) == 0:
|
||||
if local_instance.experience_pool.size['info'] > 0:
|
||||
local_instance.train(message.body[MsgKey.EPISODE], message.body[MsgKey.AGENT_NAME])
|
||||
policy_net_parameters = local_instance.algorithm.policy_net.state_dict()
|
||||
# send updated policy net parameters to the target environment runner
|
||||
for env in proxy.peers:
|
||||
proxy.send(Message(type=MsgType.UPDATED_PARAMETERS, source=proxy.name, destination=env,
|
||||
body={MsgKey.AGENT_ID: message.body[MsgKey.AGENT_ID],
|
||||
MsgKey.POLICY_NET_PARAMETERS: policy_net_parameters}))
|
||||
else:
|
||||
for env in proxy.peers:
|
||||
proxy.send(Message(type=MsgType.NOT_READY_FOR_TRAINING, source=proxy.name, destination=env))
|
||||
|
||||
pending_envs.update(proxy.peers) # reset pending environments to the full list
|
||||
else:
|
||||
logger.info(f'Pending experiences from {pending_envs}')
|
||||
|
|
|
@ -9,7 +9,8 @@ class MsgType(Enum):
|
|||
STORE_EXPERIENCE = 0 # message contains actual experience data
|
||||
INITIAL_PARAMETERS = 1 # message contains model's parameter
|
||||
UPDATED_PARAMETERS = 2 # message notify the learner is ready for training
|
||||
ENV_CHECKOUT = 3 # message notify the environment is finish and checkout
|
||||
NOT_READY_FOR_TRAINING = 3 # message to indicate that the learner has not collected enough experiences for training
|
||||
ENV_CHECKOUT = 4 # message notify the environment is finish and checkout
|
||||
|
||||
|
||||
class MsgKey(Enum):
|
||||
|
|
Загрузка…
Ссылка в новой задаче