Add DDQN (#598)
* Refine explore strategy, add prioritized sampling support; add DDQN example; add DQN test (#590) * Runnable. Should setup a benchmark and test performance. * Refine logic * Test DQN on GYM passed * Refine explore strategy * Minor * Minor * Add Dueling DQN in CIM scenario * Resolve PR comments * Add one more explanation * fix env_sampler eval info list issue * update version to 0.3.2a4 --------- Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
This commit is contained in:
Родитель
b3c6a589ad
Коммит
297709736c
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.optim import RMSprop
|
||||
|
||||
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
|
||||
from maro.rl.exploration import EpsilonGreedy
|
||||
from maro.rl.model import DiscreteQNet, FullyConnected
|
||||
from maro.rl.policy import ValueBasedPolicy
|
||||
from maro.rl.training.algorithms import DQNParams, DQNTrainer
|
||||
|
@ -23,32 +24,62 @@ learning_rate = 0.05
|
|||
|
||||
|
||||
class MyQNet(DiscreteQNet):
|
||||
def __init__(self, state_dim: int, action_num: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
state_dim: int,
|
||||
action_num: int,
|
||||
dueling_param: Optional[Tuple[dict, dict]] = None,
|
||||
) -> None:
|
||||
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
|
||||
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
|
||||
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)
|
||||
|
||||
self._use_dueling = dueling_param is not None
|
||||
self._fc = FullyConnected(input_dim=state_dim, output_dim=0 if self._use_dueling else action_num, **q_net_conf)
|
||||
if self._use_dueling:
|
||||
q_kwargs, v_kwargs = dueling_param
|
||||
self._q = FullyConnected(input_dim=self._fc.output_dim, output_dim=action_num, **q_kwargs)
|
||||
self._v = FullyConnected(input_dim=self._fc.output_dim, output_dim=1, **v_kwargs)
|
||||
|
||||
self._optim = RMSprop(self.parameters(), lr=learning_rate)
|
||||
|
||||
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._fc(states)
|
||||
logits = self._fc(states)
|
||||
if self._use_dueling:
|
||||
q = self._q(logits)
|
||||
v = self._v(logits)
|
||||
logits = q - q.mean(dim=1, keepdim=True) + v
|
||||
return logits
|
||||
|
||||
|
||||
def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
|
||||
q_kwargs = {
|
||||
"hidden_dims": [128],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"output_activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"skip_connection": False,
|
||||
"head": True,
|
||||
"dropout_p": 0.0,
|
||||
}
|
||||
v_kwargs = {
|
||||
"hidden_dims": [128],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"output_activation": None,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"skip_connection": False,
|
||||
"head": True,
|
||||
"dropout_p": 0.0,
|
||||
}
|
||||
|
||||
return ValueBasedPolicy(
|
||||
name=name,
|
||||
q_net=MyQNet(state_dim, action_num),
|
||||
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
|
||||
exploration_scheduling_options=[
|
||||
(
|
||||
"epsilon",
|
||||
MultiLinearExplorationScheduler,
|
||||
{
|
||||
"splits": [(2, 0.32)],
|
||||
"initial_value": 0.4,
|
||||
"last_ep": 5,
|
||||
"final_value": 0.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
q_net=MyQNet(
|
||||
state_dim,
|
||||
action_num,
|
||||
dueling_param=(q_kwargs, v_kwargs),
|
||||
),
|
||||
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
|
||||
warmup=100,
|
||||
)
|
||||
|
||||
|
@ -64,6 +95,7 @@ def get_dqn(name: str) -> DQNTrainer:
|
|||
num_epochs=10,
|
||||
soft_update_coef=0.1,
|
||||
double=False,
|
||||
random_overwrite=False,
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -35,4 +35,4 @@ state_dim = (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_por
|
|||
|
||||
action_num = len(action_shaping_conf["action_space"])
|
||||
|
||||
algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg
|
||||
algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
|
||||
from maro.rl.exploration import MultiLinearExplorationScheduler
|
||||
from maro.rl.exploration import EpsilonGreedy
|
||||
from maro.rl.model import DiscreteQNet, FullyConnected
|
||||
from maro.rl.policy import ValueBasedPolicy
|
||||
from maro.rl.training.algorithms import DQNParams, DQNTrainer
|
||||
|
@ -58,19 +58,7 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
|
|||
return ValueBasedPolicy(
|
||||
name=name,
|
||||
q_net=MyQNet(state_dim, action_num, num_features),
|
||||
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
|
||||
exploration_scheduling_options=[
|
||||
(
|
||||
"epsilon",
|
||||
MultiLinearExplorationScheduler,
|
||||
{
|
||||
"splits": [(100, 0.32)],
|
||||
"initial_value": 0.4,
|
||||
"last_ep": 400,
|
||||
"final_value": 0.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
|
||||
warmup=100,
|
||||
)
|
||||
|
||||
|
|
|
@ -2,6 +2,6 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
__version__ = "0.3.2a3"
|
||||
__version__ = "0.3.2a4"
|
||||
|
||||
__data_version__ = "0.2"
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler
|
||||
from .strategies import epsilon_greedy, gaussian_noise, uniform_noise
|
||||
from .strategies import EpsilonGreedy, ExploreStrategy, LinearExploration
|
||||
|
||||
__all__ = [
|
||||
"AbsExplorationScheduler",
|
||||
"LinearExplorationScheduler",
|
||||
"MultiLinearExplorationScheduler",
|
||||
"epsilon_greedy",
|
||||
"gaussian_noise",
|
||||
"uniform_noise",
|
||||
"ExploreStrategy",
|
||||
"EpsilonGreedy",
|
||||
"LinearExploration",
|
||||
]
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class AbsExplorationScheduler(ABC):
|
||||
"""Abstract exploration scheduler.
|
||||
|
||||
Args:
|
||||
exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the
|
||||
scheduler is applied.
|
||||
param_name (str): Name of the exploration parameter to which the scheduler is applied.
|
||||
initial_value (float, default=None): Initial value for the exploration parameter. If None, the value used
|
||||
when instantiating the policy will be used as the initial value.
|
||||
"""
|
||||
|
||||
def __init__(self, exploration_params: dict, param_name: str, initial_value: float = None) -> None:
|
||||
super().__init__()
|
||||
self._exploration_params = exploration_params
|
||||
self.param_name = param_name
|
||||
if initial_value is not None:
|
||||
self._exploration_params[self.param_name] = initial_value
|
||||
|
||||
def get_value(self) -> float:
|
||||
return self._exploration_params[self.param_name]
|
||||
|
||||
@abstractmethod
|
||||
def step(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LinearExplorationScheduler(AbsExplorationScheduler):
|
||||
"""Linear exploration parameter schedule.
|
||||
|
||||
Args:
|
||||
exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the
|
||||
scheduler is applied.
|
||||
param_name (str): Name of the exploration parameter to which the scheduler is applied.
|
||||
last_ep (int): Last episode.
|
||||
final_value (float): The value of the exploration parameter corresponding to ``last_ep``.
|
||||
start_ep (int, default=1): starting episode.
|
||||
initial_value (float, default=None): Initial value for the exploration parameter. If None, the value used
|
||||
when instantiating the policy will be used as the initial value.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exploration_params: dict,
|
||||
param_name: str,
|
||||
*,
|
||||
last_ep: int,
|
||||
final_value: float,
|
||||
start_ep: int = 1,
|
||||
initial_value: float = None,
|
||||
) -> None:
|
||||
super().__init__(exploration_params, param_name, initial_value=initial_value)
|
||||
self.final_value = final_value
|
||||
if last_ep > 1:
|
||||
self.delta = (self.final_value - self._exploration_params[self.param_name]) / (last_ep - start_ep)
|
||||
else:
|
||||
self.delta = 0
|
||||
|
||||
def step(self) -> None:
|
||||
if self._exploration_params[self.param_name] == self.final_value:
|
||||
return
|
||||
|
||||
self._exploration_params[self.param_name] += self.delta
|
||||
|
||||
|
||||
class MultiLinearExplorationScheduler(AbsExplorationScheduler):
|
||||
"""Exploration parameter schedule that consists of multiple linear phases.
|
||||
|
||||
Args:
|
||||
exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the
|
||||
scheduler is applied.
|
||||
param_name (str): Name of the exploration parameter to which the scheduler is applied.
|
||||
splits (List[Tuple[int, float]]): List of points that separate adjacent linear phases. Each
|
||||
point is a (episode, parameter_value) tuple that indicates the end of one linear phase and
|
||||
the start of another. These points do not have to be given in any particular order. There
|
||||
cannot be two points with the same first element (episode), or a ``ValueError`` will be raised.
|
||||
last_ep (int): Last episode.
|
||||
final_value (float): The value of the exploration parameter corresponding to ``last_ep``.
|
||||
start_ep (int, default=1): starting episode.
|
||||
initial_value (float, default=None): Initial value for the exploration parameter. If None, the value from
|
||||
the original dictionary the policy is instantiated with will be used as the initial value.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exploration_params: dict,
|
||||
param_name: str,
|
||||
*,
|
||||
splits: List[Tuple[int, float]],
|
||||
last_ep: int,
|
||||
final_value: float,
|
||||
start_ep: int = 1,
|
||||
initial_value: float = None,
|
||||
) -> None:
|
||||
super().__init__(exploration_params, param_name, initial_value=initial_value)
|
||||
|
||||
# validate splits
|
||||
splits = [(start_ep, self._exploration_params[self.param_name])] + splits + [(last_ep, final_value)]
|
||||
splits.sort()
|
||||
for (ep, _), (ep2, _) in zip(splits, splits[1:]):
|
||||
if ep == ep2:
|
||||
raise ValueError("The zeroth element of split points must be unique")
|
||||
|
||||
self.final_value = final_value
|
||||
self._splits = splits
|
||||
self._ep = start_ep
|
||||
self._split_index = 1
|
||||
self._delta = (self._splits[1][1] - self._exploration_params[self.param_name]) / (self._splits[1][0] - start_ep)
|
||||
|
||||
def step(self) -> None:
|
||||
if self._split_index == len(self._splits):
|
||||
return
|
||||
|
||||
self._exploration_params[self.param_name] += self._delta
|
||||
self._ep += 1
|
||||
if self._ep == self._splits[self._split_index][0]:
|
||||
self._split_index += 1
|
||||
if self._split_index < len(self._splits):
|
||||
self._delta = (self._splits[self._split_index][1] - self._splits[self._split_index - 1][1]) / (
|
||||
self._splits[self._split_index][0] - self._splits[self._split_index - 1][0]
|
||||
)
|
|
@ -1,93 +1,100 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Union
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def epsilon_greedy(
|
||||
state: np.ndarray,
|
||||
action: np.ndarray,
|
||||
num_actions: int,
|
||||
*,
|
||||
epsilon: float,
|
||||
) -> np.ndarray:
|
||||
"""Epsilon-greedy exploration.
|
||||
class ExploreStrategy:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_action(
|
||||
self,
|
||||
state: np.ndarray,
|
||||
action: np.ndarray,
|
||||
**kwargs: Any,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Args:
|
||||
state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the vanilla
|
||||
eps-greedy exploration and is put here to conform to the function signature required for the exploration
|
||||
strategy parameter for ``DQN``.
|
||||
action (np.ndarray): Action(s) chosen greedily by the policy.
|
||||
|
||||
Returns:
|
||||
Exploratory actions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EpsilonGreedy(ExploreStrategy):
|
||||
"""Epsilon-greedy exploration. Returns uniformly random action with probability `epsilon` or returns original
|
||||
action with probability `1.0 - epsilon`.
|
||||
|
||||
Args:
|
||||
state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the vanilla
|
||||
eps-greedy exploration and is put here to conform to the function signature required for the exploration
|
||||
strategy parameter for ``DQN``.
|
||||
action (np.ndarray): Action(s) chosen greedily by the policy.
|
||||
num_actions (int): Number of possible actions.
|
||||
epsilon (float): The probability that a random action will be selected.
|
||||
|
||||
Returns:
|
||||
Exploratory actions.
|
||||
"""
|
||||
return np.array([act if np.random.random() > epsilon else np.random.randint(num_actions) for act in action])
|
||||
|
||||
def __init__(self, num_actions: int, epsilon: float) -> None:
|
||||
super(EpsilonGreedy, self).__init__()
|
||||
|
||||
assert 0.0 <= epsilon <= 1.0
|
||||
|
||||
self._num_actions = num_actions
|
||||
self._eps = epsilon
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
state: np.ndarray,
|
||||
action: np.ndarray,
|
||||
**kwargs: Any,
|
||||
) -> np.ndarray:
|
||||
return np.array(
|
||||
[act if np.random.random() > self._eps else np.random.randint(self._num_actions) for act in action],
|
||||
)
|
||||
|
||||
|
||||
def uniform_noise(
|
||||
state: np.ndarray,
|
||||
action: np.ndarray,
|
||||
min_action: Union[float, list, np.ndarray] = None,
|
||||
max_action: Union[float, list, np.ndarray] = None,
|
||||
*,
|
||||
low: Union[float, list, np.ndarray],
|
||||
high: Union[float, list, np.ndarray],
|
||||
) -> Union[float, np.ndarray]:
|
||||
"""Apply a uniform noise to a continuous multidimensional action.
|
||||
class LinearExploration(ExploreStrategy):
|
||||
"""Epsilon greedy which the probability `epsilon` is linearly interpolated between `start_explore_prob` and
|
||||
`end_explore_prob` over `explore_steps`. After this many timesteps pass, `epsilon` is fixed to `end_explore_prob`.
|
||||
|
||||
Args:
|
||||
state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the gaussian noise
|
||||
exploration scheme and is put here to conform to the function signature for the exploration in continuous
|
||||
action spaces.
|
||||
action (np.ndarray): Action(s) chosen greedily by the policy.
|
||||
min_action (Union[float, list, np.ndarray], default=None): Lower bound for the multidimensional action space.
|
||||
max_action (Union[float, list, np.ndarray], default=None): Upper bound for the multidimensional action space.
|
||||
low (Union[float, list, np.ndarray]): Lower bound for the noise range.
|
||||
high (Union[float, list, np.ndarray]): Upper bound for the noise range.
|
||||
|
||||
Returns:
|
||||
Exploration actions with added noise.
|
||||
num_actions (int): Number of possible actions.
|
||||
explore_steps (int): Maximum number of steps to interpolate probability.
|
||||
start_explore_prob (float): Starting explore probability.
|
||||
end_explore_prob (float): Ending explore probability.
|
||||
"""
|
||||
if min_action is None and max_action is None:
|
||||
return action + np.random.uniform(low, high, size=action.shape)
|
||||
else:
|
||||
return np.clip(action + np.random.uniform(low, high, size=action.shape), min_action, max_action)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_actions: int,
|
||||
explore_steps: int,
|
||||
start_explore_prob: float,
|
||||
end_explore_prob: float,
|
||||
) -> None:
|
||||
super(LinearExploration, self).__init__()
|
||||
|
||||
def gaussian_noise(
|
||||
state: np.ndarray,
|
||||
action: np.ndarray,
|
||||
min_action: Union[float, list, np.ndarray] = None,
|
||||
max_action: Union[float, list, np.ndarray] = None,
|
||||
*,
|
||||
mean: Union[float, list, np.ndarray] = 0.0,
|
||||
stddev: Union[float, list, np.ndarray] = 1.0,
|
||||
relative: bool = False,
|
||||
) -> Union[float, np.ndarray]:
|
||||
"""Apply a gaussian noise to a continuous multidimensional action.
|
||||
self._call_count = 0
|
||||
|
||||
Args:
|
||||
state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the gaussian noise
|
||||
exploration scheme and is put here to conform to the function signature for the exploration in continuous
|
||||
action spaces.
|
||||
action (np.ndarray): Action(s) chosen greedily by the policy.
|
||||
min_action (Union[float, list, np.ndarray], default=None): Lower bound for the multidimensional action space.
|
||||
max_action (Union[float, list, np.ndarray], default=None): Upper bound for the multidimensional action space.
|
||||
mean (Union[float, list, np.ndarray], default=0.0): Gaussian noise mean.
|
||||
stddev (Union[float, list, np.ndarray], default=1.0): Standard deviation for the Gaussian noise.
|
||||
relative (bool, default=False): If True, the generated noise is treated as a relative measure and will
|
||||
be multiplied by the action itself before being added to the action.
|
||||
self._num_actions = num_actions
|
||||
self._explore_steps = explore_steps
|
||||
self._start_explore_prob = start_explore_prob
|
||||
self._end_explore_prob = end_explore_prob
|
||||
|
||||
Returns:
|
||||
Exploration actions with added noise (a numpy ndarray).
|
||||
"""
|
||||
noise = np.random.normal(loc=mean, scale=stddev, size=action.shape)
|
||||
if min_action is None and max_action is None:
|
||||
return action + ((noise * action) if relative else noise)
|
||||
else:
|
||||
return np.clip(action + ((noise * action) if relative else noise), min_action, max_action)
|
||||
def get_action(
|
||||
self,
|
||||
state: np.ndarray,
|
||||
action: np.ndarray,
|
||||
**kwargs: Any,
|
||||
) -> np.ndarray:
|
||||
ratio = min(self._call_count / self._explore_steps, 1.0)
|
||||
epsilon = self._start_explore_prob + (self._end_explore_prob - self._start_explore_prob) * ratio
|
||||
explore_flag = np.random.random() < epsilon
|
||||
action = np.array([np.random.randint(self._num_actions) if explore_flag else act for act in action])
|
||||
|
||||
self._call_count += 1
|
||||
return action
|
||||
|
|
|
@ -13,7 +13,7 @@ class FullyConnected(nn.Module):
|
|||
|
||||
Args:
|
||||
input_dim (int): Network input dimension.
|
||||
output_dim (int): Network output dimension.
|
||||
output_dim (int): Network output dimension. If it is 0, will not create the top layer.
|
||||
hidden_dims (List[int]): Dimensions of hidden layers. Its length is the number of hidden layers. For example,
|
||||
`hidden_dims=[128, 256]` refers to two hidden layers with output dim of 128 and 256, respectively.
|
||||
activation (Optional[Type[torch.nn.Module], default=nn.ReLU): Activation class provided by ``torch.nn`` or a
|
||||
|
@ -52,7 +52,6 @@ class FullyConnected(nn.Module):
|
|||
super(FullyConnected, self).__init__()
|
||||
self._input_dim = input_dim
|
||||
self._hidden_dims = hidden_dims if hidden_dims is not None else []
|
||||
self._output_dim = output_dim
|
||||
|
||||
# network features
|
||||
self._activation = activation if activation else None
|
||||
|
@ -76,9 +75,13 @@ class FullyConnected(nn.Module):
|
|||
self._build_layer(in_dim, out_dim, activation=self._activation) for in_dim, out_dim in zip(dims, dims[1:])
|
||||
]
|
||||
# top layer
|
||||
layers.append(
|
||||
self._build_layer(dims[-1], self._output_dim, head=self._head, activation=self._output_activation),
|
||||
)
|
||||
if output_dim != 0:
|
||||
layers.append(
|
||||
self._build_layer(dims[-1], output_dim, head=self._head, activation=self._output_activation),
|
||||
)
|
||||
self._output_dim = output_dim
|
||||
else:
|
||||
self._output_dim = hidden_dims[-1]
|
||||
|
||||
self._net = nn.Sequential(*layers)
|
||||
|
||||
|
|
|
@ -2,15 +2,14 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from abc import ABCMeta
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from maro.rl.exploration import epsilon_greedy
|
||||
from maro.rl.exploration import ExploreStrategy
|
||||
from maro.rl.model import DiscretePolicyNet, DiscreteQNet
|
||||
from maro.rl.utils import match_shape, ndarray_to_tensor
|
||||
from maro.utils import clone
|
||||
|
||||
from .abs_policy import RLPolicy
|
||||
|
||||
|
@ -69,8 +68,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
name (str): Name of the policy.
|
||||
q_net (DiscreteQNet): Q-net used in this value-based policy.
|
||||
trainable (bool, default=True): Whether this policy is trainable.
|
||||
exploration_strategy (Tuple[Callable, dict], default=(epsilon_greedy, {"epsilon": 0.1})): Exploration strategy.
|
||||
exploration_scheduling_options (List[tuple], default=None): List of exploration scheduler options.
|
||||
explore_strategy (Optional[ExploreStrategy], default=None): Explore strategy.
|
||||
warmup (int, default=50000): Number of steps for uniform-random action selection, before running real policy.
|
||||
Helps exploration.
|
||||
"""
|
||||
|
@ -80,8 +78,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
name: str,
|
||||
q_net: DiscreteQNet,
|
||||
trainable: bool = True,
|
||||
exploration_strategy: Tuple[Callable, dict] = (epsilon_greedy, {"epsilon": 0.1}),
|
||||
exploration_scheduling_options: List[tuple] = None,
|
||||
explore_strategy: Optional[ExploreStrategy] = None,
|
||||
warmup: int = 50000,
|
||||
) -> None:
|
||||
assert isinstance(q_net, DiscreteQNet)
|
||||
|
@ -94,15 +91,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
warmup=warmup,
|
||||
)
|
||||
self._q_net = q_net
|
||||
|
||||
self._exploration_func = exploration_strategy[0]
|
||||
self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing
|
||||
self._exploration_schedulers = (
|
||||
[opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options]
|
||||
if exploration_scheduling_options is not None
|
||||
else []
|
||||
)
|
||||
|
||||
self._explore_strategy = explore_strategy
|
||||
self._softmax = torch.nn.Softmax(dim=1)
|
||||
|
||||
@property
|
||||
|
@ -176,9 +165,6 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
assert match_shape(q_values, (states.shape[0],)) # [B]
|
||||
return q_values
|
||||
|
||||
def explore(self) -> None:
|
||||
pass # Overwrite the base method and turn off explore mode.
|
||||
|
||||
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._get_actions_with_probs_impl(states, **kwargs)[0]
|
||||
|
||||
|
@ -187,17 +173,11 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
q_matrix_softmax = self._softmax(q_matrix)
|
||||
_, actions = q_matrix.max(dim=1) # [B], [B]
|
||||
|
||||
if self._is_exploring:
|
||||
actions = self._exploration_func(
|
||||
states,
|
||||
actions.cpu().numpy(),
|
||||
self.action_num,
|
||||
**self._exploration_params,
|
||||
**kwargs,
|
||||
)
|
||||
if self._is_exploring and self._explore_strategy is not None:
|
||||
actions = self._explore_strategy.get_action(state=states.cpu().numpy(), action=actions.cpu().numpy())
|
||||
actions = ndarray_to_tensor(actions, device=self._device)
|
||||
|
||||
actions = actions.unsqueeze(1)
|
||||
actions = actions.unsqueeze(1).long()
|
||||
return actions, q_matrix_softmax.gather(1, actions).squeeze(-1) # [B, 1]
|
||||
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
|
|
@ -533,7 +533,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
|
||||
return {
|
||||
"experiences": [total_experiences],
|
||||
"info": [deepcopy(self._info)], # TODO: may have overhead issues. Leave to future work.
|
||||
"info": [deepcopy(self._info)],
|
||||
}
|
||||
|
||||
def set_policy_state(self, policy_state_dict: Dict[str, dict]) -> None:
|
||||
|
@ -592,7 +592,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
self._step(list(env_action_dict.values()))
|
||||
cache_element.next_state = self._state
|
||||
|
||||
if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
|
||||
if self._reward_eval_delay is None:
|
||||
self._calc_reward(cache_element)
|
||||
self._post_eval_step(cache_element)
|
||||
|
||||
|
@ -606,7 +606,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
self._calc_reward(cache_element)
|
||||
self._post_eval_step(cache_element)
|
||||
|
||||
info_list.append(self._info)
|
||||
info_list.append(deepcopy(self._info))
|
||||
|
||||
return {"info": info_list}
|
||||
|
||||
|
|
|
@ -2,7 +2,13 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from .proxy import TrainingProxy
|
||||
from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory
|
||||
from .replay_memory import (
|
||||
FIFOMultiReplayMemory,
|
||||
FIFOReplayMemory,
|
||||
PrioritizedReplayMemory,
|
||||
RandomMultiReplayMemory,
|
||||
RandomReplayMemory,
|
||||
)
|
||||
from .train_ops import AbsTrainOps, RemoteOps, remote
|
||||
from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer
|
||||
from .training_manager import TrainingManager
|
||||
|
@ -12,6 +18,7 @@ __all__ = [
|
|||
"TrainingProxy",
|
||||
"FIFOMultiReplayMemory",
|
||||
"FIFOReplayMemory",
|
||||
"PrioritizedReplayMemory",
|
||||
"RandomMultiReplayMemory",
|
||||
"RandomReplayMemory",
|
||||
"AbsTrainOps",
|
||||
|
|
|
@ -261,9 +261,6 @@ class DDPGTrainer(SingleAgentTrainer):
|
|||
assert isinstance(policy, ContinuousRLPolicy)
|
||||
self._policy = policy
|
||||
|
||||
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
|
||||
return transition_batch
|
||||
|
||||
def get_local_ops(self) -> AbsTrainOps:
|
||||
return DDPGOps(
|
||||
name=self._policy.name,
|
||||
|
|
|
@ -2,12 +2,21 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, cast
|
||||
from typing import Dict, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from maro.rl.policy import RLPolicy, ValueBasedPolicy
|
||||
from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote
|
||||
from maro.rl.training import (
|
||||
AbsTrainOps,
|
||||
BaseTrainerParams,
|
||||
PrioritizedReplayMemory,
|
||||
RandomReplayMemory,
|
||||
RemoteOps,
|
||||
SingleAgentTrainer,
|
||||
remote,
|
||||
)
|
||||
from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor
|
||||
from maro.utils import clone
|
||||
|
||||
|
@ -15,6 +24,9 @@ from maro.utils import clone
|
|||
@dataclass
|
||||
class DQNParams(BaseTrainerParams):
|
||||
"""
|
||||
use_prioritized_replay (bool, default=False): Whether to use prioritized replay memory.
|
||||
alpha (float, default=0.4): Alpha in prioritized replay memory.
|
||||
beta (float, default=0.6): Beta in prioritized replay memory.
|
||||
num_epochs (int, default=1): Number of training epochs.
|
||||
update_target_every (int, default=5): Number of gradient steps between target model updates.
|
||||
soft_update_coef (float, default=0.1): Soft update coefficient, e.g.,
|
||||
|
@ -27,11 +39,13 @@ class DQNParams(BaseTrainerParams):
|
|||
sequentially with wrap-around.
|
||||
"""
|
||||
|
||||
use_prioritized_replay: bool = False
|
||||
alpha: float = 0.4
|
||||
beta: float = 0.6
|
||||
num_epochs: int = 1
|
||||
update_target_every: int = 5
|
||||
soft_update_coef: float = 0.1
|
||||
double: bool = False
|
||||
random_overwrite: bool = False
|
||||
|
||||
|
||||
class DQNOps(AbsTrainOps):
|
||||
|
@ -54,20 +68,21 @@ class DQNOps(AbsTrainOps):
|
|||
self._reward_discount = reward_discount
|
||||
self._soft_update_coef = params.soft_update_coef
|
||||
self._double = params.double
|
||||
self._loss_func = torch.nn.MSELoss()
|
||||
|
||||
self._target_policy: ValueBasedPolicy = clone(self._policy)
|
||||
self._target_policy.set_name(f"target_{self._policy.name}")
|
||||
self._target_policy.eval()
|
||||
|
||||
def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor:
|
||||
def _get_batch_loss(self, batch: TransitionBatch, weight: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute the loss of the batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
weight (np.ndarray): Weight of each data entry.
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor): The loss of the batch.
|
||||
td_error (torch.Tensor): TD-error of the batch.
|
||||
"""
|
||||
assert isinstance(batch, TransitionBatch)
|
||||
assert isinstance(self._policy, ValueBasedPolicy)
|
||||
|
@ -79,19 +94,21 @@ class DQNOps(AbsTrainOps):
|
|||
rewards = ndarray_to_tensor(batch.rewards, device=self._device)
|
||||
terminals = ndarray_to_tensor(batch.terminals, device=self._device).float()
|
||||
|
||||
weight = ndarray_to_tensor(weight, device=self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self._double:
|
||||
self._policy.exploit()
|
||||
actions_by_eval_policy = self._policy.get_actions_tensor(next_states)
|
||||
next_q_values = self._target_policy.q_values_tensor(next_states, actions_by_eval_policy)
|
||||
else:
|
||||
self._target_policy.exploit()
|
||||
actions = self._target_policy.get_actions_tensor(next_states)
|
||||
next_q_values = self._target_policy.q_values_tensor(next_states, actions)
|
||||
next_q_values = self._target_policy.q_values_for_all_actions_tensor(next_states).max(dim=1)[0]
|
||||
|
||||
target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach()
|
||||
q_values = self._policy.q_values_tensor(states, actions)
|
||||
return self._loss_func(q_values, target_q_values)
|
||||
td_error = target_q_values - q_values
|
||||
|
||||
return (td_error.pow(2) * weight).mean(), td_error
|
||||
|
||||
@remote
|
||||
def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
|
||||
|
@ -103,7 +120,8 @@ class DQNOps(AbsTrainOps):
|
|||
Returns:
|
||||
grad (torch.Tensor): The gradient of the batch.
|
||||
"""
|
||||
return self._policy.get_gradients(self._get_batch_loss(batch))
|
||||
loss, _ = self._get_batch_loss(batch)
|
||||
return self._policy.get_gradients(loss)
|
||||
|
||||
def update_with_grad(self, grad_dict: dict) -> None:
|
||||
"""Update the network with remotely computed gradients.
|
||||
|
@ -114,14 +132,20 @@ class DQNOps(AbsTrainOps):
|
|||
self._policy.train()
|
||||
self._policy.apply_gradients(grad_dict)
|
||||
|
||||
def update(self, batch: TransitionBatch) -> None:
|
||||
def update(self, batch: TransitionBatch, weight: np.ndarray) -> np.ndarray:
|
||||
"""Update the network using a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
weight (np.ndarray): Weight of each data entry.
|
||||
|
||||
Returns:
|
||||
td_errors (np.ndarray)
|
||||
"""
|
||||
self._policy.train()
|
||||
self._policy.train_step(self._get_batch_loss(batch))
|
||||
loss, td_error = self._get_batch_loss(batch, weight)
|
||||
self._policy.train_step(loss)
|
||||
return td_error.detach().numpy()
|
||||
|
||||
def get_non_policy_state(self) -> dict:
|
||||
return {
|
||||
|
@ -168,20 +192,27 @@ class DQNTrainer(SingleAgentTrainer):
|
|||
|
||||
def build(self) -> None:
|
||||
self._ops = cast(DQNOps, self.get_ops())
|
||||
self._replay_memory = RandomReplayMemory(
|
||||
capacity=self._replay_memory_capacity,
|
||||
state_dim=self._ops.policy_state_dim,
|
||||
action_dim=self._ops.policy_action_dim,
|
||||
random_overwrite=self._params.random_overwrite,
|
||||
)
|
||||
|
||||
if self._params.use_prioritized_replay:
|
||||
self._replay_memory = PrioritizedReplayMemory(
|
||||
capacity=self._replay_memory_capacity,
|
||||
state_dim=self._ops.policy_state_dim,
|
||||
action_dim=self._ops.policy_action_dim,
|
||||
alpha=self._params.alpha,
|
||||
beta=self._params.beta,
|
||||
)
|
||||
else:
|
||||
self._replay_memory = RandomReplayMemory(
|
||||
capacity=self._replay_memory_capacity,
|
||||
state_dim=self._ops.policy_state_dim,
|
||||
action_dim=self._ops.policy_action_dim,
|
||||
random_overwrite=False,
|
||||
)
|
||||
|
||||
def _register_policy(self, policy: RLPolicy) -> None:
|
||||
assert isinstance(policy, ValueBasedPolicy)
|
||||
self._policy = policy
|
||||
|
||||
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
|
||||
return transition_batch
|
||||
|
||||
def get_local_ops(self) -> AbsTrainOps:
|
||||
return DQNOps(
|
||||
name=self._policy.name,
|
||||
|
@ -191,13 +222,24 @@ class DQNTrainer(SingleAgentTrainer):
|
|||
params=self._params,
|
||||
)
|
||||
|
||||
def _get_batch(self, batch_size: int = None) -> TransitionBatch:
|
||||
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)
|
||||
def _get_batch(self, batch_size: int = None) -> Tuple[TransitionBatch, np.ndarray, np.ndarray]:
|
||||
indexes = self.replay_memory.get_sample_indexes(batch_size or self._batch_size)
|
||||
batch = self.replay_memory.sample_by_indexes(indexes)
|
||||
|
||||
if self._params.use_prioritized_replay:
|
||||
weight = cast(PrioritizedReplayMemory, self.replay_memory).get_weight(indexes)
|
||||
else:
|
||||
weight = np.ones(len(indexes))
|
||||
|
||||
return batch, indexes, weight
|
||||
|
||||
def train_step(self) -> None:
|
||||
assert isinstance(self._ops, DQNOps)
|
||||
for _ in range(self._params.num_epochs):
|
||||
self._ops.update(self._get_batch())
|
||||
batch, indexes, weight = self._get_batch()
|
||||
td_error = self._ops.update(batch, weight)
|
||||
if self._params.use_prioritized_replay:
|
||||
cast(PrioritizedReplayMemory, self.replay_memory).update_weight(indexes, td_error)
|
||||
|
||||
self._try_soft_update_target()
|
||||
|
||||
|
|
|
@ -272,9 +272,6 @@ class SoftActorCriticTrainer(SingleAgentTrainer):
|
|||
if early_stop:
|
||||
break
|
||||
|
||||
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
|
||||
return transition_batch
|
||||
|
||||
def get_local_ops(self) -> SoftActorCriticOps:
|
||||
return SoftActorCriticOps(
|
||||
name=self._policy.name,
|
||||
|
|
|
@ -88,6 +88,73 @@ class RandomIndexScheduler(AbsIndexScheduler):
|
|||
return np.random.choice(self._size, size=batch_size, replace=True)
|
||||
|
||||
|
||||
class PriorityReplayIndexScheduler(AbsIndexScheduler):
|
||||
"""
|
||||
Indexer for priority replay memory: https://arxiv.org/abs/1511.05952.
|
||||
|
||||
Args:
|
||||
capacity (int): Maximum capacity of the replay memory.
|
||||
alpha (float): Alpha (see original paper for explanation).
|
||||
beta (float): Alpha (see original paper for explanation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
) -> None:
|
||||
super(PriorityReplayIndexScheduler, self).__init__(capacity)
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
self._max_prio = self._min_prio = 1.0
|
||||
self._weights = np.zeros(capacity, dtype=np.float32)
|
||||
|
||||
self._ptr = self._size = 0
|
||||
|
||||
def init_weights(self, indexes: np.ndarray) -> None:
|
||||
self._weights[indexes] = self._max_prio**self._alpha
|
||||
|
||||
def get_weight(self, indexes: np.ndarray) -> np.ndarray:
|
||||
# important sampling weight calculation
|
||||
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
|
||||
# simplified formula: (p_j/p_min)**(-beta)
|
||||
return (self._weights[indexes] / self._min_prio) ** (-self._beta)
|
||||
|
||||
def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None:
|
||||
assert indexes.shape == weight.shape
|
||||
weight = np.abs(weight) + np.finfo(np.float32).eps.item()
|
||||
self._weights[indexes] = weight**self._alpha
|
||||
self._max_prio = max(self._max_prio, weight.max())
|
||||
self._min_prio = min(self._min_prio, weight.min())
|
||||
|
||||
def get_put_indexes(self, batch_size: int) -> np.ndarray:
|
||||
if self._ptr + batch_size <= self._capacity:
|
||||
indexes = np.arange(self._ptr, self._ptr + batch_size)
|
||||
self._ptr += batch_size
|
||||
else:
|
||||
overwrites = self._ptr + batch_size - self._capacity
|
||||
indexes = np.concatenate(
|
||||
[
|
||||
np.arange(self._ptr, self._capacity),
|
||||
np.arange(overwrites),
|
||||
],
|
||||
)
|
||||
self._ptr = overwrites
|
||||
|
||||
self._size = min(self._size + batch_size, self._capacity)
|
||||
self.init_weights(indexes)
|
||||
return indexes
|
||||
|
||||
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
|
||||
assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}"
|
||||
assert self._size > 0, "Cannot sample from an empty memory."
|
||||
|
||||
weights = self._weights[: self._size]
|
||||
weights = weights / weights.sum()
|
||||
return np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=True)
|
||||
|
||||
|
||||
class FIFOIndexScheduler(AbsIndexScheduler):
|
||||
"""First-in-first-out index scheduler.
|
||||
|
||||
|
@ -154,11 +221,11 @@ class AbsReplayMemory(object, metaclass=ABCMeta):
|
|||
def state_dim(self) -> int:
|
||||
return self._state_dim
|
||||
|
||||
def _get_put_indexes(self, batch_size: int) -> np.ndarray:
|
||||
def get_put_indexes(self, batch_size: int) -> np.ndarray:
|
||||
"""Please refer to the doc string in AbsIndexScheduler."""
|
||||
return self._idx_scheduler.get_put_indexes(batch_size)
|
||||
|
||||
def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
|
||||
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
|
||||
"""Please refer to the doc string in AbsIndexScheduler."""
|
||||
return self._idx_scheduler.get_sample_indexes(batch_size)
|
||||
|
||||
|
@ -225,10 +292,10 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
if transition_batch.old_logps is not None:
|
||||
match_shape(transition_batch.old_logps, (batch_size,))
|
||||
|
||||
self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch)
|
||||
self.put_by_indexes(self.get_put_indexes(batch_size), transition_batch)
|
||||
self._n_sample = min(self._n_sample + transition_batch.size, self._capacity)
|
||||
|
||||
def _put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None:
|
||||
def put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None:
|
||||
"""Store a transition batch into the memory at the give indexes.
|
||||
|
||||
Args:
|
||||
|
@ -258,7 +325,7 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
Returns:
|
||||
batch (TransitionBatch): The sampled batch.
|
||||
"""
|
||||
indexes = self._get_sample_indexes(batch_size)
|
||||
indexes = self.get_sample_indexes(batch_size)
|
||||
return self.sample_by_indexes(indexes)
|
||||
|
||||
def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
|
||||
|
@ -306,6 +373,31 @@ class RandomReplayMemory(ReplayMemory):
|
|||
return self._random_overwrite
|
||||
|
||||
|
||||
class PrioritizedReplayMemory(ReplayMemory):
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
) -> None:
|
||||
super(PrioritizedReplayMemory, self).__init__(
|
||||
capacity,
|
||||
state_dim,
|
||||
action_dim,
|
||||
PriorityReplayIndexScheduler(capacity, alpha, beta),
|
||||
)
|
||||
|
||||
def get_weight(self, indexes: np.ndarray) -> np.ndarray:
|
||||
assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler)
|
||||
return self._idx_scheduler.get_weight(indexes)
|
||||
|
||||
def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None:
|
||||
assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler)
|
||||
self._idx_scheduler.update_weight(indexes, weight)
|
||||
|
||||
|
||||
class FIFOReplayMemory(ReplayMemory):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -393,9 +485,9 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
assert match_shape(transition_batch.agent_states[i], (batch_size, self._agent_states_dims[i]))
|
||||
assert match_shape(transition_batch.next_agent_states[i], (batch_size, self._agent_states_dims[i]))
|
||||
|
||||
self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch=transition_batch)
|
||||
self.put_by_indexes(self.get_put_indexes(batch_size), transition_batch=transition_batch)
|
||||
|
||||
def _put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None:
|
||||
def put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None:
|
||||
"""Store a transition batch into the memory at the give indexes.
|
||||
|
||||
Args:
|
||||
|
@ -424,7 +516,7 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
Returns:
|
||||
batch (MultiTransitionBatch): The sampled batch.
|
||||
"""
|
||||
indexes = self._get_sample_indexes(batch_size)
|
||||
indexes = self.get_sample_indexes(batch_size)
|
||||
return self.sample_by_indexes(indexes)
|
||||
|
||||
def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:
|
||||
|
|
|
@ -271,9 +271,8 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
|||
transition_batch = self._preprocess_batch(transition_batch)
|
||||
self.replay_memory.put(transition_batch)
|
||||
|
||||
@abstractmethod
|
||||
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
|
||||
raise NotImplementedError
|
||||
return transition_batch
|
||||
|
||||
def _assert_ops_exists(self) -> None:
|
||||
if not self.ops:
|
||||
|
|
|
@ -3,12 +3,14 @@
|
|||
|
||||
from typing import cast
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from maro.simulator import Env
|
||||
|
||||
from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine
|
||||
|
||||
env_conf = {
|
||||
"topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4
|
||||
"topology": "CartPole-v1", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4, CartPole-v1
|
||||
"start_tick": 0,
|
||||
"durations": 100000, # Set a very large number
|
||||
"options": {},
|
||||
|
@ -19,8 +21,18 @@ test_env = Env(business_engine_cls=GymBusinessEngine, **env_conf)
|
|||
num_agents = len(learn_env.agent_idx_list)
|
||||
|
||||
gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env
|
||||
gym_action_space = gym_env.action_space
|
||||
gym_state_dim = gym_env.observation_space.shape[0]
|
||||
gym_action_dim = gym_action_space.shape[0]
|
||||
action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high
|
||||
action_limit = gym_action_space.high[0]
|
||||
gym_action_space = gym_env.action_space
|
||||
is_discrete = isinstance(gym_action_space, spaces.Discrete)
|
||||
if is_discrete:
|
||||
gym_action_space = cast(spaces.Discrete, gym_action_space)
|
||||
gym_action_dim = 1
|
||||
gym_action_num = gym_action_space.n
|
||||
action_lower_bound, action_upper_bound = None, None # Should never be used
|
||||
action_limit = None # Should never be used
|
||||
else:
|
||||
gym_action_space = cast(spaces.Box, gym_action_space)
|
||||
gym_action_dim = gym_action_space.shape[0]
|
||||
gym_action_num = -1 # Should never be used
|
||||
action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high
|
||||
action_limit = action_upper_bound[0]
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Tuple, Type, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from maro.rl.policy.abs_policy import AbsPolicy
|
||||
from maro.rl.rollout import AbsEnvSampler, CacheElement
|
||||
|
@ -40,6 +41,10 @@ class GymEnvSampler(AbsEnvSampler):
|
|||
self._sample_rewards = []
|
||||
self._eval_rewards = []
|
||||
|
||||
gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env
|
||||
gym_action_space = gym_env.action_space
|
||||
self._is_discrete = isinstance(gym_action_space, spaces.Discrete)
|
||||
|
||||
def _get_global_and_agent_state_impl(
|
||||
self,
|
||||
event: DecisionEvent,
|
||||
|
@ -48,7 +53,7 @@ class GymEnvSampler(AbsEnvSampler):
|
|||
return None, {0: event.state}
|
||||
|
||||
def _translate_to_env_action(self, action_dict: dict, event: Any) -> dict:
|
||||
return {k: Action(v) for k, v in action_dict.items()}
|
||||
return {k: Action(v.item() if self._is_discrete else v) for k, v in action_dict.items()}
|
||||
|
||||
def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]:
|
||||
be = self._env.business_engine
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
# Performance for Gym Task Suite
|
||||
|
||||
We benchmarked the MARO RL Toolkit implementation in Gym task suite. Some are compared to the benchmarks in
|
||||
[OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#). We've tried to align the
|
||||
[OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#) and [RL Baseline Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/benchmark.md). We've tried to align the
|
||||
hyper-parameters for these benchmarks , but limited by the environment version difference, there may be some gaps
|
||||
between the performance here and that in Spinning Up benchmarks. Generally speaking, the performance is comparable.
|
||||
|
||||
## Experimental Setting
|
||||
## Compare with OpenAI Spinning Up
|
||||
|
||||
We compare the performance of PPO, SAC, and DDPG in MARO with [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#).
|
||||
|
||||
### Experimental Setting
|
||||
|
||||
The hyper-parameters are set to align with those used in
|
||||
[Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#experiment-details):
|
||||
|
@ -29,7 +33,7 @@ The hyper-parameters are set to align with those used in
|
|||
|
||||
More details about the parameters can be found in *tests/rl/tasks/*.
|
||||
|
||||
## Performance
|
||||
### Performance
|
||||
|
||||
Five environments from the MuJoCo Gym task suite are reported in Spinning Up, they are: HalfCheetah, Hopper, Walker2d,
|
||||
Swimmer, and Ant. The commit id of the code used to conduct the experiments for MARO RL benchmarks is ee25ce1e97.
|
||||
|
@ -52,3 +56,28 @@ python tests/rl/plot.py --smooth WINDOWSIZE
|
|||
| [**Walker2d**](https://gymnasium.farama.org/environments/mujoco/walker2d/) | ![Wab](https://spinningup.openai.com/en/latest/_images/pytorch_walker2d_performance.svg) | ![Wa1](./log/Walker2d_1.png) | ![Wa11](./log/Walker2d_11.png) |
|
||||
| [**Swimmer**](https://gymnasium.farama.org/environments/mujoco/swimmer/) | ![Swb](https://spinningup.openai.com/en/latest/_images/pytorch_swimmer_performance.svg) | ![Sw1](./log/Swimmer_1.png) | ![Sw11](./log/Swimmer_11.png) |
|
||||
| [**Ant**](https://gymnasium.farama.org/environments/mujoco/ant/) | ![Anb](https://spinningup.openai.com/en/latest/_images/pytorch_ant_performance.svg) | ![An1](./log/Ant_1.png) | ![An11](./log/Ant_11.png) |
|
||||
|
||||
## Compare with RL Baseline Zoo
|
||||
|
||||
[RL Baseline Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/benchmark.md) provides a comprehensive set of benchmarks for multiple algorithms and environments.
|
||||
However, unlike OpenAI Spinning Up, it does not provide the complete learning curve. Instead, we can only find the final metrics in it.
|
||||
We therefore leave the comparison with RL Baseline Zoo as a minor addition.
|
||||
|
||||
We compare the performance of DQN with RL Baseline Zoo.
|
||||
|
||||
### Experimental Setting
|
||||
|
||||
- Batch size: size 64 for each gradient descent step;
|
||||
- Network: size (256) with relu units;
|
||||
- Performance metric: measured as the average trajectory return across the batch collected at 10 epochs;
|
||||
- Total timesteps: 150,000.
|
||||
|
||||
### Performance
|
||||
|
||||
More details about the parameters can be found in *tests/rl/tasks/*.
|
||||
Please refer to the original link of RL Baseline Zoo for the baseline metrics.
|
||||
|
||||
| algo | env_id |mean_reward|
|
||||
|--------|-------------------------------|----------:|
|
||||
|DQN |CartPole-v1 | 500.00 |
|
||||
|DQN |MountainCar-v0 | -116.90 |
|
||||
|
|
|
@ -19,6 +19,7 @@ from tests.rl.gym_wrapper.common import (
|
|||
action_upper_bound,
|
||||
gym_action_dim,
|
||||
gym_state_dim,
|
||||
is_discrete,
|
||||
learn_env,
|
||||
num_agents,
|
||||
test_env,
|
||||
|
@ -109,6 +110,8 @@ def get_ac_trainer(name: str, state_dim: int) -> ActorCriticTrainer:
|
|||
)
|
||||
|
||||
|
||||
assert not is_discrete
|
||||
|
||||
algorithm = "ac"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
|
|
|
@ -20,6 +20,7 @@ from tests.rl.gym_wrapper.common import (
|
|||
gym_action_dim,
|
||||
gym_action_space,
|
||||
gym_state_dim,
|
||||
is_discrete,
|
||||
learn_env,
|
||||
num_agents,
|
||||
test_env,
|
||||
|
@ -123,6 +124,8 @@ def get_ddpg_trainer(name: str, state_dim: int, action_dim: int) -> DDPGTrainer:
|
|||
)
|
||||
|
||||
|
||||
assert not is_discrete
|
||||
|
||||
algorithm = "ddpg"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
|
||||
from maro.rl.exploration import LinearExploration
|
||||
from maro.rl.model import DiscreteQNet, FullyConnected
|
||||
from maro.rl.policy import ValueBasedPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.training.algorithms import DQNParams, DQNTrainer
|
||||
|
||||
from tests.rl.gym_wrapper.common import gym_action_num, gym_state_dim, is_discrete, learn_env, num_agents, test_env
|
||||
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
|
||||
|
||||
net_conf = {
|
||||
"hidden_dims": [256],
|
||||
"activation": torch.nn.ReLU,
|
||||
"output_activation": None,
|
||||
}
|
||||
lr = 1e-3
|
||||
|
||||
|
||||
class MyQNet(DiscreteQNet):
|
||||
def __init__(self, state_dim: int, action_num: int) -> None:
|
||||
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
|
||||
|
||||
self._mlp = FullyConnected(
|
||||
input_dim=state_dim,
|
||||
output_dim=action_num,
|
||||
**net_conf,
|
||||
)
|
||||
self._optim = Adam(self._mlp.parameters(), lr=lr)
|
||||
|
||||
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._mlp(states)
|
||||
|
||||
|
||||
def get_dqn_policy(
|
||||
name: str,
|
||||
state_dim: int,
|
||||
action_num: int,
|
||||
) -> ValueBasedPolicy:
|
||||
return ValueBasedPolicy(
|
||||
name=name,
|
||||
q_net=MyQNet(state_dim=state_dim, action_num=action_num),
|
||||
explore_strategy=LinearExploration(
|
||||
num_actions=action_num,
|
||||
explore_steps=10000,
|
||||
start_explore_prob=1.0,
|
||||
end_explore_prob=0.02,
|
||||
),
|
||||
warmup=0, # TODO: check this
|
||||
)
|
||||
|
||||
|
||||
def get_dqn_trainer(
|
||||
name: str,
|
||||
) -> DQNTrainer:
|
||||
return DQNTrainer(
|
||||
name=name,
|
||||
params=DQNParams(
|
||||
use_prioritized_replay=False, #
|
||||
# alpha=0.4,
|
||||
# beta=0.6,
|
||||
num_epochs=50,
|
||||
update_target_every=10,
|
||||
soft_update_coef=1.0,
|
||||
),
|
||||
replay_memory_capacity=50000,
|
||||
batch_size=64,
|
||||
reward_discount=1.0,
|
||||
)
|
||||
|
||||
|
||||
assert is_discrete
|
||||
|
||||
algorithm = "dqn"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
get_dqn_policy(
|
||||
f"{algorithm}_{i}.policy",
|
||||
state_dim=gym_state_dim,
|
||||
action_num=gym_action_num,
|
||||
)
|
||||
for i in range(num_agents)
|
||||
]
|
||||
trainers = [get_dqn_trainer(f"{algorithm}_{i}") for i in range(num_agents)]
|
||||
|
||||
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)} if torch.cuda.is_available() else None
|
||||
|
||||
rl_component_bundle = RLComponentBundle(
|
||||
env_sampler=GymEnvSampler(
|
||||
learn_env=learn_env,
|
||||
test_env=test_env,
|
||||
policies=policies,
|
||||
agent2policy=agent2policy,
|
||||
),
|
||||
agent2policy=agent2policy,
|
||||
policies=policies,
|
||||
trainers=trainers,
|
||||
device_mapping=device_mapping,
|
||||
)
|
||||
|
||||
__all__ = ["rl_component_bundle"]
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for GYM scenario.
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
job: gym_rl_workflow
|
||||
scenario_path: "tests/rl/tasks/dqn"
|
||||
log_path: "tests/rl/log/dqn_cartpole"
|
||||
main:
|
||||
num_episodes: 3000
|
||||
num_steps: 50
|
||||
eval_schedule: 50
|
||||
num_eval_episodes: 10
|
||||
min_n_sample: 1
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
rollout:
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
training:
|
||||
mode: simple
|
||||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: null
|
||||
interval: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
|
@ -11,6 +11,7 @@ from tests.rl.gym_wrapper.common import (
|
|||
action_upper_bound,
|
||||
gym_action_dim,
|
||||
gym_state_dim,
|
||||
is_discrete,
|
||||
learn_env,
|
||||
num_agents,
|
||||
test_env,
|
||||
|
@ -36,6 +37,8 @@ def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer:
|
|||
)
|
||||
|
||||
|
||||
assert not is_discrete
|
||||
|
||||
algorithm = "ppo"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
|
|
|
@ -24,6 +24,7 @@ from tests.rl.gym_wrapper.common import (
|
|||
gym_action_dim,
|
||||
gym_action_space,
|
||||
gym_state_dim,
|
||||
is_discrete,
|
||||
learn_env,
|
||||
num_agents,
|
||||
test_env,
|
||||
|
@ -133,6 +134,8 @@ def get_sac_trainer(name: str, state_dim: int, action_dim: int) -> SoftActorCrit
|
|||
)
|
||||
|
||||
|
||||
assert not is_discrete
|
||||
|
||||
algorithm = "sac"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
|
|
Загрузка…
Ссылка в новой задаче