Merge V0.3 into master: update decision event logic & rl component bundle (#569)

* updated images and refined doc

* updated images

* updated CIM-AC example

* refined proxy retry logic

* call policy update only for AbsCorePolicy

* add limitation of AbsCorePolicy in Actor.collect()

* refined actor to return only experiences for policies that received new experiences

* fix MsgKey issue in rollout_manager

* fix typo in learner

* call exit function for parallel rollout manager

* update supply chain example distributed training scripts

* 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype

* reformat render

* fix supply chain business engine action type problem

* reset supply chain example render figsize from 4 to 3

* Add render to all modes of supply chain example

* fix or policy typos

* 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes

* refined parallel policy manager

* updated rl/__init__/py

* fixed lint issues and CIM local learner bugs

* deleted unwanted supply_chain test files

* revised default config for cim-dqn

* removed test_store.py as it is no longer needed

* 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms

* updated figures

* removed unwanted import

* refactored CIM-DQN example

* added MultiProcessRolloutManager and MultiProcessTrainingManager

* updated doc

* lint issue fix

* lint issue fix

* fixed import formatting

* [Feature] Prioritized Experience Replay (#355)

* added prioritized experience replay

* deleted unwanted supply_chain test files

* fixed import order

* import fix

* fixed lint issues

* fixed import formatting

* added note in docstring that rank-based PER has yet to be implemented

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* rm AbsDecisionGenerator

* small fixes

* bug fix

* reorganized training folder structure

* fixed lint issues

* fixed lint issues

* policy manager refined

* lint fix

* restructured CIM-dqn sync code

* added policy version index and used it as a measure of experience staleness

* lint issue fix

* lint issue fix

* switched log_dir and proxy_kwargs order

* cim example refinement

* eval schedule sorted only when it's a list

* eval schedule sorted only when it's a list

* update sc env wrapper

* added docker scripts for cim-dqn

* refactored example folder structure and added workflow templates

* fixed lint issues

* fixed lint issues

* fixed template bugs

* removed unused imports

* refactoring sc in progress

* simplified cim meta

* fixed build.sh path bug

* template refinement

* deleted obsolete svgs

* updated learner logs

* minor edits

* refactored templates for easy merge with async PR

* added component names for rollout manager and policy manager

* fixed incorrect position to add last episode to eval schedule

* added max_lag option in templates

* formatting edit in docker_compose_yml script

* moved local learner and early stopper outside sync_tools

* refactored rl toolkit folder structure

* refactored rl toolkit folder structure

* moved env_wrapper and agent_wrapper inside rl/learner

* refined scripts

* fixed typo in script

* changes needed for running sc

* removed unwanted imports

* config change for testing sc scenario

* changes for perf testing

* Asynchronous Training (#364)

* remote inference code draft

* changed actor to rollout_worker and updated init files

* removed unwanted import

* updated inits

* more async code

* added async scripts

* added async training code & scripts for CIM-dqn

* changed async to async_tools to avoid conflict with python keyword

* reverted unwanted change to dockerfile

* added doc for policy server

* addressed PR comments and fixed a bug in docker_compose_yml.py

* fixed lint issue

* resolved PR comment

* resolved merge conflicts

* added async templates

* added proxy.close() for actor and policy_server

* fixed incorrect position to add last episode to eval schedule

* reverted unwanted changes

* added missing async files

* rm unwanted echo in kill.sh

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* renamed sync to synchronous and async to asynchronous to avoid conflict with keyword

* added missing policy version increment in LocalPolicyManager

* refined rollout manager recv logic

* removed a debugging print

* added sleep in distributed launcher to avoid hanging

* updated api doc and rl toolkit doc

* refined dynamic imports using importlib

* 1. moved policy update triggers to policy manager; 2. added version control in policy manager

* fixed a few bugs and updated cim RL example

* fixed a few more bugs

* added agent wrapper instantiation to workflows

* added agent wrapper instantiation to workflows

* removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet

* fixed incorrect get_ac_policy signature for CIM

* moved exploration inside core policy

* added state to exploration call to support context-dependent exploration

* separated non_rl_policy_index and rl_policy_index in workflows

* modified sc example code according to workflow changes

* modified sc example code according to workflow changes

* added replay_agent_ids parameter to get_env_func for RL examples

* fixed a few bugs

* added maro/simulator/scenarios/supply_chain as bind mount

* added post-step, post-collect, post-eval and post-update callbacks

* fixed lint issues

* fixed lint issues

* moved instantiation of policy manager inside simple learner

* fixed env_wrapper get_reward signature

* minor edits

* removed get_eperience kwargs from env_wrapper

* 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows

* added rollout exp disribution option in RL examples

* removed unwanted files

* 1. made logger internal in learner; 2 removed logger creation in abs classes

* checked out supply chain test files from v0.2_sc

* 1. added missing model.eval() to choose_action; 2.added entropy features to AC

* fixed a bug in ac entropy

* abbreviated coefficient to coeff

* removed -dqn from job name in rl example config

* added tmp patch to dev.df

* renamed image name for running rl examples

* added get_loss interface for core policies

* added policy manager in rl_toolkit.rst

* 1. env_wrapper bug fix; 2. policy manager update logic refinement

* refactored policy and algorithms

* policy interface redesigned

* refined policy interfaces

* fixed typo

* fixed bugs in refactored policy interface

* fixed some bugs

* refactoring in progress

* policy interface and policy manager redesigned

* 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts

* fixed bug in distributed policy manager

* fixed lint issues

* fixed lint issues

* added scipy in setup

* 1. trimmed rollout manager code; 2. added option to docker scripts

* updated api doc for policy manager

* 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script

* 1. simplified rl example structure; 2. fixed lint issues

* further rl toolkit code simplifications

* more numpy-based optimization in RL toolkit

* moved replay buffer inside policy

* bug fixes

* numpy optimization and associated refactoring

* extracted shaping logic out of env_sampler

* fixed bug in CIM shaping and lint issues

* preliminary implemetation of parallel batch inference

* fixed bug in ddpg transition recording

* put get_state, get_env_actions, get_reward back in EnvSampler

* simplified exploration and core model interfaces

* bug fixes and doc update

* added improve() interface for RLPolicy for single-thread support

* fixed simple policy manager bug

* updated doc, rst, notebook

* updated notebook

* fixed lint issues

* fixed entropy bugs in ac.py

* reverted to simple policy manager as default

* 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit

* fixed lint issues and updated rl toolkit images

* removed obsolete images

* added back agent2policy for general workflow use

* V0.2 rl refinement dist (#377)

* Support `slice` operation in ExperienceSet

* Support naive distributed policy training by proxy

* Dynamically allocate trainers according to number of experience

* code check

* code check

* code check

* Fix a bug in distributed trianing with no gradient

* Code check

* Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy

* 1.call allocate_trainer() at first of update(); 2.refine according to code review

* Code check

* Refine code with new interface

* Update docs of PolicyManger and ExperienceSet

* Add images for rl_toolkit docs

* Update diagram of PolicyManager

* Refine with new interface

* Extract allocation strategy into `allocation_strategy.py`

* add `distributed_learn()` in policies for data-parallel training

* Update doc of RL_toolkit

* Add gradient workers for data-parallel

* Refine code and update docs

* Lint check

* Refine by comments

* Rename `trainer` to `worker`

* Rename `distributed_learn` to `learn_with_data_parallel`

* Refine allocator and remove redundant code in policy_manager

* remove arugments in allocate_by_policy and so on

* added checkpointing for simple and multi-process policy managers

* 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager

* added missing set_state and get_state for CIM policies

* removed blank line

* updated RL workflow README

* Integrate `data_parallel` arguments into `worker_allocator` (#402)

* 1. simplified workflow config; 2. added comments to CIM shaping

* lint issue fix

* 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading

* 1. moved post_step callback inside env sampler; 2. updated README for rl workflows

* refined READEME for CIM

* VM scheduling with RL (#375)

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* added DQN

* added get_experiences func for ac in vm scheduling

* added post_step callback to env wrapper

* moved Aiming's tracking and plotting logic into callbacks

* added eval env wrapper

* renamed AC config variable name for VM

* vm scheduling RL code finished

* updated README

* fixed various bugs and hard coding for vm_scheduling

* uncommented callbacks for VM scheduling

* Minor revision for better code style

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* vm example refactoring

* fixed bugs in vm_scheduling

* removed unwanted files from cim dir

* reverted to simple policy manager as default

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* resolved rebase conflicts

* fixed bugs in vm_scheduling

* added get_state and set_state to vm_scheduling policy models

* updated README for vm_scheduling with RL

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* SC refinement (#397)

* Refine test scripts & pending_order_daily logic

* Refactor code for better code style: complete type hint, correct typos, remove unused items.

Refactor code for better code style: complete type hint, correct typos, remove unused items.

* Polish test_supply_chain.py

* update import format

* Modify vehicle steps logic & remove outdated test case

* Optimize imports

* Optimize imports

* Lint error

* Lint error

* Lint error

* Add SupplyChainAction

* Lint error

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* refined workflow scripts

* fixed bug in ParallelAgentWrapper

* 1. fixed lint issues; 2. refined main script in workflows

* lint issue fix

* restored default config for rl example

* Update rollout.py

* refined env var processing in policy manager workflow

* added hasattr check in agent wrapper

* updated docker_compose_yml.py

* Minor refinement

* Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412)

* Prepare to merge master_mirror

* Lint error

* Minor

* Merge latest master into v0.3 (#426)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* Minor

* Remove docs/source/examples/multi_agent_dqn_cim.rst

* Update .gitignore

* Update .gitignore

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>

* Change `Env.set_seed()` logic (#456)

* Change Env.set_seed() logic

* Redesign CIM reset logic; fix lint issues;

* Lint

* Seed type assertion

* Remove all SC related files (#473)

* RL Toolkit V3 (#471)

* added daemon=True for multi-process rollout, policy manager and inference

* removed obsolete files

* [REDO][PR#406]V0.2 rl refinement taskq (#408)

* Add a usable task_queue

* Rename some variables

* 1. Add ; 2. Integrate  related files; 3. Remove

* merge `data_parallel` and `num_grad_workers` into `data_parallelism`

* Fix bugs in docker_compose_yml.py and Simple/Multi-process mode.

* Move `grad_worker` into marl/rl/workflows

* 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue.

* Refine code and update docs of `TaskQueue`

* Support priority for tasks in `task_queue`

* Update diagram of policy manager and task queue.

* Add configurable `single_task_limit` and correct docstring about `data_parallelism`

* Fix lint errors in `supply chain`

* RL policy redesign (V2) (#405)

* Drafi v2.0 for V2

* Polish models with more comments

* Polish policies with more comments

* Lint

* Lint

* Add developer doc for models.

* Add developer doc for policies.

* Remove policy manager V2 since it is not used and out-of-date

* Lint

* Lint

* refined messy workflow code

* merged 'scenario_dir' and 'scenario' in rl config

* 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods

* 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2

* merged cim and cim_v2

* lint issue fix

* refined logging logic

* lint issue fix

* reversed unwanted changes

* .

.

.

.

ReplayMemory & IndexScheduler

ReplayMemory & IndexScheduler

.

MultiReplayMemory

get_actions_with_logps

EnvSampler on the road

EnvSampler

Minor

* LearnerManager

* Use batch to transfer data & add SHAPE_CHECK_FLAG

* Rename learner to trainer

* Add property for policy._is_exploring

* CIM test scenario for V3. Manual test passed. Next step: run it, make it works.

* env_sampler.py could run

* env_sampler refine on the way

* First runnable version done

* AC could run, but the result is bad. Need to check the logic

* Refine abstract method & shape check error info.

* Docs

* Very detailed compare. Try again.

* AC done

* DQN check done

* Minor

* DDPG, not tested

* Minors

* A rough draft of MAAC

* Cannot use CIM as the multi-agent scenario.

* Minor

* MAAC refinement on the way

* Remove ActionWithAux

* Refine batch & memory

* MAAC example works

* Reproduce-able fix. Policy share between env_sampler and trainer_manager.

* Detail refinement

* Simplify the user configed workflow

* Minor

* Refine example codes

* Minor polishment

* Migrate rollout_manager to V3

* Error on the way

* Redesign torch.device management

* Rl v3 maddpg (#418)

* Add MADDPG trainer

* Fit independent critics and shared critic modes.

* Add a new property: num_policies

* Lint

* Fix a bug in `sum(rewards)`

* Rename `MADDPG` to `DiscreteMADDPG` and fix type hint.

* Rename maddpg in examples.

* Preparation for data parallel (#420)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* Rename train worker to train ops; add placeholder for abstract methods;

* Lint

Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>

* [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* dsitributed training pipeline draft

* added temporary test files for review purposes

* Several code style refinements (#451)

* Polish rl_v3/utils/

* Polish rl_v3/distributed/

* Polish rl_v3/policy_trainer/abs_trainer.py

* fixed merge conflicts

* unified sync and async interfaces

* refactored rl_v3; refinement in progress

* Finish the runnable pipeline under new design

* Remove outdated files; refine class names; optimize imports;

* Lint

* Minor maddpg related refinement

* Lint

Co-authored-by: Default <huo53926@126.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Miner bug fix

* Coroutine-related bug fix ("get_policy_state") (#452)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* deleted unwanted folder

* removed unwanted changes

* resolved PR452 comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Quick fix

* Redesign experience recording logic (#453)

* Two not important fix

* Temp draft. Prepare to WFH

* Done

* Lint

* Lint

* Calculating advantages / returns (#454)

* V1.0

* Complete DDPG

* Rl v3 hanging issue fix (#455)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* Final test & format. Ready to merge.

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* Rl v3 parallel rollout (#457)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* load balancing dispatcher

* added parallel rollout

* lint

* Tracker variable type issue; rename to env_sampler_creator;

* Rl v3 parallel rollout follow ups (#458)

* AbsWorker & AbsDispatcher

* Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds.

* Fix policy_creator reuse bug

* Format code

* Merge AbsTrainerManager & SimpleTrainerManager

* AC test passed

* Lint

* Remove AbsTrainer.build() method. Put all initialization operations into __init__

* Redesign AC preprocess batches logic

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* MADDPG performance bug fix (#459)

* Fix MARL (MADDPG) terminal recording bug; some other minor refinements;

* Restore Trainer.build() method

* Calculate latest action in the get_actor_grad method in MADDPG.

* Share critic bug fix

* Rl v3 example update (#461)

* updated vm_scheduling example and cim notebook

* fixed bugs in vm_scheduling

* added local train method

* bug fix

* modified async client logic to fix hidden issue

* reverted to default config

* fixed PR comments and some bugs

* removed hardcode

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Done (#462)

* Rl v3 load save (#463)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL Toolkit data parallelism revamp & config utils (#464)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

* 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local

* 1. fixed rollout exit issue; 2. refined config

* removed config file from example

* fixed lint issues

* fixed lint issues

* added main.py under examples/rl

* fixed lint issues

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL doc string (#465)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* Rl config doc (#467)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* resolved more PR comments

* typo fix

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL online doc (#469)

* Model, policy, trainer

* RL workflows and env sampler doc in RST (#468)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* rewriting rl toolkit rst

* resolved more PR comments

* typo fix

* updated rst

Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: Default <huo53926@126.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Finish docs/source/key_components/rl_toolkit.rst

* API doc

* RL online doc image fix (#470)

* resolved some PR comments

* fix

* fixed PR comments

* added numfig=True setting in conf.py for sphinx

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Resolve PR comments

* Add example github link

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Rl v3 pr comment resolution (#474)

* added load/save feature

* 1. resolved pr comments; 2. reverted maro/cli/k8s

* fixed some bugs

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* RL renaming v2 (#476)

* Change all Logger in RL to LoggerV2

* TrainerManager => TrainingManager

* Add Trainer suffix to all algorithms

* Finish docs

* Update interface names

* Minor fix

* Cherry pick latest RL (#498)

* Cherry pick

* Remove SC related files

* Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509)

* Cherry pick RL changes from sc_refinement (2a4869)

* Limit time display precision

* RL incremental refactor (#501)

* Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training.

AC & PPO for continuous action policy; refine AC & PPO logic.

Cherry pick RL changes from GYM-DDPG

Cherry pick RL changes from GYM-SAC

Minor error in doc string

* Add min_n_sample in template and parser

* Resolve PR comments. Fix a minor issue in SAC.

* RL component bundle (#513)

* CIM passed

* Update workers

* Refine annotations

* VM passed

* Code formatting.

* Minor import loop issue

* Pass batch in PPO again

* Remove Scenario

* Complete docs

* Minor

* Remove segment

* Optimize logic in RLComponentBundle

* Resolve PR comments

* Move 'post methods from RLComponenetBundle to EnvSampler

* Add method to get mapping of available tick to frame index (#415)

* add method to get mapping of available tick to frame index

* fix lint issue

* fix naming issue

* Cherry pick from sc_refinement (#527)

* Cherry pick from sc_refinement

* Cherry pick from sc_refinement

* Refine `terminal` / `next_agent_state` logic (#531)

* Optimize RL toolkit

* Fix bug in terminal/next_state generation

* Rewrite terminal/next_state logic again

* Minor renaming

* Minor bug fix

* Resolve PR comments

* Merge master into v0.3 (#536)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* add branch v0.3 to github workflow

* update github test workflow

* Update requirements.dev.txt (#444)

Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact.

* Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460)

Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3.
- [Release notes](https://github.com/ipython/ipython/releases)
- [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3)

---
updated-dependencies:
- dependency-name: ipython
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Add & sort requirements.dev.txt

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Merge master into v0.3 (#545)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* add branch v0.3 to github workflow

* update github test workflow

* Update requirements.dev.txt (#444)

Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact.

* Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460)

Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3.
- [Release notes](https://github.com/ipython/ipython/releases)
- [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3)

---
updated-dependencies:
- dependency-name: ipython
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* update github woorkflow config

* MARO v0.3: a new design of RL Toolkit, CLI refactorization, and corresponding updates. (#539)

* refined proxy coding style

* updated images and refined doc

* updated images

* updated CIM-AC example

* refined proxy retry logic

* call policy update only for AbsCorePolicy

* add limitation of AbsCorePolicy in Actor.collect()

* refined actor to return only experiences for policies that received new experiences

* fix MsgKey issue in rollout_manager

* fix typo in learner

* call exit function for parallel rollout manager

* update supply chain example distributed training scripts

* 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype

* reformat render

* fix supply chain business engine action type problem

* reset supply chain example render figsize from 4 to 3

* Add render to all modes of supply chain example

* fix or policy typos

* 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes

* refined parallel policy manager

* updated rl/__init__/py

* fixed lint issues and CIM local learner bugs

* deleted unwanted supply_chain test files

* revised default config for cim-dqn

* removed test_store.py as it is no longer needed

* 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms

* updated figures

* removed unwanted import

* refactored CIM-DQN example

* added MultiProcessRolloutManager and MultiProcessTrainingManager

* updated doc

* lint issue fix

* lint issue fix

* fixed import formatting

* [Feature] Prioritized Experience Replay (#355)

* added prioritized experience replay

* deleted unwanted supply_chain test files

* fixed import order

* import fix

* fixed lint issues

* fixed import formatting

* added note in docstring that rank-based PER has yet to be implemented

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* rm AbsDecisionGenerator

* small fixes

* bug fix

* reorganized training folder structure

* fixed lint issues

* fixed lint issues

* policy manager refined

* lint fix

* restructured CIM-dqn sync code

* added policy version index and used it as a measure of experience staleness

* lint issue fix

* lint issue fix

* switched log_dir and proxy_kwargs order

* cim example refinement

* eval schedule sorted only when it's a list

* eval schedule sorted only when it's a list

* update sc env wrapper

* added docker scripts for cim-dqn

* refactored example folder structure and added workflow templates

* fixed lint issues

* fixed lint issues

* fixed template bugs

* removed unused imports

* refactoring sc in progress

* simplified cim meta

* fixed build.sh path bug

* template refinement

* deleted obsolete svgs

* updated learner logs

* minor edits

* refactored templates for easy merge with async PR

* added component names for rollout manager and policy manager

* fixed incorrect position to add last episode to eval schedule

* added max_lag option in templates

* formatting edit in docker_compose_yml script

* moved local learner and early stopper outside sync_tools

* refactored rl toolkit folder structure

* refactored rl toolkit folder structure

* moved env_wrapper and agent_wrapper inside rl/learner

* refined scripts

* fixed typo in script

* changes needed for running sc

* removed unwanted imports

* config change for testing sc scenario

* changes for perf testing

* Asynchronous Training (#364)

* remote inference code draft

* changed actor to rollout_worker and updated init files

* removed unwanted import

* updated inits

* more async code

* added async scripts

* added async training code & scripts for CIM-dqn

* changed async to async_tools to avoid conflict with python keyword

* reverted unwanted change to dockerfile

* added doc for policy server

* addressed PR comments and fixed a bug in docker_compose_yml.py

* fixed lint issue

* resolved PR comment

* resolved merge conflicts

* added async templates

* added proxy.close() for actor and policy_server

* fixed incorrect position to add last episode to eval schedule

* reverted unwanted changes

* added missing async files

* rm unwanted echo in kill.sh

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* renamed sync to synchronous and async to asynchronous to avoid conflict with keyword

* added missing policy version increment in LocalPolicyManager

* refined rollout manager recv logic

* removed a debugging print

* added sleep in distributed launcher to avoid hanging

* updated api doc and rl toolkit doc

* refined dynamic imports using importlib

* 1. moved policy update triggers to policy manager; 2. added version control in policy manager

* fixed a few bugs and updated cim RL example

* fixed a few more bugs

* added agent wrapper instantiation to workflows

* added agent wrapper instantiation to workflows

* removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet

* fixed incorrect get_ac_policy signature for CIM

* moved exploration inside core policy

* added state to exploration call to support context-dependent exploration

* separated non_rl_policy_index and rl_policy_index in workflows

* modified sc example code according to workflow changes

* modified sc example code according to workflow changes

* added replay_agent_ids parameter to get_env_func for RL examples

* fixed a few bugs

* added maro/simulator/scenarios/supply_chain as bind mount

* added post-step, post-collect, post-eval and post-update callbacks

* fixed lint issues

* fixed lint issues

* moved instantiation of policy manager inside simple learner

* fixed env_wrapper get_reward signature

* minor edits

* removed get_eperience kwargs from env_wrapper

* 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows

* added rollout exp disribution option in RL examples

* removed unwanted files

* 1. made logger internal in learner; 2 removed logger creation in abs classes

* checked out supply chain test files from v0.2_sc

* 1. added missing model.eval() to choose_action; 2.added entropy features to AC

* fixed a bug in ac entropy

* abbreviated coefficient to coeff

* removed -dqn from job name in rl example config

* added tmp patch to dev.df

* renamed image name for running rl examples

* added get_loss interface for core policies

* added policy manager in rl_toolkit.rst

* 1. env_wrapper bug fix; 2. policy manager update logic refinement

* refactored policy and algorithms

* policy interface redesigned

* refined policy interfaces

* fixed typo

* fixed bugs in refactored policy interface

* fixed some bugs

* refactoring in progress

* policy interface and policy manager redesigned

* 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts

* fixed bug in distributed policy manager

* fixed lint issues

* fixed lint issues

* added scipy in setup

* 1. trimmed rollout manager code; 2. added option to docker scripts

* updated api doc for policy manager

* 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script

* 1. simplified rl example structure; 2. fixed lint issues

* further rl toolkit code simplifications

* more numpy-based optimization in RL toolkit

* moved replay buffer inside policy

* bug fixes

* numpy optimization and associated refactoring

* extracted shaping logic out of env_sampler

* fixed bug in CIM shaping and lint issues

* preliminary implemetation of parallel batch inference

* fixed bug in ddpg transition recording

* put get_state, get_env_actions, get_reward back in EnvSampler

* simplified exploration and core model interfaces

* bug fixes and doc update

* added improve() interface for RLPolicy for single-thread support

* fixed simple policy manager bug

* updated doc, rst, notebook

* updated notebook

* fixed lint issues

* fixed entropy bugs in ac.py

* reverted to simple policy manager as default

* 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit

* fixed lint issues and updated rl toolkit images

* removed obsolete images

* added back agent2policy for general workflow use

* V0.2 rl refinement dist (#377)

* Support `slice` operation in ExperienceSet

* Support naive distributed policy training by proxy

* Dynamically allocate trainers according to number of experience

* code check

* code check

* code check

* Fix a bug in distributed trianing with no gradient

* Code check

* Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy

* 1.call allocate_trainer() at first of update(); 2.refine according to code review

* Code check

* Refine code with new interface

* Update docs of PolicyManger and ExperienceSet

* Add images for rl_toolkit docs

* Update diagram of PolicyManager

* Refine with new interface

* Extract allocation strategy into `allocation_strategy.py`

* add `distributed_learn()` in policies for data-parallel training

* Update doc of RL_toolkit

* Add gradient workers for data-parallel

* Refine code and update docs

* Lint check

* Refine by comments

* Rename `trainer` to `worker`

* Rename `distributed_learn` to `learn_with_data_parallel`

* Refine allocator and remove redundant code in policy_manager

* remove arugments in allocate_by_policy and so on

* added checkpointing for simple and multi-process policy managers

* 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager

* added missing set_state and get_state for CIM policies

* removed blank line

* updated RL workflow README

* Integrate `data_parallel` arguments into `worker_allocator` (#402)

* 1. simplified workflow config; 2. added comments to CIM shaping

* lint issue fix

* 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading

* 1. moved post_step callback inside env sampler; 2. updated README for rl workflows

* refined READEME for CIM

* VM scheduling with RL (#375)

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* added DQN

* added get_experiences func for ac in vm scheduling

* added post_step callback to env wrapper

* moved Aiming's tracking and plotting logic into callbacks

* added eval env wrapper

* renamed AC config variable name for VM

* vm scheduling RL code finished

* updated README

* fixed various bugs and hard coding for vm_scheduling

* uncommented callbacks for VM scheduling

* Minor revision for better code style

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* vm example refactoring

* fixed bugs in vm_scheduling

* removed unwanted files from cim dir

* reverted to simple policy manager as default

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* resolved rebase conflicts

* fixed bugs in vm_scheduling

* added get_state and set_state to vm_scheduling policy models

* updated README for vm_scheduling with RL

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* SC refinement (#397)

* Refine test scripts & pending_order_daily logic

* Refactor code for better code style: complete type hint, correct typos, remove unused items.

Refactor code for better code style: complete type hint, correct typos, remove unused items.

* Polish test_supply_chain.py

* update import format

* Modify vehicle steps logic & remove outdated test case

* Optimize imports

* Optimize imports

* Lint error

* Lint error

* Lint error

* Add SupplyChainAction

* Lint error

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* refined workflow scripts

* fixed bug in ParallelAgentWrapper

* 1. fixed lint issues; 2. refined main script in workflows

* lint issue fix

* restored default config for rl example

* Update rollout.py

* refined env var processing in policy manager workflow

* added hasattr check in agent wrapper

* updated docker_compose_yml.py

* Minor refinement

* Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412)

* Prepare to merge master_mirror

* Lint error

* Minor

* Merge latest master into v0.3 (#426)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* Minor

* Remove docs/source/examples/multi_agent_dqn_cim.rst

* Update .gitignore

* Update .gitignore

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>

* Change `Env.set_seed()` logic (#456)

* Change Env.set_seed() logic

* Redesign CIM reset logic; fix lint issues;

* Lint

* Seed type assertion

* Remove all SC related files (#473)

* RL Toolkit V3 (#471)

* added daemon=True for multi-process rollout, policy manager and inference

* removed obsolete files

* [REDO][PR#406]V0.2 rl refinement taskq (#408)

* Add a usable task_queue

* Rename some variables

* 1. Add ; 2. Integrate  related files; 3. Remove

* merge `data_parallel` and `num_grad_workers` into `data_parallelism`

* Fix bugs in docker_compose_yml.py and Simple/Multi-process mode.

* Move `grad_worker` into marl/rl/workflows

* 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue.

* Refine code and update docs of `TaskQueue`

* Support priority for tasks in `task_queue`

* Update diagram of policy manager and task queue.

* Add configurable `single_task_limit` and correct docstring about `data_parallelism`

* Fix lint errors in `supply chain`

* RL policy redesign (V2) (#405)

* Drafi v2.0 for V2

* Polish models with more comments

* Polish policies with more comments

* Lint

* Lint

* Add developer doc for models.

* Add developer doc for policies.

* Remove policy manager V2 since it is not used and out-of-date

* Lint

* Lint

* refined messy workflow code

* merged 'scenario_dir' and 'scenario' in rl config

* 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods

* 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2

* merged cim and cim_v2

* lint issue fix

* refined logging logic

* lint issue fix

* reversed unwanted changes

* .

.

.

.

ReplayMemory & IndexScheduler

ReplayMemory & IndexScheduler

.

MultiReplayMemory

get_actions_with_logps

EnvSampler on the road

EnvSampler

Minor

* LearnerManager

* Use batch to transfer data & add SHAPE_CHECK_FLAG

* Rename learner to trainer

* Add property for policy._is_exploring

* CIM test scenario for V3. Manual test passed. Next step: run it, make it works.

* env_sampler.py could run

* env_sampler refine on the way

* First runnable version done

* AC could run, but the result is bad. Need to check the logic

* Refine abstract method & shape check error info.

* Docs

* Very detailed compare. Try again.

* AC done

* DQN check done

* Minor

* DDPG, not tested

* Minors

* A rough draft of MAAC

* Cannot use CIM as the multi-agent scenario.

* Minor

* MAAC refinement on the way

* Remove ActionWithAux

* Refine batch & memory

* MAAC example works

* Reproduce-able fix. Policy share between env_sampler and trainer_manager.

* Detail refinement

* Simplify the user configed workflow

* Minor

* Refine example codes

* Minor polishment

* Migrate rollout_manager to V3

* Error on the way

* Redesign torch.device management

* Rl v3 maddpg (#418)

* Add MADDPG trainer

* Fit independent critics and shared critic modes.

* Add a new property: num_policies

* Lint

* Fix a bug in `sum(rewards)`

* Rename `MADDPG` to `DiscreteMADDPG` and fix type hint.

* Rename maddpg in examples.

* Preparation for data parallel (#420)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* Rename train worker to train ops; add placeholder for abstract methods;

* Lint

Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>

* [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* dsitributed training pipeline draft

* added temporary test files for review purposes

* Several code style refinements (#451)

* Polish rl_v3/utils/

* Polish rl_v3/distributed/

* Polish rl_v3/policy_trainer/abs_trainer.py

* fixed merge conflicts

* unified sync and async interfaces

* refactored rl_v3; refinement in progress

* Finish the runnable pipeline under new design

* Remove outdated files; refine class names; optimize imports;

* Lint

* Minor maddpg related refinement

* Lint

Co-authored-by: Default <huo53926@126.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Miner bug fix

* Coroutine-related bug fix ("get_policy_state") (#452)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* deleted unwanted folder

* removed unwanted changes

* resolved PR452 comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Quick fix

* Redesign experience recording logic (#453)

* Two not important fix

* Temp draft. Prepare to WFH

* Done

* Lint

* Lint

* Calculating advantages / returns (#454)

* V1.0

* Complete DDPG

* Rl v3 hanging issue fix (#455)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* Final test & format. Ready to merge.

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* Rl v3 parallel rollout (#457)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* load balancing dispatcher

* added parallel rollout

* lint

* Tracker variable type issue; rename to env_sampler_creator;

* Rl v3 parallel rollout follow ups (#458)

* AbsWorker & AbsDispatcher

* Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds.

* Fix policy_creator reuse bug

* Format code

* Merge AbsTrainerManager & SimpleTrainerManager

* AC test passed

* Lint

* Remove AbsTrainer.build() method. Put all initialization operations into __init__

* Redesign AC preprocess batches logic

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* MADDPG performance bug fix (#459)

* Fix MARL (MADDPG) terminal recording bug; some other minor refinements;

* Restore Trainer.build() method

* Calculate latest action in the get_actor_grad method in MADDPG.

* Share critic bug fix

* Rl v3 example update (#461)

* updated vm_scheduling example and cim notebook

* fixed bugs in vm_scheduling

* added local train method

* bug fix

* modified async client logic to fix hidden issue

* reverted to default config

* fixed PR comments and some bugs

* removed hardcode

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Done (#462)

* Rl v3 load save (#463)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL Toolkit data parallelism revamp & config utils (#464)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

* 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local

* 1. fixed rollout exit issue; 2. refined config

* removed config file from example

* fixed lint issues

* fixed lint issues

* added main.py under examples/rl

* fixed lint issues

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL doc string (#465)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* Rl config doc (#467)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* resolved more PR comments

* typo fix

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL online doc (#469)

* Model, policy, trainer

* RL workflows and env sampler doc in RST (#468)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* rewriting rl toolkit rst

* resolved more PR comments

* typo fix

* updated rst

Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: Default <huo53926@126.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Finish docs/source/key_components/rl_toolkit.rst

* API doc

* RL online doc image fix (#470)

* resolved some PR comments

* fix

* fixed PR comments

* added numfig=True setting in conf.py for sphinx

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Resolve PR comments

* Add example github link

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Rl v3 pr comment resolution (#474)

* added load/save feature

* 1. resolved pr comments; 2. reverted maro/cli/k8s

* fixed some bugs

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* RL renaming v2 (#476)

* Change all Logger in RL to LoggerV2

* TrainerManager => TrainingManager

* Add Trainer suffix to all algorithms

* Finish docs

* Update interface names

* Minor fix

* Cherry pick latest RL (#498)

* Cherry pick

* Remove SC related files

* Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509)

* Cherry pick RL changes from sc_refinement (2a4869)

* Limit time display precision

* RL incremental refactor (#501)

* Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training.

AC & PPO for continuous action policy; refine AC & PPO logic.

Cherry pick RL changes from GYM-DDPG

Cherry pick RL changes from GYM-SAC

Minor error in doc string

* Add min_n_sample in template and parser

* Resolve PR comments. Fix a minor issue in SAC.

* RL component bundle (#513)

* CIM passed

* Update workers

* Refine annotations

* VM passed

* Code formatting.

* Minor import loop issue

* Pass batch in PPO again

* Remove Scenario

* Complete docs

* Minor

* Remove segment

* Optimize logic in RLComponentBundle

* Resolve PR comments

* Move 'post methods from RLComponenetBundle to EnvSampler

* Add method to get mapping of available tick to frame index (#415)

* add method to get mapping of available tick to frame index

* fix lint issue

* fix naming issue

* Cherry pick from sc_refinement (#527)

* Cherry pick from sc_refinement

* Cherry pick from sc_refinement

* Refine `terminal` / `next_agent_state` logic (#531)

* Optimize RL toolkit

* Fix bug in terminal/next_state generation

* Rewrite terminal/next_state logic again

* Minor renaming

* Minor bug fix

* Resolve PR comments

* Merge master into v0.3 (#536)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* add branch v0.3 to github workflow

* update github test workflow

* Update requirements.dev.txt (#444)

Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact.

* Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460)

Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3.
- [Release notes](https://github.com/ipython/ipython/releases)
- [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3)

---
updated-dependencies:
- dependency-name: ipython
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Add & sort requirements.dev.txt

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Remove random_config.py

* Remove test_trajectory_utils.py

* Pass tests

* Update rl docs

* Remove python 3.6 in test

* Update docs

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: Wang.Jinyu <jinywan@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: GQ.Chen <675865907@qq.com>
Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: Chaos Yu <chaos.you@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Logger bug hotfix (#543)

* Rename param

* Rename param

* Quick fix in env_data_process

* frame data precision issue fix (#544)

* fix frame precision issue

* add .xmake to .gitignore

* update frame precision lost warning message

* add assert to frame precision checking

* typo fix

* add TODO for future Long data type issue fix

* Minor cleaning

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jinyu Wang <jinyu@RL4Inv.l1ea1prscrcu1p4sa0eapum5vc.bx.internal.cloudapp.net>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: GQ.Chen <675865907@qq.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: Chaos Yu <chaos.you@gmail.com>

* Update requirements. (#552)

* Fix several encoding issues; update requirements.

* Test & minor

* Remove torch in requirements.build.txt

* Polish

* Update README

* Resolve PR comments

* Keep working

* Keep working

* Update test requirements

* Done (#554)

* Update requirements in example and notebook (#553)

* Update requirements in example and notebook

* Remove autopep8

* Add jupyterlab packages back

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>

* Refine decision event logic (#559)

* Add DecisionEventPayload

* Change decision payload name

* Refine action logic

* Add doc for env.step

* Restore pre-commit config

* Resolve PR comments

* Refactor decision event & action

* Pre-commit

* Resolve PR comments

* Refine rl component bundle (#549)

* Config files

* Done

* Minor bugfix

* Add autoflake

* Update isort exclude; add pre-commit to requirements

* Check only isort

* Minor

* Format

* Test passed

* Run pre-commit

* Minor bugfix in rl_component_bundle

* Pass mypy

* Fix a bug in RL notebook

* A minor bug fix

* Add upper bound for numpy version in test

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: GQ.Chen <675865907@qq.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: Huoran Li <huo53926@126.com>
Co-authored-by: Chaos Yu <chaos.you@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jinyu Wang <jinyu@RL4Inv.l1ea1prscrcu1p4sa0eapum5vc.bx.internal.cloudapp.net>
This commit is contained in:
Jinyu-W 2022-12-27 17:15:46 +08:00 коммит произвёл GitHub
Родитель 38eb389df1
Коммит 6512879608
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
81 изменённых файлов: 1975 добавлений и 1874 удалений

Просмотреть файл

@ -121,14 +121,16 @@ of user-defined functions for message auto-handling, cluster provision, and job
```sh
# Install MARO from source.
bash scripts/install_maro.sh
bash scripts/install_maro.sh;
pip install -r ./requirements.dev.txt;
```
- Windows
```powershell
# Install MARO from source.
.\scripts\install_maro.bat
.\scripts\install_maro.bat;
pip install -r ./requirements.dev.txt;
```
- *Notes: If your package is not found, remember to set your PYTHONPATH*

Просмотреть файл

@ -326,236 +326,84 @@ In CIM scenario, there are 3 node types:
port
++++
capacity
********
type: int
slots: 1
The capacity of port for stocking containers.
empty
*****
type: int
slots: 1
Empty container volume on the port.
full
****
type: int
slots: 1
Laden container volume on the port.
on_shipper
**********
type: int
slots: 1
Empty containers, which are released to the shipper.
on_consignee
************
type: int
slots: 1
Laden containers, which are delivered to the consignee.
shortage
********
type: int
slots: 1
Per tick state. Shortage of empty container at current tick.
acc_storage
***********
type: int
slots: 1
Accumulated shortage number to the current tick.
booking
*******
type: int
slots: 1
Per tick state. Order booking number of a port at the current tick.
acc_booking
***********
type: int
slots: 1
Accumulated order booking number of a port to the current tick.
fulfillment
***********
type: int
slots: 1
Fulfilled order number of a port at the current tick.
acc_fulfillment
***************
type: int
slots: 1
Accumulated fulfilled order number of a port to the current tick.
transfer_cost
*************
type: float
slots: 1
Cost of transferring container, which also covers loading and discharging cost.
+------------------+-------+--------+----------------------------------------------------------------------------------+
| Field | Type | Slots | Description |
+==================+=======+========+==================================================================================+
| capacity | int | 1 | The capacity of port for stocking containers. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| empty | int | 1 | Empty container volume on the port. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| full | int | 1 | Laden container volume on the port. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| on_shipper | int | 1 | Empty containers, which are released to the shipper. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| on_consignee | int | 1 | Laden containers, which are delivered to the consignee. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| shortage | int | 1 | Per tick state. Shortage of empty container at current tick. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| acc_storage | int | 1 | Accumulated shortage number to the current tick. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| booking | int | 1 | Per tick state. Order booking number of a port at the current tick. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| acc_booking | int | 1 | Accumulated order booking number of a port to the current tick. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| fulfillment | int | 1 | Fulfilled order number of a port at the current tick. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| acc_fulfillment | int | 1 | Accumulated fulfilled order number of a port to the current tick. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
| transfer_cost | float | 1 | Cost of transferring container, which also covers loading and discharging cost. |
+------------------+-------+--------+----------------------------------------------------------------------------------+
vessel
++++++
capacity
********
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Field | Type | Slots | Description |
+========================+========+==========+=================================================================================================================================================================================================================================+
| capacity | int | 1 | The capacity of vessel for transferring containers. NOTE: This attribute is ignored in current implementation. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| empty | int | 1 | Empty container volume on the vessel. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| full | int | 1 | Laden container volume on the vessel. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| remaining_space | int | 1 | Remaining space of the vessel. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| early_discharge | int | 1 | Discharged empty container number for loading laden containers. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| is_parking | short | 1 | Is parking or not |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| loc_port_idx | int | 1 | The port index the vessel is parking at. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| route_idx | int | 1 | Which route current vessel belongs to. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| last_loc_idx | int | 1 | Last stop port index in route, it is used to identify where is current vessel. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| next_loc_idx | int | 1 | Next stop port index in route, it is used to identify where is current vessel. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| past_stop_list | int | dynamic | NOTE: This and following attribute are special, that its slot number is determined by configuration, but different with a list attribute, its slot number is fixed at runtime. Stop indices that we have stopped in the past. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| past_stop_tick_list | int | dynamic | Ticks that we stopped at the port in the past. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| future_stop_list | int | dynamic | Stop indices that we will stop in the future. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| future_stop_tick_list | int | dynamic | Ticks that we will stop in the future. |
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
type: int
slots: 1
The capacity of vessel for transferring containers.
NOTE:
This attribute is ignored in current implementation.
empty
*****
type: int
slots: 1
Empty container volume on the vessel.
full
****
type: int
slots: 1
Laden container volume on the vessel.
remaining_space
***************
type: int
slots: 1
Remaining space of the vessel.
early_discharge
***************
type: int
slots: 1
Discharged empty container number for loading laden containers.
route_idx
*********
type: int
slots: 1
Which route current vessel belongs to.
last_loc_idx
************
type: int
slots: 1
Last stop port index in route, it is used to identify where is current vessel.
next_loc_idx
************
type: int
slots: 1
Next stop port index in route, it is used to identify where is current vessel.
past_stop_list
**************
type: int
slots: dynamic
NOTE:
This and following attribute are special, that its slot number is determined by configuration,
but different with a list attribute, its slot number is fixed at runtime.
Stop indices that we have stopped in the past.
past_stop_tick_list
*******************
type: int
slots: dynamic
Ticks that we stopped at the port in the past.
future_stop_list
****************
type: int
slots: dynamic
Stop indices that we will stop in the future.
future_stop_tick_list
*********************
type: int
slots: dynamic
Ticks that we will stop in the future.
matrices
++++++++
Matrices node is used to store big matrix for ports, vessels and containers.
full_on_ports
*************
type: int
slots: port number * port number
Distribution of full from port to port.
full_on_vessels
***************
type: int
slots: vessel number * port number
Distribution of full from vessel to port.
vessel_plans
************
type: int
slots: vessel number * port number
Planed route info for vessels.
+------------------+-------+------------------------------+---------------------------------------------+
| Field | Type | Slots | Description |
+==================+=======+==============================+=============================================+
| full_on_ports | int | port number * port number | Distribution of full from port to port. |
+------------------+-------+------------------------------+---------------------------------------------+
| full_on_vessels | int | vessel number * port number | Distribution of full from vessel to port. |
+------------------+-------+------------------------------+---------------------------------------------+
| vessel_plans | int | vessel number * port number | Planed route info for vessels. |
+------------------+-------+------------------------------+---------------------------------------------+
How to
~~~~~~
@ -597,133 +445,47 @@ Nodes and attributes in scenario
station
+++++++
bikes
*****
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| Field | Type | Slots | Description |
+===================+=======+========+===========================================================================================================+
| bikes | int | 1 | How many bikes avaiable in current station. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| shortage | int | 1 | Per tick state. Lack number of bikes in current station. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| trip_requirement | int | 1 | Per tick states. How many requirements in current station. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| fulfillment | int | 1 | How many requirement is fit in current station. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| capacity | int | 1 | Max number of bikes this station can take. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| id | int | 1 | Id of current station. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| weekday | short | 1 | Weekday at current tick. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| temperature | short | 1 | Temperature at current tick. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| weather | short | 1 | Weather at current tick. (0: sunny, 1: rainy, 2: snowy 3: sleet) |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| holiday | short | 1 | If it is holidy at current tick. (0: holiday, 1: not holiday) |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| extra_cost | int | 1 | Cost after we reach the capacity after executing action, we have to move extra bikes to other stations. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| transfer_cost | int | 1 | Cost to execute action to transfer bikes to other station. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| failed_return | int | 1 | Per tick state. How many bikes failed to return to current station. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
| min_bikes | int | 1 | Min bikes number in a frame. |
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
type: int
slots: 1
How many bikes avaiable in current station.
shortage
********
type: int
slots: 1
Per tick state. Lack number of bikes in current station.
trip_requirement
****************
type: int
slots: 1
Per tick states. How many requirements in current station.
fulfillment
***********
type: int
slots: 1
How many requirement is fit in current station.
capacity
********
type: int
slots: 1
Max number of bikes this station can take.
id
+++
type: int
slots: 1
Id of current station.
weekday
*******
type: short
slots: 1
Weekday at current tick.
temperature
***********
type: short
slots: 1
Temperature at current tick.
weather
*******
type: short
slots: 1
Weather at current tick.
0: sunny, 1: rainy, 2: snowy 3: sleet.
holiday
*******
type: short
slots: 1
If it is holidy at current tick.
0: holiday, 1: not holiday
extra_cost
**********
type: int
slots: 1
Cost after we reach the capacity after executing action, we have to move extra bikes
to other stations.
transfer_cost
*************
type: int
slots: 1
Cost to execute action to transfer bikes to other station.
failed_return
*************
type: int
slots: 1
Per tick state. How many bikes failed to return to current station.
min_bikes
*********
type: int
slots: 1
Min bikes number in a frame.
matrices
++++++++
trips_adj
*********
type: int
slots: station number * station number
Used to store trip requirement number between 2 stations.
+------------+-------+----------------------------------+------------------------------------------------------------+
| Field | Type | Slots | Description |
+============+=======+==================================+============================================================+
| trips_adj | int | station number * station number | Used to store trip requirement number between 2 stations. |
+------------+-------+----------------------------------+------------------------------------------------------------+
VM-scheduling
@ -743,315 +505,121 @@ Nodes and attributes in scenario
Cluster
+++++++
id
***
type: short
slots: 1
Id of the cluster.
region_id
*********
type: short
slots: 1
Region is of current cluster.
data_center_id
**************
type: short
slots: 1
Data center id of current cluster.
total_machine_num
******************
type: int
slots: 1
Total number of machines in the cluster.
empty_machine_num
******************
type: int
slots: 1
The number of empty machines in this cluster. A empty machine means that its allocated CPU cores are 0.
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
| Field | Type | Slots | Description |
+====================+=======+========+==========================================================================================================+
| id | short | 1 | Id of the cluster. |
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
| region_id | short | 1 | Region id of current cluster. |
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
| zond_id | short | 1 | Zone id of current cluster. |
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
| data_center_id | short | 1 | Data center id of current cluster. |
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
| total_machine_num | int | 1 | Total number of machines in the cluster. |
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
| empty_machine_num | int | 1 | The number of empty machines in this cluster. A empty machine means that its allocated CPU cores are 0. |
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
data_centers
++++++++++++
id
***
type: short
slots: 1
Id of current data center.
region_id
*********
type: short
slots: 1
Region id of current data center.
zone_id
*******
type: short
slots: 1
Zone id of current data center.
total_machine_num
*****************
type: int
slots: 1
Total number of machine in current data center.
empty_machine_num
*****************
type: int
slots: 1
The number of empty machines in current data center.
+--------------------+-------+--------+-------------------------------------------------------+
| Field | Type | Slots | Description |
+====================+=======+========+=======================================================+
| id | short | 1 | Id of current data center. |
+--------------------+-------+--------+-------------------------------------------------------+
| region_id | short | 1 | Region id of current data center. |
+--------------------+-------+--------+-------------------------------------------------------+
| zone_id | short | 1 | Zone id of current data center. |
+--------------------+-------+--------+-------------------------------------------------------+
| total_machine_num | int | 1 | Total number of machine in current data center. |
+--------------------+-------+--------+-------------------------------------------------------+
| empty_machine_num | int | 1 | The number of empty machines in current data center. |
+--------------------+-------+--------+-------------------------------------------------------+
pms
+++
Physical machine node.
id
***
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| Field | Type | Slots | Description |
+=====================+=======+========+=================================================================================+
| id | int | 1 | Id of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| cpu_cores_capacity | short | 1 | Max number of cpu core can be used for current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| memory_capacity | short | 1 | Max number of memory can be used for current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| pm_type | short | 1 | Type of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| cpu_cores_allocated | short | 1 | How many cpu core is allocated. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| memory_allocated | short | 1 | How many memory is allocated. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| cpu_utilization | float | 1 | CPU utilization of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| energy_consumption | float | 1 | Energy consumption of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| oversubscribable | short | 1 | Physical machine non-oversubscribable is -1, empty: 0, oversubscribable is 1. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| region_id | short | 1 | Region id of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| zone_id | short | 1 | Zone id of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| data_center_id | short | 1 | Data center id of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| cluster_id | short | 1 | Cluster id of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
| rack_id | short | 1 | Rack id of current machine. |
+---------------------+-------+--------+---------------------------------------------------------------------------------+
type: int
slots: 1
Id of current machine.
cpu_cores_capacity
******************
type: short
slots: 1
Max number of cpu core can be used for current machine.
memory_capacity
***************
type: short
slots: 1
Max number of memory can be used for current machine.
pm_type
*******
type: short
slots: 1
Type of current machine.
cpu_cores_allocated
*******************
type: short
slots: 1
How many cpu core is allocated.
memory_allocated
****************
type: short
slots: 1
How many memory is allocated.
cpu_utilization
***************
type: float
slots: 1
CPU utilization of current machine.
energy_consumption
******************
type: float
slots: 1
Energy consumption of current machine.
oversubscribable
****************
type: short
slots: 1
Physical machine type: non-oversubscribable is -1, empty: 0, oversubscribable is 1.
region_id
*********
type: short
slots: 1
Region id of current machine.
zone_id
*******
type: short
slots: 1
Zone id of current machine.
data_center_id
**************
type: short
slots: 1
Data center id of current machine.
cluster_id
**********
type: short
slots: 1
Cluster id of current machine.
rack_id
*******
type: short
slots: 1
Rack id of current machine.
Rack
rack
++++
id
***
type: int
slots: 1
Id of current rack.
region_id
*********
type: short
slots: 1
Region id of current rack.
zone_id
*******
type: short
slots: 1
Zone id of current rack.
data_center_id
**************
type: short
slots: 1
Data center id of current rack.
cluster_id
**********
type: short
slots: 1
Cluster id of current rack.
total_machine_num
*****************
type: int
slots: 1
Total number of machines on this rack.
empty_machine_num
*****************
type: int
slots: 1
Number of machines that not in use on this rack.
+--------------------+-------+--------+---------------------------------------------------+
| Field | Type | Slots | Description |
+====================+=======+========+===================================================+
| id | int | 1 | Id of current rack. |
+--------------------+-------+--------+---------------------------------------------------+
| region_id | short | 1 | Region id of current rack. |
+--------------------+-------+--------+---------------------------------------------------+
| zone_id | short | 1 | Zone id of current rack. |
+--------------------+-------+--------+---------------------------------------------------+
| data_center_id | short | 1 | Data center id of current rack. |
+--------------------+-------+--------+---------------------------------------------------+
| cluster_id | short | 1 | Cluster id of current rack. |
+--------------------+-------+--------+---------------------------------------------------+
| total_machine_num | int | 1 | Total number of machines on this rack. |
+--------------------+-------+--------+---------------------------------------------------+
| empty_machine_num | int | 1 | Number of machines that not in use on this rack. |
+--------------------+-------+--------+---------------------------------------------------+
regions
+++++++
id
***
type: short
slots: 1
Id of curent region.
total_machine_num
*****************
type: int
slots: 1
Total number of machines in this region.
empty_machine_num
*****************
type: int
slots: 1
Number of machines that not in use in this region.
+--------------------+-------+--------+------------------------------------------------------+
| Field | Type | Slots | Description |
+====================+=======+========+======================================================+
| id | short | 1 | Id of current region. |
+--------------------+-------+--------+------------------------------------------------------+
| total_machine_num | int | 1 | Total number of machines in this region. |
+--------------------+-------+--------+------------------------------------------------------+
| empty_machine_num | int | 1 | Number of machines that not in use in this region. |
+--------------------+-------+--------+------------------------------------------------------+
zones
+++++
id
***
type: short
slots: 1
Id of this zone.
total_machine_num
*****************
type: int
slots: 1
Total number of machines in this zone.
empty_machine_num
*****************
type: int
slots: 1
Number of machines that not in use in this zone.
+--------------------+-------+--------+---------------------------------------------------+
| Field | Type | Slots | Description |
+====================+=======+========+===================================================+
| id | short | 1 | Id of this zone. |
+--------------------+-------+--------+---------------------------------------------------+
| region_id | short | 1 | Region id of current rack. |
+--------------------+-------+--------+---------------------------------------------------+
| total_machine_num | int | 1 | Total number of machines in this zone. |
+--------------------+-------+--------+---------------------------------------------------+
| empty_machine_num | int | 1 | Number of machines that not in use in this zone. |
+--------------------+-------+--------+---------------------------------------------------+

Просмотреть файл

@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .rl_component_bundle import CIMBundle as rl_component_bundle_cls
from .rl_component_bundle import rl_component_bundle
__all__ = [
"rl_component_bundle_cls",
"rl_component_bundle",
]

Просмотреть файл

@ -54,9 +54,9 @@ def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyG
def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
reward_discount=0.0,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=0.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,

Просмотреть файл

@ -55,14 +55,14 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
reward_discount=0.0,
replay_memory_capacity=10000,
batch_size=32,
params=DQNParams(
reward_discount=0.0,
update_target_every=5,
num_epochs=10,
soft_update_coef=0.1,
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
),
)

Просмотреть файл

@ -62,8 +62,8 @@ def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePol
def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer:
return DiscreteMADDPGTrainer(
name=name,
reward_discount=0.0,
params=DiscreteMADDPGParams(
reward_discount=0.0,
num_epoch=10,
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
shared_critic=False,

Просмотреть файл

@ -16,12 +16,11 @@ def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicy
def get_ppo(state_dim: int, name: str) -> PPOTrainer:
return PPOTrainer(
name=name,
reward_discount=0.0,
params=PPOParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=0.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
lam=0.0,
clip_ratio=0.1,
),

Просмотреть файл

@ -7,11 +7,6 @@ env_conf = {
"durations": 560,
}
if env_conf["topology"].startswith("toy"):
num_agents = int(env_conf["topology"].split(".")[1][0])
else:
num_agents = int(env_conf["topology"].split(".")[1][:2])
port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
vessel_attributes = ["empty", "full", "remaining_space"]

Просмотреть файл

@ -1,77 +1,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Any, Callable, Dict, Optional
from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
from maro.simulator import Env
from .algorithms.ac import get_ac, get_ac_policy
from .algorithms.dqn import get_dqn, get_dqn_policy
from .algorithms.maddpg import get_maddpg, get_maddpg_policy
from .algorithms.ppo import get_ppo, get_ppo_policy
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
from examples.cim.rl.config import action_num, algorithm, env_conf, reward_shaping_conf, state_dim
from examples.cim.rl.env_sampler import CIMEnvSampler
# Environments
learn_env = Env(**env_conf)
test_env = learn_env
class CIMBundle(RLComponentBundle):
def get_env_config(self) -> dict:
return env_conf
# Agent, policy, and trainers
num_agents = len(learn_env.agent_idx_list)
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
if algorithm == "ac":
policies = [get_ac_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
trainers = [get_ac(state_dim, f"{algorithm}_{i}") for i in range(num_agents)]
elif algorithm == "ppo":
policies = [get_ppo_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
trainers = [get_ppo(state_dim, f"{algorithm}_{i}") for i in range(num_agents)]
elif algorithm == "dqn":
policies = [get_dqn_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
trainers = [get_dqn(f"{algorithm}_{i}") for i in range(num_agents)]
elif algorithm == "discrete_maddpg":
policies = [get_maddpg_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
trainers = [get_maddpg(state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)]
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
def get_test_env_config(self) -> Optional[dict]:
return None
def get_env_sampler(self) -> AbsEnvSampler:
return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"])
def get_agent2policy(self) -> Dict[Any, str]:
return {agent: f"{algorithm}_{agent}.policy" for agent in self.env.agent_idx_list}
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
if algorithm == "ac":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_ac_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "ppo":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_ppo_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "dqn":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_dqn_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "discrete_maddpg":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_maddpg_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return policy_creator
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
if algorithm == "ac":
trainer_creator = {
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
}
elif algorithm == "ppo":
trainer_creator = {
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
}
elif algorithm == "dqn":
trainer_creator = {f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}") for i in range(num_agents)}
elif algorithm == "discrete_maddpg":
trainer_creator = {
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return trainer_creator
# Build RLComponentBundle
rl_component_bundle = RLComponentBundle(
env_sampler=CIMEnvSampler(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
reward_eval_delay=reward_shaping_conf["time_window"],
),
agent2policy=agent2policy,
policies=policies,
trainers=trainers,
)

Просмотреть файл

@ -1 +1,3 @@
PuLP==2.1
matplotlib>=3.1.2
pulp>=2.1.0
tweepy>=4.10.0

Просмотреть файл

@ -0,0 +1,47 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for CIM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
# Run this workflow by executing one of the following commands:
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml
job: cim_rl_workflow
scenario_path: "examples/cim/rl"
log_path: "log/rl_job/cim.txt"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
eval_schedule: 5
logging:
stdout: INFO
file: DEBUG
rollout:
parallelism:
sampling: 3
eval: null
min_env_samples: 3
grace_factor: 0.2
controller:
host: "127.0.0.1"
port: 20000
logging:
stdout: INFO
file: DEBUG
training:
mode: parallel
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/cim"
interval: 5
proxy:
host: "127.0.0.1"
frontend: 10000
backend: 10001
num_workers: 2
logging:
stdout: INFO
file: DEBUG

Просмотреть файл

@ -12,7 +12,7 @@ import yaml
from ilp_agent import IlpAgent
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import DecisionPayload
from maro.simulator.scenarios.vm_scheduling import DecisionEvent
from maro.simulator.scenarios.vm_scheduling.common import Action
from maro.utils import LogFormat, Logger, convert_dottable
@ -46,7 +46,7 @@ if __name__ == "__main__":
env.set_seed(config.env.seed)
metrics: object = None
decision_event: DecisionPayload = None
decision_event: DecisionEvent = None
is_done: bool = False
action: Action = None

Просмотреть файл

@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .rl_component_bundle import VMBundle as rl_component_bundle_cls
from .rl_component_bundle import rl_component_bundle
__all__ = [
"rl_component_bundle_cls",
"rl_component_bundle",
]

Просмотреть файл

@ -61,9 +61,9 @@ def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str)
def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
reward_discount=0.9,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features),
reward_discount=0.9,
grad_iters=100,
critic_loss_cls=torch.nn.MSELoss,
min_logp=-20,

Просмотреть файл

@ -77,15 +77,15 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
reward_discount=0.9,
replay_memory_capacity=10000,
batch_size=32,
data_parallelism=2,
params=DQNParams(
reward_discount=0.9,
update_target_every=5,
num_epochs=100,
soft_update_coef=0.1,
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
data_parallelism=2,
),
)

Просмотреть файл

@ -5,14 +5,15 @@ import time
from collections import defaultdict
from os import makedirs
from os.path import dirname, join, realpath
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Type, Union
import numpy as np
from matplotlib import pyplot as plt
from maro.rl.rollout import AbsEnvSampler, CacheElement
from maro.rl.policy import AbsPolicy
from maro.rl.rollout import AbsAgentWrapper, AbsEnvSampler, CacheElement, SimpleAgentWrapper
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent, PostponeAction
from .config import (
num_features,
@ -30,8 +31,25 @@ makedirs(plt_path, exist_ok=True)
class VMEnvSampler(AbsEnvSampler):
def __init__(self, learn_env: Env, test_env: Env) -> None:
super(VMEnvSampler, self).__init__(learn_env, test_env)
def __init__(
self,
learn_env: Env,
test_env: Env,
policies: List[AbsPolicy],
agent2policy: Dict[Any, str],
trainable_policies: List[str] = None,
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
reward_eval_delay: int = None,
) -> None:
super(VMEnvSampler, self).__init__(
learn_env,
test_env,
policies,
agent2policy,
trainable_policies,
agent_wrapper_cls,
reward_eval_delay,
)
self._learn_env.set_seed(seed)
self._test_env.set_seed(test_seed)
@ -44,7 +62,7 @@ class VMEnvSampler(AbsEnvSampler):
def _get_global_and_agent_state_impl(
self,
event: DecisionPayload,
event: DecisionEvent,
tick: int = None,
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
pm_state, vm_state = self._get_pm_state(), self._get_vm_state(event)
@ -71,14 +89,14 @@ class VMEnvSampler(AbsEnvSampler):
def _translate_to_env_action(
self,
action_dict: Dict[Any, Union[np.ndarray, List[object]]],
event: DecisionPayload,
event: DecisionEvent,
) -> Dict[Any, object]:
if action_dict["AGENT"] == self.num_pms:
return {"AGENT": PostponeAction(vm_id=event.vm_id, postpone_step=1)}
else:
return {"AGENT": AllocateAction(vm_id=event.vm_id, pm_id=action_dict["AGENT"][0])}
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionPayload, tick: int) -> Dict[Any, float]:
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:
action = env_action_dict["AGENT"]
conf = reward_shaping_conf if self._env == self._learn_env else test_reward_shaping_conf
if isinstance(action, PostponeAction): # postponement
@ -121,7 +139,7 @@ class VMEnvSampler(AbsEnvSampler):
],
)
def _get_allocation_reward(self, event: DecisionPayload, alpha: float, beta: float):
def _get_allocation_reward(self, event: DecisionEvent, alpha: float, beta: float):
vm_unit_price = self._env.business_engine._get_unit_price(
event.vm_cpu_cores_requirement,
event.vm_memory_requirement,

Просмотреть файл

@ -1,67 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Any, Callable, Dict, Optional
from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
from maro.simulator import Env
from examples.vm_scheduling.rl.algorithms.ac import get_ac, get_ac_policy
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn, get_dqn_policy
from .algorithms.ac import get_ac, get_ac_policy
from .algorithms.dqn import get_dqn, get_dqn_policy
from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf
from examples.vm_scheduling.rl.env_sampler import VMEnvSampler
# Environments
learn_env = Env(**env_conf)
test_env = Env(**test_env_conf)
class VMBundle(RLComponentBundle):
def get_env_config(self) -> dict:
return env_conf
# Agent, policy, and trainers
action_num = num_pms + 1
agent2policy = {"AGENT": f"{algorithm}.policy"}
if algorithm == "ac":
policies = [get_ac_policy(state_dim, action_num, num_features, f"{algorithm}.policy")]
trainers = [get_ac(state_dim, num_features, algorithm)]
elif algorithm == "dqn":
policies = [get_dqn_policy(state_dim, action_num, num_features, f"{algorithm}.policy")]
trainers = [get_dqn(algorithm)]
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
def get_test_env_config(self) -> Optional[dict]:
return test_env_conf
def get_env_sampler(self) -> AbsEnvSampler:
return VMEnvSampler(self.env, self.test_env)
def get_agent2policy(self) -> Dict[Any, str]:
return {"AGENT": f"{algorithm}.policy"}
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
action_num = num_pms + 1 # action could be any PM or postponement, hence the plus 1
if algorithm == "ac":
policy_creator = {
f"{algorithm}.policy": partial(
get_ac_policy,
state_dim,
action_num,
num_features,
f"{algorithm}.policy",
),
}
elif algorithm == "dqn":
policy_creator = {
f"{algorithm}.policy": partial(
get_dqn_policy,
state_dim,
action_num,
num_features,
f"{algorithm}.policy",
),
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return policy_creator
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
if algorithm == "ac":
trainer_creator = {algorithm: partial(get_ac, state_dim, num_features, algorithm)}
elif algorithm == "dqn":
trainer_creator = {algorithm: partial(get_dqn, algorithm)}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return trainer_creator
# Build RLComponentBundle
rl_component_bundle = RLComponentBundle(
env_sampler=VMEnvSampler(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
),
agent2policy=agent2policy,
policies=policies,
trainers=trainers,
)

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent, PostponeAction
from maro.simulator.scenarios.vm_scheduling.common import Action
@ -10,7 +10,7 @@ class VMSchedulingAgent(object):
def __init__(self, algorithm):
self._algorithm = algorithm
def choose_action(self, decision_event: DecisionPayload, env: Env) -> Action:
def choose_action(self, decision_event: DecisionEvent, env: Env) -> Action:
"""This method will determine whether to postpone the current VM or allocate a PM to the current VM."""
valid_pm_num: int = len(decision_event.valid_pms)

Просмотреть файл

@ -5,7 +5,7 @@ import numpy as np
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
class BestFit(RuleBasedAlgorithm):
@ -13,7 +13,7 @@ class BestFit(RuleBasedAlgorithm):
super().__init__()
self._metric_type: str = kwargs["metric_type"]
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
# Use a rule to choose a valid PM.
chosen_idx: int = self._pick_pm_func(decision_event, env)
# Take action to allocate on the chose PM.

Просмотреть файл

@ -7,7 +7,7 @@ import numpy as np
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
class BinPacking(RuleBasedAlgorithm):
@ -24,7 +24,7 @@ class BinPacking(RuleBasedAlgorithm):
self._bins = [[] for _ in range(self._pm_cpu_core_num + 1)]
self._bin_size = [0] * (self._pm_cpu_core_num + 1)
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
# Initialize the bin.
self._init_bin()

Просмотреть файл

@ -4,14 +4,14 @@
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
class FirstFit(RuleBasedAlgorithm):
def __init__(self, **kwargs):
super().__init__()
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
# Use a valid PM based on its order.
chosen_idx: int = decision_event.valid_pms[0]
# Take action to allocate on the chose PM.

Просмотреть файл

@ -6,14 +6,14 @@ import random
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
class RandomPick(RuleBasedAlgorithm):
def __init__(self, **kwargs):
super().__init__()
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
valid_pm_num: int = len(decision_event.valid_pms)
# Random choose a valid PM.
chosen_idx: int = random.randint(0, valid_pm_num - 1)

Просмотреть файл

@ -4,7 +4,7 @@
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
class RoundRobin(RuleBasedAlgorithm):
@ -15,7 +15,7 @@ class RoundRobin(RuleBasedAlgorithm):
kwargs["env"].snapshot_list["pms"][kwargs["env"].frame_index :: ["cpu_cores_capacity"]].shape[0]
)
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
# Choose the valid PM which index is next to the previous chose PM's index
chosen_idx: int = (self._prev_idx + 1) % self._pm_num
while chosen_idx not in decision_event.valid_pms:

Просмотреть файл

@ -4,13 +4,13 @@
import abc
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
class RuleBasedAlgorithm(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
"""This method will determine allocate which PM to the current VM."""
raise NotImplementedError

Просмотреть файл

@ -25,7 +25,7 @@ def start_cim_dashboard(source_path: str, epoch_num: int, prefix: str):
--ports.csv: Record ports' attributes in this file.
--vessel.csv: Record vessels' attributes in this file.
--matrices.csv: Record transfer volume information in this file.
......
--epoch_{epoch_num-1}
--manifest.yml: Record basic info like scenario name, name of index_name_mapping file.
--config.yml: Record the relationship between ports' index and name.

Просмотреть файл

@ -24,7 +24,7 @@ def start_citi_bike_dashboard(source_path: str, epoch_num: int, prefix: str):
--stations.csv: Record stations' attributes in this file.
--matrices.csv: Record transfer volume information in this file.
--stations_summary.csv: Record the summary data of current epoch.
......
--epoch_{epoch_num-1}
--manifest.yml: Record basic info like scenario name, name of index_name_mapping file.
--full_stations.json: Record the relationship between ports' index and name.

Просмотреть файл

@ -28,7 +28,7 @@ def start_vis(source_path: str, force: str, **kwargs: dict):
-input_file_folder_path
--epoch_0 : Data of current epoch.
--holder_info.csv: Attributes of current epoch.
......
--epoch_{epoch_num-1}
--manifest.yml: Record basic info like scenario name, name of index_name_mapping file.
--index_name_mapping file: Record the relationship between an index and its name.

21
maro/common.py Normal file
Просмотреть файл

@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
class BaseDecisionEvent:
"""Base class for all decision events.
We made this design for the convenience of users. As a summary, there are two types of events in MARO:
- CascadeEvent & AtomEvent: used to drive the MARO Env / business engine.
- DecisionEvent: exposed to users as a means of communication.
The latter one serves as the `payload` of the former ones inside of MARO Env.
Therefore, the related namings might be a little bit tricky.
- Inside MARO Env: `decision_event` is actually a CascadeEvent. DecisionEvent is the payload of them.
- Outside MARO Env (for users): `decision_event` is a DecisionEvent.
"""
class BaseAction:
"""Base class for all action payloads"""

Просмотреть файл

@ -4,8 +4,9 @@
import csv
from collections import defaultdict
from typing import Callable, List, Optional
from typing import Callable, List, Optional, cast
from ..common import BaseAction, BaseDecisionEvent
from .event import ActualEvent, AtomEvent, CascadeEvent
from .event_linked_list import EventLinkedList
from .event_pool import EventPool
@ -122,9 +123,7 @@ class EventBuffer:
Returns:
AtomEvent: Atom event object
"""
event = self._event_pool.gen(tick, event_type, payload, False)
assert isinstance(event, AtomEvent)
return event
return cast(AtomEvent, self._event_pool.gen(tick, event_type, payload, is_cascade=False))
def gen_cascade_event(self, tick: int, event_type: object, payload: object) -> CascadeEvent:
"""Generate an cascade event that used to hold immediate events that
@ -138,31 +137,32 @@ class EventBuffer:
Returns:
CascadeEvent: Cascade event object.
"""
event = self._event_pool.gen(tick, event_type, payload, True)
assert isinstance(event, CascadeEvent)
return event
return cast(CascadeEvent, self._event_pool.gen(tick, event_type, payload, is_cascade=True))
def gen_decision_event(self, tick: int, payload: object) -> CascadeEvent:
def gen_decision_event(self, tick: int, payload: BaseDecisionEvent) -> CascadeEvent:
"""Generate a decision event that will stop current simulation, and ask agent for action.
Args:
tick (int): Tick that the event will be processed.
payload (object): Payload of event, used to pass data to handlers.
payload (BaseDecisionEvent): Payload of event, used to pass data to handlers.
Returns:
CascadeEvent: Event object
"""
assert isinstance(payload, BaseDecisionEvent)
return self.gen_cascade_event(tick, MaroEvents.PENDING_DECISION, payload)
def gen_action_event(self, tick: int, payload: object) -> CascadeEvent:
def gen_action_event(self, tick: int, payloads: List[BaseAction]) -> CascadeEvent:
"""Generate an event that used to dispatch action to business engine.
Args:
tick (int): Tick that the event will be processed.
payload (object): Payload of event, used to pass data to handlers.
payloads (List[BaseAction]): Payloads of event, used to pass data to handlers.
Returns:
CascadeEvent: Event object
"""
return self.gen_cascade_event(tick, MaroEvents.TAKE_ACTION, payload)
assert isinstance(payloads, list)
assert all(isinstance(p, BaseAction) for p in payloads)
return self.gen_cascade_event(tick, MaroEvents.TAKE_ACTION, payloads)
def register_event_handler(self, event_type: object, handler: Callable) -> None:
"""Register an event with handler, when there is an event need to be processed,

Просмотреть файл

@ -1,6 +1,6 @@
pyjwt
numpy<1.20.0
Cython>=0.29.14
PyJWT>=2.4.0
numpy>=1.19.0
cython>=0.29.14
altair>=4.1.0
streamlit>=0.69.1
tqdm>=4.51.0

Просмотреть файл

@ -3,8 +3,12 @@
from .abs_proxy import AbsProxy
from .abs_worker import AbsWorker
from .port_config import DEFAULT_ROLLOUT_PRODUCER_PORT, DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT
__all__ = [
"AbsProxy",
"AbsWorker",
"DEFAULT_ROLLOUT_PRODUCER_PORT",
"DEFAULT_TRAINING_FRONTEND_PORT",
"DEFAULT_TRAINING_BACKEND_PORT",
]

Просмотреть файл

@ -2,6 +2,7 @@
# Licensed under the MIT license.
from abc import abstractmethod
from typing import Union
import zmq
from tornado.ioloop import IOLoop
@ -33,7 +34,7 @@ class AbsWorker(object):
super(AbsWorker, self).__init__()
self._id = f"worker.{idx}"
self._logger = logger if logger else DummyLogger()
self._logger: Union[LoggerV2, DummyLogger] = logger if logger else DummyLogger()
# ZMQ sockets and streams
self._context = Context.instance()

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
DEFAULT_ROLLOUT_PRODUCER_PORT = 20000
DEFAULT_TRAINING_FRONTEND_PORT = 10000
DEFAULT_TRAINING_BACKEND_PORT = 10001

Просмотреть файл

@ -98,14 +98,15 @@ class MultiLinearExplorationScheduler(AbsExplorationScheduler):
start_ep: int = 1,
initial_value: float = None,
) -> None:
super().__init__(exploration_params, param_name, initial_value=initial_value)
# validate splits
splits = [(start_ep, initial_value)] + splits + [(last_ep, final_value)]
splits = [(start_ep, self._exploration_params[self.param_name])] + splits + [(last_ep, final_value)]
splits.sort()
for (ep, _), (ep2, _) in zip(splits, splits[1:]):
if ep == ep2:
raise ValueError("The zeroth element of split points must be unique")
super().__init__(exploration_params, param_name, initial_value=initial_value)
self.final_value = final_value
self._splits = splits
self._ep = start_ep

Просмотреть файл

@ -4,7 +4,7 @@
from __future__ import annotations
from abc import ABCMeta
from typing import Any, Dict, Optional
from typing import Any, Dict
import torch.nn
from torch.optim import Optimizer
@ -18,7 +18,11 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
def __init__(self) -> None:
super(AbsNet, self).__init__()
self._optim: Optional[Optimizer] = None
@property
def optim(self) -> Optimizer:
optim = getattr(self, "_optim", None)
assert isinstance(optim, Optimizer), "Each AbsNet must have an optimizer"
return optim
def step(self, loss: torch.Tensor) -> None:
"""Run a training step to update the net's parameters according to the given loss.
@ -26,9 +30,9 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
Args:
loss (torch.tensor): Loss used to update the model.
"""
self._optim.zero_grad()
self.optim.zero_grad()
loss.backward()
self._optim.step()
self.optim.step()
def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Get the gradients with respect to all parameters according to the given loss.
@ -39,7 +43,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
Returns:
Gradients (Dict[str, torch.Tensor]): A dict that contains gradients for all parameters.
"""
self._optim.zero_grad()
self.optim.zero_grad()
loss.backward()
return {name: param.grad for name, param in self.named_parameters()}
@ -51,7 +55,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
"""
for name, param in self.named_parameters():
param.grad = grad[name]
self._optim.step()
self.optim.step()
def _forward_unimplemented(self, *input: Any) -> None:
pass
@ -64,7 +68,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
"""
return {
"network": self.state_dict(),
"optim": self._optim.state_dict(),
"optim": self.optim.state_dict(),
}
def set_state(self, net_state: dict) -> None:
@ -74,7 +78,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
net_state (dict): A dict that contains the net's state.
"""
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])
self.optim.load_state_dict(net_state["optim"])
def soft_update(self, other_model: AbsNet, tau: float) -> None:
"""Soft update the net's parameters according to another net, i.e.,

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
from collections import OrderedDict
from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Tuple, Type
import torch
import torch.nn as nn
@ -46,7 +46,7 @@ class FullyConnected(nn.Module):
skip_connection: bool = False,
dropout_p: float = None,
gradient_threshold: float = None,
name: str = None,
name: str = "NONAME",
) -> None:
super(FullyConnected, self).__init__()
self._input_dim = input_dim
@ -84,7 +84,7 @@ class FullyConnected(nn.Module):
self._name = name
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self._net(x)
out = self._net(x.float())
if self._skip_connection:
out += x
return self._softmax(out) if self._softmax else out
@ -101,12 +101,12 @@ class FullyConnected(nn.Module):
def output_dim(self) -> int:
return self._output_dim
def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> torch.nn.Module:
def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> nn.Module:
"""Build a basic layer.
BN -> Linear -> Activation -> Dropout
"""
components = []
components: List[Tuple[str, nn.Module]] = []
if self._batch_norm:
components.append(("batch_norm", nn.BatchNorm1d(input_dim)))
components.append(("linear", nn.Linear(input_dim, output_dim)))

Просмотреть файл

@ -4,7 +4,7 @@
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
@ -27,14 +27,14 @@ class AbsPolicy(object, metaclass=ABCMeta):
self._trainable = trainable
@abstractmethod
def get_actions(self, states: object) -> object:
def get_actions(self, states: Union[list, np.ndarray]) -> Any:
"""Get actions according to states.
Args:
states (object): States.
states (Union[list, np.ndarray]): States.
Returns:
actions (object): Actions.
actions (Any): Actions.
"""
raise NotImplementedError
@ -79,7 +79,7 @@ class DummyPolicy(AbsPolicy):
def __init__(self) -> None:
super(DummyPolicy, self).__init__(name="DUMMY_POLICY", trainable=False)
def get_actions(self, states: object) -> None:
def get_actions(self, states: Union[list, np.ndarray]) -> None:
return None
def explore(self) -> None:
@ -101,11 +101,11 @@ class RuleBasedPolicy(AbsPolicy, metaclass=ABCMeta):
def __init__(self, name: str) -> None:
super(RuleBasedPolicy, self).__init__(name=name, trainable=False)
def get_actions(self, states: List[object]) -> List[object]:
def get_actions(self, states: list) -> list:
return self._rule(states)
@abstractmethod
def _rule(self, states: List[object]) -> List[object]:
def _rule(self, states: list) -> list:
raise NotImplementedError
def explore(self) -> None:
@ -304,7 +304,7 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def get_state(self) -> object:
def get_state(self) -> dict:
"""Get the state of the policy."""
raise NotImplementedError

Просмотреть файл

@ -62,12 +62,10 @@ class ContinuousRLPolicy(RLPolicy):
)
self._lbounds, self._ubounds = _parse_action_range(self.action_dim, action_range)
assert self._lbounds is not None and self._ubounds is not None
self._policy_net = policy_net
@property
def action_bounds(self) -> Tuple[List[float], List[float]]:
def action_bounds(self) -> Tuple[Optional[List[float]], Optional[List[float]]]:
return self._lbounds, self._ubounds
@property
@ -118,7 +116,7 @@ class ContinuousRLPolicy(RLPolicy):
def train(self) -> None:
self._policy_net.train()
def get_state(self) -> object:
def get_state(self) -> dict:
return self._policy_net.get_state()
def set_state(self, policy_state: dict) -> None:

Просмотреть файл

@ -85,9 +85,11 @@ class ValueBasedPolicy(DiscreteRLPolicy):
self._exploration_func = exploration_strategy[0]
self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing
self._exploration_schedulers = [
opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options
]
self._exploration_schedulers = (
[opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options]
if exploration_scheduling_options is not None
else []
)
self._call_cnt = 0
self._warmup = warmup
@ -219,7 +221,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
def train(self) -> None:
self._q_net.train()
def get_state(self) -> object:
def get_state(self) -> dict:
return self._q_net.get_state()
def set_state(self, policy_state: dict) -> None:

Просмотреть файл

@ -1,194 +1,103 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import abstractmethod
from functools import partial
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List
from maro.rl.policy import AbsPolicy
from maro.rl.policy import AbsPolicy, RLPolicy
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
from maro.simulator import Env
class RLComponentBundle(object):
class RLComponentBundle:
"""Bundle of all necessary components to run a RL job in MARO.
Users should create their own subclass of `RLComponentBundle` and implement following methods:
- get_env_config()
- get_test_env_config()
- get_env_sampler()
- get_agent2policy()
- get_policy_creator()
- get_trainer_creator()
Following methods could be overwritten when necessary:
- get_device_mapping()
Please refer to the doc string of each method for detailed explanations.
env_sampler (AbsEnvSampler): Environment sampler of the scenario.
agent2policy (Dict[Any, str]): Agent name to policy name mapping of the RL job. For example:
{agent1: policy1, agent2: policy1, agent3: policy2}.
policies (List[AbsPolicy]): Policies.
trainers (List[AbsTrainer]): Trainers.
device_mapping (Dict[str, str], default=None): Device mapping that identifying which device to put each policy.
If None, there will be no explicit device assignment.
policy_trainer_mapping (Dict[str, str], default=None): Policy-trainer mapping which identifying which trainer to
train each policy. If None, then a policy's trainer's name is the first segment of the policy's name,
seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1". Only policies that provided in
policy-trainer mapping are considered as trainable polices. Policies that not provided in policy-trainer
mapping will not be trained.
"""
def __init__(self) -> None:
super(RLComponentBundle, self).__init__()
def __init__(
self,
env_sampler: AbsEnvSampler,
agent2policy: Dict[Any, str],
policies: List[AbsPolicy],
trainers: List[AbsTrainer],
device_mapping: Dict[str, str] = None,
policy_trainer_mapping: Dict[str, str] = None,
) -> None:
self.env_sampler = env_sampler
self.agent2policy = agent2policy
self.policies = policies
self.trainers = trainers
self.trainer_creator: Optional[Dict[str, Callable[[], AbsTrainer]]] = None
policy_set = set([policy.name for policy in self.policies])
not_found = [policy_name for policy_name in self.agent2policy.values() if policy_name not in policy_set]
if len(not_found) > 0:
raise ValueError(f"The following policies are required but cannot be found: [{', '.join(not_found)}]")
self.agent2policy: Optional[Dict[Any, str]] = None
self.trainable_agent2policy: Optional[Dict[Any, str]] = None
self.policy_creator: Optional[Dict[str, Callable[[], AbsPolicy]]] = None
self.policy_names: Optional[List[str]] = None
self.trainable_policy_creator: Optional[Dict[str, Callable[[], AbsPolicy]]] = None
self.trainable_policy_names: Optional[List[str]] = None
# Remove unused policies
kept_policies = []
for policy in self.policies:
if policy.name not in self.agent2policy.values():
raise Warning(f"Policy {policy.name} is removed since it is not used by any agent.")
else:
kept_policies.append(policy)
self.policies = kept_policies
policy_set = set([policy.name for policy in self.policies])
self.device_mapping: Optional[Dict[str, str]] = None
self.policy_trainer_mapping: Optional[Dict[str, str]] = None
self.device_mapping = (
{k: v for k, v in device_mapping.items() if k in policy_set} if device_mapping is not None else {}
)
self.policy_trainer_mapping = (
policy_trainer_mapping
if policy_trainer_mapping is not None
else {policy_name: policy_name.split(".")[0] for policy_name in policy_set}
)
self._policy_cache: Optional[Dict[str, AbsPolicy]] = None
# Will be created when `env_sampler()` is first called
self._env_sampler: Optional[AbsEnvSampler] = None
self._complete_resources()
########################################################################################
# Users MUST implement the following methods #
########################################################################################
@abstractmethod
def get_env_config(self) -> dict:
"""Return the environment configuration to build the MARO Env for training.
Returns:
Environment configuration.
"""
raise NotImplementedError
@abstractmethod
def get_test_env_config(self) -> Optional[dict]:
"""Return the environment configuration to build the MARO Env for testing. If returns `None`, the training
environment will be reused as testing environment.
Returns:
Environment configuration or `None`.
"""
raise NotImplementedError
@abstractmethod
def get_env_sampler(self) -> AbsEnvSampler:
"""Return the environment sampler of the scenario.
Returns:
The environment sampler of the scenario.
"""
raise NotImplementedError
@abstractmethod
def get_agent2policy(self) -> Dict[Any, str]:
"""Return agent name to policy name mapping of the RL job. This mapping identifies which policy should
the agents use. For example: {agent1: policy1, agent2: policy1, agent3: policy2}.
Returns:
Agent name to policy name mapping.
"""
raise NotImplementedError
@abstractmethod
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
"""Return policy creator. Policy creator is a dictionary that contains a group of functions that generate
policy instances. The key of this dictionary is the policy name, and the value is the function that generate
the corresponding policy instance. Note that the creation function should not take any parameters.
"""
raise NotImplementedError
@abstractmethod
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
"""Return trainer creator. Trainer creator is similar to policy creator, but is used to creator trainers."""
raise NotImplementedError
########################################################################################
# Users could overwrite the following methods #
########################################################################################
def get_device_mapping(self) -> Dict[str, str]:
"""Return the device mapping that identifying which device to put each policy.
If user does not overwrite this method, then all policies will be put on CPU by default.
"""
return {policy_name: "cpu" for policy_name in self.get_policy_creator()}
def get_policy_trainer_mapping(self) -> Dict[str, str]:
"""Return the policy-trainer mapping which identifying which trainer to train each policy.
If user does not overwrite this method, then a policy's trainer's name is the first segment of the policy's
name, seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1".
Only policies that provided in policy-trainer mapping are considered as trainable polices. Policies that
not provided in policy-trainer mapping will not be trained since we do not assign a trainer to it.
"""
return {policy_name: policy_name.split(".")[0] for policy_name in self.policy_creator}
########################################################################################
# Methods invisible to users #
########################################################################################
@property
def env_sampler(self) -> AbsEnvSampler:
if self._env_sampler is None:
self._env_sampler = self.get_env_sampler()
self._env_sampler.build(self)
return self._env_sampler
def _complete_resources(self) -> None:
"""Generate all attributes by calling user-defined logics. Do necessary checking and transformations."""
env_config = self.get_env_config()
test_env_config = self.get_test_env_config()
self.env = Env(**env_config)
self.test_env = self.env if test_env_config is None else Env(**test_env_config)
self.trainer_creator = self.get_trainer_creator()
self.device_mapping = self.get_device_mapping()
self.policy_creator = self.get_policy_creator()
self.agent2policy = self.get_agent2policy()
self.policy_trainer_mapping = self.get_policy_trainer_mapping()
required_policies = set(self.agent2policy.values())
self.policy_creator = {name: self.policy_creator[name] for name in required_policies}
# Check missing trainers
self.policy_trainer_mapping = {
name: self.policy_trainer_mapping[name] for name in required_policies if name in self.policy_trainer_mapping
policy_name: trainer_name
for policy_name, trainer_name in self.policy_trainer_mapping.items()
if policy_name in policy_set
}
self.policy_names = list(required_policies)
assert len(required_policies) == len(self.policy_creator) # Should have same size after filter
trainer_set = set([trainer.name for trainer in self.trainers])
not_found = [
trainer_name for trainer_name in self.policy_trainer_mapping.values() if trainer_name not in trainer_set
]
if len(not_found) > 0:
raise ValueError(f"The following trainers are required but cannot be found: [{', '.join(not_found)}]")
required_trainers = set(self.policy_trainer_mapping.values())
self.trainer_creator = {name: self.trainer_creator[name] for name in required_trainers}
assert len(required_trainers) == len(self.trainer_creator) # Should have same size after filter
# Remove unused trainers
kept_trainers = []
for trainer in self.trainers:
if trainer.name not in self.policy_trainer_mapping.values():
raise Warning(f"Trainer {trainer.name} if removed since no policy is trained by it.")
else:
kept_trainers.append(trainer)
self.trainers = kept_trainers
self.trainable_policy_names = list(self.policy_trainer_mapping.keys())
self.trainable_policy_creator = {
policy_name: self.policy_creator[policy_name] for policy_name in self.trainable_policy_names
}
self.trainable_agent2policy = {
@property
def trainable_agent2policy(self) -> Dict[Any, str]:
return {
agent_name: policy_name
for agent_name, policy_name in self.agent2policy.items()
if policy_name in self.trainable_policy_names
if policy_name in self.policy_trainer_mapping
}
def pre_create_policy_instances(self) -> None:
"""Pre-create policy instances, and return the pre-created policy instances when the external callers
want to create new policies. This will ensure that each policy will have at most one reusable duplicate.
Under specific scenarios (for example, simple training & rollout), this will reduce unnecessary overheads.
"""
old_policy_creator = self.policy_creator
self._policy_cache: Dict[str, AbsPolicy] = {}
for policy_name in self.policy_names:
self._policy_cache[policy_name] = old_policy_creator[policy_name]()
def _get_policy_instance(policy_name: str) -> AbsPolicy:
return self._policy_cache[policy_name]
self.policy_creator = {
policy_name: partial(_get_policy_instance, policy_name) for policy_name in self.policy_names
}
self.trainable_policy_creator = {
policy_name: self.policy_creator[policy_name] for policy_name in self.trainable_policy_names
}
@property
def trainable_policies(self) -> List[RLPolicy]:
policies = []
for policy in self.policies:
if policy.name in self.policy_trainer_mapping:
assert isinstance(policy, RLPolicy)
policies.append(policy)
return policies

Просмотреть файл

@ -4,12 +4,13 @@
import os
import time
from itertools import chain
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import torch
import zmq
from zmq import Context, Poller
from maro.rl.distributed import DEFAULT_ROLLOUT_PRODUCER_PORT
from maro.rl.utils.common import bytes_to_pyobj, get_own_ip_address, pyobj_to_bytes
from maro.rl.utils.objects import FILE_SUFFIX
from maro.utils import DummyLogger, LoggerV2
@ -37,19 +38,19 @@ class ParallelTaskController(object):
self._poller = Poller()
self._poller.register(self._task_endpoint, zmq.POLLIN)
self._workers = set()
self._logger = logger
self._workers: set = set()
self._logger: Union[DummyLogger, LoggerV2] = logger if logger is not None else DummyLogger()
def _wait_for_workers_ready(self, k: int) -> None:
while len(self._workers) < k:
self._workers.add(self._task_endpoint.recv_multipart()[0])
def _recv_result_for_target_index(self, index: int) -> object:
def _recv_result_for_target_index(self, index: int) -> Any:
rep = bytes_to_pyobj(self._task_endpoint.recv_multipart()[-1])
assert isinstance(rep, dict)
return rep["result"] if rep["index"] == index else None
def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: int = None) -> List[dict]:
def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: float = None) -> List[dict]:
"""Send a task request to a set of remote workers and collect the results.
Args:
@ -70,7 +71,7 @@ class ParallelTaskController(object):
min_replies = parallelism
start_time = time.time()
results = []
results: list = []
for worker_id in list(self._workers)[:parallelism]:
self._task_endpoint.send_multipart([worker_id, pyobj_to_bytes(req)])
self._logger.debug(f"Sent {parallelism} roll-out requests...")
@ -81,7 +82,7 @@ class ParallelTaskController(object):
results.append(result)
if grace_factor is not None:
countdown = int((time.time() - start_time) * grace_factor) * 1000 # milliseconds
countdown = int((time.time() - start_time) * grace_factor) * 1000.0 # milliseconds
self._logger.debug(f"allowing {countdown / 1000} seconds for remaining results")
while len(results) < parallelism and countdown > 0:
start = time.time()
@ -125,15 +126,18 @@ class BatchEnvSampler:
def __init__(
self,
sampling_parallelism: int,
port: int = 20000,
port: int = None,
min_env_samples: int = None,
grace_factor: float = None,
eval_parallelism: int = None,
logger: LoggerV2 = None,
) -> None:
super(BatchEnvSampler, self).__init__()
self._logger = logger if logger else DummyLogger()
self._controller = ParallelTaskController(port=port, logger=logger)
self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger()
self._controller = ParallelTaskController(
port=port if port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT,
logger=logger,
)
self._sampling_parallelism = 1 if sampling_parallelism is None else sampling_parallelism
self._min_env_samples = min_env_samples if min_env_samples is not None else self._sampling_parallelism
@ -143,11 +147,15 @@ class BatchEnvSampler:
self._ep = 0
self._end_of_episode = True
def sample(self, policy_state: Optional[Dict[str, object]] = None, num_steps: Optional[int] = None) -> dict:
def sample(
self,
policy_state: Optional[Dict[str, Dict[str, Any]]] = None,
num_steps: Optional[int] = None,
) -> dict:
"""Collect experiences from a set of remote roll-out workers.
Args:
policy_state (Dict[str, object]): Policy state dict. If it is not None, then we need to update all
policy_state (Dict[str, Any]): Policy state dict. If it is not None, then we need to update all
policies according to the latest policy states, then start the experience collection.
num_steps (Optional[int], default=None): Number of environment steps to collect experiences for. If
it is None, interactions with the (remote) environments will continue until the terminal state is
@ -181,7 +189,7 @@ class BatchEnvSampler:
"info": [res["info"][0] for res in results],
}
def eval(self, policy_state: Dict[str, object] = None) -> dict:
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
req = {"type": "eval", "policy_state": policy_state, "index": self._ep} # -1 signals test
results = self._controller.collect(req, self._eval_parallelism)
return {
@ -209,3 +217,11 @@ class BatchEnvSampler:
def exit(self) -> None:
self._controller.exit()
def post_collect(self, info_list: list, ep: int) -> None:
req = {"type": "post_collect", "info_list": info_list, "index": ep}
self._controller.collect(req, 1)
def post_evaluate(self, info_list: list, ep: int) -> None:
req = {"type": "post_evaluate", "info_list": info_list, "index": ep}
self._controller.collect(req, 1)

Просмотреть файл

@ -5,7 +5,6 @@ from __future__ import annotations
import collections
import os
import typing
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
@ -18,9 +17,6 @@ from maro.rl.policy import AbsPolicy, RLPolicy
from maro.rl.utils.objects import FILE_SUFFIX
from maro.simulator import Env
if typing.TYPE_CHECKING:
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
class AbsAgentWrapper(object, metaclass=ABCMeta):
"""Agent wrapper. Used to manager agents & policies during experience collection.
@ -51,16 +47,16 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
def choose_actions(
self,
state_by_agent: Dict[Any, Union[np.ndarray, List[object]]],
) -> Dict[Any, Union[np.ndarray, List[object]]]:
state_by_agent: Dict[Any, Union[np.ndarray, list]],
) -> Dict[Any, Union[np.ndarray, list]]:
"""Choose action according to the given (observable) states of all agents.
Args:
state_by_agent (Dict[Any, Union[np.ndarray, List[object]]]): Dictionary containing each agent's states.
state_by_agent (Dict[Any, Union[np.ndarray, list]]): Dictionary containing each agent's states.
If the policy is a `RLPolicy`, its state is a Numpy array. Otherwise, its state is a list of objects.
Returns:
actions (Dict[Any, Union[np.ndarray, List[object]]]): Dict that contains the action for all agents.
actions (Dict[Any, Union[np.ndarray, list]]): Dict that contains the action for all agents.
If the policy is a `RLPolicy`, its action is a Numpy array. Otherwise, its action is a list of objects.
"""
self.switch_to_eval_mode()
@ -71,8 +67,8 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
@abstractmethod
def _choose_actions_impl(
self,
state_by_agent: Dict[Any, Union[np.ndarray, List[object]]],
) -> Dict[Any, Union[np.ndarray, List[object]]]:
state_by_agent: Dict[Any, Union[np.ndarray, list]],
) -> Dict[Any, Union[np.ndarray, list]]:
"""Implementation of `choose_actions`."""
raise NotImplementedError
@ -95,15 +91,15 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
class SimpleAgentWrapper(AbsAgentWrapper):
def __init__(
self,
policy_dict: Dict[str, RLPolicy], # {policy_name: RLPolicy}
policy_dict: Dict[str, AbsPolicy], # {policy_name: AbsPolicy}
agent2policy: Dict[Any, str], # {agent_name: policy_name}
) -> None:
super(SimpleAgentWrapper, self).__init__(policy_dict=policy_dict, agent2policy=agent2policy)
def _choose_actions_impl(
self,
state_by_agent: Dict[Any, Union[np.ndarray, List[object]]],
) -> Dict[Any, Union[np.ndarray, List[object]]]:
state_by_agent: Dict[Any, Union[np.ndarray, list]],
) -> Dict[Any, Union[np.ndarray, list]]:
# Aggregate states by policy
states_by_policy = collections.defaultdict(list) # {str: list of np.ndarray}
agents_by_policy = collections.defaultdict(list) # {str: list of str}
@ -112,15 +108,15 @@ class SimpleAgentWrapper(AbsAgentWrapper):
states_by_policy[policy_name].append(state)
agents_by_policy[policy_name].append(agent_name)
action_dict = {}
action_dict: dict = {}
for policy_name in agents_by_policy:
policy = self._policy_dict[policy_name]
if isinstance(policy, RLPolicy):
states = np.vstack(states_by_policy[policy_name]) # np.ndarray
else:
states = states_by_policy[policy_name] # List[object]
actions = policy.get_actions(states) # np.ndarray or List[object]
states = states_by_policy[policy_name] # list
actions: Union[np.ndarray, list] = policy.get_actions(states) # np.ndarray or list
action_dict.update(zip(agents_by_policy[policy_name], actions))
return action_dict
@ -188,7 +184,7 @@ class ExpElement:
Contents (Dict[str, ExpElement]): A dict that contains the ExpElements of all trainers. The key of this
dict is the trainer name.
"""
ret = collections.defaultdict(
ret: Dict[str, ExpElement] = collections.defaultdict(
lambda: ExpElement(
tick=self.tick,
state=self.state,
@ -213,7 +209,7 @@ class ExpElement:
@dataclass
class CacheElement(ExpElement):
event: object
event: Any
env_action_dict: Dict[Any, np.ndarray]
def make_exp_element(self) -> ExpElement:
@ -238,6 +234,9 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
Args:
learn_env (Env): Environment used for training.
test_env (Env): Environment used for testing.
policies (List[AbsPolicy]): List of policies.
agent2policy (Dict[Any, str]): Agent name to policy name mapping of the RL job.
trainable_policies (List[str]): Name of trainable policies.
agent_wrapper_cls (Type[AbsAgentWrapper], default=SimpleAgentWrapper): Specific AgentWrapper type.
reward_eval_delay (int, default=None): Number of ticks required after a decision event to evaluate the reward
for the action taken for that event. If it is None, calculate reward immediately after `step()`.
@ -247,6 +246,9 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
self,
learn_env: Env,
test_env: Env,
policies: List[AbsPolicy],
agent2policy: Dict[Any, str],
trainable_policies: List[str] = None,
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
reward_eval_delay: int = None,
) -> None:
@ -255,7 +257,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
self._agent_wrapper_cls = agent_wrapper_cls
self._event = None
self._event: Optional[list] = None
self._end_of_episode = True
self._state: Optional[np.ndarray] = None
self._agent_state_dict: Dict[Any, np.ndarray] = {}
@ -264,31 +266,23 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._trans_cache
self._reward_eval_delay = reward_eval_delay
self._info = {}
self._info: dict = {}
assert self._reward_eval_delay is None or self._reward_eval_delay >= 0
def build(
self,
rl_component_bundle: RLComponentBundle,
) -> None:
"""
Args:
rl_component_bundle (RLComponentBundle): The RL component bundle of the job.
"""
#
self._env: Optional[Env] = None
self._policy_dict = {
policy_name: rl_component_bundle.policy_creator[policy_name]()
for policy_name in rl_component_bundle.policy_names
}
self._policy_dict: Dict[str, AbsPolicy] = {policy.name: policy for policy in policies}
self._rl_policy_dict: Dict[str, RLPolicy] = {
name: policy for name, policy in self._policy_dict.items() if isinstance(policy, RLPolicy)
policy.name: policy for policy in policies if isinstance(policy, RLPolicy)
}
self._agent2policy = rl_component_bundle.agent2policy
self._agent2policy = agent2policy
self._agent_wrapper = self._agent_wrapper_cls(self._policy_dict, self._agent2policy)
self._trainable_policies = set(rl_component_bundle.trainable_policy_names)
if trainable_policies is not None:
self._trainable_policies = trainable_policies
else:
self._trainable_policies = list(self._policy_dict.keys()) # Default: all policies are trainable
self._trainable_agents = {
agent_id for agent_id, policy_name in self._agent2policy.items() if policy_name in self._trainable_policies
}
@ -297,23 +291,31 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
[policy_name in self._rl_policy_dict for policy_name in self._trainable_policies],
), "All trainable policies must be RL policies!"
@property
def env(self) -> Env:
assert self._env is not None
return self._env
def _switch_env(self, env: Env) -> None:
self._env = env
def assign_policy_to_device(self, policy_name: str, device: torch.device) -> None:
self._rl_policy_dict[policy_name].to_device(device)
def _get_global_and_agent_state(
self,
event: object,
event: Any,
tick: int = None,
) -> Tuple[Optional[object], Dict[Any, Union[np.ndarray, List[object]]]]:
) -> Tuple[Optional[Any], Dict[Any, Union[np.ndarray, list]]]:
"""Get the global and individual agents' states.
Args:
event (object): Event.
event (Any): Event.
tick (int, default=None): Current tick.
Returns:
Global state (Optional[object])
Dict of agent states (Dict[Any, Union[np.ndarray, List[object]]]). If the policy is a `RLPolicy`,
Global state (Optional[Any])
Dict of agent states (Dict[Any, Union[np.ndarray, list]]). If the policy is a `RLPolicy`,
its state is a Numpy array. Otherwise, its state is a list of objects.
"""
global_state, agent_state_dict = self._get_global_and_agent_state_impl(event, tick)
@ -327,23 +329,23 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
@abstractmethod
def _get_global_and_agent_state_impl(
self,
event: object,
event: Any,
tick: int = None,
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
) -> Tuple[Union[None, np.ndarray, list], Dict[Any, Union[np.ndarray, list]]]:
raise NotImplementedError
@abstractmethod
def _translate_to_env_action(
self,
action_dict: Dict[Any, Union[np.ndarray, List[object]]],
event: object,
) -> Dict[Any, object]:
action_dict: Dict[Any, Union[np.ndarray, list]],
event: Any,
) -> dict:
"""Translate model-generated actions into an object that can be executed by the env.
Args:
action_dict (Dict[Any, Union[np.ndarray, List[object]]]): Action for all agents. If the policy is a
action_dict (Dict[Any, Union[np.ndarray, list]]): Action for all agents. If the policy is a
`RLPolicy`, its (input) action is a Numpy array. Otherwise, its (input) action is a list of objects.
event (object): Decision event.
event (Any): Decision event.
Returns:
A dict that contains env actions for all agents.
@ -351,12 +353,12 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def _get_reward(self, env_action_dict: Dict[Any, object], event: object, tick: int) -> Dict[Any, float]:
def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]:
"""Get rewards according to the env actions.
Args:
env_action_dict (Dict[Any, object]): Dict that contains env actions for all agents.
event (object): Decision event.
env_action_dict (dict): Dict that contains env actions for all agents.
event (Any): Decision event.
tick (int): Current tick.
Returns:
@ -365,7 +367,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
raise NotImplementedError
def _step(self, actions: Optional[list]) -> None:
_, self._event, self._end_of_episode = self._env.step(actions)
_, self._event, self._end_of_episode = self.env.step(actions)
self._state, self._agent_state_dict = (
(None, {}) if self._end_of_episode else self._get_global_and_agent_state(self._event)
)
@ -403,7 +405,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
self._agent_last_index[agent_name] = cur_index
def _reset(self) -> None:
self._env.reset()
self.env.reset()
self._info.clear()
self._trans_cache.clear()
self._agent_last_index.clear()
@ -412,7 +414,11 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
def _select_trainable_agents(self, original_dict: dict) -> dict:
return {k: v for k, v in original_dict.items() if k in self._trainable_agents}
def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Optional[int] = None) -> dict:
def sample(
self,
policy_state: Optional[Dict[str, Dict[str, Any]]] = None,
num_steps: Optional[int] = None,
) -> dict:
"""Sample experiences.
Args:
@ -425,7 +431,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
A dict that contains the collected experiences and additional information.
"""
# Init the env
self._env = self._learn_env
self._switch_env(self._learn_env)
if self._end_of_episode:
self._reset()
@ -443,7 +449,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
# Store experiences in the cache
cache_element = CacheElement(
tick=self._env.tick,
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
@ -466,7 +472,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
steps_to_go -= 1
self._append_cache_element(None)
tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
experiences: List[ExpElement] = []
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
cache_element = self._trans_cache.pop(0)
@ -508,8 +514,8 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
return loaded
def eval(self, policy_state: Dict[str, dict] = None) -> dict:
self._env = self._test_env
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
self._switch_env(self._test_env)
self._reset()
if policy_state is not None:
self.set_policy_state(policy_state)
@ -521,7 +527,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
# Store experiences in the cache
cache_element = CacheElement(
tick=self._env.tick,
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
@ -544,7 +550,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
self._append_cache_element(cache_element)
self._append_cache_element(None)
tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
cache_element = self._trans_cache.pop(0)
if self._reward_eval_delay is not None:

Просмотреть файл

@ -5,7 +5,7 @@ from __future__ import annotations
import typing
from maro.rl.distributed import AbsWorker
from maro.rl.distributed import DEFAULT_ROLLOUT_PRODUCER_PORT, AbsWorker
from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes
from maro.utils import LoggerV2
@ -19,7 +19,7 @@ class RolloutWorker(AbsWorker):
Args:
idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}",
so that the parallel roll-out controller can keep track of its connection status.
rl_component_bundle (RLComponentBundle): The RL component bundle of the job.
rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow.
producer_host (str): IP address of the parallel task controller host to connect to.
producer_port (int, default=20000): Port of the parallel task controller host to connect to.
logger (LoggerV2, default=None): The logger of the workflow.
@ -30,13 +30,13 @@ class RolloutWorker(AbsWorker):
idx: int,
rl_component_bundle: RLComponentBundle,
producer_host: str,
producer_port: int = 20000,
producer_port: int = None,
logger: LoggerV2 = None,
) -> None:
super(RolloutWorker, self).__init__(
idx=idx,
producer_host=producer_host,
producer_port=producer_port,
producer_port=producer_port if producer_port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT,
logger=logger,
)
self._env_sampler = rl_component_bundle.env_sampler
@ -53,13 +53,20 @@ class RolloutWorker(AbsWorker):
else:
req = bytes_to_pyobj(msg[-1])
assert isinstance(req, dict)
assert req["type"] in {"sample", "eval", "set_policy_state"}
if req["type"] == "sample":
result = self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"])
elif req["type"] == "eval":
result = self._env_sampler.eval(policy_state=req["policy_state"])
else:
self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"])
result = True
assert req["type"] in {"sample", "eval", "set_policy_state", "post_collect", "post_evaluate"}
self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]}))
if req["type"] in ("sample", "eval"):
result = (
self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"])
if req["type"] == "sample"
else self._env_sampler.eval(policy_state=req["policy_state"])
)
self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]}))
else:
if req["type"] == "set_policy_state":
self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"])
elif req["type"] == "post_collect":
self._env_sampler.post_collect(info_list=req["info_list"], ep=req["index"])
else:
self._env_sampler.post_evaluate(info_list=req["info_list"], ep=req["index"])
self._stream.send(pyobj_to_bytes({"result": True, "index": req["index"]}))

Просмотреть файл

@ -4,7 +4,7 @@
from .proxy import TrainingProxy
from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory
from .train_ops import AbsTrainOps, RemoteOps, remote
from .trainer import AbsTrainer, MultiAgentTrainer, SingleAgentTrainer, TrainerParams
from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer
from .training_manager import TrainingManager
from .worker import TrainOpsWorker
@ -18,9 +18,9 @@ __all__ = [
"RemoteOps",
"remote",
"AbsTrainer",
"BaseTrainerParams",
"MultiAgentTrainer",
"SingleAgentTrainer",
"TrainerParams",
"TrainingManager",
"TrainOpsWorker",
]

Просмотреть файл

@ -2,7 +2,6 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Dict
from maro.rl.training.algorithms.base import ACBasedParams, ACBasedTrainer
@ -13,18 +12,8 @@ class ActorCriticParams(ACBasedParams):
for detailed information.
"""
def extract_ops_params(self) -> Dict[str, object]:
return {
"get_v_critic_net_func": self.get_v_critic_net_func,
"reward_discount": self.reward_discount,
"critic_loss_cls": self.critic_loss_cls,
"lam": self.lam,
"min_logp": self.min_logp,
"is_discrete_action": self.is_discrete_action,
}
def __post_init__(self) -> None:
assert self.get_v_critic_net_func is not None
assert self.clip_ratio is None
class ActorCriticTrainer(ACBasedTrainer):
@ -34,5 +23,20 @@ class ActorCriticTrainer(ACBasedTrainer):
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/vpg
"""
def __init__(self, name: str, params: ActorCriticParams) -> None:
super(ActorCriticTrainer, self).__init__(name, params)
def __init__(
self,
name: str,
params: ActorCriticParams,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(ActorCriticTrainer, self).__init__(
name,
params,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)

Просмотреть файл

@ -3,19 +3,19 @@
from abc import ABCMeta
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple, cast
import numpy as np
import torch
from maro.rl.model import VNet
from maro.rl.policy import ContinuousRLPolicy, DiscretePolicyGradient, RLPolicy
from maro.rl.training import AbsTrainOps, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote
from maro.rl.training import AbsTrainOps, BaseTrainerParams, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, remote
from maro.rl.utils import TransitionBatch, discount_cumsum, get_torch_device, ndarray_to_tensor
@dataclass
class ACBasedParams(TrainerParams, metaclass=ABCMeta):
class ACBasedParams(BaseTrainerParams, metaclass=ABCMeta):
"""
Parameter bundle for Actor-Critic based algorithms (Actor-Critic & PPO)
@ -23,18 +23,16 @@ class ACBasedParams(TrainerParams, metaclass=ABCMeta):
grad_iters (int, default=1): Number of iterations to calculate gradients.
critic_loss_cls (Callable, default=None): Critic loss function. If it is None, use MSE.
lam (float, default=0.9): Lambda value for generalized advantage estimation (TD-Lambda).
min_logp (float, default=None): Lower bound for clamping logP values during learning.
min_logp (float, default=float("-inf")): Lower bound for clamping logP values during learning.
This is to prevent logP from becoming very large in magnitude and causing stability issues.
If it is None, it means no lower bound.
is_discrete_action (bool, default=True): Indicator of continuous or discrete action policy.
"""
get_v_critic_net_func: Callable[[], VNet] = None
get_v_critic_net_func: Callable[[], VNet]
grad_iters: int = 1
critic_loss_cls: Callable = None
critic_loss_cls: Optional[Callable] = None
lam: float = 0.9
min_logp: Optional[float] = None
is_discrete_action: bool = True
min_logp: float = float("-inf")
clip_ratio: Optional[float] = None
class ACBasedOps(AbsTrainOps):
@ -43,33 +41,26 @@ class ACBasedOps(AbsTrainOps):
def __init__(
self,
name: str,
policy_creator: Callable[[], RLPolicy],
get_v_critic_net_func: Callable[[], VNet],
parallelism: int = 1,
policy: RLPolicy,
params: ACBasedParams,
reward_discount: float = 0.9,
critic_loss_cls: Callable = None,
clip_ratio: float = None,
lam: float = 0.9,
min_logp: float = None,
is_discrete_action: bool = True,
parallelism: int = 1,
) -> None:
super(ACBasedOps, self).__init__(
name=name,
policy_creator=policy_creator,
policy=policy,
parallelism=parallelism,
)
assert isinstance(self._policy, DiscretePolicyGradient) or isinstance(self._policy, ContinuousRLPolicy)
assert isinstance(self._policy, (ContinuousRLPolicy, DiscretePolicyGradient))
self._reward_discount = reward_discount
self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss()
self._clip_ratio = clip_ratio
self._lam = lam
self._min_logp = min_logp
self._v_critic_net = get_v_critic_net_func()
self._is_discrete_action = is_discrete_action
self._device = None
self._critic_loss_func = params.critic_loss_cls() if params.critic_loss_cls is not None else torch.nn.MSELoss()
self._clip_ratio = params.clip_ratio
self._lam = params.lam
self._min_logp = params.min_logp
self._v_critic_net = params.get_v_critic_net_func()
self._is_discrete_action = isinstance(self._policy, DiscretePolicyGradient)
def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
"""Compute the critic loss of the batch.
@ -249,14 +240,32 @@ class ACBasedTrainer(SingleAgentTrainer):
https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f
"""
def __init__(self, name: str, params: ACBasedParams) -> None:
super(ACBasedTrainer, self).__init__(name, params)
def __init__(
self,
name: str,
params: ACBasedParams,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(ACBasedTrainer, self).__init__(
name,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)
self._params = params
def _register_policy(self, policy: RLPolicy) -> None:
assert isinstance(policy, (ContinuousRLPolicy, DiscretePolicyGradient))
self._policy = policy
def build(self) -> None:
self._ops = self.get_ops()
self._ops = cast(ACBasedOps, self.get_ops())
self._replay_memory = FIFOReplayMemory(
capacity=self._params.replay_memory_capacity,
capacity=self._replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim,
)
@ -266,10 +275,11 @@ class ACBasedTrainer(SingleAgentTrainer):
def get_local_ops(self) -> AbsTrainOps:
return ACBasedOps(
name=self._policy_name,
policy_creator=self._policy_creator,
parallelism=self._params.data_parallelism,
**self._params.extract_ops_params(),
name=self._policy.name,
policy=self._policy,
parallelism=self._data_parallelism,
reward_discount=self._reward_discount,
params=self._params,
)
def _get_batch(self) -> TransitionBatch:

Просмотреть файл

@ -2,19 +2,19 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Callable, Dict
from typing import Callable, Dict, Optional, cast
import torch
from maro.rl.model import QNet
from maro.rl.policy import ContinuousRLPolicy, RLPolicy
from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote
from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote
from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor
from maro.utils import clone
@dataclass
class DDPGParams(TrainerParams):
class DDPGParams(BaseTrainerParams):
"""
get_q_critic_net_func (Callable[[], QNet]): Function to get Q critic net.
num_epochs (int, default=1): Number of training epochs per call to ``learn``.
@ -30,25 +30,14 @@ class DDPGParams(TrainerParams):
min_num_to_trigger_training (int, default=0): Minimum number required to start training.
"""
get_q_critic_net_func: Callable[[], QNet] = None
get_q_critic_net_func: Callable[[], QNet]
num_epochs: int = 1
update_target_every: int = 5
q_value_loss_cls: Callable = None
q_value_loss_cls: Optional[Callable] = None
soft_update_coef: float = 1.0
random_overwrite: bool = False
min_num_to_trigger_training: int = 0
def __post_init__(self) -> None:
assert self.get_q_critic_net_func is not None
def extract_ops_params(self) -> Dict[str, object]:
return {
"get_q_critic_net_func": self.get_q_critic_net_func,
"reward_discount": self.reward_discount,
"q_value_loss_cls": self.q_value_loss_cls,
"soft_update_coef": self.soft_update_coef,
}
class DDPGOps(AbsTrainOps):
"""DDPG algorithm implementation. Reference: https://spinningup.openai.com/en/latest/algorithms/ddpg.html"""
@ -56,31 +45,31 @@ class DDPGOps(AbsTrainOps):
def __init__(
self,
name: str,
policy_creator: Callable[[], RLPolicy],
get_q_critic_net_func: Callable[[], QNet],
reward_discount: float,
policy: RLPolicy,
params: DDPGParams,
reward_discount: float = 0.9,
parallelism: int = 1,
q_value_loss_cls: Callable = None,
soft_update_coef: float = 1.0,
) -> None:
super(DDPGOps, self).__init__(
name=name,
policy_creator=policy_creator,
policy=policy,
parallelism=parallelism,
)
assert isinstance(self._policy, ContinuousRLPolicy)
self._target_policy = clone(self._policy)
self._target_policy: ContinuousRLPolicy = clone(self._policy)
self._target_policy.set_name(f"target_{self._policy.name}")
self._target_policy.eval()
self._q_critic_net = get_q_critic_net_func()
self._q_critic_net = params.get_q_critic_net_func()
self._target_q_critic_net: QNet = clone(self._q_critic_net)
self._target_q_critic_net.eval()
self._reward_discount = reward_discount
self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss()
self._soft_update_coef = soft_update_coef
self._q_value_loss_func = (
params.q_value_loss_cls() if params.q_value_loss_cls is not None else torch.nn.MSELoss()
)
self._soft_update_coef = params.soft_update_coef
def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
"""Compute the critic loss of the batch.
@ -207,7 +196,7 @@ class DDPGOps(AbsTrainOps):
self._target_policy.soft_update(self._policy, self._soft_update_coef)
self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef)
def to_device(self, device: str) -> None:
def to_device(self, device: str = None) -> None:
self._device = get_torch_device(device=device)
self._policy.to_device(self._device)
self._target_policy.to_device(self._device)
@ -223,30 +212,49 @@ class DDPGTrainer(SingleAgentTrainer):
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg
"""
def __init__(self, name: str, params: DDPGParams) -> None:
super(DDPGTrainer, self).__init__(name, params)
def __init__(
self,
name: str,
params: DDPGParams,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(DDPGTrainer, self).__init__(
name,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)
self._params = params
self._policy_version = self._target_policy_version = 0
self._memory_size = 0
def build(self) -> None:
self._ops = self.get_ops()
self._ops = cast(DDPGOps, self.get_ops())
self._replay_memory = RandomReplayMemory(
capacity=self._params.replay_memory_capacity,
capacity=self._replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim,
random_overwrite=self._params.random_overwrite,
)
def _register_policy(self, policy: RLPolicy) -> None:
assert isinstance(policy, ContinuousRLPolicy)
self._policy = policy
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
return transition_batch
def get_local_ops(self) -> AbsTrainOps:
return DDPGOps(
name=self._policy_name,
policy_creator=self._policy_creator,
parallelism=self._params.data_parallelism,
**self._params.extract_ops_params(),
name=self._policy.name,
policy=self._policy,
parallelism=self._data_parallelism,
reward_discount=self._reward_discount,
params=self._params,
)
def _get_batch(self, batch_size: int = None) -> TransitionBatch:

Просмотреть файл

@ -2,18 +2,18 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Callable, Dict
from typing import Dict, cast
import torch
from maro.rl.policy import RLPolicy, ValueBasedPolicy
from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote
from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote
from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor
from maro.utils import clone
@dataclass
class DQNParams(TrainerParams):
class DQNParams(BaseTrainerParams):
"""
num_epochs (int, default=1): Number of training epochs.
update_target_every (int, default=5): Number of gradient steps between target model updates.
@ -33,42 +33,34 @@ class DQNParams(TrainerParams):
double: bool = False
random_overwrite: bool = False
def extract_ops_params(self) -> Dict[str, object]:
return {
"reward_discount": self.reward_discount,
"soft_update_coef": self.soft_update_coef,
"double": self.double,
}
class DQNOps(AbsTrainOps):
def __init__(
self,
name: str,
policy_creator: Callable[[], RLPolicy],
parallelism: int = 1,
policy: RLPolicy,
params: DQNParams,
reward_discount: float = 0.9,
soft_update_coef: float = 0.1,
double: bool = False,
parallelism: int = 1,
) -> None:
super(DQNOps, self).__init__(
name=name,
policy_creator=policy_creator,
policy=policy,
parallelism=parallelism,
)
assert isinstance(self._policy, ValueBasedPolicy)
self._reward_discount = reward_discount
self._soft_update_coef = soft_update_coef
self._double = double
self._soft_update_coef = params.soft_update_coef
self._double = params.double
self._loss_func = torch.nn.MSELoss()
self._target_policy: ValueBasedPolicy = clone(self._policy)
self._target_policy.set_name(f"target_{self._policy.name}")
self._target_policy.eval()
def _get_batch_loss(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.Tensor]]:
def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor:
"""Compute the loss of the batch.
Args:
@ -78,6 +70,8 @@ class DQNOps(AbsTrainOps):
loss (torch.Tensor): The loss of the batch.
"""
assert isinstance(batch, TransitionBatch)
assert isinstance(self._policy, ValueBasedPolicy)
self._policy.train()
states = ndarray_to_tensor(batch.states, device=self._device)
next_states = ndarray_to_tensor(batch.next_states, device=self._device)
@ -100,7 +94,7 @@ class DQNOps(AbsTrainOps):
return self._loss_func(q_values, target_q_values)
@remote
def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.Tensor]]:
def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
"""Compute the network's gradients of a batch.
Args:
@ -141,7 +135,7 @@ class DQNOps(AbsTrainOps):
"""Soft update the target policy."""
self._target_policy.soft_update(self._policy, self._soft_update_coef)
def to_device(self, device: str) -> None:
def to_device(self, device: str = None) -> None:
self._device = get_torch_device(device)
self._policy.to_device(self._device)
self._target_policy.to_device(self._device)
@ -153,29 +147,48 @@ class DQNTrainer(SingleAgentTrainer):
See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details.
"""
def __init__(self, name: str, params: DQNParams) -> None:
super(DQNTrainer, self).__init__(name, params)
def __init__(
self,
name: str,
params: DQNParams,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(DQNTrainer, self).__init__(
name,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)
self._params = params
self._q_net_version = self._target_q_net_version = 0
def build(self) -> None:
self._ops = self.get_ops()
self._ops = cast(DQNOps, self.get_ops())
self._replay_memory = RandomReplayMemory(
capacity=self._params.replay_memory_capacity,
capacity=self._replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim,
random_overwrite=self._params.random_overwrite,
)
def _register_policy(self, policy: RLPolicy) -> None:
assert isinstance(policy, ValueBasedPolicy)
self._policy = policy
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
return transition_batch
def get_local_ops(self) -> AbsTrainOps:
return DQNOps(
name=self._policy_name,
policy_creator=self._policy_creator,
parallelism=self._params.data_parallelism,
**self._params.extract_ops_params(),
name=self._policy.name,
policy=self._policy,
parallelism=self._data_parallelism,
reward_discount=self._reward_discount,
params=self._params,
)
def _get_batch(self, batch_size: int = None) -> TransitionBatch:

Просмотреть файл

@ -4,7 +4,7 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple
from typing import Callable, Dict, List, Optional, Tuple, cast
import numpy as np
import torch
@ -12,14 +12,21 @@ import torch
from maro.rl.model import MultiQNet
from maro.rl.policy import DiscretePolicyGradient, RLPolicy
from maro.rl.rollout import ExpElement
from maro.rl.training import AbsTrainOps, MultiAgentTrainer, RandomMultiReplayMemory, RemoteOps, TrainerParams, remote
from maro.rl.training import (
AbsTrainOps,
BaseTrainerParams,
MultiAgentTrainer,
RandomMultiReplayMemory,
RemoteOps,
remote,
)
from maro.rl.utils import MultiTransitionBatch, get_torch_device, ndarray_to_tensor
from maro.rl.utils.objects import FILE_SUFFIX
from maro.utils import clone
@dataclass
class DiscreteMADDPGParams(TrainerParams):
class DiscreteMADDPGParams(BaseTrainerParams):
"""
get_q_critic_net_func (Callable[[], MultiQNet]): Function to get multi Q critic net.
num_epochs (int, default=10): Number of training epochs.
@ -30,44 +37,28 @@ class DiscreteMADDPGParams(TrainerParams):
shared_critic (bool, default=False): Whether different policies use shared critic or individual policies.
"""
get_q_critic_net_func: Callable[[], MultiQNet] = None
get_q_critic_net_func: Callable[[], MultiQNet]
num_epoch: int = 10
update_target_every: int = 5
soft_update_coef: float = 0.5
q_value_loss_cls: Callable = None
q_value_loss_cls: Optional[Callable] = None
shared_critic: bool = False
def __post_init__(self) -> None:
assert self.get_q_critic_net_func is not None
def extract_ops_params(self) -> Dict[str, object]:
return {
"get_q_critic_net_func": self.get_q_critic_net_func,
"shared_critic": self.shared_critic,
"reward_discount": self.reward_discount,
"soft_update_coef": self.soft_update_coef,
"update_target_every": self.update_target_every,
"q_value_loss_func": self.q_value_loss_cls() if self.q_value_loss_cls is not None else torch.nn.MSELoss(),
}
class DiscreteMADDPGOps(AbsTrainOps):
def __init__(
self,
name: str,
policy_creator: Callable[[], RLPolicy],
get_q_critic_net_func: Callable[[], MultiQNet],
policy: RLPolicy,
param: DiscreteMADDPGParams,
shared_critic: bool,
policy_idx: int,
parallelism: int = 1,
shared_critic: bool = False,
reward_discount: float = 0.9,
soft_update_coef: float = 0.5,
update_target_every: int = 5,
q_value_loss_func: Callable = None,
) -> None:
super(DiscreteMADDPGOps, self).__init__(
name=name,
policy_creator=policy_creator,
policy=policy,
parallelism=parallelism,
)
@ -75,23 +66,21 @@ class DiscreteMADDPGOps(AbsTrainOps):
self._shared_critic = shared_critic
# Actor
if self._policy_creator:
if self._policy:
assert isinstance(self._policy, DiscretePolicyGradient)
self._target_policy: DiscretePolicyGradient = clone(self._policy)
self._target_policy.set_name(f"target_{self._policy.name}")
self._target_policy.eval()
# Critic
self._q_critic_net: MultiQNet = get_q_critic_net_func()
self._q_critic_net: MultiQNet = param.get_q_critic_net_func()
self._target_q_critic_net: MultiQNet = clone(self._q_critic_net)
self._target_q_critic_net.eval()
self._reward_discount = reward_discount
self._q_value_loss_func = q_value_loss_func
self._update_target_every = update_target_every
self._soft_update_coef = soft_update_coef
self._device = None
self._q_value_loss_func = param.q_value_loss_cls() if param.q_value_loss_cls is not None else torch.nn.MSELoss()
self._update_target_every = param.update_target_every
self._soft_update_coef = param.soft_update_coef
def get_target_action(self, batch: MultiTransitionBatch) -> torch.Tensor:
"""Get the target policies' actions according to the batch.
@ -248,7 +237,7 @@ class DiscreteMADDPGOps(AbsTrainOps):
def soft_update_target(self) -> None:
"""Soft update the target policies and target critics."""
if self._policy_creator:
if self._policy:
self._target_policy.soft_update(self._policy, self._soft_update_coef)
if not self._shared_critic:
self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef)
@ -264,13 +253,13 @@ class DiscreteMADDPGOps(AbsTrainOps):
self._target_q_critic_net.set_state(ops_state_dict["target_critic"])
def get_actor_state(self) -> dict:
if self._policy_creator:
if self._policy:
return {"policy": self._policy.get_state(), "target_policy": self._target_policy.get_state()}
else:
return {}
def set_actor_state(self, ops_state_dict: dict) -> None:
if self._policy_creator:
if self._policy:
self._policy.set_state(ops_state_dict["policy"])
self._target_policy.set_state(ops_state_dict["target_policy"])
@ -280,9 +269,9 @@ class DiscreteMADDPGOps(AbsTrainOps):
def set_non_policy_state(self, state: dict) -> None:
self.set_critic_state(state)
def to_device(self, device: str) -> None:
def to_device(self, device: str = None) -> None:
self._device = get_torch_device(device)
if self._policy_creator:
if self._policy:
self._policy.to_device(self._device)
self._target_policy.to_device(self._device)
@ -296,31 +285,51 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
See https://arxiv.org/abs/1706.02275 for details.
"""
def __init__(self, name: str, params: DiscreteMADDPGParams) -> None:
super(DiscreteMADDPGTrainer, self).__init__(name, params)
def __init__(
self,
name: str,
params: DiscreteMADDPGParams,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(DiscreteMADDPGTrainer, self).__init__(
name,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)
self._params = params
self._ops_params = self._params.extract_ops_params()
self._state_dim = params.get_q_critic_net_func().state_dim
self._policy_version = self._target_policy_version = 0
self._shared_critic_ops_name = f"{self._name}.shared_critic"
self._actor_ops_list = []
self._critic_ops = None
self._replay_memory = None
self._policy2agent = {}
self._actor_ops_list: List[DiscreteMADDPGOps] = []
self._critic_ops: Optional[DiscreteMADDPGOps] = None
self._policy2agent: Dict[str, str] = {}
self._ops_dict: Dict[str, DiscreteMADDPGOps] = {}
def build(self) -> None:
for policy_name in self._policy_creator:
self._ops_dict[policy_name] = self.get_ops(policy_name)
self._placeholder_policy = self._policy_dict[self._policy_names[0]]
for policy in self._policy_dict.values():
self._ops_dict[policy.name] = cast(DiscreteMADDPGOps, self.get_ops(policy.name))
self._actor_ops_list = list(self._ops_dict.values())
if self._params.shared_critic:
self._ops_dict[self._shared_critic_ops_name] = self.get_ops(self._shared_critic_ops_name)
assert self._critic_ops is not None
self._ops_dict[self._shared_critic_ops_name] = cast(
DiscreteMADDPGOps,
self.get_ops(self._shared_critic_ops_name),
)
self._critic_ops = self._ops_dict[self._shared_critic_ops_name]
self._replay_memory = RandomMultiReplayMemory(
capacity=self._params.replay_memory_capacity,
capacity=self._replay_memory_capacity,
state_dim=self._state_dim,
action_dims=[ops.policy_action_dim for ops in self._actor_ops_list],
agent_states_dims=[ops.policy_state_dim for ops in self._actor_ops_list],
@ -342,7 +351,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
rewards: List[np.ndarray] = []
agent_states: List[np.ndarray] = []
next_agent_states: List[np.ndarray] = []
for policy_name in self._policy_names:
for policy_name in self._policy_dict:
agent_name = self._policy2agent[policy_name]
actions.append(np.vstack([exp_element.action_dict[agent_name] for exp_element in exp_elements]))
rewards.append(np.array([exp_element.reward_dict[agent_name] for exp_element in exp_elements]))
@ -374,23 +383,25 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
def get_local_ops(self, name: str) -> AbsTrainOps:
if name == self._shared_critic_ops_name:
ops_params = dict(self._ops_params)
ops_params.update(
{
"policy_idx": -1,
"shared_critic": False,
},
return DiscreteMADDPGOps(
name=name,
policy=self._placeholder_policy,
param=self._params,
shared_critic=False,
policy_idx=-1,
parallelism=self._data_parallelism,
reward_discount=self._reward_discount,
)
return DiscreteMADDPGOps(name=name, **ops_params)
else:
ops_params = dict(self._ops_params)
ops_params.update(
{
"policy_creator": self._policy_creator[name],
"policy_idx": self._policy_names.index(name),
},
return DiscreteMADDPGOps(
name=name,
policy=self._policy_dict[name],
param=self._params,
shared_critic=self._params.shared_critic,
policy_idx=self._policy_names.index(name),
parallelism=self._data_parallelism,
reward_discount=self._reward_discount,
)
return DiscreteMADDPGOps(name=name, **ops_params)
def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch:
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)
@ -405,6 +416,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
# Update critic
if self._params.shared_critic:
assert self._critic_ops is not None
self._critic_ops.update_critic(batch, next_actions)
critic_state_dict = self._critic_ops.get_critic_state()
# Sync latest critic to ops
@ -431,6 +443,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
# Update critic
if self._params.shared_critic:
assert self._critic_ops is not None
critic_grad = await asyncio.gather(*[self._critic_ops.get_critic_grad(batch, next_actions)])
assert isinstance(critic_grad, list) and isinstance(critic_grad[0], dict)
self._critic_ops.update_critic_with_grad(critic_grad[0])
@ -460,10 +473,11 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
for ops in self._actor_ops_list:
ops.soft_update_target()
if self._params.shared_critic:
assert self._critic_ops is not None
self._critic_ops.soft_update_target()
self._target_policy_version = self._policy_version
def get_policy_state(self) -> Dict[str, object]:
def get_policy_state(self) -> Dict[str, dict]:
self._assert_ops_exists()
ret_policy_state = {}
for ops in self._actor_ops_list:
@ -484,6 +498,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
trainer_state = {ops.name: ops.get_state() for ops in self._actor_ops_list}
if self._params.shared_critic:
assert self._critic_ops is not None
trainer_state[self._critic_ops.name] = self._critic_ops.get_state()
policy_state_dict = {ops_name: state["policy"] for ops_name, state in trainer_state.items()}

Просмотреть файл

@ -2,16 +2,16 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Callable, Dict, Tuple
from typing import Tuple
import numpy as np
import torch
from torch.distributions import Categorical
from maro.rl.model import VNet
from maro.rl.policy import DiscretePolicyGradient, RLPolicy
from maro.rl.training.algorithms.base import ACBasedOps, ACBasedParams, ACBasedTrainer
from maro.rl.utils import TransitionBatch, discount_cumsum, ndarray_to_tensor
from maro.utils import clone
@dataclass
@ -23,21 +23,7 @@ class PPOParams(ACBasedParams):
If it is None, the actor loss is calculated using the usual policy gradient theorem.
"""
clip_ratio: float = None
def extract_ops_params(self) -> Dict[str, object]:
return {
"get_v_critic_net_func": self.get_v_critic_net_func,
"reward_discount": self.reward_discount,
"critic_loss_cls": self.critic_loss_cls,
"clip_ratio": self.clip_ratio,
"lam": self.lam,
"min_logp": self.min_logp,
"is_discrete_action": self.is_discrete_action,
}
def __post_init__(self) -> None:
assert self.get_v_critic_net_func is not None
assert self.clip_ratio is not None
@ -45,31 +31,20 @@ class DiscretePPOWithEntropyOps(ACBasedOps):
def __init__(
self,
name: str,
policy_creator: Callable[[], RLPolicy],
get_v_critic_net_func: Callable[[], VNet],
policy: RLPolicy,
params: ACBasedParams,
parallelism: int = 1,
reward_discount: float = 0.9,
critic_loss_cls: Callable = None,
clip_ratio: float = None,
lam: float = 0.9,
min_logp: float = None,
is_discrete_action: bool = True,
) -> None:
super(DiscretePPOWithEntropyOps, self).__init__(
name=name,
policy_creator=policy_creator,
get_v_critic_net_func=get_v_critic_net_func,
parallelism=parallelism,
reward_discount=reward_discount,
critic_loss_cls=critic_loss_cls,
clip_ratio=clip_ratio,
lam=lam,
min_logp=min_logp,
is_discrete_action=is_discrete_action,
name,
policy,
params,
reward_discount,
parallelism,
)
assert is_discrete_action
assert isinstance(self._policy, DiscretePolicyGradient)
self._policy_old = self._policy_creator()
assert self._is_discrete_action
self._policy_old: DiscretePolicyGradient = clone(policy)
self.update_policy_old()
def update_policy_old(self) -> None:
@ -172,8 +147,23 @@ class PPOTrainer(ACBasedTrainer):
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ppo.
"""
def __init__(self, name: str, params: PPOParams) -> None:
super(PPOTrainer, self).__init__(name, params)
def __init__(
self,
name: str,
params: PPOParams,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(PPOTrainer, self).__init__(
name,
params,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)
class DiscretePPOWithEntropyTrainer(ACBasedTrainer):
@ -182,10 +172,11 @@ class DiscretePPOWithEntropyTrainer(ACBasedTrainer):
def get_local_ops(self) -> DiscretePPOWithEntropyOps:
return DiscretePPOWithEntropyOps(
name=self._policy_name,
policy_creator=self._policy_creator,
parallelism=self._params.data_parallelism,
**self._params.extract_ops_params(),
name=self._policy.name,
policy=self._policy,
parallelism=self._data_parallelism,
reward_discount=self._reward_discount,
params=self._params,
)
def train_step(self) -> None:

Просмотреть файл

@ -2,73 +2,59 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple, cast
import torch
from maro.rl.model import QNet
from maro.rl.policy import ContinuousRLPolicy, RLPolicy
from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote
from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote
from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor
from maro.utils import clone
@dataclass
class SoftActorCriticParams(TrainerParams):
get_q_critic_net_func: Callable[[], QNet] = None
class SoftActorCriticParams(BaseTrainerParams):
get_q_critic_net_func: Callable[[], QNet]
update_target_every: int = 5
random_overwrite: bool = False
entropy_coef: float = 0.1
num_epochs: int = 1
n_start_train: int = 0
q_value_loss_cls: Callable = None
q_value_loss_cls: Optional[Callable] = None
soft_update_coef: float = 1.0
def __post_init__(self) -> None:
assert self.get_q_critic_net_func is not None
def extract_ops_params(self) -> Dict[str, object]:
return {
"get_q_critic_net_func": self.get_q_critic_net_func,
"entropy_coef": self.entropy_coef,
"reward_discount": self.reward_discount,
"q_value_loss_cls": self.q_value_loss_cls,
"soft_update_coef": self.soft_update_coef,
}
class SoftActorCriticOps(AbsTrainOps):
def __init__(
self,
name: str,
policy_creator: Callable[[], RLPolicy],
get_q_critic_net_func: Callable[[], QNet],
policy: RLPolicy,
params: SoftActorCriticParams,
reward_discount: float = 0.9,
parallelism: int = 1,
*,
entropy_coef: float,
reward_discount: float,
q_value_loss_cls: Callable = None,
soft_update_coef: float = 1.0,
) -> None:
super(SoftActorCriticOps, self).__init__(
name=name,
policy_creator=policy_creator,
policy=policy,
parallelism=parallelism,
)
assert isinstance(self._policy, ContinuousRLPolicy)
self._q_net1 = get_q_critic_net_func()
self._q_net2 = get_q_critic_net_func()
self._q_net1 = params.get_q_critic_net_func()
self._q_net2 = params.get_q_critic_net_func()
self._target_q_net1: QNet = clone(self._q_net1)
self._target_q_net1.eval()
self._target_q_net2: QNet = clone(self._q_net2)
self._target_q_net2.eval()
self._entropy_coef = entropy_coef
self._soft_update_coef = soft_update_coef
self._entropy_coef = params.entropy_coef
self._soft_update_coef = params.soft_update_coef
self._reward_discount = reward_discount
self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss()
self._q_value_loss_func = (
params.q_value_loss_cls() if params.q_value_loss_cls is not None else torch.nn.MSELoss()
)
def _get_critic_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, torch.Tensor]:
self._q_net1.train()
@ -100,11 +86,11 @@ class SoftActorCriticOps(AbsTrainOps):
grad_q2 = self._q_net2.get_gradients(loss_q2)
return grad_q1, grad_q2
def update_critic_with_grad(self, grad_dict1: dict, grad_dict2: dict) -> None:
def update_critic_with_grad(self, grad_dicts: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]) -> None:
self._q_net1.train()
self._q_net2.train()
self._q_net1.apply_gradients(grad_dict1)
self._q_net2.apply_gradients(grad_dict2)
self._q_net1.apply_gradients(grad_dicts[0])
self._q_net2.apply_gradients(grad_dicts[1])
def update_critic(self, batch: TransitionBatch) -> None:
self._q_net1.train()
@ -154,7 +140,7 @@ class SoftActorCriticOps(AbsTrainOps):
self._target_q_net1.soft_update(self._q_net1, self._soft_update_coef)
self._target_q_net2.soft_update(self._q_net2, self._soft_update_coef)
def to_device(self, device: str) -> None:
def to_device(self, device: str = None) -> None:
self._device = get_torch_device(device=device)
self._q_net1.to(self._device)
self._q_net2.to(self._device)
@ -163,22 +149,38 @@ class SoftActorCriticOps(AbsTrainOps):
class SoftActorCriticTrainer(SingleAgentTrainer):
def __init__(self, name: str, params: SoftActorCriticParams) -> None:
super(SoftActorCriticTrainer, self).__init__(name, params)
def __init__(
self,
name: str,
params: SoftActorCriticParams,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(SoftActorCriticTrainer, self).__init__(
name,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)
self._params = params
self._qnet_version = self._target_qnet_version = 0
self._replay_memory: Optional[RandomReplayMemory] = None
def build(self) -> None:
self._ops = self.get_ops()
self._ops = cast(SoftActorCriticOps, self.get_ops())
self._replay_memory = RandomReplayMemory(
capacity=self._params.replay_memory_capacity,
capacity=self._replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim,
random_overwrite=self._params.random_overwrite,
)
def _register_policy(self, policy: RLPolicy) -> None:
assert isinstance(policy, ContinuousRLPolicy)
self._policy = policy
def train_step(self) -> None:
assert isinstance(self._ops, SoftActorCriticOps)
@ -218,10 +220,11 @@ class SoftActorCriticTrainer(SingleAgentTrainer):
def get_local_ops(self) -> SoftActorCriticOps:
return SoftActorCriticOps(
name=self._policy_name,
policy_creator=self._policy_creator,
parallelism=self._params.data_parallelism,
**self._params.extract_ops_params(),
name=self._policy.name,
policy=self._policy,
parallelism=self._data_parallelism,
reward_discount=self._reward_discount,
params=self._params,
)
def _get_batch(self, batch_size: int = None) -> TransitionBatch:

Просмотреть файл

@ -2,8 +2,9 @@
# Licensed under the MIT license.
from collections import defaultdict, deque
from typing import Deque
from maro.rl.distributed import AbsProxy
from maro.rl.distributed import DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT, AbsProxy
from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes
from maro.rl.utils.torch_utils import average_grads
from maro.utils import LoggerV2
@ -20,13 +21,16 @@ class TrainingProxy(AbsProxy):
backend_port (int, default=10001): Network port for communicating with back-end workers (task consumers).
"""
def __init__(self, frontend_port: int = 10000, backend_port: int = 10001) -> None:
super(TrainingProxy, self).__init__(frontend_port=frontend_port, backend_port=backend_port)
self._available_workers = deque()
self._worker_ready = False
self._connected_ops = set()
self._result_cache = defaultdict(list)
self._expected_num_results = {}
def __init__(self, frontend_port: int = None, backend_port: int = None) -> None:
super(TrainingProxy, self).__init__(
frontend_port=frontend_port if frontend_port is not None else DEFAULT_TRAINING_FRONTEND_PORT,
backend_port=backend_port if backend_port is not None else DEFAULT_TRAINING_BACKEND_PORT,
)
self._available_workers: Deque = deque()
self._worker_ready: bool = False
self._connected_ops: set = set()
self._result_cache: dict = defaultdict(list)
self._expected_num_results: dict = {}
self._logger = LoggerV2("TRAIN-PROXY")
def _route_request_to_compute_node(self, msg: list) -> None:
@ -48,10 +52,12 @@ class TrainingProxy(AbsProxy):
self._connected_ops.add(msg[0])
req = bytes_to_pyobj(msg[-1])
assert isinstance(req, dict)
desired_parallelism = req["desired_parallelism"]
req["args"] = list(req["args"])
batch = req["args"][0]
workers = []
workers: list = []
while len(workers) < desired_parallelism and self._available_workers:
workers.append(self._available_workers.popleft())

Просмотреть файл

@ -3,8 +3,9 @@
import inspect
from abc import ABCMeta, abstractmethod
from typing import Callable, Tuple
from typing import Any, Callable, Optional, Tuple, Union
import torch
import zmq
from zmq.asyncio import Context, Poller
@ -19,24 +20,21 @@ class AbsTrainOps(object, metaclass=ABCMeta):
Args:
name (str): Name of the ops. This is usually a policy name.
policy_creator (Callable[[], RLPolicy]): Function to create a policy instance.
policy (RLPolicy): Policy instance.
parallelism (int, default=1): Desired degree of data parallelism.
"""
def __init__(
self,
name: str,
policy_creator: Callable[[], RLPolicy],
policy: RLPolicy,
parallelism: int = 1,
) -> None:
super(AbsTrainOps, self).__init__()
self._name = name
self._policy_creator = policy_creator
# Create the policy.
if self._policy_creator:
self._policy = self._policy_creator()
self._policy = policy
self._parallelism = parallelism
self._device: Optional[torch.device] = None
@property
def name(self) -> str:
@ -44,11 +42,11 @@ class AbsTrainOps(object, metaclass=ABCMeta):
@property
def policy_state_dim(self) -> int:
return self._policy.state_dim if self._policy_creator else None
return self._policy.state_dim
@property
def policy_action_dim(self) -> int:
return self._policy.action_dim if self._policy_creator else None
return self._policy.action_dim
@property
def parallelism(self) -> int:
@ -75,20 +73,20 @@ class AbsTrainOps(object, metaclass=ABCMeta):
self.set_policy_state(ops_state_dict["policy"][1])
self.set_non_policy_state(ops_state_dict["non_policy"])
def get_policy_state(self) -> Tuple[str, object]:
def get_policy_state(self) -> Tuple[str, dict]:
"""Get the policy's state.
Returns:
policy_name (str)
policy_state (object)
policy_state (Any)
"""
return self._policy.name, self._policy.get_state()
def set_policy_state(self, policy_state: object) -> None:
def set_policy_state(self, policy_state: dict) -> None:
"""Update the policy's state.
Args:
policy_state (object): The policy state.
policy_state (dict): The policy state.
"""
self._policy.set_state(policy_state)
@ -111,17 +109,17 @@ class AbsTrainOps(object, metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def to_device(self, device: str):
def to_device(self, device: str = None) -> None:
raise NotImplementedError
def remote(func) -> Callable:
def remote(func: Callable) -> Callable:
"""Annotation to indicate that a function / method can be called remotely.
This annotation takes effect only when an ``AbsTrainOps`` object is wrapped by a ``RemoteOps``.
"""
def remote_annotate(*args, **kwargs) -> object:
def remote_annotate(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
return remote_annotate
@ -137,7 +135,7 @@ class AsyncClient(object):
"""
def __init__(self, name: str, address: Tuple[str, int], logger: LoggerV2 = None) -> None:
self._logger = DummyLogger() if logger is None else logger
self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger()
self._name = name
host, port = address
self._proxy_ip = get_ip_address_by_hostname(host)
@ -155,7 +153,7 @@ class AsyncClient(object):
await self._socket.send(pyobj_to_bytes(req))
self._logger.debug(f"{self._name} sent request {req['func']}")
async def get_response(self) -> object:
async def get_response(self) -> Any:
"""Waits for a result in asynchronous fashion.
This is a coroutine and is executed asynchronously with calls to other AsyncClients' ``get_response`` calls.
@ -209,15 +207,15 @@ class RemoteOps(object):
self._client = AsyncClient(self._ops.name, address, logger=logger)
self._client.connect()
def __getattribute__(self, attr_name: str) -> object:
def __getattribute__(self, attr_name: str) -> Any:
# Ignore methods that belong to the parent class
try:
return super().__getattribute__(attr_name)
except AttributeError:
pass
def remote_method(ops_state, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable:
async def remote_call(*args, **kwargs) -> object:
def remote_method(ops_state: Any, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable:
async def remote_call(*args: Any, **kwargs: Any) -> Any:
req = {
"state": ops_state,
"func": func_name,

Просмотреть файл

@ -5,7 +5,7 @@ import collections
import os
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -21,37 +21,8 @@ from .train_ops import AbsTrainOps, RemoteOps
@dataclass
class TrainerParams:
"""Common trainer parameters.
replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory.
batch_size (int, default=128): Training batch size.
data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when
a model is large and computing gradients with respect to a batch becomes expensive. In this case, the
batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set
of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets
updated only after collecting all the gradients from the remote nodes. Note that this value is the desired
parallelism and the actual parallelism in a distributed experiment may be smaller depending on the
availability of compute resources. For details on distributed deep learning and data parallelism, see
https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an abundance
of resources available on the internet.
reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology.
"""
replay_memory_capacity: int = 10000
batch_size: int = 128
data_parallelism: int = 1
reward_discount: float = 0.9
@abstractmethod
def extract_ops_params(self) -> Dict[str, object]:
"""Extract parameters that should be passed to the train ops.
Returns:
params (Dict[str, object]): Parameter dict.
"""
raise NotImplementedError
class BaseTrainerParams:
pass
class AbsTrainer(object, metaclass=ABCMeta):
@ -64,16 +35,36 @@ class AbsTrainer(object, metaclass=ABCMeta):
Args:
name (str): Name of the trainer.
params (TrainerParams): Trainer's parameters.
replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory.
batch_size (int, default=128): Training batch size.
data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when
a model is large and computing gradients with respect to a batch becomes expensive. In this case, the
batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set
of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets
updated only after collecting all the gradients from the remote nodes. Note that this value is the desired
parallelism and the actual parallelism in a distributed experiment may be smaller depending on the
availability of compute resources. For details on distributed deep learning and data parallelism, see
https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an
abundance of resources available on the internet.
reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology.
"""
def __init__(self, name: str, params: TrainerParams) -> None:
def __init__(
self,
name: str,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
self._name = name
self._params = params
self._batch_size = self._params.batch_size
self._replay_memory_capacity = replay_memory_capacity
self._batch_size = batch_size
self._data_parallelism = data_parallelism
self._reward_discount = reward_discount
self._agent2policy: Dict[Any, str] = {}
self._proxy_address: Optional[Tuple[str, int]] = None
self._logger = None
@property
def name(self) -> str:
@ -83,13 +74,11 @@ class AbsTrainer(object, metaclass=ABCMeta):
def agent_num(self) -> int:
return len(self._agent2policy)
def register_logger(self, logger: LoggerV2) -> None:
def register_logger(self, logger: LoggerV2 = None) -> None:
self._logger = logger
def register_agent2policy(self, agent2policy: Dict[Any, str], policy_trainer_mapping: Dict[str, str]) -> None:
"""Register the agent to policy dict that correspond to the current trainer. A valid policy name should start
with the name of its trainer. For example, "DQN.POLICY_NAME". Therefore, we could identify which policies
should be registered to the current trainer according to the policy's name.
"""Register the agent to policy dict that correspond to the current trainer.
Args:
agent2policy (Dict[Any, str]): Agent name to policy name mapping.
@ -102,16 +91,11 @@ class AbsTrainer(object, metaclass=ABCMeta):
}
@abstractmethod
def register_policy_creator(
self,
global_policy_creator: Dict[str, Callable[[], AbsPolicy]],
policy_trainer_mapping: Dict[str, str],
) -> None:
"""Register the policy creator. Only keep the creators of the policies that the current trainer need to train.
def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None:
"""Register the policies. Only keep the creators of the policies that the current trainer need to train.
Args:
global_policy_creator (Dict[str, Callable[[], AbsPolicy]]): Dict that contains the creators for all
policies.
policies (List[AbsPolicy]): All policies.
policy_trainer_mapping (Dict[str, str]): Policy name to trainer name mapping.
"""
raise NotImplementedError
@ -147,7 +131,7 @@ class AbsTrainer(object, metaclass=ABCMeta):
self._proxy_address = proxy_address
@abstractmethod
def get_policy_state(self) -> Dict[str, object]:
def get_policy_state(self) -> Dict[str, dict]:
"""Get policies' states.
Returns:
@ -171,30 +155,46 @@ class AbsTrainer(object, metaclass=ABCMeta):
class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
"""Policy trainer that trains only one policy."""
def __init__(self, name: str, params: TrainerParams) -> None:
super(SingleAgentTrainer, self).__init__(name, params)
self._policy_name: Optional[str] = None
self._policy_creator: Optional[Callable[[], RLPolicy]] = None
self._ops: Optional[AbsTrainOps] = None
self._replay_memory: Optional[ReplayMemory] = None
def __init__(
self,
name: str,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(SingleAgentTrainer, self).__init__(
name,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)
@property
def ops(self):
return self._ops
def ops(self) -> Union[AbsTrainOps, RemoteOps]:
ops = getattr(self, "_ops", None)
assert isinstance(ops, (AbsTrainOps, RemoteOps))
return ops
def register_policy_creator(
self,
global_policy_creator: Dict[str, Callable[[], AbsPolicy]],
policy_trainer_mapping: Dict[str, str],
) -> None:
policy_names = [
policy_name for policy_name in global_policy_creator if policy_trainer_mapping[policy_name] == self.name
]
if len(policy_names) != 1:
@property
def replay_memory(self) -> ReplayMemory:
replay_memory = getattr(self, "_replay_memory", None)
assert isinstance(replay_memory, ReplayMemory), "Replay memory is required."
return replay_memory
def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None:
policies = [policy for policy in policies if policy_trainer_mapping[policy.name] == self.name]
if len(policies) != 1:
raise ValueError(f"Trainer {self._name} should have exactly one policy assigned to it")
self._policy_name = policy_names.pop()
self._policy_creator = global_policy_creator[self._policy_name]
policy = policies.pop()
assert isinstance(policy, RLPolicy)
self._register_policy(policy)
@abstractmethod
def _register_policy(self, policy: RLPolicy) -> None:
raise NotImplementedError
@abstractmethod
def get_local_ops(self) -> AbsTrainOps:
@ -216,9 +216,9 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
ops = self.get_local_ops()
return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops
def get_policy_state(self) -> Dict[str, object]:
def get_policy_state(self) -> Dict[str, dict]:
self._assert_ops_exists()
policy_name, state = self._ops.get_policy_state()
policy_name, state = self.ops.get_policy_state()
return {policy_name: state}
def load(self, path: str) -> None:
@ -227,7 +227,7 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
policy_state = torch.load(os.path.join(path, f"{self.name}_policy.{FILE_SUFFIX}"))
non_policy_state = torch.load(os.path.join(path, f"{self.name}_non_policy.{FILE_SUFFIX}"))
self._ops.set_state(
self.ops.set_state(
{
"policy": policy_state,
"non_policy": non_policy_state,
@ -237,7 +237,7 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
def save(self, path: str) -> None:
self._assert_ops_exists()
ops_state = self._ops.get_state()
ops_state = self.ops.get_state()
policy_state = ops_state["policy"]
non_policy_state = ops_state["non_policy"]
@ -267,46 +267,57 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
next_states=np.vstack([exp[4] for exp in exps]),
)
transition_batch = self._preprocess_batch(transition_batch)
self._replay_memory.put(transition_batch)
self.replay_memory.put(transition_batch)
@abstractmethod
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
raise NotImplementedError
def _assert_ops_exists(self) -> None:
if not self._ops:
if not self.ops:
raise ValueError("'build' needs to be called to create an ops instance first.")
async def exit(self) -> None:
self._assert_ops_exists()
if isinstance(self._ops, RemoteOps):
await self._ops.exit()
ops = self.ops
if isinstance(ops, RemoteOps):
await ops.exit()
class MultiAgentTrainer(AbsTrainer, metaclass=ABCMeta):
"""Policy trainer that trains multiple policies."""
def __init__(self, name: str, params: TrainerParams) -> None:
super(MultiAgentTrainer, self).__init__(name, params)
self._policy_creator: Dict[str, Callable[[], RLPolicy]] = {}
self._policy_names: List[str] = []
self._ops_dict: Dict[str, AbsTrainOps] = {}
def __init__(
self,
name: str,
replay_memory_capacity: int = 10000,
batch_size: int = 128,
data_parallelism: int = 1,
reward_discount: float = 0.9,
) -> None:
super(MultiAgentTrainer, self).__init__(
name,
replay_memory_capacity,
batch_size,
data_parallelism,
reward_discount,
)
@property
def ops_dict(self):
return self._ops_dict
def ops_dict(self) -> Dict[str, AbsTrainOps]:
ops_dict = getattr(self, "_ops_dict", None)
assert isinstance(ops_dict, dict)
return ops_dict
def register_policy_creator(
self,
global_policy_creator: Dict[str, Callable[[], AbsPolicy]],
policy_trainer_mapping: Dict[str, str],
) -> None:
self._policy_creator: Dict[str, Callable[[], RLPolicy]] = {
policy_name: func
for policy_name, func in global_policy_creator.items()
if policy_trainer_mapping[policy_name] == self.name
}
self._policy_names = list(self._policy_creator.keys())
def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None:
self._policy_names: List[str] = [
policy.name for policy in policies if policy_trainer_mapping[policy.name] == self.name
]
self._policy_dict: Dict[str, RLPolicy] = {}
for policy in policies:
if policy_trainer_mapping[policy.name] == self.name:
assert isinstance(policy, RLPolicy)
self._policy_dict[policy.name] = policy
@abstractmethod
def get_local_ops(self, name: str) -> AbsTrainOps:
@ -335,7 +346,7 @@ class MultiAgentTrainer(AbsTrainer, metaclass=ABCMeta):
return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops
@abstractmethod
def get_policy_state(self) -> Dict[str, object]:
def get_policy_state(self) -> Dict[str, dict]:
raise NotImplementedError
@abstractmethod

Просмотреть файл

@ -7,7 +7,6 @@ import asyncio
import collections
import os
import typing
from itertools import chain
from typing import Any, Dict, Iterable, List, Tuple
from maro.rl.rollout import ExpElement
@ -26,8 +25,8 @@ class TrainingManager(object):
Training manager. Manage and schedule all trainers to train policies.
Args:
rl_component_bundle (RLComponentBundle): The RL component bundle of the job.
explicit_assign_device (bool): Whether to assign policy to its device in the training manager.
rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow.
explicit_assign_device (bool, default=False): Whether to assign policy to its device in the training manager.
proxy_address (Tuple[str, int], default=None): Address of the training proxy. If it is not None,
it is registered to all trainers, which in turn create `RemoteOps` for distributed training.
logger (LoggerV2, default=None): A logger for logging key events.
@ -36,36 +35,33 @@ class TrainingManager(object):
def __init__(
self,
rl_component_bundle: RLComponentBundle,
explicit_assign_device: bool,
explicit_assign_device: bool = False,
proxy_address: Tuple[str, int] = None,
logger: LoggerV2 = None,
) -> None:
super(TrainingManager, self).__init__()
self._trainer_dict: Dict[str, AbsTrainer] = {}
self._proxy_address = proxy_address
for trainer_name, func in rl_component_bundle.trainer_creator.items():
trainer = func()
self._trainer_dict: Dict[str, AbsTrainer] = {}
for trainer in rl_component_bundle.trainers:
if self._proxy_address:
trainer.set_proxy_address(self._proxy_address)
trainer.register_agent2policy(
rl_component_bundle.trainable_agent2policy,
rl_component_bundle.policy_trainer_mapping,
agent2policy=rl_component_bundle.trainable_agent2policy,
policy_trainer_mapping=rl_component_bundle.policy_trainer_mapping,
)
trainer.register_policy_creator(
rl_component_bundle.trainable_policy_creator,
rl_component_bundle.policy_trainer_mapping,
trainer.register_policies(
policies=rl_component_bundle.policies,
policy_trainer_mapping=rl_component_bundle.policy_trainer_mapping,
)
trainer.register_logger(logger)
trainer.build() # `build()` must be called after `register_policy_creator()`
self._trainer_dict[trainer_name] = trainer
trainer.build() # `build()` must be called after `register_policies()`
self._trainer_dict[trainer.name] = trainer
# User-defined allocation of compute devices, i.e., GPU's to the trainer ops
if explicit_assign_device:
for policy_name, device_name in rl_component_bundle.device_mapping.items():
if policy_name not in rl_component_bundle.policy_trainer_mapping: # No need to assign device
continue
trainer = self._trainer_dict[rl_component_bundle.policy_trainer_mapping[policy_name]]
if isinstance(trainer, SingleAgentTrainer):
@ -95,13 +91,16 @@ class TrainingManager(object):
for trainer in self._trainer_dict.values():
trainer.train_step()
def get_policy_state(self) -> Dict[str, Dict[str, object]]:
def get_policy_state(self) -> Dict[str, dict]:
"""Get policies' states.
Returns:
A double-deck dict with format: {trainer_name: {policy_name: policy_state}}
"""
return dict(chain(*[trainer.get_policy_state().items() for trainer in self._trainer_dict.values()]))
policy_states: Dict[str, dict] = {}
for trainer in self._trainer_dict.values():
policy_states.update(trainer.get_policy_state())
return policy_states
def record_experiences(self, experiences: List[List[ExpElement]]) -> None:
"""Record experiences collected from external modules (for example, EnvSampler).

Просмотреть файл

@ -6,7 +6,7 @@ from __future__ import annotations
import typing
from typing import Dict
from maro.rl.distributed import AbsWorker
from maro.rl.distributed import DEFAULT_TRAINING_BACKEND_PORT, AbsWorker
from maro.rl.training import SingleAgentTrainer
from maro.rl.utils.common import bytes_to_pyobj, bytes_to_string, pyobj_to_bytes
from maro.utils import LoggerV2
@ -24,7 +24,7 @@ class TrainOpsWorker(AbsWorker):
Args:
idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}",
so that the proxy can keep track of its connection status.
rl_component_bundle (RLComponentBundle): The RL component bundle of the job.
rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow.
producer_host (str): IP address of the proxy host to connect to.
producer_port (int, default=10001): Port of the proxy host to connect to.
"""
@ -34,13 +34,13 @@ class TrainOpsWorker(AbsWorker):
idx: int,
rl_component_bundle: RLComponentBundle,
producer_host: str,
producer_port: int = 10001,
producer_port: int = None,
logger: LoggerV2 = None,
) -> None:
super(TrainOpsWorker, self).__init__(
idx=idx,
producer_host=producer_host,
producer_port=producer_port,
producer_port=producer_port if producer_port is not None else DEFAULT_TRAINING_BACKEND_PORT,
logger=logger,
)
@ -62,13 +62,17 @@ class TrainOpsWorker(AbsWorker):
ops_name, req = bytes_to_string(msg[0]), bytes_to_pyobj(msg[-1])
assert isinstance(req, dict)
trainer_dict: Dict[str, AbsTrainer] = {
trainer.name: trainer for trainer in self._rl_component_bundle.trainers
}
if ops_name not in self._ops_dict:
trainer_name = ops_name.split(".")[0]
trainer_name = self._rl_component_bundle.policy_trainer_mapping[ops_name]
if trainer_name not in self._trainer_dict:
trainer = self._rl_component_bundle.trainer_creator[trainer_name]()
trainer.register_policy_creator(
self._rl_component_bundle.trainable_policy_creator,
self._rl_component_bundle.policy_trainer_mapping,
trainer = trainer_dict[trainer_name]
trainer.register_policies(
policies=self._rl_component_bundle.policies,
policy_trainer_mapping=self._rl_component_bundle.policy_trainer_mapping,
)
self._trainer_dict[trainer_name] = trainer

Просмотреть файл

@ -4,17 +4,17 @@
import os
import pickle
import socket
from typing import List, Optional
from typing import Any, List, Optional
def get_env(var_name: str, required: bool = True, default: object = None) -> str:
def get_env(var_name: str, required: bool = True, default: str = None) -> Optional[str]:
"""Wrapper for os.getenv() that includes a check for mandatory environment variables.
Args:
var_name (str): Variable name.
required (bool, default=True): Flag indicating whether the environment variable in questions is required.
If this is true and the environment variable is not present in ``os.environ``, a ``KeyError`` is raised.
default (object, default=None): Default value for the environment variable if it is missing in ``os.environ``
default (str, default=None): Default value for the environment variable if it is missing in ``os.environ``
and ``required`` is false. Ignored if ``required`` is True.
Returns:
@ -52,11 +52,11 @@ def bytes_to_string(bytes_: bytes) -> str:
return bytes_.decode(DEFAULT_MSG_ENCODING)
def pyobj_to_bytes(pyobj) -> bytes:
def pyobj_to_bytes(pyobj: Any) -> bytes:
return pickle.dumps(pyobj)
def bytes_to_pyobj(bytes_: bytes) -> object:
def bytes_to_pyobj(bytes_: bytes) -> Any:
return pickle.loads(bytes_)

Просмотреть файл

@ -55,5 +55,5 @@ def average_grads(grad_list: List[dict]) -> dict:
}
def get_torch_device(device: str = None):
def get_torch_device(device: str = None) -> torch.device:
return torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))

Просмотреть файл

@ -207,7 +207,7 @@ class ConfigParser:
f"{self._validation_err_pfx}: 'training.checkpointing.interval' must be an int",
)
def _validate_logging_section(self, component, level_dict: dict) -> None:
def _validate_logging_section(self, component: str, level_dict: dict) -> None:
if any(key not in {"stdout", "file"} for key in level_dict):
raise KeyError(
f"{self._validation_err_pfx}: fields under section '{component}.logging' must be 'stdout' or 'file'",
@ -261,7 +261,7 @@ class ConfigParser:
num_episodes = self._config["main"]["num_episodes"]
main_proc = f"{self._config['job']}.main"
min_n_sample = self._config["main"].get("min_n_sample", 1)
env = {
env: dict = {
main_proc: (
os.path.join(self._get_workflow_path(containerize=containerize), "main.py"),
{

Просмотреть файл

@ -6,116 +6,157 @@ import importlib
import os
import sys
import time
from typing import List, Type
from typing import List, Union
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import BatchEnvSampler, ExpElement
from maro.rl.rollout import AbsEnvSampler, BatchEnvSampler, ExpElement
from maro.rl.training import TrainingManager
from maro.rl.utils import get_torch_device
from maro.rl.utils.common import float_or_none, get_env, int_or_none, list_or_none
from maro.rl.utils.training import get_latest_ep
from maro.rl.workflows.utils import env_str_helper
from maro.utils import LoggerV2
def get_args() -> argparse.Namespace:
class WorkflowEnvAttributes:
def __init__(self) -> None:
# Number of training episodes
self.num_episodes = int(env_str_helper(get_env("NUM_EPISODES")))
# Maximum number of steps in on round of sampling.
self.num_steps = int_or_none(get_env("NUM_STEPS", required=False))
# Minimum number of data samples to start a round of training. If the data samples are insufficient, re-run
# data sampling until we have at least `min_n_sample` data entries.
self.min_n_sample = int(env_str_helper(get_env("MIN_N_SAMPLE")))
# Path to store logs.
self.log_path = get_env("LOG_PATH")
# Log levels
self.log_level_stdout = get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL")
self.log_level_file = get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL")
# Parallelism of sampling / evaluation. Used in distributed sampling.
self.env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False))
self.env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False))
# Training mode, simple or distributed
self.train_mode = get_env("TRAIN_MODE")
# Evaluating schedule.
self.eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False))
# Restore configurations.
self.load_path = get_env("LOAD_PATH", required=False)
self.load_episode = int_or_none(get_env("LOAD_EPISODE", required=False))
# Checkpointing configurations.
self.checkpoint_path = get_env("CHECKPOINT_PATH", required=False)
self.checkpoint_interval = int_or_none(get_env("CHECKPOINT_INTERVAL", required=False))
# Parallel sampling configurations.
self.parallel_rollout = self.env_sampling_parallelism is not None or self.env_eval_parallelism is not None
if self.parallel_rollout:
self.port = int(env_str_helper(get_env("ROLLOUT_CONTROLLER_PORT")))
self.min_env_samples = int_or_none(get_env("MIN_ENV_SAMPLES", required=False))
self.grace_factor = float_or_none(get_env("GRACE_FACTOR", required=False))
self.is_single_thread = self.train_mode == "simple" and not self.parallel_rollout
# Distributed training configurations.
if self.train_mode != "simple":
self.proxy_address = (
env_str_helper(get_env("TRAIN_PROXY_HOST")),
int(env_str_helper(get_env("TRAIN_PROXY_FRONTEND_PORT"))),
)
self.logger = LoggerV2(
"MAIN",
dump_path=self.log_path,
dump_mode="a",
stdout_level=self.log_level_stdout,
file_level=self.log_level_file,
)
def _get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="MARO RL workflow parser")
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
return parser.parse_args()
def main(rl_component_bundle: RLComponentBundle, args: argparse.Namespace) -> None:
if args.evaluate_only:
evaluate_only_workflow(rl_component_bundle)
else:
training_workflow(rl_component_bundle)
def training_workflow(rl_component_bundle: RLComponentBundle) -> None:
num_episodes = int(get_env("NUM_EPISODES"))
num_steps = int_or_none(get_env("NUM_STEPS", required=False))
min_n_sample = int_or_none(get_env("MIN_N_SAMPLE"))
logger = LoggerV2(
"MAIN",
dump_path=get_env("LOG_PATH"),
dump_mode="a",
stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"),
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
)
logger.info("Start training workflow.")
env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False))
env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False))
parallel_rollout = env_sampling_parallelism is not None or env_eval_parallelism is not None
train_mode = get_env("TRAIN_MODE")
is_single_thread = train_mode == "simple" and not parallel_rollout
if is_single_thread:
rl_component_bundle.pre_create_policy_instances()
if parallel_rollout:
env_sampler = BatchEnvSampler(
sampling_parallelism=env_sampling_parallelism,
port=int(get_env("ROLLOUT_CONTROLLER_PORT")),
min_env_samples=int_or_none(get_env("MIN_ENV_SAMPLES", required=False)),
grace_factor=float_or_none(get_env("GRACE_FACTOR", required=False)),
eval_parallelism=env_eval_parallelism,
logger=logger,
def _get_env_sampler(
rl_component_bundle: RLComponentBundle,
env_attr: WorkflowEnvAttributes,
) -> Union[AbsEnvSampler, BatchEnvSampler]:
if env_attr.parallel_rollout:
assert env_attr.env_sampling_parallelism is not None
return BatchEnvSampler(
sampling_parallelism=env_attr.env_sampling_parallelism,
port=env_attr.port,
min_env_samples=env_attr.min_env_samples,
grace_factor=env_attr.grace_factor,
eval_parallelism=env_attr.env_eval_parallelism,
logger=env_attr.logger,
)
else:
env_sampler = rl_component_bundle.env_sampler
if train_mode != "simple":
if rl_component_bundle.device_mapping is not None:
for policy_name, device_name in rl_component_bundle.device_mapping.items():
env_sampler.assign_policy_to_device(policy_name, get_torch_device(device_name))
return env_sampler
def main(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes, args: argparse.Namespace) -> None:
if args.evaluate_only:
evaluate_only_workflow(rl_component_bundle, env_attr)
else:
training_workflow(rl_component_bundle, env_attr)
def training_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
env_attr.logger.info("Start training workflow.")
env_sampler = _get_env_sampler(rl_component_bundle, env_attr)
# evaluation schedule
eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False))
logger.info(f"Policy will be evaluated at the end of episodes {eval_schedule}")
env_attr.logger.info(f"Policy will be evaluated at the end of episodes {env_attr.eval_schedule}")
eval_point_index = 0
training_manager = TrainingManager(
rl_component_bundle=rl_component_bundle,
explicit_assign_device=(train_mode == "simple"),
proxy_address=None
if train_mode == "simple"
else (
get_env("TRAIN_PROXY_HOST"),
int(get_env("TRAIN_PROXY_FRONTEND_PORT")),
),
logger=logger,
explicit_assign_device=(env_attr.train_mode == "simple"),
proxy_address=None if env_attr.train_mode == "simple" else env_attr.proxy_address,
logger=env_attr.logger,
)
load_path = get_env("LOAD_PATH", required=False)
load_episode = int_or_none(get_env("LOAD_EPISODE", required=False))
if load_path:
assert isinstance(load_path, str)
if env_attr.load_path:
assert isinstance(env_attr.load_path, str)
ep = load_episode if load_episode is not None else get_latest_ep(load_path)
path = os.path.join(load_path, str(ep))
ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path)
path = os.path.join(env_attr.load_path, str(ep))
loaded = env_sampler.load_policy_state(path)
logger.info(f"Loaded policies {loaded} into env sampler from {path}")
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
loaded = training_manager.load(path)
logger.info(f"Loaded trainers {loaded} from {path}")
env_attr.logger.info(f"Loaded trainers {loaded} from {path}")
start_ep = ep + 1
else:
start_ep = 1
checkpoint_path = get_env("CHECKPOINT_PATH", required=False)
checkpoint_interval = int_or_none(get_env("CHECKPOINT_INTERVAL", required=False))
# main loop
for ep in range(start_ep, num_episodes + 1):
collect_time = training_time = 0
for ep in range(start_ep, env_attr.num_episodes + 1):
collect_time = training_time = 0.0
total_experiences: List[List[ExpElement]] = []
total_info_list: List[dict] = []
n_sample = 0
while n_sample < min_n_sample:
while n_sample < env_attr.min_n_sample:
tc0 = time.time()
result = env_sampler.sample(
policy_state=training_manager.get_policy_state() if not is_single_thread else None,
num_steps=num_steps,
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
num_steps=env_attr.num_steps,
)
experiences: List[List[ExpElement]] = result["experiences"]
info_list: List[dict] = result["info"]
@ -128,23 +169,25 @@ def training_workflow(rl_component_bundle: RLComponentBundle) -> None:
env_sampler.post_collect(total_info_list, ep)
logger.info(f"Roll-out completed for episode {ep}. Training started...")
env_attr.logger.info(f"Roll-out completed for episode {ep}. Training started...")
tu0 = time.time()
training_manager.record_experiences(total_experiences)
training_manager.train_step()
if checkpoint_path and (checkpoint_interval is None or ep % checkpoint_interval == 0):
assert isinstance(checkpoint_path, str)
pth = os.path.join(checkpoint_path, str(ep))
if env_attr.checkpoint_path and (not env_attr.checkpoint_interval or ep % env_attr.checkpoint_interval == 0):
assert isinstance(env_attr.checkpoint_path, str)
pth = os.path.join(env_attr.checkpoint_path, str(ep))
training_manager.save(pth)
logger.info(f"All trainer states saved under {pth}")
env_attr.logger.info(f"All trainer states saved under {pth}")
training_time += time.time() - tu0
# performance details
logger.info(f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds")
if eval_schedule and ep == eval_schedule[eval_point_index]:
env_attr.logger.info(
f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds",
)
if env_attr.eval_schedule and ep == env_attr.eval_schedule[eval_point_index]:
eval_point_index += 1
result = env_sampler.eval(
policy_state=training_manager.get_policy_state() if not is_single_thread else None,
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
)
env_sampler.post_evaluate(result["info"], ep)
@ -153,42 +196,19 @@ def training_workflow(rl_component_bundle: RLComponentBundle) -> None:
training_manager.exit()
def evaluate_only_workflow(rl_component_bundle: RLComponentBundle) -> None:
logger = LoggerV2(
"MAIN",
dump_path=get_env("LOG_PATH"),
dump_mode="a",
stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"),
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
)
logger.info("Start evaluate only workflow.")
def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
env_attr.logger.info("Start evaluate only workflow.")
env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False))
env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False))
parallel_rollout = env_sampling_parallelism is not None or env_eval_parallelism is not None
env_sampler = _get_env_sampler(rl_component_bundle, env_attr)
if parallel_rollout:
env_sampler = BatchEnvSampler(
sampling_parallelism=env_sampling_parallelism,
port=int(get_env("ROLLOUT_CONTROLLER_PORT")),
min_env_samples=int_or_none(get_env("MIN_ENV_SAMPLES", required=False)),
grace_factor=float_or_none(get_env("GRACE_FACTOR", required=False)),
eval_parallelism=env_eval_parallelism,
logger=logger,
)
else:
env_sampler = rl_component_bundle.env_sampler
if env_attr.load_path:
assert isinstance(env_attr.load_path, str)
load_path = get_env("LOAD_PATH", required=False)
load_episode = int_or_none(get_env("LOAD_EPISODE", required=False))
if load_path:
assert isinstance(load_path, str)
ep = load_episode if load_episode is not None else get_latest_ep(load_path)
path = os.path.join(load_path, str(ep))
ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path)
path = os.path.join(env_attr.load_path, str(ep))
loaded = env_sampler.load_policy_state(path)
logger.info(f"Loaded policies {loaded} into env sampler from {path}")
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
result = env_sampler.eval()
env_sampler.post_evaluate(result["info"], -1)
@ -198,11 +218,9 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle) -> None:
if __name__ == "__main__":
scenario_path = get_env("SCENARIO_PATH")
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
scenario_path = os.path.normpath(scenario_path)
sys.path.insert(0, os.path.dirname(scenario_path))
module = importlib.import_module(os.path.basename(scenario_path))
rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls")
rl_component_bundle = rl_component_bundle_cls()
main(rl_component_bundle, args=get_args())
main(getattr(module, "rl_component_bundle"), WorkflowEnvAttributes(), args=_get_args())

Просмотреть файл

@ -4,23 +4,22 @@
import importlib
import os
import sys
from typing import Type
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import RolloutWorker
from maro.rl.utils.common import get_env, int_or_none
from maro.rl.workflows.utils import env_str_helper
from maro.utils import LoggerV2
if __name__ == "__main__":
scenario_path = get_env("SCENARIO_PATH")
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
scenario_path = os.path.normpath(scenario_path)
sys.path.insert(0, os.path.dirname(scenario_path))
module = importlib.import_module(os.path.basename(scenario_path))
rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls")
rl_component_bundle = rl_component_bundle_cls()
rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle")
worker_idx = int_or_none(get_env("ID"))
worker_idx = int(env_str_helper(get_env("ID")))
logger = LoggerV2(
f"ROLLOUT-WORKER.{worker_idx}",
dump_path=get_env("LOG_PATH"),
@ -31,7 +30,7 @@ if __name__ == "__main__":
worker = RolloutWorker(
idx=worker_idx,
rl_component_bundle=rl_component_bundle,
producer_host=get_env("ROLLOUT_CONTROLLER_HOST"),
producer_host=env_str_helper(get_env("ROLLOUT_CONTROLLER_HOST")),
producer_port=int_or_none(get_env("ROLLOUT_CONTROLLER_PORT")),
logger=logger,
)

Просмотреть файл

@ -4,21 +4,20 @@
import importlib
import os
import sys
from typing import Type
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.training import TrainOpsWorker
from maro.rl.utils.common import get_env, int_or_none
from maro.rl.workflows.utils import env_str_helper
from maro.utils import LoggerV2
if __name__ == "__main__":
scenario_path = get_env("SCENARIO_PATH")
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
scenario_path = os.path.normpath(scenario_path)
sys.path.insert(0, os.path.dirname(scenario_path))
module = importlib.import_module(os.path.basename(scenario_path))
rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls")
rl_component_bundle = rl_component_bundle_cls()
rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle")
worker_idx = int_or_none(get_env("ID"))
logger = LoggerV2(
@ -29,9 +28,9 @@ if __name__ == "__main__":
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
)
worker = TrainOpsWorker(
idx=int_or_none(get_env("ID")),
idx=int(env_str_helper(get_env("ID"))),
rl_component_bundle=rl_component_bundle,
producer_host=get_env("TRAIN_PROXY_HOST"),
producer_host=env_str_helper(get_env("TRAIN_PROXY_HOST")),
producer_port=int_or_none(get_env("TRAIN_PROXY_BACKEND_PORT")),
logger=logger,
)

Просмотреть файл

@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
def env_str_helper(string: Optional[str]) -> str:
assert string is not None
return string

Просмотреть файл

@ -72,7 +72,7 @@ class AbsEnv(ABC):
return self._business_engine
@abstractmethod
def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
def step(self, action) -> Tuple[Optional[dict], Optional[list], bool]:
"""Push the environment to next step with action.
Args:

Просмотреть файл

@ -1,10 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections.abc import Iterable
from importlib import import_module
from inspect import getmembers, isclass
from typing import Generator, List, Optional, Tuple
from typing import Generator, List, Optional, Tuple, Union, cast
from maro.backends.frame import FrameBase, SnapshotList
from maro.data_lib.dump_csv_converter import DumpConverter
@ -12,6 +11,7 @@ from maro.event_buffer import ActualEvent, CascadeEvent, EventBuffer, EventState
from maro.streamit import streamit
from maro.utils.exception.simulator_exception import BusinessEngineNotFoundError
from ..common import BaseAction, BaseDecisionEvent
from .abs_core import AbsEnv, DecisionMode
from .scenarios.abs_business_engine import AbsBusinessEngine
from .utils.common import tick_to_frame_index
@ -73,8 +73,8 @@ class Env(AbsEnv):
self._event_buffer = EventBuffer(disable_finished_events, record_finished_events, record_file_path)
# decision_events array for dump.
self._decision_events = []
# decision_payloads array for dump.
self._decision_payloads = []
# The generator used to push the simulator forward.
self._simulate_generator = self._simulate()
@ -89,21 +89,48 @@ class Env(AbsEnv):
self._streamit_episode = 0
def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
def step(
self,
action: Union[BaseAction, List[BaseAction], None] = None,
) -> Tuple[Optional[dict], Union[BaseDecisionEvent, List[BaseDecisionEvent], None], bool]:
"""Push the environment to next step with action.
Under Sequential mode:
- If `action` is None, an empty list will be assigned to the decision event.
- Otherwise, the action(s) will be assigned to the decision event.
Under Joint mode:
- If `action` is None, no actions will be assigned to any decision event.
- If `action` is a single action, it will be assigned to the first decision event.
- If `action` is a list, actions are assigned to each decision event in order. If the number of actions
is less than the number of decision events, extra decision events will not be assigned actions. If
the number of actions if larger than the number of decision events, extra actions will be ignored.
If you want to assign multiple actions to specific event(s), please explicitly pass a list of list. For
example:
```
env.step(action=[[a1, a2], a3, [a4, a5]])
```
Will assign `a1` & `a2` to the first decision event, `a3` to the second decision event, and `a4` & `a5`
to the third decision event.
Particularly, if you only want to assign multiple actions to the first decision event, please
pass `[[a1, a2, ..., an]]` (a list of one list) instead of `[a1, a2, ..., an]` (an 1D list of n elements),
since the latter one will assign the n actions to the first n decision events.
Args:
action (Action): Action(s) from agent.
action (Union[BaseAction, List[BaseAction], None]): Action(s) from agent.
Returns:
tuple: a tuple of (metrics, decision event, is_done).
"""
try:
metrics, decision_event, _is_done = self._simulate_generator.send(action)
metrics, decision_payloads, _is_done = self._simulate_generator.send(action)
except StopIteration:
return None, None, True
return metrics, decision_event, _is_done
return metrics, decision_payloads, _is_done
def dump(self) -> None:
"""Dump environment for restore.
@ -131,10 +158,14 @@ class Env(AbsEnv):
self._business_engine.frame.dump(dump_folder)
self._converter.start_processing(self.configs)
self._converter.dump_descsion_events(self._decision_events, self._start_tick, self._snapshot_resolution)
self._converter.dump_descsion_events(
self._decision_payloads,
self._start_tick,
self._snapshot_resolution,
)
self._business_engine.dump(dump_folder)
self._decision_events.clear()
self._decision_payloads.clear()
self._business_engine.reset(keep_seed)
@ -267,7 +298,29 @@ class Env(AbsEnv):
additional_options=self._additional_options,
)
def _simulate(self) -> Generator[Tuple[dict, List[object], bool], object, None]:
def _assign_action(
self,
action: Union[BaseAction, List[BaseAction], None],
decision_event: CascadeEvent,
) -> None:
decision_event.state = EventState.EXECUTING
if action is None:
actions = []
elif not isinstance(action, list):
actions = [action]
else:
actions = action
decision_event.add_immediate_event(self._event_buffer.gen_action_event(self._tick, actions), is_head=True)
def _simulate(
self,
) -> Generator[
Tuple[dict, Union[BaseDecisionEvent, List[BaseDecisionEvent]], bool],
Union[BaseAction, List[BaseAction], None],
None,
]:
"""This is the generator to wrap each episode process."""
self._streamit_episode += 1
@ -282,7 +335,7 @@ class Env(AbsEnv):
while True:
# Keep processing events, until no more events in this tick.
pending_events = self._event_buffer.execute(self._tick)
pending_events = cast(List[CascadeEvent], self._event_buffer.execute(self._tick))
if len(pending_events) == 0:
# We have processed all the event of current tick, lets go for next tick.
@ -292,50 +345,25 @@ class Env(AbsEnv):
self._business_engine.frame.take_snapshot(self.frame_index)
# Append source event id to decision events, to support sequential action in joint mode.
decision_events = [event.payload for event in pending_events]
decision_events = (
decision_events[0] if self._decision_mode == DecisionMode.Sequential else decision_events
)
# Yield current state first, and waiting for action.
actions = yield self._business_engine.get_metrics(), decision_events, False
# archive decision events.
self._decision_events.append(decision_events)
if actions is None:
# Make business engine easy to work.
actions = []
elif not isinstance(actions, Iterable):
actions = [actions]
decision_payloads = [event.payload for event in pending_events]
if self._decision_mode == DecisionMode.Sequential:
# Generate a new atom event first.
action_event = self._event_buffer.gen_action_event(self._tick, actions)
# NOTE: decision event always be a CascadeEvent
# We just append the action into sub event of first pending cascade event.
event = pending_events[0]
assert isinstance(event, CascadeEvent)
event.state = EventState.EXECUTING
event.add_immediate_event(action_event, is_head=True)
self._decision_payloads.append(decision_payloads[0])
action = yield self._business_engine.get_metrics(), decision_payloads[0], False
self._assign_action(action, pending_events[0])
else:
# For joint mode, we will assign actions from beginning to end.
# Then mark others pending events to finished if not sequential action mode.
for i, pending_event in enumerate(pending_events):
if i >= len(actions):
if self._decision_mode == DecisionMode.Joint:
# Ignore following pending events that have no action matched.
pending_event.state = EventState.FINISHED
else:
# Set the state as executing, so event buffer will not pop them again.
# Then insert the action to it.
action = actions[i]
pending_event.state = EventState.EXECUTING
action_event = self._event_buffer.gen_action_event(self._tick, action)
self._decision_payloads += decision_payloads
actions = yield self._business_engine.get_metrics(), decision_payloads, False
if actions is None:
actions = []
assert isinstance(actions, list)
assert isinstance(pending_event, CascadeEvent)
pending_event.add_immediate_event(action_event, is_head=True)
for action, event in zip(actions, pending_events):
self._assign_action(action, event)
if self._decision_mode == DecisionMode.Joint:
for event in pending_events[len(actions) :]:
event.state = EventState.FINISHED
# Check the end tick of the simulation to decide if we should end the simulation.
is_end_tick = self._business_engine.post_step(self._tick)

Просмотреть файл

@ -714,38 +714,38 @@ class CimBusinessEngine(AbsBusinessEngine):
actions = event.payload
assert isinstance(actions, list)
if actions:
for action in actions:
vessel_idx = action.vessel_idx
port_idx = action.port_idx
move_num = action.quantity
vessel = self._vessels[vessel_idx]
port = self._ports[port_idx]
port_empty = port.empty
vessel_empty = vessel.empty
for action in actions:
assert isinstance(action, Action)
assert isinstance(action, Action)
action_type = action.action_type
vessel_idx = action.vessel_idx
port_idx = action.port_idx
move_num = action.quantity
vessel = self._vessels[vessel_idx]
port = self._ports[port_idx]
port_empty = port.empty
vessel_empty = vessel.empty
if action_type == ActionType.DISCHARGE:
assert move_num <= vessel_empty
action_type = action.action_type
port.empty = port_empty + move_num
vessel.empty = vessel_empty - move_num
else:
assert move_num <= min(port_empty, vessel.remaining_space)
if action_type == ActionType.DISCHARGE:
assert move_num <= vessel_empty
port.empty = port_empty - move_num
vessel.empty = vessel_empty + move_num
port.empty = port_empty + move_num
vessel.empty = vessel_empty - move_num
else:
assert move_num <= min(port_empty, vessel.remaining_space)
# Align the event type to make the output readable.
event.event_type = Events.DISCHARGE_EMPTY if action_type == ActionType.DISCHARGE else Events.LOAD_EMPTY
port.empty = port_empty - move_num
vessel.empty = vessel_empty + move_num
# Update transfer cost for port and metrics.
self._total_operate_num += move_num
port.transfer_cost += move_num
# Align the event type to make the output readable.
event.event_type = Events.DISCHARGE_EMPTY if action_type == ActionType.DISCHARGE else Events.LOAD_EMPTY
self._vessel_plans[vessel_idx, port_idx] += self._data_cntr.vessel_period[vessel_idx]
# Update transfer cost for port and metrics.
self._total_operate_num += move_num
port.transfer_cost += move_num
self._vessel_plans[vessel_idx, port_idx] += self._data_cntr.vessel_period[vessel_idx]
def _stream_base_info(self):
if streamit:

Просмотреть файл

@ -5,6 +5,7 @@
from enum import Enum, IntEnum
from maro.backends.frame import SnapshotList
from maro.common import BaseAction, BaseDecisionEvent
class VesselState(IntEnum):
@ -21,7 +22,7 @@ class ActionType(Enum):
DISCHARGE = "discharge"
class Action:
class Action(BaseAction):
"""Action object that used to pass action from agent to business engine.
Args:
@ -68,7 +69,7 @@ class ActionScope:
return "%s {load: %r, discharge: %r}" % (self.__class__.__name__, self.load, self.discharge)
class DecisionEvent:
class DecisionEvent(BaseDecisionEvent):
"""Decision event for agent.
Args:

Просмотреть файл

@ -15,7 +15,7 @@ from maro.backends.frame import FrameBase, SnapshotList
from maro.cli.data_pipeline.citi_bike import CitiBikeProcess
from maro.cli.data_pipeline.utils import chagne_file_path
from maro.data_lib import BinaryReader
from maro.event_buffer import AtomEvent, EventBuffer, MaroEvents
from maro.event_buffer import AtomEvent, CascadeEvent, EventBuffer, MaroEvents
from maro.simulator.scenarios import AbsBusinessEngine
from maro.simulator.scenarios.helpers import DocableDict
from maro.simulator.scenarios.matrix_accessor import MatrixAttributeAccessor
@ -23,7 +23,7 @@ from maro.utils.exception.cli_exception import CommandError
from maro.utils.logger import CliLogger
from .adj_loader import load_adj_from_csv
from .common import BikeReturnPayload, BikeTransferPayload, DecisionEvent
from .common import Action, BikeReturnPayload, BikeTransferPayload, DecisionEvent
from .decision_strategy import BikeDecisionStrategy
from .events import CitiBikeEvents
from .frame_builder import build_frame
@ -33,7 +33,6 @@ from .weather_table import WeatherTable
logger = CliLogger(name=__name__)
metrics_desc = """
Citi bike metrics used to provide statistics information at current point (may be in the middle of a tick).
It contains following keys:
@ -519,14 +518,15 @@ class CitibikeBusinessEngine(AbsBusinessEngine):
station.bikes = station_bikes + max_accept_number
def _on_action_received(self, evt: AtomEvent):
def _on_action_received(self, evt: CascadeEvent):
"""Callback when we get an action from agent."""
action = None
actions = evt.payload
if evt is None or evt.payload is None:
return
assert isinstance(actions, list)
for action in actions:
assert isinstance(action, Action)
for action in evt.payload:
from_station_idx: int = action.from_station_idx
to_station_idx: int = action.to_station_idx

Просмотреть файл

@ -3,6 +3,8 @@
from enum import Enum
from maro.common import BaseAction, BaseDecisionEvent
class BikeTransferPayload:
"""Payload for bike transfer event.
@ -63,7 +65,7 @@ class DecisionType(Enum):
Demand = "demand"
class DecisionEvent:
class DecisionEvent(BaseDecisionEvent):
"""Citi bike scenario decision event that contains station information for agent to choose action.
Args:
@ -127,7 +129,7 @@ class DecisionEvent:
)
class Action:
class Action(BaseAction):
"""Citi bike scenario action object, that used to pass action from agent to business engine.
Args:

Просмотреть файл

@ -25,7 +25,7 @@ class Station(NodeBase):
# avg temp
temperature = NodeAttribute("i2")
# 0: sunny, 1: rainy, 2: snowy 3: sleet
# 0: sunny, 1: rainy, 2: snowy, 3: sleet
weather = NodeAttribute("i2")
# 0: holiday, 1: not holiday

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
from .business_engine import VmSchedulingBusinessEngine
from .common import AllocateAction, DecisionPayload, Latency, PostponeAction, VmRequestPayload
from .common import AllocateAction, DecisionEvent, Latency, PostponeAction, VmRequestPayload
from .cpu_reader import CpuReader
from .enums import Events, PmState, PostponeType, VmCategory
from .physical_machine import PhysicalMachine
@ -12,7 +12,7 @@ __all__ = [
"VmSchedulingBusinessEngine",
"AllocateAction",
"PostponeAction",
"DecisionPayload",
"DecisionEvent",
"Latency",
"VmRequestPayload",
"CpuReader",

Просмотреть файл

@ -17,7 +17,7 @@ from maro.simulator.scenarios.helpers import DocableDict
from maro.utils.logger import CliLogger
from maro.utils.utils import convert_dottable
from .common import AllocateAction, DecisionPayload, Latency, PostponeAction, VmRequestPayload
from .common import Action, AllocateAction, DecisionEvent, Latency, PostponeAction, VmRequestPayload
from .cpu_reader import CpuReader
from .enums import Events, PmState, PostponeType, VmCategory
from .frame_builder import build_frame
@ -528,7 +528,7 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
"""dict: Event payload details of current scenario."""
return {
Events.REQUEST.name: VmRequestPayload.summary_key,
MaroEvents.PENDING_DECISION.name: DecisionPayload.summary_key,
MaroEvents.PENDING_DECISION.name: DecisionEvent.summary_key,
}
def get_agent_idx_list(self) -> List[int]:
@ -820,7 +820,7 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
if len(valid_pm_list) > 0:
# Generate pending decision.
decision_payload = DecisionPayload(
decision_payload = DecisionEvent(
frame_index=self.frame_index(tick=self._tick),
valid_pms=valid_pm_list,
vm_id=vm_info.id,
@ -846,20 +846,24 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
def _on_action_received(self, event: CascadeEvent):
"""Callback wen we get an action from agent."""
action = None
if event is None or event.payload is None:
actions = event.payload
assert isinstance(actions, list)
if len(actions) == 0:
self._pending_vm_request_payload.pop(self._pending_action_vm_id)
return
cur_tick: int = event.tick
for action in actions:
assert isinstance(action, Action)
cur_tick: int = event.tick
for action in event.payload:
vm_id: int = action.vm_id
if vm_id not in self._pending_vm_request_payload:
raise Exception(f"The VM id: '{vm_id}' sent by agent is invalid.")
if type(action) == AllocateAction:
if isinstance(action, AllocateAction):
pm_id = action.pm_id
vm: VirtualMachine = self._pending_vm_request_payload[vm_id].vm_info
lifetime = vm.lifetime
@ -899,7 +903,7 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
)
self._successful_allocation += 1
elif type(action) == PostponeAction:
elif isinstance(action, PostponeAction):
postpone_step = action.postpone_step
remaining_buffer_time = self._pending_vm_request_payload[vm_id].remaining_buffer_time
# Either postpone the requirement event or failed.

Просмотреть файл

@ -3,10 +3,12 @@
from typing import List
from maro.common import BaseAction, BaseDecisionEvent
from .virtual_machine import VirtualMachine
class Action:
class Action(BaseAction):
"""VM Scheduling scenario action object, which was used to pass action from agent to business engine.
Args:
@ -74,7 +76,7 @@ class VmRequestPayload:
)
class DecisionPayload:
class DecisionEvent(BaseDecisionEvent):
"""Decision event in VM Scheduling scenario that contains information for agent to choose action.
Args:

Просмотреть файл

@ -27,6 +27,11 @@ def get_available_envs():
return envs
def scenario_not_empty(scenario_path: str) -> bool:
_, _, files = next(os.walk(scenario_path))
return "business_engine.py" in files
def get_scenarios() -> List[str]:
"""Get built-in scenario name list.
@ -35,7 +40,13 @@ def get_scenarios() -> List[str]:
"""
try:
_, scenarios, _ = next(os.walk(scenarios_root_folder))
scenarios = sorted([s for s in scenarios if not s.startswith("__")])
scenarios = sorted(
[
s
for s in scenarios
if not s.startswith("__") and scenario_not_empty(os.path.join(scenarios_root_folder, s))
],
)
except StopIteration:
return []

Просмотреть файл

@ -0,0 +1,452 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quick Start\n",
"\n",
"This notebook demonstrates the use of MARO's RL toolkit to optimize container inventory management (CIM). The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import necessary packages\n",
"from typing import Any, Dict, List, Tuple, Union\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from torch.optim import Adam, RMSprop\n",
"\n",
"from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet\n",
"from maro.rl.policy import DiscretePolicyGradient\n",
"from maro.rl.rl_component.rl_component_bundle import RLComponentBundle\n",
"from maro.rl.rollout import AbsEnvSampler, CacheElement, ExpElement\n",
"from maro.rl.training import TrainingManager\n",
"from maro.rl.training.algorithms import PPOParams, PPOTrainer\n",
"from maro.simulator import Env\n",
"from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# env and shaping config\n",
"reward_shaping_conf = {\n",
" \"time_window\": 99,\n",
" \"fulfillment_factor\": 1.0,\n",
" \"shortage_factor\": 1.0,\n",
" \"time_decay\": 0.97,\n",
"}\n",
"state_shaping_conf = {\n",
" \"look_back\": 7,\n",
" \"max_ports_downstream\": 2,\n",
"}\n",
"port_attributes = [\"empty\", \"full\", \"on_shipper\", \"on_consignee\", \"booking\", \"shortage\", \"fulfillment\"]\n",
"vessel_attributes = [\"empty\", \"full\", \"remaining_space\"]\n",
"action_shaping_conf = {\n",
" \"action_space\": [(i - 10) / 10 for i in range(21)],\n",
" \"finite_vessel_space\": True,\n",
" \"has_early_discharge\": True,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Environment Sampler\n",
"\n",
"An environment sampler defines state, action and reward shaping logic so that policies can interact with the environment."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class CIMEnvSampler(AbsEnvSampler):\n",
" def _get_global_and_agent_state_impl(\n",
" self, event: DecisionEvent, tick: int = None,\n",
" ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:\n",
" tick = self._env.tick\n",
" vessel_snapshots, port_snapshots = self._env.snapshot_list[\"vessels\"], self._env.snapshot_list[\"ports\"]\n",
" port_idx, vessel_idx = event.port_idx, event.vessel_idx\n",
" ticks = [max(0, tick - rt) for rt in range(state_shaping_conf[\"look_back\"] - 1)]\n",
" future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')\n",
" state = np.concatenate([\n",
" port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes],\n",
" vessel_snapshots[tick: vessel_idx: vessel_attributes]\n",
" ])\n",
" return state, {port_idx: state}\n",
"\n",
" def _translate_to_env_action(\n",
" self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionEvent,\n",
" ) -> Dict[Any, object]:\n",
" action_space = action_shaping_conf[\"action_space\"]\n",
" finite_vsl_space = action_shaping_conf[\"finite_vessel_space\"]\n",
" has_early_discharge = action_shaping_conf[\"has_early_discharge\"]\n",
"\n",
" port_idx, model_action = list(action_dict.items()).pop()\n",
"\n",
" vsl_idx, action_scope = event.vessel_idx, event.action_scope\n",
" vsl_snapshots = self._env.snapshot_list[\"vessels\"]\n",
" vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float(\"inf\")\n",
"\n",
" percent = abs(action_space[model_action[0]])\n",
" zero_action_idx = len(action_space) / 2 # index corresponding to value zero.\n",
" if model_action < zero_action_idx:\n",
" action_type = ActionType.LOAD\n",
" actual_action = min(round(percent * action_scope.load), vsl_space)\n",
" elif model_action > zero_action_idx:\n",
" action_type = ActionType.DISCHARGE\n",
" early_discharge = vsl_snapshots[self._env.tick:vsl_idx:\"early_discharge\"][0] if has_early_discharge else 0\n",
" plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge\n",
" actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)\n",
" else:\n",
" actual_action, action_type = 0, None\n",
"\n",
" return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)}\n",
"\n",
" def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:\n",
" start_tick = tick + 1\n",
" ticks = list(range(start_tick, start_tick + reward_shaping_conf[\"time_window\"]))\n",
"\n",
" # Get the ports that took actions at the given tick\n",
" ports = [int(port) for port in list(env_action_dict.keys())]\n",
" port_snapshots = self._env.snapshot_list[\"ports\"]\n",
" future_fulfillment = port_snapshots[ticks:ports:\"fulfillment\"].reshape(len(ticks), -1)\n",
" future_shortage = port_snapshots[ticks:ports:\"shortage\"].reshape(len(ticks), -1)\n",
"\n",
" decay_list = [reward_shaping_conf[\"time_decay\"] ** i for i in range(reward_shaping_conf[\"time_window\"])]\n",
" rewards = np.float32(\n",
" reward_shaping_conf[\"fulfillment_factor\"] * np.dot(future_fulfillment.T, decay_list)\n",
" - reward_shaping_conf[\"shortage_factor\"] * np.dot(future_shortage.T, decay_list)\n",
" )\n",
" return {agent_id: reward for agent_id, reward in zip(ports, rewards)}\n",
"\n",
" def _post_step(self, cache_element: CacheElement) -> None:\n",
" self._info[\"env_metric\"] = self._env.metrics\n",
"\n",
" def _post_eval_step(self, cache_element: CacheElement) -> None:\n",
" self._post_step(cache_element)\n",
"\n",
" def post_collect(self, info_list: list, ep: int) -> None:\n",
" # print the env metric from each rollout worker\n",
" for info in info_list:\n",
" print(f\"env summary (episode {ep}): {info['env_metric']}\")\n",
"\n",
" # print the average env metric\n",
" if len(info_list) > 1:\n",
" metric_keys, num_envs = info_list[0][\"env_metric\"].keys(), len(info_list)\n",
" avg_metric = {key: sum(info[\"env_metric\"][key] for info in info_list) / num_envs for key in metric_keys}\n",
" print(f\"average env summary (episode {ep}): {avg_metric}\")\n",
"\n",
" def post_evaluate(self, info_list: list, ep: int) -> None:\n",
" self.post_collect(info_list, ep)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Policies & Trainers"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"state_dim = (\n",
" (state_shaping_conf[\"look_back\"] + 1) * (state_shaping_conf[\"max_ports_downstream\"] + 1) * len(port_attributes)\n",
" + len(vessel_attributes)\n",
")\n",
"action_num = len(action_shaping_conf[\"action_space\"])\n",
"\n",
"actor_net_conf = {\n",
" \"hidden_dims\": [256, 128, 64],\n",
" \"activation\": torch.nn.Tanh,\n",
" \"softmax\": True,\n",
" \"batch_norm\": False,\n",
" \"head\": True,\n",
"}\n",
"critic_net_conf = {\n",
" \"hidden_dims\": [256, 128, 64],\n",
" \"output_dim\": 1,\n",
" \"activation\": torch.nn.LeakyReLU,\n",
" \"softmax\": False,\n",
" \"batch_norm\": True,\n",
" \"head\": True,\n",
"}\n",
"\n",
"actor_learning_rate = 0.001\n",
"critic_learning_rate = 0.001\n",
"\n",
"class MyActorNet(DiscreteACBasedNet):\n",
" def __init__(self, state_dim: int, action_num: int) -> None:\n",
" super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)\n",
" self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)\n",
" self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)\n",
"\n",
" def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:\n",
" return self._actor(states)\n",
"\n",
"\n",
"class MyCriticNet(VNet):\n",
" def __init__(self, state_dim: int) -> None:\n",
" super(MyCriticNet, self).__init__(state_dim=state_dim)\n",
" self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf)\n",
" self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate)\n",
"\n",
" def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:\n",
" return self._critic(states).squeeze(-1)\n",
"\n",
"def get_ppo_trainer(state_dim: int, name: str) -> PPOTrainer:\n",
" return PPOTrainer(\n",
" name=name,\n",
" reward_discount=.0,\n",
" params=PPOParams(\n",
" get_v_critic_net_func=lambda: MyCriticNet(state_dim),\n",
" grad_iters=10,\n",
" critic_loss_cls=torch.nn.SmoothL1Loss,\n",
" min_logp=None,\n",
" lam=.0,\n",
" clip_ratio=0.1,\n",
" ),\n",
" )\n",
"\n",
"learn_env = Env(scenario=\"cim\", topology=\"toy.4p_ssdd_l0.0\", durations=500)\n",
"test_env = learn_env\n",
"num_agents = len(learn_env.agent_idx_list)\n",
"agent2policy = {agent: f\"ppo_{agent}.policy\"for agent in learn_env.agent_idx_list}\n",
"policies = [DiscretePolicyGradient(name=f\"ppo_{i}.policy\", policy_net=MyActorNet(state_dim, action_num)) for i in range(num_agents)]\n",
"trainers = [get_ppo_trainer(state_dim, f\"ppo_{i}\") for i in range(num_agents)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RL component bundle\n",
"\n",
"An RL component bundle integrate all necessary resources to launch a learning loop."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"rl_component_bundle = RLComponentBundle(\n",
" env_sampler=CIMEnvSampler(\n",
" learn_env=learn_env,\n",
" test_env=test_env,\n",
" policies=policies,\n",
" agent2policy=agent2policy,\n",
" reward_eval_delay=reward_shaping_conf[\"time_window\"],\n",
" ),\n",
" agent2policy=agent2policy,\n",
" policies=policies,\n",
" trainers=trainers,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning Loop\n",
"\n",
"This code cell demonstrates a typical single-threaded training workflow."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting result:\n",
"env summary (episode 1): {'order_requirements': 1000000, 'container_shortage': 688632, 'operation_number': 1940226}\n",
"\n",
"Collecting result:\n",
"env summary (episode 2): {'order_requirements': 1000000, 'container_shortage': 601337, 'operation_number': 2030600}\n",
"\n",
"Collecting result:\n",
"env summary (episode 3): {'order_requirements': 1000000, 'container_shortage': 544572, 'operation_number': 1737291}\n",
"\n",
"Collecting result:\n",
"env summary (episode 4): {'order_requirements': 1000000, 'container_shortage': 545506, 'operation_number': 2008160}\n",
"\n",
"Collecting result:\n",
"env summary (episode 5): {'order_requirements': 1000000, 'container_shortage': 442000, 'operation_number': 1935439}\n",
"\n",
"Evaluation result:\n",
"env summary (episode 5): {'order_requirements': 1000000, 'container_shortage': 533699, 'operation_number': 891248}\n",
"\n",
"Collecting result:\n",
"env summary (episode 6): {'order_requirements': 1000000, 'container_shortage': 448461, 'operation_number': 1918664}\n",
"\n",
"Collecting result:\n",
"env summary (episode 7): {'order_requirements': 1000000, 'container_shortage': 469874, 'operation_number': 1745973}\n",
"\n",
"Collecting result:\n",
"env summary (episode 8): {'order_requirements': 1000000, 'container_shortage': 364469, 'operation_number': 1974592}\n",
"\n",
"Collecting result:\n",
"env summary (episode 9): {'order_requirements': 1000000, 'container_shortage': 425449, 'operation_number': 1821885}\n",
"\n",
"Collecting result:\n",
"env summary (episode 10): {'order_requirements': 1000000, 'container_shortage': 386687, 'operation_number': 1798356}\n",
"\n",
"Evaluation result:\n",
"env summary (episode 10): {'order_requirements': 1000000, 'container_shortage': 950000, 'operation_number': 0}\n",
"\n",
"Collecting result:\n",
"env summary (episode 11): {'order_requirements': 1000000, 'container_shortage': 403236, 'operation_number': 1742253}\n",
"\n",
"Collecting result:\n",
"env summary (episode 12): {'order_requirements': 1000000, 'container_shortage': 373426, 'operation_number': 1682848}\n",
"\n",
"Collecting result:\n",
"env summary (episode 13): {'order_requirements': 1000000, 'container_shortage': 357254, 'operation_number': 1845318}\n",
"\n",
"Collecting result:\n",
"env summary (episode 14): {'order_requirements': 1000000, 'container_shortage': 215681, 'operation_number': 1969606}\n",
"\n",
"Collecting result:\n",
"env summary (episode 15): {'order_requirements': 1000000, 'container_shortage': 288347, 'operation_number': 1739670}\n",
"\n",
"Evaluation result:\n",
"env summary (episode 15): {'order_requirements': 1000000, 'container_shortage': 639517, 'operation_number': 680980}\n",
"\n",
"Collecting result:\n",
"env summary (episode 16): {'order_requirements': 1000000, 'container_shortage': 258659, 'operation_number': 1747509}\n",
"\n",
"Collecting result:\n",
"env summary (episode 17): {'order_requirements': 1000000, 'container_shortage': 202262, 'operation_number': 1982958}\n",
"\n",
"Collecting result:\n",
"env summary (episode 18): {'order_requirements': 1000000, 'container_shortage': 209018, 'operation_number': 1765574}\n",
"\n",
"Collecting result:\n",
"env summary (episode 19): {'order_requirements': 1000000, 'container_shortage': 256471, 'operation_number': 1764379}\n",
"\n",
"Collecting result:\n",
"env summary (episode 20): {'order_requirements': 1000000, 'container_shortage': 259231, 'operation_number': 1737222}\n",
"\n",
"Evaluation result:\n",
"env summary (episode 20): {'order_requirements': 1000000, 'container_shortage': 9000, 'operation_number': 1974766}\n",
"\n",
"Collecting result:\n",
"env summary (episode 21): {'order_requirements': 1000000, 'container_shortage': 268553, 'operation_number': 1697234}\n",
"\n",
"Collecting result:\n",
"env summary (episode 22): {'order_requirements': 1000000, 'container_shortage': 212987, 'operation_number': 1788601}\n",
"\n",
"Collecting result:\n",
"env summary (episode 23): {'order_requirements': 1000000, 'container_shortage': 234729, 'operation_number': 1803468}\n",
"\n",
"Collecting result:\n",
"env summary (episode 24): {'order_requirements': 1000000, 'container_shortage': 224261, 'operation_number': 1736261}\n",
"\n",
"Collecting result:\n",
"env summary (episode 25): {'order_requirements': 1000000, 'container_shortage': 191424, 'operation_number': 1952505}\n",
"\n",
"Evaluation result:\n",
"env summary (episode 25): {'order_requirements': 1000000, 'container_shortage': 606940, 'operation_number': 710472}\n",
"\n",
"Collecting result:\n",
"env summary (episode 26): {'order_requirements': 1000000, 'container_shortage': 223272, 'operation_number': 1895614}\n",
"\n",
"Collecting result:\n",
"env summary (episode 27): {'order_requirements': 1000000, 'container_shortage': 427395, 'operation_number': 1351830}\n",
"\n",
"Collecting result:\n",
"env summary (episode 28): {'order_requirements': 1000000, 'container_shortage': 266455, 'operation_number': 1924877}\n",
"\n",
"Collecting result:\n",
"env summary (episode 29): {'order_requirements': 1000000, 'container_shortage': 362452, 'operation_number': 1747022}\n",
"\n",
"Collecting result:\n",
"env summary (episode 30): {'order_requirements': 1000000, 'container_shortage': 320532, 'operation_number': 1506639}\n",
"\n",
"Evaluation result:\n",
"env summary (episode 30): {'order_requirements': 1000000, 'container_shortage': 639581, 'operation_number': 681708}\n",
"\n"
]
}
],
"source": [
"env_sampler = rl_component_bundle.env_sampler\n",
"\n",
"num_episodes = 30\n",
"eval_schedule = [5, 10, 15, 20, 25, 30]\n",
"eval_point_index = 0\n",
"\n",
"training_manager = TrainingManager(rl_component_bundle=rl_component_bundle)\n",
"\n",
"# main loop\n",
"for ep in range(1, num_episodes + 1):\n",
" result = env_sampler.sample()\n",
" experiences: List[List[ExpElement]] = result[\"experiences\"]\n",
" info_list: List[dict] = result[\"info\"]\n",
" \n",
" print(\"Collecting result:\")\n",
" env_sampler.post_collect(info_list, ep)\n",
" print()\n",
"\n",
" training_manager.record_experiences(experiences)\n",
" training_manager.train_step()\n",
"\n",
" if ep == eval_schedule[eval_point_index]:\n",
" eval_point_index += 1\n",
" result = env_sampler.eval()\n",
" \n",
" print(\"Evaluation result:\")\n",
" env_sampler.post_evaluate(result[\"info\"], ep)\n",
" print()\n",
"\n",
"training_manager.exit()"
]
}
],
"metadata": {
"interpreter": {
"hash": "8f57a09d39b50edfb56e79199ef40583334d721b06ead0e38a39e7e79092073c"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -1,4 +1,5 @@
jupyter==1.0.0
ipython-genutils>=0.2.0
ipython>=7.16.3
jupyter-client
jupyter-console
jupyter-contrib-core
@ -7,18 +8,10 @@ jupyter-core
jupyter-highlight-selected-word
jupyter-latex-envs
jupyter-nbextensions-configurator
jupyter>=1.0.0
jupyterlab
jupyterlab-server
jupyterthemes
isort==4.3.21
autopep8==1.4.4
isort==4.3.21
pandas==0.25.3
matplotlib==3.1.2
seaborn==0.9.0
ipython==7.16.3
ipython-genutils==0.2.0
shap==0.32.1
seaborn==0.9.0
numpy<1.20.0
numba==0.46.0
matplotlib>=3.1.2
seaborn>=0.9.0
shap>=0.32.1

Просмотреть файл

@ -1,7 +1,5 @@
add-trailing-comma
altair==4.1.0
aria2p==0.9.1
astroid==2.3.3
altair>=4.1.0
aria2p>=0.9.1
azure-identity
azure-mgmt-authorization
azure-mgmt-containerservice
@ -10,60 +8,51 @@ azure-mgmt-storage
azure-storage-file-share
black==22.3.0
certifi==2022.12.7
cryptography==36.0.1
cycler==0.10.0
cryptography>=36.0.1
Cython==0.29.14
deepdiff==5.7.0
deepdiff>=5.7.0
docker
editorconfig-checker==2.4.0
flake8==4.0.1
flask-cors==3.0.10
flask==1.1.2
flask_cors==3.0.10
flask_socketio==5.2.0
Flask>=1.1.2
Flask_Cors>=3.0.10
Flask_SocketIO>=5.2.0
flloat==0.3.0
geopy==2.0.0
guppy3==3.0.9
holidays==0.10.3
isort==4.3.21
jinja2==2.11.3
kiwisolver==1.1.0
kubernetes==21.7.0
lazy-object-proxy==1.4.3
geopy>=2.0.0
holidays>=0.10.3
Jinja2>=2.11.3
kubernetes>=21.7.0
markupsafe==2.0.1
matplotlib==3.5.2
mccabe==0.6.1
networkx==2.4
networkx==2.4
numpy<1.20.0
numpy>=1.19.5
palettable==3.3.0
pandas==0.25.3
pandas>=0.25.3
paramiko>=2.9.2
pre-commit==2.19.0
prompt_toolkit==2.0.10
psutil==5.8.0
ptvsd==4.3.2
prompt_toolkit>=2.0.10
psutil>=5.8.0
ptvsd>=4.3.2
pulp==2.6.0
pyaml==20.4.0
PyJWT==2.4.0
pyparsing==2.4.5
python-dateutil==2.8.1
PyYAML==5.4.1
pyzmq==19.0.2
PyJWT>=2.4.0
python_dateutil>=2.8.1
PyYAML>=5.4.1
pyzmq>=19.0.2
recommonmark~=0.6.0
redis==3.5.3
requests==2.25.1
scipy==1.7.0
requests>=2.25.1
scipy>=1.7.0
setuptools==58.0.4
six==1.13.0
sphinx==1.8.6
sphinx_rtd_theme==1.0.0
streamlit==1.11.1
stringcase==1.2.0
tabulate==0.8.5
streamlit>=0.69.1
stringcase>=1.2.0
tabulate>=0.8.5
termgraph==0.5.3
torch==1.6.0
torchsummary==1.5.1
tqdm==4.51.0
tornado>=6.1
tqdm>=4.51.0
urllib3==1.26.5
wrapt==1.11.2
zmq==0.0.0

Просмотреть файл

@ -135,25 +135,20 @@ setup(
],
install_requires=[
# TODO: use a helper function to collect these
"numpy<1.20.0",
"scipy<=1.7.0",
"torch<1.8.0",
"holidays>=0.10.3",
"pyaml>=20.4.0",
"numpy>=1.19.5",
"pandas>=0.25.3",
"paramiko>=2.9.2",
"ptvsd>=4.3.2",
"python_dateutil>=2.8.1",
"PyYAML>=5.4.1",
"pyzmq>=19.0.2",
"redis>=3.5.3",
"pyzmq<22.1.0",
"requests<=2.26.0",
"psutil<5.9.0",
"deepdiff>=5.2.2",
"azure-storage-blob<12.9.0",
"azure-storage-common",
"geopy>=2.0.0",
"pandas<1.2",
"PyYAML<5.5.0",
"paramiko>=2.7.2",
"kubernetes>=12.0.1",
"prompt_toolkit<3.1.0",
"stringcase>=1.2.0",
"requests>=2.25.1",
"scipy>=1.7.0",
"tabulate>=0.8.5",
"torch>=1.6.0, <1.8.0",
"tornado>=6.1",
]
+ specific_requires,
entry_points={

Просмотреть файл

@ -1,23 +1,17 @@
matplotlib>=3.1.2
geopy
pandas<1.2
numpy<1.20.0
holidays>=0.10.3
pyaml>=20.4.0
redis>=3.5.3
pyzmq<22.1.0
influxdb
requests<=2.26.0
psutil<5.9.0
deepdiff>=5.2.2
altair>=4.1.0
azure-storage-blob<12.9.0
azure-storage-common
torch<1.8.0
pytest
coverage
termgraph
paramiko>=2.7.2
pytz==2019.3
aria2p==0.9.1
kubernetes>=12.0.1
PyYAML<5.5.0
coverage>=6.4.1
deepdiff>=5.7.0
geopy>=2.0.0
holidays>=0.10.3
kubernetes>=21.7.0
numpy>=1.19.5,<1.24.0
pandas>=0.25.3
paramiko>=2.9.2
pytest>=7.1.2
PyYAML>=5.4.1
pyzmq>=19.0.2
redis>=3.5.3
streamlit>=0.69.1
termgraph>=0.5.3

Просмотреть файл

@ -8,7 +8,7 @@ from maro.simulator.core import Env
from maro.simulator.utils import get_available_envs, get_scenarios, get_topologies
from maro.simulator.utils.common import frame_index_to_ticks, tick_to_frame_index
from .dummy.dummy_business_engine import DummyEngine
from tests.dummy.dummy_business_engine import DummyEngine
from tests.utils import backends_to_test
@ -326,7 +326,7 @@ class TestEnv(unittest.TestCase):
msg=f"env should stop at tick 6, but {env.tick}",
)
# avaiable snapshot should be 7 (0-6)
# available snapshot should be 7 (0-6)
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
self.assertEqual(
@ -376,24 +376,17 @@ class TestEnv(unittest.TestCase):
with self.assertRaises(FileNotFoundError) as ctx:
Env("cim", "None", 100)
def test_get_avaiable_envs(self):
scenario_names = get_scenarios()
def test_get_available_envs(self):
scenario_names = sorted(get_scenarios())
# we have 3 built-in scenarios
self.assertEqual(3, len(scenario_names))
self.assertTrue("cim" in scenario_names)
self.assertTrue("citi_bike" in scenario_names)
cim_topoloies = get_topologies("cim")
citi_bike_topologies = get_topologies("citi_bike")
vm_topoloties = get_topologies("vm_scheduling")
self.assertListEqual(scenario_names, ["cim", "citi_bike", "vm_scheduling"])
env_list = get_available_envs()
self.assertEqual(
len(env_list),
len(cim_topoloies) + len(citi_bike_topologies) + len(vm_topoloties) + len(get_topologies("supply_chain")),
sum(len(get_topologies(s)) for s in scenario_names),
)
def test_frame_index_to_ticks(self):
@ -404,7 +397,7 @@ class TestEnv(unittest.TestCase):
self.assertListEqual([0, 1], ticks[0])
self.assertListEqual([8, 9], ticks[4])
def test_get_avalible_frame_index_to_ticks_with_default_resolution(self):
def test_get_available_frame_index_to_ticks_with_default_resolution(self):
for backend_name in backends_to_test:
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
@ -425,7 +418,7 @@ class TestEnv(unittest.TestCase):
self.assertListEqual([t for t in t2f_mapping.keys()], [t for t in range(max_tick)])
self.assertListEqual([f for f in t2f_mapping.values()], [f for f in range(max_tick)])
def test_get_avalible_frame_index_to_ticks_with_resolution2(self):
def test_get_available_frame_index_to_ticks_with_resolution2(self):
for backend_name in backends_to_test:
os.environ["DEFAULT_BACKEND_NAME"] = backend_name

Просмотреть файл

@ -6,6 +6,7 @@ import time
import unittest
from typing import Optional
from maro.common import BaseDecisionEvent
from maro.event_buffer import ActualEvent, AtomEvent, CascadeEvent, DummyEvent, EventBuffer, EventState, MaroEvents
from maro.event_buffer.event_linked_list import EventLinkedList
from maro.event_buffer.event_pool import EventPool
@ -251,9 +252,9 @@ class TestEventBuffer(unittest.TestCase):
self.eb.execute(1)
def test_sub_events_with_decision(self):
evt1 = self.eb.gen_decision_event(1, (1, 1, 1))
sub1 = self.eb.gen_decision_event(1, (2, 2, 2))
sub2 = self.eb.gen_decision_event(1, (3, 3, 3))
evt1 = self.eb.gen_decision_event(1, BaseDecisionEvent())
sub1 = self.eb.gen_decision_event(1, BaseDecisionEvent())
sub2 = self.eb.gen_decision_event(1, BaseDecisionEvent())
evt1.add_immediate_event(sub1, is_head=True)
evt1.add_immediate_event(sub2)