refine import path for examples (#195)
* refine import path for examples * refine indents * fixed formatting issues * update code style * add editorconfig-checker, add editorconfig path into lint, change super-linter version * change path for code saving in cim.gnn Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com>
This commit is contained in:
Родитель
8c800d721b
Коммит
25642a48ca
|
@ -40,13 +40,14 @@ jobs:
|
|||
# Run Linter against code base #
|
||||
################################
|
||||
- name: Lint Code Base
|
||||
uses: github/super-linter@v3
|
||||
uses: github/super-linter@latest
|
||||
env:
|
||||
VALIDATE_ALL_CODEBASE: false
|
||||
VALIDATE_PYTHON_PYLINT: false # disable pylint, as we have not configure it
|
||||
VALIDATE_PYTHON_BLACK: false # same as above
|
||||
PYTHON_FLAKE8_CONFIG_FILE: tox.ini
|
||||
PYTHON_ISORT_CONFIG_FILE: tox.ini
|
||||
EDITORCONFIG_FILE_NAME: ../../.editorconfig
|
||||
FILTER_REGEX_INCLUDE: maro/.*
|
||||
FILTER_REGEX_EXCLUDE: tests/.*
|
||||
DEFAULT_BRANCH: master
|
||||
|
|
|
@ -33,6 +33,13 @@ MARO is newborn for Reinforcement learning as a Service (RaaS) in the resource o
|
|||
# Lint with flake8.
|
||||
flake8 --config .github/linters/tox.ini
|
||||
|
||||
# Install editorconfig-checker.
|
||||
pip install editorconfig-checker
|
||||
|
||||
# Lint with editorconfig-checker.
|
||||
# PATH: Directory or file path of your changes.
|
||||
editorconfig-checker --config .editorconfig PATH
|
||||
|
||||
```
|
||||
|
||||
- [Update Change Log](https://github.com/github-changelog-generator/github-changelog-generator#installation) (if needed)
|
||||
|
|
|
@ -1,2 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .action_shaper import CIMActionShaper
|
||||
from .agent_manager import DQNAgentManager
|
||||
from .config import config
|
||||
from .experience_shaper import TruncatedExperienceShaper
|
||||
from .state_shaper import CIMStateShaper
|
||||
|
||||
__all__ = [
|
||||
"CIMActionShaper",
|
||||
"DQNAgentManager",
|
||||
"config",
|
||||
"TruncatedExperienceShaper",
|
||||
"CIMStateShaper"
|
||||
]
|
||||
|
|
|
@ -7,8 +7,9 @@ from maro.rl import AbsAgent, ColumnBasedStore
|
|||
|
||||
|
||||
class CIMAgent(AbsAgent):
|
||||
def __init__(self, name, algorithm, experience_pool: ColumnBasedStore, min_experiences_to_train,
|
||||
num_batches, batch_size):
|
||||
def __init__(
|
||||
self, name, algorithm, experience_pool: ColumnBasedStore, min_experiences_to_train, num_batches, batch_size
|
||||
):
|
||||
super().__init__(name, algorithm, experience_pool)
|
||||
self._min_experiences_to_train = min_experiences_to_train
|
||||
self._num_batches = num_batches
|
||||
|
|
|
@ -16,21 +16,31 @@ class DQNAgentManager(AbsAgentManager):
|
|||
set_seeds(config.agents.seed)
|
||||
num_actions = config.agents.algorithm.num_actions
|
||||
for agent_id in self._agent_id_list:
|
||||
eval_model = LearningModel(decision_layers=MLPDecisionLayers(name=f'{agent_id}.policy',
|
||||
input_dim=self._state_shaper.dim,
|
||||
output_dim=num_actions,
|
||||
**config.agents.algorithm.model)
|
||||
)
|
||||
eval_model = LearningModel(
|
||||
decision_layers=MLPDecisionLayers(
|
||||
name=f'{agent_id}.policy',
|
||||
input_dim=self._state_shaper.dim,
|
||||
output_dim=num_actions,
|
||||
**config.agents.algorithm.model
|
||||
)
|
||||
)
|
||||
|
||||
algorithm = DQN(model_dict={"eval": eval_model},
|
||||
optimizer_opt=(RMSprop, config.agents.algorithm.optimizer),
|
||||
loss_func_dict={"eval": smooth_l1_loss},
|
||||
hyper_params=DQNHyperParams(**config.agents.algorithm.hyper_parameters,
|
||||
num_actions=num_actions))
|
||||
algorithm = DQN(
|
||||
model_dict={"eval": eval_model},
|
||||
optimizer_opt=(RMSprop, config.agents.algorithm.optimizer),
|
||||
loss_func_dict={"eval": smooth_l1_loss},
|
||||
hyper_params=DQNHyperParams(
|
||||
**config.agents.algorithm.hyper_parameters,
|
||||
num_actions=num_actions)
|
||||
)
|
||||
|
||||
experience_pool = ColumnBasedStore(**config.agents.experience_pool)
|
||||
agent_dict[agent_id] = CIMAgent(name=agent_id, algorithm=algorithm, experience_pool=experience_pool,
|
||||
**config.agents.training_loop_parameters)
|
||||
agent_dict[agent_id] = CIMAgent(
|
||||
name=agent_id,
|
||||
algorithm=algorithm,
|
||||
experience_pool=experience_pool,
|
||||
**config.agents.training_loop_parameters
|
||||
)
|
||||
|
||||
def store_experiences(self, experiences):
|
||||
for agent_id, exp in experiences.items():
|
||||
|
|
|
@ -9,8 +9,9 @@ from maro.rl import ExperienceShaper
|
|||
|
||||
|
||||
class TruncatedExperienceShaper(ExperienceShaper):
|
||||
def __init__(self, *, time_window: int, time_decay_factor: float, fulfillment_factor: float,
|
||||
shortage_factor: float):
|
||||
def __init__(
|
||||
self, *, time_window: int, time_decay_factor: float, fulfillment_factor: float, shortage_factor: float
|
||||
):
|
||||
super().__init__(reward_func=None)
|
||||
self._time_window = time_window
|
||||
self._time_decay_factor = time_decay_factor
|
||||
|
@ -40,8 +41,10 @@ class TruncatedExperienceShaper(ExperienceShaper):
|
|||
# calculate tc reward
|
||||
future_fulfillment = snapshot_list["ports"][ticks::"fulfillment"]
|
||||
future_shortage = snapshot_list["ports"][ticks::"shortage"]
|
||||
decay_list = [self._time_decay_factor ** i for i in range(end_tick - start_tick)
|
||||
for _ in range(future_fulfillment.shape[0] // (end_tick - start_tick))]
|
||||
decay_list = [
|
||||
self._time_decay_factor ** i for i in range(end_tick - start_tick)
|
||||
for _ in range(future_fulfillment.shape[0] // (end_tick - start_tick))
|
||||
]
|
||||
|
||||
tot_fulfillment = np.dot(future_fulfillment, decay_list)
|
||||
tot_shortage = np.dot(future_shortage, decay_list)
|
||||
|
|
|
@ -6,11 +6,8 @@ import numpy as np
|
|||
from maro.rl import ActorWorker, AgentMode, KStepExperienceShaper, SimpleActor, TwoPhaseLinearExplorer
|
||||
from maro.simulator import Env
|
||||
|
||||
from .components.action_shaper import CIMActionShaper
|
||||
from .components.agent_manager import DQNAgentManager
|
||||
from .components.config import config
|
||||
from .components.experience_shaper import TruncatedExperienceShaper
|
||||
from .components.state_shaper import CIMStateShaper
|
||||
from components import CIMActionShaper, CIMStateShaper, DQNAgentManager, TruncatedExperienceShaper, config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = Env(config.env.scenario, config.env.topology, durations=config.env.durations)
|
||||
|
@ -22,26 +19,32 @@ if __name__ == "__main__":
|
|||
else:
|
||||
experience_shaper = KStepExperienceShaper(
|
||||
reward_func=lambda mt: 1 - mt["container_shortage"] / mt["order_requirements"],
|
||||
**config.experience_shaping.k_step)
|
||||
**config.experience_shaping.k_step
|
||||
)
|
||||
|
||||
exploration_config = {"epsilon_range_dict": {"_all_": config.exploration.epsilon_range},
|
||||
"split_point_dict": {"_all_": config.exploration.split_point},
|
||||
"with_cache": config.exploration.with_cache
|
||||
}
|
||||
exploration_config = {
|
||||
"epsilon_range_dict": {"_all_": config.exploration.epsilon_range},
|
||||
"split_point_dict": {"_all_": config.exploration.split_point},
|
||||
"with_cache": config.exploration.with_cache
|
||||
}
|
||||
explorer = TwoPhaseLinearExplorer(agent_id_list, config.general.total_training_episodes, **exploration_config)
|
||||
agent_manager = DQNAgentManager(name="cim_remote_actor",
|
||||
agent_id_list=agent_id_list,
|
||||
mode=AgentMode.INFERENCE,
|
||||
state_shaper=state_shaper,
|
||||
action_shaper=action_shaper,
|
||||
experience_shaper=experience_shaper,
|
||||
explorer=explorer)
|
||||
agent_manager = DQNAgentManager(
|
||||
name="cim_remote_actor",
|
||||
agent_id_list=agent_id_list,
|
||||
mode=AgentMode.INFERENCE,
|
||||
state_shaper=state_shaper,
|
||||
action_shaper=action_shaper,
|
||||
experience_shaper=experience_shaper,
|
||||
explorer=explorer
|
||||
)
|
||||
proxy_params = {
|
||||
"group_name": config.distributed.group_name,
|
||||
"expected_peers": config.distributed.actor.peer,
|
||||
"redis_address": (config.distributed.redis.host_name, config.distributed.redis.port),
|
||||
"max_retries": 10
|
||||
}
|
||||
actor_worker = ActorWorker(local_actor=SimpleActor(env=env, inference_agents=agent_manager),
|
||||
proxy_params=proxy_params)
|
||||
actor_worker = ActorWorker(
|
||||
local_actor=SimpleActor(env=env, inference_agents=agent_manager),
|
||||
proxy_params=proxy_params
|
||||
)
|
||||
actor_worker.launch()
|
||||
|
|
|
@ -3,24 +3,29 @@
|
|||
|
||||
import os
|
||||
|
||||
from components.agent_manager import DQNAgentManager
|
||||
from components.config import config
|
||||
from components.state_shaper import CIMStateShaper
|
||||
from maro.rl import ActorProxy, AgentMode, SimpleLearner, TwoPhaseLinearExplorer
|
||||
from maro.simulator import Env
|
||||
from maro.utils import Logger
|
||||
|
||||
from components import CIMStateShaper, DQNAgentManager, config
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = Env(config.env.scenario, config.env.topology, durations=config.env.durations)
|
||||
agent_id_list = [str(agent_id) for agent_id in env.agent_idx_list]
|
||||
state_shaper = CIMStateShaper(**config.state_shaping)
|
||||
exploration_config = {"epsilon_range_dict": {"_all_": config.exploration.epsilon_range},
|
||||
"split_point_dict": {"_all_": config.exploration.split_point},
|
||||
"with_cache": config.exploration.with_cache
|
||||
}
|
||||
exploration_config = {
|
||||
"epsilon_range_dict": {"_all_": config.exploration.epsilon_range},
|
||||
"split_point_dict": {"_all_": config.exploration.split_point},
|
||||
"with_cache": config.exploration.with_cache
|
||||
}
|
||||
explorer = TwoPhaseLinearExplorer(agent_id_list, config.general.total_training_episodes, **exploration_config)
|
||||
agent_manager = DQNAgentManager(name="cim_remote_learner", agent_id_list=agent_id_list, mode=AgentMode.TRAIN,
|
||||
state_shaper=state_shaper, explorer=explorer)
|
||||
agent_manager = DQNAgentManager(
|
||||
name="cim_remote_learner",
|
||||
agent_id_list=agent_id_list,
|
||||
mode=AgentMode.TRAIN,
|
||||
state_shaper=state_shaper,
|
||||
explorer=explorer
|
||||
)
|
||||
|
||||
proxy_params = {
|
||||
"group_name": config.distributed.group_name,
|
||||
|
@ -28,9 +33,15 @@ if __name__ == "__main__":
|
|||
"redis_address": (config.distributed.redis.host_name, config.distributed.redis.port),
|
||||
"max_retries": 10
|
||||
}
|
||||
learner = SimpleLearner(trainable_agents=agent_manager,
|
||||
actor=ActorProxy(proxy_params=proxy_params),
|
||||
logger=Logger("distributed_cim_learner", auto_timestamp=False))
|
||||
learner = SimpleLearner(
|
||||
trainable_agents=agent_manager,
|
||||
actor=ActorProxy(proxy_params=proxy_params),
|
||||
logger=Logger(
|
||||
tag="distributed_cim_learner",
|
||||
dump_folder=os.path.join(os.path.split(os.path.realpath(__file__))[0], "log"),
|
||||
auto_timestamp=False
|
||||
)
|
||||
)
|
||||
learner.train(total_episodes=config.general.total_training_episodes)
|
||||
learner.test()
|
||||
learner.dump_models(os.path.join(os.getcwd(), "models"))
|
||||
|
|
|
@ -7,16 +7,16 @@ This script is used to debug distributed algorithm in single host multi-process
|
|||
|
||||
import os
|
||||
|
||||
from .components.config import config
|
||||
from components import config
|
||||
|
||||
ACTOR_NUM = config.distributed.learner.peer["actor"] # must be same as in config
|
||||
LEARNER_NUM = config.distributed.actor.peer["learner"]
|
||||
|
||||
learner_path = f"{os.path.split(os.path.realpath(__file__))[0]}/dist_learner.py &"
|
||||
actor_path = f"{os.path.split(os.path.realpath(__file__))[0]}/dist_actor.py &"
|
||||
learner_path = f"{os.path.split(os.path.realpath(__file__))[0]}/dist_learner.py"
|
||||
actor_path = f"{os.path.split(os.path.realpath(__file__))[0]}/dist_actor.py"
|
||||
|
||||
for l_num in range(LEARNER_NUM):
|
||||
os.system(f"python {learner_path}")
|
||||
os.system(f"python {learner_path} &")
|
||||
|
||||
for a_num in range(ACTOR_NUM):
|
||||
os.system(f"python {actor_path}")
|
||||
os.system(f"python {actor_path} &")
|
||||
|
|
|
@ -1,20 +1,17 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from components.action_shaper import CIMActionShaper
|
||||
from components.agent_manager import DQNAgentManager
|
||||
from components.config import config
|
||||
from components.experience_shaper import TruncatedExperienceShaper
|
||||
from components.state_shaper import CIMStateShaper
|
||||
from maro.rl import AgentMode, KStepExperienceShaper, SimpleActor, SimpleLearner, TwoPhaseLinearExplorer
|
||||
from maro.simulator import Env
|
||||
from maro.utils import Logger
|
||||
|
||||
from components import CIMActionShaper, CIMStateShaper, DQNAgentManager, TruncatedExperienceShaper, config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Step 1: initialize a CIM environment for using a toy dataset.
|
||||
env = Env(config.env.scenario, config.env.topology, durations=config.env.durations)
|
||||
|
@ -32,25 +29,35 @@ if __name__ == "__main__":
|
|||
**config.experience_shaping.k_step
|
||||
)
|
||||
|
||||
exploration_config = {"epsilon_range_dict": {"_all_": config.exploration.epsilon_range},
|
||||
"split_point_dict": {"_all_": config.exploration.split_point},
|
||||
"with_cache": config.exploration.with_cache
|
||||
}
|
||||
exploration_config = {
|
||||
"epsilon_range_dict": {"_all_": config.exploration.epsilon_range},
|
||||
"split_point_dict": {"_all_": config.exploration.split_point},
|
||||
"with_cache": config.exploration.with_cache
|
||||
}
|
||||
explorer = TwoPhaseLinearExplorer(agent_id_list, config.general.total_training_episodes, **exploration_config)
|
||||
|
||||
# Step 3: create an agent manager.
|
||||
agent_manager = DQNAgentManager(name="cim_learner",
|
||||
mode=AgentMode.TRAIN_INFERENCE,
|
||||
agent_id_list=agent_id_list,
|
||||
state_shaper=state_shaper,
|
||||
action_shaper=action_shaper,
|
||||
experience_shaper=experience_shaper,
|
||||
explorer=explorer)
|
||||
agent_manager = DQNAgentManager(
|
||||
name="cim_learner",
|
||||
mode=AgentMode.TRAIN_INFERENCE,
|
||||
agent_id_list=agent_id_list,
|
||||
state_shaper=state_shaper,
|
||||
action_shaper=action_shaper,
|
||||
experience_shaper=experience_shaper,
|
||||
explorer=explorer
|
||||
)
|
||||
|
||||
# Step 4: Create an actor and a learner to start the training process.
|
||||
actor = SimpleActor(env=env, inference_agents=agent_manager)
|
||||
learner = SimpleLearner(trainable_agents=agent_manager, actor=actor,
|
||||
logger=Logger("single_host_cim_learner", auto_timestamp=False))
|
||||
learner = SimpleLearner(
|
||||
trainable_agents=agent_manager,
|
||||
actor=actor,
|
||||
logger=Logger(
|
||||
tag="single_host_cim_learner",
|
||||
dump_folder=os.path.join(os.path.split(os.path.realpath(__file__))[0], "log"),
|
||||
auto_timestamp=False
|
||||
)
|
||||
)
|
||||
|
||||
learner.train(total_episodes=config.general.total_training_episodes)
|
||||
learner.test()
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
from .actor import ParallelActor
|
||||
from .agent_manager import SimpleAgentManger
|
||||
from .learner import GNNLearner
|
||||
from .state_shaper import GNNStateShaper
|
||||
from .utils import decision_cnt_analysis, load_config, return_scaler, save_code, save_config
|
||||
|
||||
__all__ = [
|
||||
"ParallelActor",
|
||||
"SimpleAgentManger",
|
||||
"GNNLearner",
|
||||
"GNNStateShaper",
|
||||
"decision_cnt_analysis", "load_config", "return_scaler", "save_code", "save_config"
|
||||
]
|
|
@ -24,7 +24,7 @@ def organize_exp_list(experience_collections: dict, idx_mapping: dict):
|
|||
"""The function assemble the experience from multiple processes into a dictionary.
|
||||
|
||||
Args:
|
||||
experience_collections (dict): It stores the experience in all agents. The structure is the same as what is
|
||||
experience_collections (dict): It stores the experience in all agents. The structure is the same as what is
|
||||
defined in the SharedStructure in the ParallelActor except additional key for experience length. For
|
||||
example:
|
||||
|
||||
|
@ -46,7 +46,7 @@ def organize_exp_list(experience_collections: dict, idx_mapping: dict):
|
|||
example, if agent x starts at b_x in batch index and the experience is l_x length long, the range [b_x,
|
||||
l_x) in the batch is the experience of agent x.
|
||||
|
||||
idx_mapping (dict): The key is the name of each agent and the value is the starting index, e.g., b_x, of the
|
||||
idx_mapping (dict): The key is the name of each agent and the value is the starting index, e.g., b_x, of the
|
||||
storage space where the experience of the agent is stored.
|
||||
"""
|
||||
result = {}
|
|
@ -5,9 +5,10 @@ from torch import nn
|
|||
from torch.distributions import Categorical
|
||||
from torch.nn.utils import clip_grad
|
||||
|
||||
from examples.cim.gnn.utils import gnn_union
|
||||
from maro.rl import AbsAlgorithm
|
||||
|
||||
from .utils import gnn_union
|
||||
|
||||
|
||||
class ActorCritic(AbsAlgorithm):
|
||||
"""Actor-Critic algorithm in CIM problem.
|
|
@ -2,10 +2,11 @@ from collections import defaultdict
|
|||
|
||||
import numpy as np
|
||||
|
||||
from examples.cim.gnn.numpy_store import Shuffler
|
||||
from maro.rl import AbsAgent
|
||||
from maro.utils import DummyLogger
|
||||
|
||||
from .numpy_store import Shuffler
|
||||
|
||||
|
||||
class TrainableAgent(AbsAgent):
|
||||
def __init__(self, name, algorithm, experience_pool, logger=DummyLogger()):
|
|
@ -4,15 +4,16 @@ import os
|
|||
from maro.simulator import Env
|
||||
from maro.utils import Logger
|
||||
|
||||
from .actor import ParallelActor
|
||||
from .agent_manager import SimpleAgentManger
|
||||
from .learner import GNNLearner
|
||||
from .state_shaper import GNNStateShaper
|
||||
from .utils import decision_cnt_analysis, load_config, return_scaler, save_code, save_config
|
||||
from components import (
|
||||
GNNLearner, GNNStateShaper, ParallelActor, SimpleAgentManger,
|
||||
decision_cnt_analysis, load_config, return_scaler, save_code, save_config
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
config_pth = "examples/cim/gnn/config.yml"
|
||||
config = load_config(config_pth)
|
||||
real_path = os.path.split(os.path.realpath(__file__))[0]
|
||||
|
||||
config_path = os.path.join(real_path, "config.yml")
|
||||
config = load_config(config_path)
|
||||
|
||||
# Generate log path.
|
||||
date_str = datetime.datetime.now().strftime("%Y%m%d")
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import heapq
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
|
||||
import yaml
|
||||
|
@ -8,7 +9,8 @@ from maro.simulator import Env
|
|||
from maro.simulator.scenarios.citi_bike.common import Action, DecisionEvent, DecisionType
|
||||
from maro.utils import convert_dottable
|
||||
|
||||
with io.open("config.yml", "r") as in_file:
|
||||
config_path = os.path.join(os.path.split(os.path.realpath(__file__))[0], "config.yml")
|
||||
with io.open(config_path, "r") as in_file:
|
||||
raw_config = yaml.safe_load(in_file)
|
||||
config = convert_dottable(raw_config)
|
||||
|
||||
|
|
|
@ -30,4 +30,5 @@ sphinx
|
|||
recommonmark
|
||||
sphinx_rtd_theme
|
||||
jinja2
|
||||
flake8
|
||||
flake8
|
||||
editorconfig-checker
|
||||
|
|
|
@ -28,13 +28,13 @@ if __name__ == "__main__":
|
|||
cur_script_path = os.path.join(path, fn)
|
||||
|
||||
spliter = "\\" if sys.platform == "win32" else "/"
|
||||
|
||||
|
||||
module_name = ".".join(os.path.relpath(cur_script_path)[0:-3].split(spliter))
|
||||
|
||||
test_case_list.append(module_name)
|
||||
|
||||
|
||||
print("loading test cases from following module")
|
||||
|
||||
|
||||
for i, n in enumerate(test_case_list):
|
||||
print(f"{i}: {n}")
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ def run_to_end(env: Env):
|
|||
class TestEnv(unittest.TestCase):
|
||||
"""
|
||||
this test will use dummy scenario
|
||||
"""
|
||||
"""
|
||||
|
||||
def test_builtin_scenario_with_default_parameters(self):
|
||||
"""Test if the env with built-in scenario initializing correct"""
|
||||
|
@ -69,7 +69,7 @@ class TestEnv(unittest.TestCase):
|
|||
dummy_number = node_info["dummies"]["number"]
|
||||
|
||||
self.assertEqual(10, dummy_number, msg=f"dummy should contains 10 nodes, got {dummy_number}")
|
||||
|
||||
|
||||
attributes = node_info["dummies"]["attributes"]
|
||||
|
||||
# it will contains one attribute
|
||||
|
@ -96,7 +96,7 @@ class TestEnv(unittest.TestCase):
|
|||
self.assertIsNotNone(env.snapshot_list, msg="snapshot list should be None")
|
||||
|
||||
# reset should work
|
||||
|
||||
|
||||
dummies_ss = env.snapshot_list["dummies"]
|
||||
vals_before_reset = dummies_ss[env.frame_index::"val"]
|
||||
|
||||
|
@ -123,9 +123,9 @@ class TestEnv(unittest.TestCase):
|
|||
|
||||
# snapshot at 2, 5, 8, 9 ticks
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
|
||||
# NOTE: frame_index is the index of frame in snapshot list, it is 0 based, so snapshot resolution will make tick not equals to frame_index
|
||||
#
|
||||
#
|
||||
for frame_index, tick in enumerate((2, 5, 8, 9)):
|
||||
self.assertListEqual(list(states[frame_index]), [tick] * 10, msg=f"states should be {tick}")
|
||||
|
||||
|
@ -136,14 +136,14 @@ class TestEnv(unittest.TestCase):
|
|||
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0, durations=max_tick, max_snapshots=2)
|
||||
|
||||
run_to_end(env)
|
||||
run_to_end(env)
|
||||
|
||||
# we should have 2 snapshots totally with max_snapshots speified
|
||||
self.assertEqual(2, len(env.snapshot_list), msg="We should have 2 snapshots in memory")
|
||||
|
||||
# and only 87 and 9 in snapshot
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
|
||||
# 1st should states at tick 7
|
||||
self.assertListEqual(list(states[0]), [8] * 10, msg="1st snapshot should be at tick 8")
|
||||
|
||||
|
@ -163,7 +163,7 @@ class TestEnv(unittest.TestCase):
|
|||
|
||||
# and only 7 and 9 in snapshot
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
|
||||
# 1st should states at tick 7
|
||||
self.assertListEqual(list(states[0]), [7] * 10, msg="1st snapshot should be at tick 7")
|
||||
|
||||
|
@ -184,7 +184,7 @@ class TestEnv(unittest.TestCase):
|
|||
|
||||
# avaiable snapshot should be 7 (0-6)
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
|
||||
self.assertEqual(7, len(states), msg=f"available snapshot number should be 7, but {len(states)}")
|
||||
|
||||
# and last one should be 6
|
||||
|
|
Загрузка…
Ссылка в новой задаче