* Refine explore strategy, add prioritized sampling support; add DDQN example; add DQN test (#590)

* Runnable. Should setup a benchmark and test performance.

* Refine logic

* Test DQN on GYM passed

* Refine explore strategy

* Minor

* Minor

* Add Dueling DQN in CIM scenario

* Resolve PR comments

* Add one more explanation

* fix env_sampler eval info list issue

* update version to 0.3.2a4

---------

Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
This commit is contained in:
Jinyu-W 2023-10-27 14:12:46 +08:00 коммит произвёл GitHub
Родитель b3c6a589ad
Коммит 297709736c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
25 изменённых файлов: 538 добавлений и 331 удалений

Просмотреть файл

@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Tuple
import torch
from torch.optim import RMSprop
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.exploration import EpsilonGreedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
@ -23,32 +24,62 @@ learning_rate = 0.05
class MyQNet(DiscreteQNet):
def __init__(self, state_dim: int, action_num: int) -> None:
def __init__(
self,
state_dim: int,
action_num: int,
dueling_param: Optional[Tuple[dict, dict]] = None,
) -> None:
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)
self._use_dueling = dueling_param is not None
self._fc = FullyConnected(input_dim=state_dim, output_dim=0 if self._use_dueling else action_num, **q_net_conf)
if self._use_dueling:
q_kwargs, v_kwargs = dueling_param
self._q = FullyConnected(input_dim=self._fc.output_dim, output_dim=action_num, **q_kwargs)
self._v = FullyConnected(input_dim=self._fc.output_dim, output_dim=1, **v_kwargs)
self._optim = RMSprop(self.parameters(), lr=learning_rate)
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
return self._fc(states)
logits = self._fc(states)
if self._use_dueling:
q = self._q(logits)
v = self._v(logits)
logits = q - q.mean(dim=1, keepdim=True) + v
return logits
def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
q_kwargs = {
"hidden_dims": [128],
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}
v_kwargs = {
"hidden_dims": [128],
"activation": torch.nn.LeakyReLU,
"output_activation": None,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num),
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(2, 0.32)],
"initial_value": 0.4,
"last_ep": 5,
"final_value": 0.0,
},
q_net=MyQNet(
state_dim,
action_num,
dueling_param=(q_kwargs, v_kwargs),
),
],
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
warmup=100,
)
@ -64,6 +95,7 @@ def get_dqn(name: str) -> DQNTrainer:
num_epochs=10,
soft_update_coef=0.1,
double=False,
random_overwrite=False,
alpha=1.0,
beta=1.0,
),
)

Просмотреть файл

@ -35,4 +35,4 @@ state_dim = (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_por
action_num = len(action_shaping_conf["action_space"])
algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg
algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg

Просмотреть файл

@ -6,7 +6,7 @@ import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from maro.rl.exploration import MultiLinearExplorationScheduler
from maro.rl.exploration import EpsilonGreedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
@ -58,19 +58,7 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num, num_features),
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(100, 0.32)],
"initial_value": 0.4,
"last_ep": 400,
"final_value": 0.0,
},
),
],
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
warmup=100,
)

Просмотреть файл

@ -2,6 +2,6 @@
# Licensed under the MIT license.
__version__ = "0.3.2a3"
__version__ = "0.3.2a4"
__data_version__ = "0.2"

Просмотреть файл

@ -1,14 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler
from .strategies import epsilon_greedy, gaussian_noise, uniform_noise
from .strategies import EpsilonGreedy, ExploreStrategy, LinearExploration
__all__ = [
"AbsExplorationScheduler",
"LinearExplorationScheduler",
"MultiLinearExplorationScheduler",
"epsilon_greedy",
"gaussian_noise",
"uniform_noise",
"ExploreStrategy",
"EpsilonGreedy",
"LinearExploration",
]

Просмотреть файл

@ -1,127 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod
from typing import List, Tuple
class AbsExplorationScheduler(ABC):
"""Abstract exploration scheduler.
Args:
exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the
scheduler is applied.
param_name (str): Name of the exploration parameter to which the scheduler is applied.
initial_value (float, default=None): Initial value for the exploration parameter. If None, the value used
when instantiating the policy will be used as the initial value.
"""
def __init__(self, exploration_params: dict, param_name: str, initial_value: float = None) -> None:
super().__init__()
self._exploration_params = exploration_params
self.param_name = param_name
if initial_value is not None:
self._exploration_params[self.param_name] = initial_value
def get_value(self) -> float:
return self._exploration_params[self.param_name]
@abstractmethod
def step(self) -> None:
raise NotImplementedError
class LinearExplorationScheduler(AbsExplorationScheduler):
"""Linear exploration parameter schedule.
Args:
exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the
scheduler is applied.
param_name (str): Name of the exploration parameter to which the scheduler is applied.
last_ep (int): Last episode.
final_value (float): The value of the exploration parameter corresponding to ``last_ep``.
start_ep (int, default=1): starting episode.
initial_value (float, default=None): Initial value for the exploration parameter. If None, the value used
when instantiating the policy will be used as the initial value.
"""
def __init__(
self,
exploration_params: dict,
param_name: str,
*,
last_ep: int,
final_value: float,
start_ep: int = 1,
initial_value: float = None,
) -> None:
super().__init__(exploration_params, param_name, initial_value=initial_value)
self.final_value = final_value
if last_ep > 1:
self.delta = (self.final_value - self._exploration_params[self.param_name]) / (last_ep - start_ep)
else:
self.delta = 0
def step(self) -> None:
if self._exploration_params[self.param_name] == self.final_value:
return
self._exploration_params[self.param_name] += self.delta
class MultiLinearExplorationScheduler(AbsExplorationScheduler):
"""Exploration parameter schedule that consists of multiple linear phases.
Args:
exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the
scheduler is applied.
param_name (str): Name of the exploration parameter to which the scheduler is applied.
splits (List[Tuple[int, float]]): List of points that separate adjacent linear phases. Each
point is a (episode, parameter_value) tuple that indicates the end of one linear phase and
the start of another. These points do not have to be given in any particular order. There
cannot be two points with the same first element (episode), or a ``ValueError`` will be raised.
last_ep (int): Last episode.
final_value (float): The value of the exploration parameter corresponding to ``last_ep``.
start_ep (int, default=1): starting episode.
initial_value (float, default=None): Initial value for the exploration parameter. If None, the value from
the original dictionary the policy is instantiated with will be used as the initial value.
"""
def __init__(
self,
exploration_params: dict,
param_name: str,
*,
splits: List[Tuple[int, float]],
last_ep: int,
final_value: float,
start_ep: int = 1,
initial_value: float = None,
) -> None:
super().__init__(exploration_params, param_name, initial_value=initial_value)
# validate splits
splits = [(start_ep, self._exploration_params[self.param_name])] + splits + [(last_ep, final_value)]
splits.sort()
for (ep, _), (ep2, _) in zip(splits, splits[1:]):
if ep == ep2:
raise ValueError("The zeroth element of split points must be unique")
self.final_value = final_value
self._splits = splits
self._ep = start_ep
self._split_index = 1
self._delta = (self._splits[1][1] - self._exploration_params[self.param_name]) / (self._splits[1][0] - start_ep)
def step(self) -> None:
if self._split_index == len(self._splits):
return
self._exploration_params[self.param_name] += self._delta
self._ep += 1
if self._ep == self._splits[self._split_index][0]:
self._split_index += 1
if self._split_index < len(self._splits):
self._delta = (self._splits[self._split_index][1] - self._splits[self._split_index - 1][1]) / (
self._splits[self._split_index][0] - self._splits[self._split_index - 1][0]
)

Просмотреть файл

@ -1,93 +1,100 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Union
from abc import abstractmethod
from typing import Any
import numpy as np
def epsilon_greedy(
class ExploreStrategy:
def __init__(self) -> None:
pass
@abstractmethod
def get_action(
self,
state: np.ndarray,
action: np.ndarray,
num_actions: int,
*,
epsilon: float,
) -> np.ndarray:
"""Epsilon-greedy exploration.
**kwargs: Any,
) -> np.ndarray:
"""
Args:
state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the vanilla
eps-greedy exploration and is put here to conform to the function signature required for the exploration
strategy parameter for ``DQN``.
action (np.ndarray): Action(s) chosen greedily by the policy.
num_actions (int): Number of possible actions.
epsilon (float): The probability that a random action will be selected.
Returns:
Exploratory actions.
"""
return np.array([act if np.random.random() > epsilon else np.random.randint(num_actions) for act in action])
raise NotImplementedError
def uniform_noise(
state: np.ndarray,
action: np.ndarray,
min_action: Union[float, list, np.ndarray] = None,
max_action: Union[float, list, np.ndarray] = None,
*,
low: Union[float, list, np.ndarray],
high: Union[float, list, np.ndarray],
) -> Union[float, np.ndarray]:
"""Apply a uniform noise to a continuous multidimensional action.
class EpsilonGreedy(ExploreStrategy):
"""Epsilon-greedy exploration. Returns uniformly random action with probability `epsilon` or returns original
action with probability `1.0 - epsilon`.
Args:
state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the gaussian noise
exploration scheme and is put here to conform to the function signature for the exploration in continuous
action spaces.
action (np.ndarray): Action(s) chosen greedily by the policy.
min_action (Union[float, list, np.ndarray], default=None): Lower bound for the multidimensional action space.
max_action (Union[float, list, np.ndarray], default=None): Upper bound for the multidimensional action space.
low (Union[float, list, np.ndarray]): Lower bound for the noise range.
high (Union[float, list, np.ndarray]): Upper bound for the noise range.
Returns:
Exploration actions with added noise.
num_actions (int): Number of possible actions.
epsilon (float): The probability that a random action will be selected.
"""
if min_action is None and max_action is None:
return action + np.random.uniform(low, high, size=action.shape)
else:
return np.clip(action + np.random.uniform(low, high, size=action.shape), min_action, max_action)
def __init__(self, num_actions: int, epsilon: float) -> None:
super(EpsilonGreedy, self).__init__()
def gaussian_noise(
assert 0.0 <= epsilon <= 1.0
self._num_actions = num_actions
self._eps = epsilon
def get_action(
self,
state: np.ndarray,
action: np.ndarray,
min_action: Union[float, list, np.ndarray] = None,
max_action: Union[float, list, np.ndarray] = None,
*,
mean: Union[float, list, np.ndarray] = 0.0,
stddev: Union[float, list, np.ndarray] = 1.0,
relative: bool = False,
) -> Union[float, np.ndarray]:
"""Apply a gaussian noise to a continuous multidimensional action.
**kwargs: Any,
) -> np.ndarray:
return np.array(
[act if np.random.random() > self._eps else np.random.randint(self._num_actions) for act in action],
)
class LinearExploration(ExploreStrategy):
"""Epsilon greedy which the probability `epsilon` is linearly interpolated between `start_explore_prob` and
`end_explore_prob` over `explore_steps`. After this many timesteps pass, `epsilon` is fixed to `end_explore_prob`.
Args:
state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the gaussian noise
exploration scheme and is put here to conform to the function signature for the exploration in continuous
action spaces.
action (np.ndarray): Action(s) chosen greedily by the policy.
min_action (Union[float, list, np.ndarray], default=None): Lower bound for the multidimensional action space.
max_action (Union[float, list, np.ndarray], default=None): Upper bound for the multidimensional action space.
mean (Union[float, list, np.ndarray], default=0.0): Gaussian noise mean.
stddev (Union[float, list, np.ndarray], default=1.0): Standard deviation for the Gaussian noise.
relative (bool, default=False): If True, the generated noise is treated as a relative measure and will
be multiplied by the action itself before being added to the action.
Returns:
Exploration actions with added noise (a numpy ndarray).
num_actions (int): Number of possible actions.
explore_steps (int): Maximum number of steps to interpolate probability.
start_explore_prob (float): Starting explore probability.
end_explore_prob (float): Ending explore probability.
"""
noise = np.random.normal(loc=mean, scale=stddev, size=action.shape)
if min_action is None and max_action is None:
return action + ((noise * action) if relative else noise)
else:
return np.clip(action + ((noise * action) if relative else noise), min_action, max_action)
def __init__(
self,
num_actions: int,
explore_steps: int,
start_explore_prob: float,
end_explore_prob: float,
) -> None:
super(LinearExploration, self).__init__()
self._call_count = 0
self._num_actions = num_actions
self._explore_steps = explore_steps
self._start_explore_prob = start_explore_prob
self._end_explore_prob = end_explore_prob
def get_action(
self,
state: np.ndarray,
action: np.ndarray,
**kwargs: Any,
) -> np.ndarray:
ratio = min(self._call_count / self._explore_steps, 1.0)
epsilon = self._start_explore_prob + (self._end_explore_prob - self._start_explore_prob) * ratio
explore_flag = np.random.random() < epsilon
action = np.array([np.random.randint(self._num_actions) if explore_flag else act for act in action])
self._call_count += 1
return action

Просмотреть файл

@ -13,7 +13,7 @@ class FullyConnected(nn.Module):
Args:
input_dim (int): Network input dimension.
output_dim (int): Network output dimension.
output_dim (int): Network output dimension. If it is 0, will not create the top layer.
hidden_dims (List[int]): Dimensions of hidden layers. Its length is the number of hidden layers. For example,
`hidden_dims=[128, 256]` refers to two hidden layers with output dim of 128 and 256, respectively.
activation (Optional[Type[torch.nn.Module], default=nn.ReLU): Activation class provided by ``torch.nn`` or a
@ -52,7 +52,6 @@ class FullyConnected(nn.Module):
super(FullyConnected, self).__init__()
self._input_dim = input_dim
self._hidden_dims = hidden_dims if hidden_dims is not None else []
self._output_dim = output_dim
# network features
self._activation = activation if activation else None
@ -76,9 +75,13 @@ class FullyConnected(nn.Module):
self._build_layer(in_dim, out_dim, activation=self._activation) for in_dim, out_dim in zip(dims, dims[1:])
]
# top layer
if output_dim != 0:
layers.append(
self._build_layer(dims[-1], self._output_dim, head=self._head, activation=self._output_activation),
self._build_layer(dims[-1], output_dim, head=self._head, activation=self._output_activation),
)
self._output_dim = output_dim
else:
self._output_dim = hidden_dims[-1]
self._net = nn.Sequential(*layers)

Просмотреть файл

@ -2,15 +2,14 @@
# Licensed under the MIT license.
from abc import ABCMeta
from typing import Callable, Dict, List, Tuple
from typing import Dict, Optional, Tuple
import numpy as np
import torch
from maro.rl.exploration import epsilon_greedy
from maro.rl.exploration import ExploreStrategy
from maro.rl.model import DiscretePolicyNet, DiscreteQNet
from maro.rl.utils import match_shape, ndarray_to_tensor
from maro.utils import clone
from .abs_policy import RLPolicy
@ -69,8 +68,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
name (str): Name of the policy.
q_net (DiscreteQNet): Q-net used in this value-based policy.
trainable (bool, default=True): Whether this policy is trainable.
exploration_strategy (Tuple[Callable, dict], default=(epsilon_greedy, {"epsilon": 0.1})): Exploration strategy.
exploration_scheduling_options (List[tuple], default=None): List of exploration scheduler options.
explore_strategy (Optional[ExploreStrategy], default=None): Explore strategy.
warmup (int, default=50000): Number of steps for uniform-random action selection, before running real policy.
Helps exploration.
"""
@ -80,8 +78,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
name: str,
q_net: DiscreteQNet,
trainable: bool = True,
exploration_strategy: Tuple[Callable, dict] = (epsilon_greedy, {"epsilon": 0.1}),
exploration_scheduling_options: List[tuple] = None,
explore_strategy: Optional[ExploreStrategy] = None,
warmup: int = 50000,
) -> None:
assert isinstance(q_net, DiscreteQNet)
@ -94,15 +91,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
warmup=warmup,
)
self._q_net = q_net
self._exploration_func = exploration_strategy[0]
self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing
self._exploration_schedulers = (
[opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options]
if exploration_scheduling_options is not None
else []
)
self._explore_strategy = explore_strategy
self._softmax = torch.nn.Softmax(dim=1)
@property
@ -176,9 +165,6 @@ class ValueBasedPolicy(DiscreteRLPolicy):
assert match_shape(q_values, (states.shape[0],)) # [B]
return q_values
def explore(self) -> None:
pass # Overwrite the base method and turn off explore mode.
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
return self._get_actions_with_probs_impl(states, **kwargs)[0]
@ -187,17 +173,11 @@ class ValueBasedPolicy(DiscreteRLPolicy):
q_matrix_softmax = self._softmax(q_matrix)
_, actions = q_matrix.max(dim=1) # [B], [B]
if self._is_exploring:
actions = self._exploration_func(
states,
actions.cpu().numpy(),
self.action_num,
**self._exploration_params,
**kwargs,
)
if self._is_exploring and self._explore_strategy is not None:
actions = self._explore_strategy.get_action(state=states.cpu().numpy(), action=actions.cpu().numpy())
actions = ndarray_to_tensor(actions, device=self._device)
actions = actions.unsqueeze(1)
actions = actions.unsqueeze(1).long()
return actions, q_matrix_softmax.gather(1, actions).squeeze(-1) # [B, 1]
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:

Просмотреть файл

@ -533,7 +533,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
return {
"experiences": [total_experiences],
"info": [deepcopy(self._info)], # TODO: may have overhead issues. Leave to future work.
"info": [deepcopy(self._info)],
}
def set_policy_state(self, policy_state_dict: Dict[str, dict]) -> None:
@ -592,7 +592,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
self._step(list(env_action_dict.values()))
cache_element.next_state = self._state
if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
if self._reward_eval_delay is None:
self._calc_reward(cache_element)
self._post_eval_step(cache_element)
@ -606,7 +606,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
self._calc_reward(cache_element)
self._post_eval_step(cache_element)
info_list.append(self._info)
info_list.append(deepcopy(self._info))
return {"info": info_list}

Просмотреть файл

@ -2,7 +2,13 @@
# Licensed under the MIT license.
from .proxy import TrainingProxy
from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory
from .replay_memory import (
FIFOMultiReplayMemory,
FIFOReplayMemory,
PrioritizedReplayMemory,
RandomMultiReplayMemory,
RandomReplayMemory,
)
from .train_ops import AbsTrainOps, RemoteOps, remote
from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer
from .training_manager import TrainingManager
@ -12,6 +18,7 @@ __all__ = [
"TrainingProxy",
"FIFOMultiReplayMemory",
"FIFOReplayMemory",
"PrioritizedReplayMemory",
"RandomMultiReplayMemory",
"RandomReplayMemory",
"AbsTrainOps",

Просмотреть файл

@ -261,9 +261,6 @@ class DDPGTrainer(SingleAgentTrainer):
assert isinstance(policy, ContinuousRLPolicy)
self._policy = policy
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
return transition_batch
def get_local_ops(self) -> AbsTrainOps:
return DDPGOps(
name=self._policy.name,

Просмотреть файл

@ -2,12 +2,21 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Dict, cast
from typing import Dict, Tuple, cast
import numpy as np
import torch
from maro.rl.policy import RLPolicy, ValueBasedPolicy
from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote
from maro.rl.training import (
AbsTrainOps,
BaseTrainerParams,
PrioritizedReplayMemory,
RandomReplayMemory,
RemoteOps,
SingleAgentTrainer,
remote,
)
from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor
from maro.utils import clone
@ -15,6 +24,9 @@ from maro.utils import clone
@dataclass
class DQNParams(BaseTrainerParams):
"""
use_prioritized_replay (bool, default=False): Whether to use prioritized replay memory.
alpha (float, default=0.4): Alpha in prioritized replay memory.
beta (float, default=0.6): Beta in prioritized replay memory.
num_epochs (int, default=1): Number of training epochs.
update_target_every (int, default=5): Number of gradient steps between target model updates.
soft_update_coef (float, default=0.1): Soft update coefficient, e.g.,
@ -27,11 +39,13 @@ class DQNParams(BaseTrainerParams):
sequentially with wrap-around.
"""
use_prioritized_replay: bool = False
alpha: float = 0.4
beta: float = 0.6
num_epochs: int = 1
update_target_every: int = 5
soft_update_coef: float = 0.1
double: bool = False
random_overwrite: bool = False
class DQNOps(AbsTrainOps):
@ -54,20 +68,21 @@ class DQNOps(AbsTrainOps):
self._reward_discount = reward_discount
self._soft_update_coef = params.soft_update_coef
self._double = params.double
self._loss_func = torch.nn.MSELoss()
self._target_policy: ValueBasedPolicy = clone(self._policy)
self._target_policy.set_name(f"target_{self._policy.name}")
self._target_policy.eval()
def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor:
def _get_batch_loss(self, batch: TransitionBatch, weight: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the loss of the batch.
Args:
batch (TransitionBatch): Batch.
weight (np.ndarray): Weight of each data entry.
Returns:
loss (torch.Tensor): The loss of the batch.
td_error (torch.Tensor): TD-error of the batch.
"""
assert isinstance(batch, TransitionBatch)
assert isinstance(self._policy, ValueBasedPolicy)
@ -79,19 +94,21 @@ class DQNOps(AbsTrainOps):
rewards = ndarray_to_tensor(batch.rewards, device=self._device)
terminals = ndarray_to_tensor(batch.terminals, device=self._device).float()
weight = ndarray_to_tensor(weight, device=self._device)
with torch.no_grad():
if self._double:
self._policy.exploit()
actions_by_eval_policy = self._policy.get_actions_tensor(next_states)
next_q_values = self._target_policy.q_values_tensor(next_states, actions_by_eval_policy)
else:
self._target_policy.exploit()
actions = self._target_policy.get_actions_tensor(next_states)
next_q_values = self._target_policy.q_values_tensor(next_states, actions)
next_q_values = self._target_policy.q_values_for_all_actions_tensor(next_states).max(dim=1)[0]
target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach()
q_values = self._policy.q_values_tensor(states, actions)
return self._loss_func(q_values, target_q_values)
td_error = target_q_values - q_values
return (td_error.pow(2) * weight).mean(), td_error
@remote
def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
@ -103,7 +120,8 @@ class DQNOps(AbsTrainOps):
Returns:
grad (torch.Tensor): The gradient of the batch.
"""
return self._policy.get_gradients(self._get_batch_loss(batch))
loss, _ = self._get_batch_loss(batch)
return self._policy.get_gradients(loss)
def update_with_grad(self, grad_dict: dict) -> None:
"""Update the network with remotely computed gradients.
@ -114,14 +132,20 @@ class DQNOps(AbsTrainOps):
self._policy.train()
self._policy.apply_gradients(grad_dict)
def update(self, batch: TransitionBatch) -> None:
def update(self, batch: TransitionBatch, weight: np.ndarray) -> np.ndarray:
"""Update the network using a batch.
Args:
batch (TransitionBatch): Batch.
weight (np.ndarray): Weight of each data entry.
Returns:
td_errors (np.ndarray)
"""
self._policy.train()
self._policy.train_step(self._get_batch_loss(batch))
loss, td_error = self._get_batch_loss(batch, weight)
self._policy.train_step(loss)
return td_error.detach().numpy()
def get_non_policy_state(self) -> dict:
return {
@ -168,20 +192,27 @@ class DQNTrainer(SingleAgentTrainer):
def build(self) -> None:
self._ops = cast(DQNOps, self.get_ops())
if self._params.use_prioritized_replay:
self._replay_memory = PrioritizedReplayMemory(
capacity=self._replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim,
alpha=self._params.alpha,
beta=self._params.beta,
)
else:
self._replay_memory = RandomReplayMemory(
capacity=self._replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim,
random_overwrite=self._params.random_overwrite,
random_overwrite=False,
)
def _register_policy(self, policy: RLPolicy) -> None:
assert isinstance(policy, ValueBasedPolicy)
self._policy = policy
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
return transition_batch
def get_local_ops(self) -> AbsTrainOps:
return DQNOps(
name=self._policy.name,
@ -191,13 +222,24 @@ class DQNTrainer(SingleAgentTrainer):
params=self._params,
)
def _get_batch(self, batch_size: int = None) -> TransitionBatch:
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)
def _get_batch(self, batch_size: int = None) -> Tuple[TransitionBatch, np.ndarray, np.ndarray]:
indexes = self.replay_memory.get_sample_indexes(batch_size or self._batch_size)
batch = self.replay_memory.sample_by_indexes(indexes)
if self._params.use_prioritized_replay:
weight = cast(PrioritizedReplayMemory, self.replay_memory).get_weight(indexes)
else:
weight = np.ones(len(indexes))
return batch, indexes, weight
def train_step(self) -> None:
assert isinstance(self._ops, DQNOps)
for _ in range(self._params.num_epochs):
self._ops.update(self._get_batch())
batch, indexes, weight = self._get_batch()
td_error = self._ops.update(batch, weight)
if self._params.use_prioritized_replay:
cast(PrioritizedReplayMemory, self.replay_memory).update_weight(indexes, td_error)
self._try_soft_update_target()

Просмотреть файл

@ -272,9 +272,6 @@ class SoftActorCriticTrainer(SingleAgentTrainer):
if early_stop:
break
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
return transition_batch
def get_local_ops(self) -> SoftActorCriticOps:
return SoftActorCriticOps(
name=self._policy.name,

Просмотреть файл

@ -88,6 +88,73 @@ class RandomIndexScheduler(AbsIndexScheduler):
return np.random.choice(self._size, size=batch_size, replace=True)
class PriorityReplayIndexScheduler(AbsIndexScheduler):
"""
Indexer for priority replay memory: https://arxiv.org/abs/1511.05952.
Args:
capacity (int): Maximum capacity of the replay memory.
alpha (float): Alpha (see original paper for explanation).
beta (float): Alpha (see original paper for explanation).
"""
def __init__(
self,
capacity: int,
alpha: float,
beta: float,
) -> None:
super(PriorityReplayIndexScheduler, self).__init__(capacity)
self._alpha = alpha
self._beta = beta
self._max_prio = self._min_prio = 1.0
self._weights = np.zeros(capacity, dtype=np.float32)
self._ptr = self._size = 0
def init_weights(self, indexes: np.ndarray) -> None:
self._weights[indexes] = self._max_prio**self._alpha
def get_weight(self, indexes: np.ndarray) -> np.ndarray:
# important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
return (self._weights[indexes] / self._min_prio) ** (-self._beta)
def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None:
assert indexes.shape == weight.shape
weight = np.abs(weight) + np.finfo(np.float32).eps.item()
self._weights[indexes] = weight**self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def get_put_indexes(self, batch_size: int) -> np.ndarray:
if self._ptr + batch_size <= self._capacity:
indexes = np.arange(self._ptr, self._ptr + batch_size)
self._ptr += batch_size
else:
overwrites = self._ptr + batch_size - self._capacity
indexes = np.concatenate(
[
np.arange(self._ptr, self._capacity),
np.arange(overwrites),
],
)
self._ptr = overwrites
self._size = min(self._size + batch_size, self._capacity)
self.init_weights(indexes)
return indexes
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}"
assert self._size > 0, "Cannot sample from an empty memory."
weights = self._weights[: self._size]
weights = weights / weights.sum()
return np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=True)
class FIFOIndexScheduler(AbsIndexScheduler):
"""First-in-first-out index scheduler.
@ -154,11 +221,11 @@ class AbsReplayMemory(object, metaclass=ABCMeta):
def state_dim(self) -> int:
return self._state_dim
def _get_put_indexes(self, batch_size: int) -> np.ndarray:
def get_put_indexes(self, batch_size: int) -> np.ndarray:
"""Please refer to the doc string in AbsIndexScheduler."""
return self._idx_scheduler.get_put_indexes(batch_size)
def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
"""Please refer to the doc string in AbsIndexScheduler."""
return self._idx_scheduler.get_sample_indexes(batch_size)
@ -225,10 +292,10 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
if transition_batch.old_logps is not None:
match_shape(transition_batch.old_logps, (batch_size,))
self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch)
self.put_by_indexes(self.get_put_indexes(batch_size), transition_batch)
self._n_sample = min(self._n_sample + transition_batch.size, self._capacity)
def _put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None:
def put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None:
"""Store a transition batch into the memory at the give indexes.
Args:
@ -258,7 +325,7 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
Returns:
batch (TransitionBatch): The sampled batch.
"""
indexes = self._get_sample_indexes(batch_size)
indexes = self.get_sample_indexes(batch_size)
return self.sample_by_indexes(indexes)
def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
@ -306,6 +373,31 @@ class RandomReplayMemory(ReplayMemory):
return self._random_overwrite
class PrioritizedReplayMemory(ReplayMemory):
def __init__(
self,
capacity: int,
state_dim: int,
action_dim: int,
alpha: float,
beta: float,
) -> None:
super(PrioritizedReplayMemory, self).__init__(
capacity,
state_dim,
action_dim,
PriorityReplayIndexScheduler(capacity, alpha, beta),
)
def get_weight(self, indexes: np.ndarray) -> np.ndarray:
assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler)
return self._idx_scheduler.get_weight(indexes)
def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None:
assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler)
self._idx_scheduler.update_weight(indexes, weight)
class FIFOReplayMemory(ReplayMemory):
def __init__(
self,
@ -393,9 +485,9 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
assert match_shape(transition_batch.agent_states[i], (batch_size, self._agent_states_dims[i]))
assert match_shape(transition_batch.next_agent_states[i], (batch_size, self._agent_states_dims[i]))
self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch=transition_batch)
self.put_by_indexes(self.get_put_indexes(batch_size), transition_batch=transition_batch)
def _put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None:
def put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None:
"""Store a transition batch into the memory at the give indexes.
Args:
@ -424,7 +516,7 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
Returns:
batch (MultiTransitionBatch): The sampled batch.
"""
indexes = self._get_sample_indexes(batch_size)
indexes = self.get_sample_indexes(batch_size)
return self.sample_by_indexes(indexes)
def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:

Просмотреть файл

@ -271,9 +271,8 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
transition_batch = self._preprocess_batch(transition_batch)
self.replay_memory.put(transition_batch)
@abstractmethod
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
raise NotImplementedError
return transition_batch
def _assert_ops_exists(self) -> None:
if not self.ops:

Просмотреть файл

@ -3,12 +3,14 @@
from typing import cast
from gym import spaces
from maro.simulator import Env
from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine
env_conf = {
"topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4
"topology": "CartPole-v1", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4, CartPole-v1
"start_tick": 0,
"durations": 100000, # Set a very large number
"options": {},
@ -19,8 +21,18 @@ test_env = Env(business_engine_cls=GymBusinessEngine, **env_conf)
num_agents = len(learn_env.agent_idx_list)
gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env
gym_action_space = gym_env.action_space
gym_state_dim = gym_env.observation_space.shape[0]
gym_action_dim = gym_action_space.shape[0]
action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high
action_limit = gym_action_space.high[0]
gym_action_space = gym_env.action_space
is_discrete = isinstance(gym_action_space, spaces.Discrete)
if is_discrete:
gym_action_space = cast(spaces.Discrete, gym_action_space)
gym_action_dim = 1
gym_action_num = gym_action_space.n
action_lower_bound, action_upper_bound = None, None # Should never be used
action_limit = None # Should never be used
else:
gym_action_space = cast(spaces.Box, gym_action_space)
gym_action_dim = gym_action_space.shape[0]
gym_action_num = -1 # Should never be used
action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high
action_limit = action_upper_bound[0]

Просмотреть файл

@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, List, Tuple, Type, Union
from typing import Any, Dict, List, Tuple, Type, Union, cast
import numpy as np
from gym import spaces
from maro.rl.policy.abs_policy import AbsPolicy
from maro.rl.rollout import AbsEnvSampler, CacheElement
@ -40,6 +41,10 @@ class GymEnvSampler(AbsEnvSampler):
self._sample_rewards = []
self._eval_rewards = []
gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env
gym_action_space = gym_env.action_space
self._is_discrete = isinstance(gym_action_space, spaces.Discrete)
def _get_global_and_agent_state_impl(
self,
event: DecisionEvent,
@ -48,7 +53,7 @@ class GymEnvSampler(AbsEnvSampler):
return None, {0: event.state}
def _translate_to_env_action(self, action_dict: dict, event: Any) -> dict:
return {k: Action(v) for k, v in action_dict.items()}
return {k: Action(v.item() if self._is_discrete else v) for k, v in action_dict.items()}
def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]:
be = self._env.business_engine

Просмотреть файл

@ -1,11 +1,15 @@
# Performance for Gym Task Suite
We benchmarked the MARO RL Toolkit implementation in Gym task suite. Some are compared to the benchmarks in
[OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#). We've tried to align the
[OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#) and [RL Baseline Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/benchmark.md). We've tried to align the
hyper-parameters for these benchmarks , but limited by the environment version difference, there may be some gaps
between the performance here and that in Spinning Up benchmarks. Generally speaking, the performance is comparable.
## Experimental Setting
## Compare with OpenAI Spinning Up
We compare the performance of PPO, SAC, and DDPG in MARO with [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#).
### Experimental Setting
The hyper-parameters are set to align with those used in
[Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#experiment-details):
@ -29,7 +33,7 @@ The hyper-parameters are set to align with those used in
More details about the parameters can be found in *tests/rl/tasks/*.
## Performance
### Performance
Five environments from the MuJoCo Gym task suite are reported in Spinning Up, they are: HalfCheetah, Hopper, Walker2d,
Swimmer, and Ant. The commit id of the code used to conduct the experiments for MARO RL benchmarks is ee25ce1e97.
@ -52,3 +56,28 @@ python tests/rl/plot.py --smooth WINDOWSIZE
| [**Walker2d**](https://gymnasium.farama.org/environments/mujoco/walker2d/) | ![Wab](https://spinningup.openai.com/en/latest/_images/pytorch_walker2d_performance.svg) | ![Wa1](./log/Walker2d_1.png) | ![Wa11](./log/Walker2d_11.png) |
| [**Swimmer**](https://gymnasium.farama.org/environments/mujoco/swimmer/) | ![Swb](https://spinningup.openai.com/en/latest/_images/pytorch_swimmer_performance.svg) | ![Sw1](./log/Swimmer_1.png) | ![Sw11](./log/Swimmer_11.png) |
| [**Ant**](https://gymnasium.farama.org/environments/mujoco/ant/) | ![Anb](https://spinningup.openai.com/en/latest/_images/pytorch_ant_performance.svg) | ![An1](./log/Ant_1.png) | ![An11](./log/Ant_11.png) |
## Compare with RL Baseline Zoo
[RL Baseline Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/benchmark.md) provides a comprehensive set of benchmarks for multiple algorithms and environments.
However, unlike OpenAI Spinning Up, it does not provide the complete learning curve. Instead, we can only find the final metrics in it.
We therefore leave the comparison with RL Baseline Zoo as a minor addition.
We compare the performance of DQN with RL Baseline Zoo.
### Experimental Setting
- Batch size: size 64 for each gradient descent step;
- Network: size (256) with relu units;
- Performance metric: measured as the average trajectory return across the batch collected at 10 epochs;
- Total timesteps: 150,000.
### Performance
More details about the parameters can be found in *tests/rl/tasks/*.
Please refer to the original link of RL Baseline Zoo for the baseline metrics.
| algo | env_id |mean_reward|
|--------|-------------------------------|----------:|
|DQN |CartPole-v1 | 500.00 |
|DQN |MountainCar-v0 | -116.90 |

Просмотреть файл

@ -19,6 +19,7 @@ from tests.rl.gym_wrapper.common import (
action_upper_bound,
gym_action_dim,
gym_state_dim,
is_discrete,
learn_env,
num_agents,
test_env,
@ -109,6 +110,8 @@ def get_ac_trainer(name: str, state_dim: int) -> ActorCriticTrainer:
)
assert not is_discrete
algorithm = "ac"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [

Просмотреть файл

@ -20,6 +20,7 @@ from tests.rl.gym_wrapper.common import (
gym_action_dim,
gym_action_space,
gym_state_dim,
is_discrete,
learn_env,
num_agents,
test_env,
@ -123,6 +124,8 @@ def get_ddpg_trainer(name: str, state_dim: int, action_dim: int) -> DDPGTrainer:
)
assert not is_discrete
algorithm = "ddpg"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [

Просмотреть файл

@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from torch.optim import Adam
from maro.rl.exploration import LinearExploration
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.training.algorithms import DQNParams, DQNTrainer
from tests.rl.gym_wrapper.common import gym_action_num, gym_state_dim, is_discrete, learn_env, num_agents, test_env
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
net_conf = {
"hidden_dims": [256],
"activation": torch.nn.ReLU,
"output_activation": None,
}
lr = 1e-3
class MyQNet(DiscreteQNet):
def __init__(self, state_dim: int, action_num: int) -> None:
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._mlp = FullyConnected(
input_dim=state_dim,
output_dim=action_num,
**net_conf,
)
self._optim = Adam(self._mlp.parameters(), lr=lr)
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
return self._mlp(states)
def get_dqn_policy(
name: str,
state_dim: int,
action_num: int,
) -> ValueBasedPolicy:
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim=state_dim, action_num=action_num),
explore_strategy=LinearExploration(
num_actions=action_num,
explore_steps=10000,
start_explore_prob=1.0,
end_explore_prob=0.02,
),
warmup=0, # TODO: check this
)
def get_dqn_trainer(
name: str,
) -> DQNTrainer:
return DQNTrainer(
name=name,
params=DQNParams(
use_prioritized_replay=False, #
# alpha=0.4,
# beta=0.6,
num_epochs=50,
update_target_every=10,
soft_update_coef=1.0,
),
replay_memory_capacity=50000,
batch_size=64,
reward_discount=1.0,
)
assert is_discrete
algorithm = "dqn"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [
get_dqn_policy(
f"{algorithm}_{i}.policy",
state_dim=gym_state_dim,
action_num=gym_action_num,
)
for i in range(num_agents)
]
trainers = [get_dqn_trainer(f"{algorithm}_{i}") for i in range(num_agents)]
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)} if torch.cuda.is_available() else None
rl_component_bundle = RLComponentBundle(
env_sampler=GymEnvSampler(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
),
agent2policy=agent2policy,
policies=policies,
trainers=trainers,
device_mapping=device_mapping,
)
__all__ = ["rl_component_bundle"]

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for GYM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
job: gym_rl_workflow
scenario_path: "tests/rl/tasks/dqn"
log_path: "tests/rl/log/dqn_cartpole"
main:
num_episodes: 3000
num_steps: 50
eval_schedule: 50
num_eval_episodes: 10
min_n_sample: 1
logging:
stdout: INFO
file: DEBUG
rollout:
logging:
stdout: INFO
file: DEBUG
training:
mode: simple
load_path: null
load_episode: null
checkpointing:
path: null
interval: 5
logging:
stdout: INFO
file: DEBUG

Просмотреть файл

@ -11,6 +11,7 @@ from tests.rl.gym_wrapper.common import (
action_upper_bound,
gym_action_dim,
gym_state_dim,
is_discrete,
learn_env,
num_agents,
test_env,
@ -36,6 +37,8 @@ def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer:
)
assert not is_discrete
algorithm = "ppo"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [

Просмотреть файл

@ -24,6 +24,7 @@ from tests.rl.gym_wrapper.common import (
gym_action_dim,
gym_action_space,
gym_state_dim,
is_discrete,
learn_env,
num_agents,
test_env,
@ -133,6 +134,8 @@ def get_sac_trainer(name: str, state_dim: int, action_dim: int) -> SoftActorCrit
)
assert not is_discrete
algorithm = "sac"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [