V0.3: Upgrade RL Workflow; Add RL Benchmarks; Update Package Version (#588)

* call policy update only for AbsCorePolicy

* add limitation of AbsCorePolicy in Actor.collect()

* refined actor to return only experiences for policies that received new experiences

* fix MsgKey issue in rollout_manager

* fix typo in learner

* call exit function for parallel rollout manager

* update supply chain example distributed training scripts

* 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype

* reformat render

* fix supply chain business engine action type problem

* reset supply chain example render figsize from 4 to 3

* Add render to all modes of supply chain example

* fix or policy typos

* 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes

* refined parallel policy manager

* updated rl/__init__/py

* fixed lint issues and CIM local learner bugs

* deleted unwanted supply_chain test files

* revised default config for cim-dqn

* removed test_store.py as it is no longer needed

* 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms

* updated figures

* removed unwanted import

* refactored CIM-DQN example

* added MultiProcessRolloutManager and MultiProcessTrainingManager

* updated doc

* lint issue fix

* lint issue fix

* fixed import formatting

* [Feature] Prioritized Experience Replay (#355)

* added prioritized experience replay

* deleted unwanted supply_chain test files

* fixed import order

* import fix

* fixed lint issues

* fixed import formatting

* added note in docstring that rank-based PER has yet to be implemented

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* rm AbsDecisionGenerator

* small fixes

* bug fix

* reorganized training folder structure

* fixed lint issues

* fixed lint issues

* policy manager refined

* lint fix

* restructured CIM-dqn sync code

* added policy version index and used it as a measure of experience staleness

* lint issue fix

* lint issue fix

* switched log_dir and proxy_kwargs order

* cim example refinement

* eval schedule sorted only when it's a list

* eval schedule sorted only when it's a list

* update sc env wrapper

* added docker scripts for cim-dqn

* refactored example folder structure and added workflow templates

* fixed lint issues

* fixed lint issues

* fixed template bugs

* removed unused imports

* refactoring sc in progress

* simplified cim meta

* fixed build.sh path bug

* template refinement

* deleted obsolete svgs

* updated learner logs

* minor edits

* refactored templates for easy merge with async PR

* added component names for rollout manager and policy manager

* fixed incorrect position to add last episode to eval schedule

* added max_lag option in templates

* formatting edit in docker_compose_yml script

* moved local learner and early stopper outside sync_tools

* refactored rl toolkit folder structure

* refactored rl toolkit folder structure

* moved env_wrapper and agent_wrapper inside rl/learner

* refined scripts

* fixed typo in script

* changes needed for running sc

* removed unwanted imports

* config change for testing sc scenario

* changes for perf testing

* Asynchronous Training (#364)

* remote inference code draft

* changed actor to rollout_worker and updated init files

* removed unwanted import

* updated inits

* more async code

* added async scripts

* added async training code & scripts for CIM-dqn

* changed async to async_tools to avoid conflict with python keyword

* reverted unwanted change to dockerfile

* added doc for policy server

* addressed PR comments and fixed a bug in docker_compose_yml.py

* fixed lint issue

* resolved PR comment

* resolved merge conflicts

* added async templates

* added proxy.close() for actor and policy_server

* fixed incorrect position to add last episode to eval schedule

* reverted unwanted changes

* added missing async files

* rm unwanted echo in kill.sh

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* renamed sync to synchronous and async to asynchronous to avoid conflict with keyword

* added missing policy version increment in LocalPolicyManager

* refined rollout manager recv logic

* removed a debugging print

* added sleep in distributed launcher to avoid hanging

* updated api doc and rl toolkit doc

* refined dynamic imports using importlib

* 1. moved policy update triggers to policy manager; 2. added version control in policy manager

* fixed a few bugs and updated cim RL example

* fixed a few more bugs

* added agent wrapper instantiation to workflows

* added agent wrapper instantiation to workflows

* removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet

* fixed incorrect get_ac_policy signature for CIM

* moved exploration inside core policy

* added state to exploration call to support context-dependent exploration

* separated non_rl_policy_index and rl_policy_index in workflows

* modified sc example code according to workflow changes

* modified sc example code according to workflow changes

* added replay_agent_ids parameter to get_env_func for RL examples

* fixed a few bugs

* added maro/simulator/scenarios/supply_chain as bind mount

* added post-step, post-collect, post-eval and post-update callbacks

* fixed lint issues

* fixed lint issues

* moved instantiation of policy manager inside simple learner

* fixed env_wrapper get_reward signature

* minor edits

* removed get_eperience kwargs from env_wrapper

* 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows

* added rollout exp disribution option in RL examples

* removed unwanted files

* 1. made logger internal in learner; 2 removed logger creation in abs classes

* checked out supply chain test files from v0.2_sc

* 1. added missing model.eval() to choose_action; 2.added entropy features to AC

* fixed a bug in ac entropy

* abbreviated coefficient to coeff

* removed -dqn from job name in rl example config

* added tmp patch to dev.df

* renamed image name for running rl examples

* added get_loss interface for core policies

* added policy manager in rl_toolkit.rst

* 1. env_wrapper bug fix; 2. policy manager update logic refinement

* refactored policy and algorithms

* policy interface redesigned

* refined policy interfaces

* fixed typo

* fixed bugs in refactored policy interface

* fixed some bugs

* refactoring in progress

* policy interface and policy manager redesigned

* 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts

* fixed bug in distributed policy manager

* fixed lint issues

* fixed lint issues

* added scipy in setup

* 1. trimmed rollout manager code; 2. added option to docker scripts

* updated api doc for policy manager

* 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script

* 1. simplified rl example structure; 2. fixed lint issues

* further rl toolkit code simplifications

* more numpy-based optimization in RL toolkit

* moved replay buffer inside policy

* bug fixes

* numpy optimization and associated refactoring

* extracted shaping logic out of env_sampler

* fixed bug in CIM shaping and lint issues

* preliminary implemetation of parallel batch inference

* fixed bug in ddpg transition recording

* put get_state, get_env_actions, get_reward back in EnvSampler

* simplified exploration and core model interfaces

* bug fixes and doc update

* added improve() interface for RLPolicy for single-thread support

* fixed simple policy manager bug

* updated doc, rst, notebook

* updated notebook

* fixed lint issues

* fixed entropy bugs in ac.py

* reverted to simple policy manager as default

* 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit

* fixed lint issues and updated rl toolkit images

* removed obsolete images

* added back agent2policy for general workflow use

* V0.2 rl refinement dist (#377)

* Support `slice` operation in ExperienceSet

* Support naive distributed policy training by proxy

* Dynamically allocate trainers according to number of experience

* code check

* code check

* code check

* Fix a bug in distributed trianing with no gradient

* Code check

* Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy

* 1.call allocate_trainer() at first of update(); 2.refine according to code review

* Code check

* Refine code with new interface

* Update docs of PolicyManger and ExperienceSet

* Add images for rl_toolkit docs

* Update diagram of PolicyManager

* Refine with new interface

* Extract allocation strategy into `allocation_strategy.py`

* add `distributed_learn()` in policies for data-parallel training

* Update doc of RL_toolkit

* Add gradient workers for data-parallel

* Refine code and update docs

* Lint check

* Refine by comments

* Rename `trainer` to `worker`

* Rename `distributed_learn` to `learn_with_data_parallel`

* Refine allocator and remove redundant code in policy_manager

* remove arugments in allocate_by_policy and so on

* added checkpointing for simple and multi-process policy managers

* 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager

* added missing set_state and get_state for CIM policies

* removed blank line

* updated RL workflow README

* Integrate `data_parallel` arguments into `worker_allocator` (#402)

* 1. simplified workflow config; 2. added comments to CIM shaping

* lint issue fix

* 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading

* 1. moved post_step callback inside env sampler; 2. updated README for rl workflows

* refined READEME for CIM

* VM scheduling with RL (#375)

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* added DQN

* added get_experiences func for ac in vm scheduling

* added post_step callback to env wrapper

* moved Aiming's tracking and plotting logic into callbacks

* added eval env wrapper

* renamed AC config variable name for VM

* vm scheduling RL code finished

* updated README

* fixed various bugs and hard coding for vm_scheduling

* uncommented callbacks for VM scheduling

* Minor revision for better code style

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* vm example refactoring

* fixed bugs in vm_scheduling

* removed unwanted files from cim dir

* reverted to simple policy manager as default

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* resolved rebase conflicts

* fixed bugs in vm_scheduling

* added get_state and set_state to vm_scheduling policy models

* updated README for vm_scheduling with RL

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* SC refinement (#397)

* Refine test scripts & pending_order_daily logic

* Refactor code for better code style: complete type hint, correct typos, remove unused items.

Refactor code for better code style: complete type hint, correct typos, remove unused items.

* Polish test_supply_chain.py

* update import format

* Modify vehicle steps logic & remove outdated test case

* Optimize imports

* Optimize imports

* Lint error

* Lint error

* Lint error

* Add SupplyChainAction

* Lint error

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* refined workflow scripts

* fixed bug in ParallelAgentWrapper

* 1. fixed lint issues; 2. refined main script in workflows

* lint issue fix

* restored default config for rl example

* Update rollout.py

* refined env var processing in policy manager workflow

* added hasattr check in agent wrapper

* updated docker_compose_yml.py

* Minor refinement

* Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412)

* Prepare to merge master_mirror

* Lint error

* Minor

* Merge latest master into v0.3 (#426)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* Minor

* Remove docs/source/examples/multi_agent_dqn_cim.rst

* Update .gitignore

* Update .gitignore

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>

* Change `Env.set_seed()` logic (#456)

* Change Env.set_seed() logic

* Redesign CIM reset logic; fix lint issues;

* Lint

* Seed type assertion

* Remove all SC related files (#473)

* RL Toolkit V3 (#471)

* added daemon=True for multi-process rollout, policy manager and inference

* removed obsolete files

* [REDO][PR#406]V0.2 rl refinement taskq (#408)

* Add a usable task_queue

* Rename some variables

* 1. Add ; 2. Integrate  related files; 3. Remove

* merge `data_parallel` and `num_grad_workers` into `data_parallelism`

* Fix bugs in docker_compose_yml.py and Simple/Multi-process mode.

* Move `grad_worker` into marl/rl/workflows

* 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue.

* Refine code and update docs of `TaskQueue`

* Support priority for tasks in `task_queue`

* Update diagram of policy manager and task queue.

* Add configurable `single_task_limit` and correct docstring about `data_parallelism`

* Fix lint errors in `supply chain`

* RL policy redesign (V2) (#405)

* Drafi v2.0 for V2

* Polish models with more comments

* Polish policies with more comments

* Lint

* Lint

* Add developer doc for models.

* Add developer doc for policies.

* Remove policy manager V2 since it is not used and out-of-date

* Lint

* Lint

* refined messy workflow code

* merged 'scenario_dir' and 'scenario' in rl config

* 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods

* 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2

* merged cim and cim_v2

* lint issue fix

* refined logging logic

* lint issue fix

* reversed unwanted changes

* .

.

.

.

ReplayMemory & IndexScheduler

ReplayMemory & IndexScheduler

.

MultiReplayMemory

get_actions_with_logps

EnvSampler on the road

EnvSampler

Minor

* LearnerManager

* Use batch to transfer data & add SHAPE_CHECK_FLAG

* Rename learner to trainer

* Add property for policy._is_exploring

* CIM test scenario for V3. Manual test passed. Next step: run it, make it works.

* env_sampler.py could run

* env_sampler refine on the way

* First runnable version done

* AC could run, but the result is bad. Need to check the logic

* Refine abstract method & shape check error info.

* Docs

* Very detailed compare. Try again.

* AC done

* DQN check done

* Minor

* DDPG, not tested

* Minors

* A rough draft of MAAC

* Cannot use CIM as the multi-agent scenario.

* Minor

* MAAC refinement on the way

* Remove ActionWithAux

* Refine batch & memory

* MAAC example works

* Reproduce-able fix. Policy share between env_sampler and trainer_manager.

* Detail refinement

* Simplify the user configed workflow

* Minor

* Refine example codes

* Minor polishment

* Migrate rollout_manager to V3

* Error on the way

* Redesign torch.device management

* Rl v3 maddpg (#418)

* Add MADDPG trainer

* Fit independent critics and shared critic modes.

* Add a new property: num_policies

* Lint

* Fix a bug in `sum(rewards)`

* Rename `MADDPG` to `DiscreteMADDPG` and fix type hint.

* Rename maddpg in examples.

* Preparation for data parallel (#420)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* Rename train worker to train ops; add placeholder for abstract methods;

* Lint

Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>

* [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* dsitributed training pipeline draft

* added temporary test files for review purposes

* Several code style refinements (#451)

* Polish rl_v3/utils/

* Polish rl_v3/distributed/

* Polish rl_v3/policy_trainer/abs_trainer.py

* fixed merge conflicts

* unified sync and async interfaces

* refactored rl_v3; refinement in progress

* Finish the runnable pipeline under new design

* Remove outdated files; refine class names; optimize imports;

* Lint

* Minor maddpg related refinement

* Lint

Co-authored-by: Default <huo53926@126.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Miner bug fix

* Coroutine-related bug fix ("get_policy_state") (#452)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* deleted unwanted folder

* removed unwanted changes

* resolved PR452 comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Quick fix

* Redesign experience recording logic (#453)

* Two not important fix

* Temp draft. Prepare to WFH

* Done

* Lint

* Lint

* Calculating advantages / returns (#454)

* V1.0

* Complete DDPG

* Rl v3 hanging issue fix (#455)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* Final test & format. Ready to merge.

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* Rl v3 parallel rollout (#457)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* load balancing dispatcher

* added parallel rollout

* lint

* Tracker variable type issue; rename to env_sampler_creator;

* Rl v3 parallel rollout follow ups (#458)

* AbsWorker & AbsDispatcher

* Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds.

* Fix policy_creator reuse bug

* Format code

* Merge AbsTrainerManager & SimpleTrainerManager

* AC test passed

* Lint

* Remove AbsTrainer.build() method. Put all initialization operations into __init__

* Redesign AC preprocess batches logic

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* MADDPG performance bug fix (#459)

* Fix MARL (MADDPG) terminal recording bug; some other minor refinements;

* Restore Trainer.build() method

* Calculate latest action in the get_actor_grad method in MADDPG.

* Share critic bug fix

* Rl v3 example update (#461)

* updated vm_scheduling example and cim notebook

* fixed bugs in vm_scheduling

* added local train method

* bug fix

* modified async client logic to fix hidden issue

* reverted to default config

* fixed PR comments and some bugs

* removed hardcode

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Done (#462)

* Rl v3 load save (#463)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL Toolkit data parallelism revamp & config utils (#464)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

* 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local

* 1. fixed rollout exit issue; 2. refined config

* removed config file from example

* fixed lint issues

* fixed lint issues

* added main.py under examples/rl

* fixed lint issues

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL doc string (#465)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* Rl config doc (#467)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* resolved more PR comments

* typo fix

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL online doc (#469)

* Model, policy, trainer

* RL workflows and env sampler doc in RST (#468)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* rewriting rl toolkit rst

* resolved more PR comments

* typo fix

* updated rst

Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: Default <huo53926@126.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Finish docs/source/key_components/rl_toolkit.rst

* API doc

* RL online doc image fix (#470)

* resolved some PR comments

* fix

* fixed PR comments

* added numfig=True setting in conf.py for sphinx

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Resolve PR comments

* Add example github link

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Rl v3 pr comment resolution (#474)

* added load/save feature

* 1. resolved pr comments; 2. reverted maro/cli/k8s

* fixed some bugs

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* RL renaming v2 (#476)

* Change all Logger in RL to LoggerV2

* TrainerManager => TrainingManager

* Add Trainer suffix to all algorithms

* Finish docs

* Update interface names

* Minor fix

* Cherry pick latest RL (#498)

* Cherry pick

* Remove SC related files

* Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509)

* Cherry pick RL changes from sc_refinement (2a4869)

* Limit time display precision

* RL incremental refactor (#501)

* Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training.

AC & PPO for continuous action policy; refine AC & PPO logic.

Cherry pick RL changes from GYM-DDPG

Cherry pick RL changes from GYM-SAC

Minor error in doc string

* Add min_n_sample in template and parser

* Resolve PR comments. Fix a minor issue in SAC.

* RL component bundle (#513)

* CIM passed

* Update workers

* Refine annotations

* VM passed

* Code formatting.

* Minor import loop issue

* Pass batch in PPO again

* Remove Scenario

* Complete docs

* Minor

* Remove segment

* Optimize logic in RLComponentBundle

* Resolve PR comments

* Move 'post methods from RLComponenetBundle to EnvSampler

* Add method to get mapping of available tick to frame index (#415)

* add method to get mapping of available tick to frame index

* fix lint issue

* fix naming issue

* Cherry pick from sc_refinement (#527)

* Cherry pick from sc_refinement

* Cherry pick from sc_refinement

* Refine `terminal` / `next_agent_state` logic (#531)

* Optimize RL toolkit

* Fix bug in terminal/next_state generation

* Rewrite terminal/next_state logic again

* Minor renaming

* Minor bug fix

* Resolve PR comments

* Merge master into v0.3 (#536)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* add branch v0.3 to github workflow

* update github test workflow

* Update requirements.dev.txt (#444)

Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact.

* Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460)

Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3.
- [Release notes](https://github.com/ipython/ipython/releases)
- [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3)

---
updated-dependencies:
- dependency-name: ipython
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Add & sort requirements.dev.txt

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Merge master into v0.3 (#545)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* add branch v0.3 to github workflow

* update github test workflow

* Update requirements.dev.txt (#444)

Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact.

* Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460)

Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3.
- [Release notes](https://github.com/ipython/ipython/releases)
- [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3)

---
updated-dependencies:
- dependency-name: ipython
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* update github woorkflow config

* MARO v0.3: a new design of RL Toolkit, CLI refactorization, and corresponding updates. (#539)

* refined proxy coding style

* updated images and refined doc

* updated images

* updated CIM-AC example

* refined proxy retry logic

* call policy update only for AbsCorePolicy

* add limitation of AbsCorePolicy in Actor.collect()

* refined actor to return only experiences for policies that received new experiences

* fix MsgKey issue in rollout_manager

* fix typo in learner

* call exit function for parallel rollout manager

* update supply chain example distributed training scripts

* 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype

* reformat render

* fix supply chain business engine action type problem

* reset supply chain example render figsize from 4 to 3

* Add render to all modes of supply chain example

* fix or policy typos

* 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes

* refined parallel policy manager

* updated rl/__init__/py

* fixed lint issues and CIM local learner bugs

* deleted unwanted supply_chain test files

* revised default config for cim-dqn

* removed test_store.py as it is no longer needed

* 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms

* updated figures

* removed unwanted import

* refactored CIM-DQN example

* added MultiProcessRolloutManager and MultiProcessTrainingManager

* updated doc

* lint issue fix

* lint issue fix

* fixed import formatting

* [Feature] Prioritized Experience Replay (#355)

* added prioritized experience replay

* deleted unwanted supply_chain test files

* fixed import order

* import fix

* fixed lint issues

* fixed import formatting

* added note in docstring that rank-based PER has yet to be implemented

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* rm AbsDecisionGenerator

* small fixes

* bug fix

* reorganized training folder structure

* fixed lint issues

* fixed lint issues

* policy manager refined

* lint fix

* restructured CIM-dqn sync code

* added policy version index and used it as a measure of experience staleness

* lint issue fix

* lint issue fix

* switched log_dir and proxy_kwargs order

* cim example refinement

* eval schedule sorted only when it's a list

* eval schedule sorted only when it's a list

* update sc env wrapper

* added docker scripts for cim-dqn

* refactored example folder structure and added workflow templates

* fixed lint issues

* fixed lint issues

* fixed template bugs

* removed unused imports

* refactoring sc in progress

* simplified cim meta

* fixed build.sh path bug

* template refinement

* deleted obsolete svgs

* updated learner logs

* minor edits

* refactored templates for easy merge with async PR

* added component names for rollout manager and policy manager

* fixed incorrect position to add last episode to eval schedule

* added max_lag option in templates

* formatting edit in docker_compose_yml script

* moved local learner and early stopper outside sync_tools

* refactored rl toolkit folder structure

* refactored rl toolkit folder structure

* moved env_wrapper and agent_wrapper inside rl/learner

* refined scripts

* fixed typo in script

* changes needed for running sc

* removed unwanted imports

* config change for testing sc scenario

* changes for perf testing

* Asynchronous Training (#364)

* remote inference code draft

* changed actor to rollout_worker and updated init files

* removed unwanted import

* updated inits

* more async code

* added async scripts

* added async training code & scripts for CIM-dqn

* changed async to async_tools to avoid conflict with python keyword

* reverted unwanted change to dockerfile

* added doc for policy server

* addressed PR comments and fixed a bug in docker_compose_yml.py

* fixed lint issue

* resolved PR comment

* resolved merge conflicts

* added async templates

* added proxy.close() for actor and policy_server

* fixed incorrect position to add last episode to eval schedule

* reverted unwanted changes

* added missing async files

* rm unwanted echo in kill.sh

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* renamed sync to synchronous and async to asynchronous to avoid conflict with keyword

* added missing policy version increment in LocalPolicyManager

* refined rollout manager recv logic

* removed a debugging print

* added sleep in distributed launcher to avoid hanging

* updated api doc and rl toolkit doc

* refined dynamic imports using importlib

* 1. moved policy update triggers to policy manager; 2. added version control in policy manager

* fixed a few bugs and updated cim RL example

* fixed a few more bugs

* added agent wrapper instantiation to workflows

* added agent wrapper instantiation to workflows

* removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet

* fixed incorrect get_ac_policy signature for CIM

* moved exploration inside core policy

* added state to exploration call to support context-dependent exploration

* separated non_rl_policy_index and rl_policy_index in workflows

* modified sc example code according to workflow changes

* modified sc example code according to workflow changes

* added replay_agent_ids parameter to get_env_func for RL examples

* fixed a few bugs

* added maro/simulator/scenarios/supply_chain as bind mount

* added post-step, post-collect, post-eval and post-update callbacks

* fixed lint issues

* fixed lint issues

* moved instantiation of policy manager inside simple learner

* fixed env_wrapper get_reward signature

* minor edits

* removed get_eperience kwargs from env_wrapper

* 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows

* added rollout exp disribution option in RL examples

* removed unwanted files

* 1. made logger internal in learner; 2 removed logger creation in abs classes

* checked out supply chain test files from v0.2_sc

* 1. added missing model.eval() to choose_action; 2.added entropy features to AC

* fixed a bug in ac entropy

* abbreviated coefficient to coeff

* removed -dqn from job name in rl example config

* added tmp patch to dev.df

* renamed image name for running rl examples

* added get_loss interface for core policies

* added policy manager in rl_toolkit.rst

* 1. env_wrapper bug fix; 2. policy manager update logic refinement

* refactored policy and algorithms

* policy interface redesigned

* refined policy interfaces

* fixed typo

* fixed bugs in refactored policy interface

* fixed some bugs

* refactoring in progress

* policy interface and policy manager redesigned

* 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts

* fixed bug in distributed policy manager

* fixed lint issues

* fixed lint issues

* added scipy in setup

* 1. trimmed rollout manager code; 2. added option to docker scripts

* updated api doc for policy manager

* 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script

* 1. simplified rl example structure; 2. fixed lint issues

* further rl toolkit code simplifications

* more numpy-based optimization in RL toolkit

* moved replay buffer inside policy

* bug fixes

* numpy optimization and associated refactoring

* extracted shaping logic out of env_sampler

* fixed bug in CIM shaping and lint issues

* preliminary implemetation of parallel batch inference

* fixed bug in ddpg transition recording

* put get_state, get_env_actions, get_reward back in EnvSampler

* simplified exploration and core model interfaces

* bug fixes and doc update

* added improve() interface for RLPolicy for single-thread support

* fixed simple policy manager bug

* updated doc, rst, notebook

* updated notebook

* fixed lint issues

* fixed entropy bugs in ac.py

* reverted to simple policy manager as default

* 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit

* fixed lint issues and updated rl toolkit images

* removed obsolete images

* added back agent2policy for general workflow use

* V0.2 rl refinement dist (#377)

* Support `slice` operation in ExperienceSet

* Support naive distributed policy training by proxy

* Dynamically allocate trainers according to number of experience

* code check

* code check

* code check

* Fix a bug in distributed trianing with no gradient

* Code check

* Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy

* 1.call allocate_trainer() at first of update(); 2.refine according to code review

* Code check

* Refine code with new interface

* Update docs of PolicyManger and ExperienceSet

* Add images for rl_toolkit docs

* Update diagram of PolicyManager

* Refine with new interface

* Extract allocation strategy into `allocation_strategy.py`

* add `distributed_learn()` in policies for data-parallel training

* Update doc of RL_toolkit

* Add gradient workers for data-parallel

* Refine code and update docs

* Lint check

* Refine by comments

* Rename `trainer` to `worker`

* Rename `distributed_learn` to `learn_with_data_parallel`

* Refine allocator and remove redundant code in policy_manager

* remove arugments in allocate_by_policy and so on

* added checkpointing for simple and multi-process policy managers

* 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager

* added missing set_state and get_state for CIM policies

* removed blank line

* updated RL workflow README

* Integrate `data_parallel` arguments into `worker_allocator` (#402)

* 1. simplified workflow config; 2. added comments to CIM shaping

* lint issue fix

* 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading

* 1. moved post_step callback inside env sampler; 2. updated README for rl workflows

* refined READEME for CIM

* VM scheduling with RL (#375)

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* added DQN

* added get_experiences func for ac in vm scheduling

* added post_step callback to env wrapper

* moved Aiming's tracking and plotting logic into callbacks

* added eval env wrapper

* renamed AC config variable name for VM

* vm scheduling RL code finished

* updated README

* fixed various bugs and hard coding for vm_scheduling

* uncommented callbacks for VM scheduling

* Minor revision for better code style

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* vm example refactoring

* fixed bugs in vm_scheduling

* removed unwanted files from cim dir

* reverted to simple policy manager as default

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* resolved rebase conflicts

* fixed bugs in vm_scheduling

* added get_state and set_state to vm_scheduling policy models

* updated README for vm_scheduling with RL

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* SC refinement (#397)

* Refine test scripts & pending_order_daily logic

* Refactor code for better code style: complete type hint, correct typos, remove unused items.

Refactor code for better code style: complete type hint, correct typos, remove unused items.

* Polish test_supply_chain.py

* update import format

* Modify vehicle steps logic & remove outdated test case

* Optimize imports

* Optimize imports

* Lint error

* Lint error

* Lint error

* Add SupplyChainAction

* Lint error

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* refined workflow scripts

* fixed bug in ParallelAgentWrapper

* 1. fixed lint issues; 2. refined main script in workflows

* lint issue fix

* restored default config for rl example

* Update rollout.py

* refined env var processing in policy manager workflow

* added hasattr check in agent wrapper

* updated docker_compose_yml.py

* Minor refinement

* Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412)

* Prepare to merge master_mirror

* Lint error

* Minor

* Merge latest master into v0.3 (#426)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* Minor

* Remove docs/source/examples/multi_agent_dqn_cim.rst

* Update .gitignore

* Update .gitignore

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>

* Change `Env.set_seed()` logic (#456)

* Change Env.set_seed() logic

* Redesign CIM reset logic; fix lint issues;

* Lint

* Seed type assertion

* Remove all SC related files (#473)

* RL Toolkit V3 (#471)

* added daemon=True for multi-process rollout, policy manager and inference

* removed obsolete files

* [REDO][PR#406]V0.2 rl refinement taskq (#408)

* Add a usable task_queue

* Rename some variables

* 1. Add ; 2. Integrate  related files; 3. Remove

* merge `data_parallel` and `num_grad_workers` into `data_parallelism`

* Fix bugs in docker_compose_yml.py and Simple/Multi-process mode.

* Move `grad_worker` into marl/rl/workflows

* 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue.

* Refine code and update docs of `TaskQueue`

* Support priority for tasks in `task_queue`

* Update diagram of policy manager and task queue.

* Add configurable `single_task_limit` and correct docstring about `data_parallelism`

* Fix lint errors in `supply chain`

* RL policy redesign (V2) (#405)

* Drafi v2.0 for V2

* Polish models with more comments

* Polish policies with more comments

* Lint

* Lint

* Add developer doc for models.

* Add developer doc for policies.

* Remove policy manager V2 since it is not used and out-of-date

* Lint

* Lint

* refined messy workflow code

* merged 'scenario_dir' and 'scenario' in rl config

* 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods

* 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2

* merged cim and cim_v2

* lint issue fix

* refined logging logic

* lint issue fix

* reversed unwanted changes

* .

.

.

.

ReplayMemory & IndexScheduler

ReplayMemory & IndexScheduler

.

MultiReplayMemory

get_actions_with_logps

EnvSampler on the road

EnvSampler

Minor

* LearnerManager

* Use batch to transfer data & add SHAPE_CHECK_FLAG

* Rename learner to trainer

* Add property for policy._is_exploring

* CIM test scenario for V3. Manual test passed. Next step: run it, make it works.

* env_sampler.py could run

* env_sampler refine on the way

* First runnable version done

* AC could run, but the result is bad. Need to check the logic

* Refine abstract method & shape check error info.

* Docs

* Very detailed compare. Try again.

* AC done

* DQN check done

* Minor

* DDPG, not tested

* Minors

* A rough draft of MAAC

* Cannot use CIM as the multi-agent scenario.

* Minor

* MAAC refinement on the way

* Remove ActionWithAux

* Refine batch & memory

* MAAC example works

* Reproduce-able fix. Policy share between env_sampler and trainer_manager.

* Detail refinement

* Simplify the user configed workflow

* Minor

* Refine example codes

* Minor polishment

* Migrate rollout_manager to V3

* Error on the way

* Redesign torch.device management

* Rl v3 maddpg (#418)

* Add MADDPG trainer

* Fit independent critics and shared critic modes.

* Add a new property: num_policies

* Lint

* Fix a bug in `sum(rewards)`

* Rename `MADDPG` to `DiscreteMADDPG` and fix type hint.

* Rename maddpg in examples.

* Preparation for data parallel (#420)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* Rename train worker to train ops; add placeholder for abstract methods;

* Lint

Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>

* [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* dsitributed training pipeline draft

* added temporary test files for review purposes

* Several code style refinements (#451)

* Polish rl_v3/utils/

* Polish rl_v3/distributed/

* Polish rl_v3/policy_trainer/abs_trainer.py

* fixed merge conflicts

* unified sync and async interfaces

* refactored rl_v3; refinement in progress

* Finish the runnable pipeline under new design

* Remove outdated files; refine class names; optimize imports;

* Lint

* Minor maddpg related refinement

* Lint

Co-authored-by: Default <huo53926@126.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Miner bug fix

* Coroutine-related bug fix ("get_policy_state") (#452)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* deleted unwanted folder

* removed unwanted changes

* resolved PR452 comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Quick fix

* Redesign experience recording logic (#453)

* Two not important fix

* Temp draft. Prepare to WFH

* Done

* Lint

* Lint

* Calculating advantages / returns (#454)

* V1.0

* Complete DDPG

* Rl v3 hanging issue fix (#455)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* Final test & format. Ready to merge.

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* Rl v3 parallel rollout (#457)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* load balancing dispatcher

* added parallel rollout

* lint

* Tracker variable type issue; rename to env_sampler_creator;

* Rl v3 parallel rollout follow ups (#458)

* AbsWorker & AbsDispatcher

* Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds.

* Fix policy_creator reuse bug

* Format code

* Merge AbsTrainerManager & SimpleTrainerManager

* AC test passed

* Lint

* Remove AbsTrainer.build() method. Put all initialization operations into __init__

* Redesign AC preprocess batches logic

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* MADDPG performance bug fix (#459)

* Fix MARL (MADDPG) terminal recording bug; some other minor refinements;

* Restore Trainer.build() method

* Calculate latest action in the get_actor_grad method in MADDPG.

* Share critic bug fix

* Rl v3 example update (#461)

* updated vm_scheduling example and cim notebook

* fixed bugs in vm_scheduling

* added local train method

* bug fix

* modified async client logic to fix hidden issue

* reverted to default config

* fixed PR comments and some bugs

* removed hardcode

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Done (#462)

* Rl v3 load save (#463)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL Toolkit data parallelism revamp & config utils (#464)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

* 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local

* 1. fixed rollout exit issue; 2. refined config

* removed config file from example

* fixed lint issues

* fixed lint issues

* added main.py under examples/rl

* fixed lint issues

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL doc string (#465)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* Rl config doc (#467)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* resolved more PR comments

* typo fix

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL online doc (#469)

* Model, policy, trainer

* RL workflows and env sampler doc in RST (#468)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* rewriting rl toolkit rst

* resolved more PR comments

* typo fix

* updated rst

Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: Default <huo53926@126.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Finish docs/source/key_components/rl_toolkit.rst

* API doc

* RL online doc image fix (#470)

* resolved some PR comments

* fix

* fixed PR comments

* added numfig=True setting in conf.py for sphinx

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Resolve PR comments

* Add example github link

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Rl v3 pr comment resolution (#474)

* added load/save feature

* 1. resolved pr comments; 2. reverted maro/cli/k8s

* fixed some bugs

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* RL renaming v2 (#476)

* Change all Logger in RL to LoggerV2

* TrainerManager => TrainingManager

* Add Trainer suffix to all algorithms

* Finish docs

* Update interface names

* Minor fix

* Cherry pick latest RL (#498)

* Cherry pick

* Remove SC related files

* Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509)

* Cherry pick RL changes from sc_refinement (2a4869)

* Limit time display precision

* RL incremental refactor (#501)

* Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training.

AC & PPO for continuous action policy; refine AC & PPO logic.

Cherry pick RL changes from GYM-DDPG

Cherry pick RL changes from GYM-SAC

Minor error in doc string

* Add min_n_sample in template and parser

* Resolve PR comments. Fix a minor issue in SAC.

* RL component bundle (#513)

* CIM passed

* Update workers

* Refine annotations

* VM passed

* Code formatting.

* Minor import loop issue

* Pass batch in PPO again

* Remove Scenario

* Complete docs

* Minor

* Remove segment

* Optimize logic in RLComponentBundle

* Resolve PR comments

* Move 'post methods from RLComponenetBundle to EnvSampler

* Add method to get mapping of available tick to frame index (#415)

* add method to get mapping of available tick to frame index

* fix lint issue

* fix naming issue

* Cherry pick from sc_refinement (#527)

* Cherry pick from sc_refinement

* Cherry pick from sc_refinement

* Refine `terminal` / `next_agent_state` logic (#531)

* Optimize RL toolkit

* Fix bug in terminal/next_state generation

* Rewrite terminal/next_state logic again

* Minor renaming

* Minor bug fix

* Resolve PR comments

* Merge master into v0.3 (#536)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* add branch v0.3 to github workflow

* update github test workflow

* Update requirements.dev.txt (#444)

Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact.

* Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460)

Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3.
- [Release notes](https://github.com/ipython/ipython/releases)
- [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3)

---
updated-dependencies:
- dependency-name: ipython
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Add & sort requirements.dev.txt

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Remove random_config.py

* Remove test_trajectory_utils.py

* Pass tests

* Update rl docs

* Remove python 3.6 in test

* Update docs

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: Wang.Jinyu <jinywan@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: GQ.Chen <675865907@qq.com>
Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: Chaos Yu <chaos.you@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Logger bug hotfix (#543)

* Rename param

* Rename param

* Quick fix in env_data_process

* frame data precision issue fix (#544)

* fix frame precision issue

* add .xmake to .gitignore

* update frame precision lost warning message

* add assert to frame precision checking

* typo fix

* add TODO for future Long data type issue fix

* Minor cleaning

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jinyu Wang <jinyu@RL4Inv.l1ea1prscrcu1p4sa0eapum5vc.bx.internal.cloudapp.net>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: GQ.Chen <675865907@qq.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: Chaos Yu <chaos.you@gmail.com>

* Update requirements. (#552)

* Fix several encoding issues; update requirements.

* Test & minor

* Remove torch in requirements.build.txt

* Polish

* Update README

* Resolve PR comments

* Keep working

* Keep working

* Update test requirements

* Done (#554)

* Update requirements in example and notebook (#553)

* Update requirements in example and notebook

* Remove autopep8

* Add jupyterlab packages back

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>

* Refine decision event logic (#559)

* Add DecisionEventPayload

* Change decision payload name

* Refine action logic

* Add doc for env.step

* Restore pre-commit config

* Resolve PR comments

* Refactor decision event & action

* Pre-commit

* Resolve PR comments

* Refine rl component bundle (#549)

* Config files

* Done

* Minor bugfix

* Add autoflake

* Update isort exclude; add pre-commit to requirements

* Check only isort

* Minor

* Format

* Test passed

* Run pre-commit

* Minor bugfix in rl_component_bundle

* Pass mypy

* Fix a bug in RL notebook

* A minor bug fix

* Add upper bound for numpy version in test

* Remove numpy data type (#571)

* Change numpy data type; change test requirements.

* Lint

* RL benchmark on GYM (#575)

* PPO, SAC, DDPG passed

* Explore in SAC

* Test GYM on server

* Sync server changes

* pre-commit

* Ready to try on server

* .

* .

* .

* .

* .

* Performance OK

* Move to tests

* Remove old versions

* PPO done

* Start to test AC

* Start to test SAC

* SAC test passed

* update for some PR comments; Add a MARKDOWN file (#576)

Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com>

* Use FullyConnected to replace mlp

* Update action bound

* Pre-commit

---------

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com>

* Refine RL workflow & tune RL models under GYM (#577)

* PPO, SAC, DDPG passed

* Explore in SAC

* Test GYM on server

* Sync server changes

* pre-commit

* Ready to try on server

* .

* .

* .

* .

* .

* Performance OK

* Move to tests

* Remove old versions

* PPO done

* Start to test AC

* Start to test SAC

* SAC test passed

* Multiple round in evaluation

* Modify config.yml

* Add Callbacks

* [wip] SAC performance not good

* [wip] still not good

* update for some PR comments; Add a MARKDOWN file (#576)

Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com>

* Use FullyConnected to replace mlp

* Update action bound

* ???

* Change gym env wrapper metrics logci

* Change gym env wrapper metrics logci

* refine env_sampler.sample under step mode

* Add DDPG. Performance not good...

* Add DDPG. Performance not good...

* wip

* Sounds like sac works

* Refactor file structure

* Refactor file structure

* Refactor file structure

* Pre-commit

* Pre commit

* Minor refinement of CIM RL

* Jinyu/rl workflow refine (#578)

* remove useless files; add device mapping; update pdoc

* add default checkpoint path; fix distributed worker log path issue; update example log path

* update performance doc

* remove tests/rl/algorithms folder

* Resolve PR comments

* Compare PPO with spinning up (#579)

* [wip] compare PPO

* PPO matching

* Revert unnecessary changes

* Minor

* Minor

* SAC Test parameters update (#580)

* fix sac to_device issue; update sac gym test parameters

* add rl test performance plot func

* update sac eval interval config

* update sac checkpoint interval config

* fix callback issue

* update plot func

* update plot func

* update plot func

* update performance doc; upload performance images

* Minor fix in callbacks; refine plot.py format.

* Add n_interactions. Use n_interactions to plot curves.

* pre-commit

---------

Co-authored-by: Huoran Li <huo53926@126.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* Episode truncation & early stopping (#581)

* Add truncated logic

* (To be tested) early stop

* Early stop test passed

* Test passed

* Random action. To be tested.

* Warmup OK

* Pre-commit

* random seed

* Revert pre-commit config

---------

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com>

* DDPG parameters update (#583)

* Tune params

* fix conflict

* remove duplicated seed setting

---------

Co-authored-by: Huoran Li <huo53926@126.com>

* Update RL Benchmarks (#584)

* update plot func for rl tests

* Refine seed setting logic

* Refine metrics logic; add warmup to ddpg.

* Complete ddpg config

* Minor refinement of GymEnvSampler and plot.py

* update rl benchmark performance results

* Lint

---------

Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: Huoran Li <huo53926@126.com>

* Update Input Template of RL Policy to Improve Module Flexisiblity (#589)

* add customized_callbacks to RLComponentBundle

* add env.tick to replace the default None in AbsEnvSampler._get_global_and_agent_state()

* fix rl algorithms to_device issue

* add kwargs to RL models' forward funcs and _shape_check()

* add kwargs to RL policies' get_action related funcs and _post_check()

* add detached loss to the return value of update_critic() and update_actor() of current TrainOps; add default False early_stop to update_actor() of current TrainOps

* add kwargs to choose_actions of AbsEnvSampler; remain it None in current sample() and eval()

* ufix line length issue

* fix line break issue

---------

Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com>

* update code version to 0.3.2a1

---------

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: GQ.Chen <675865907@qq.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: Huoran Li <huo53926@126.com>
Co-authored-by: Chaos Yu <chaos.you@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jinyu Wang <jinyu@RL4Inv.l1ea1prscrcu1p4sa0eapum5vc.bx.internal.cloudapp.net>
This commit is contained in:
Jinyu-W 2023-03-30 09:57:25 +08:00 коммит произвёл GitHub
Родитель 94548c7be8
Коммит ef2a358a1d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
76 изменённых файлов: 2223 добавлений и 550 удалений

Просмотреть файл

@ -11,6 +11,7 @@ from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
actor_net_conf = {
"hidden_dims": [256, 128, 64],
"activation": torch.nn.Tanh,
"output_activation": torch.nn.Tanh,
"softmax": True,
"batch_norm": False,
"head": True,
@ -19,6 +20,7 @@ critic_net_conf = {
"hidden_dims": [256, 128, 64],
"output_dim": 1,
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"head": True,

Просмотреть файл

@ -12,6 +12,7 @@ from maro.rl.training.algorithms import DQNParams, DQNTrainer
q_net_conf = {
"hidden_dims": [256, 128, 64, 32],
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"skip_connection": False,

Просмотреть файл

@ -14,6 +14,7 @@ from maro.rl.training.algorithms import DiscreteMADDPGParams, DiscreteMADDPGTrai
actor_net_conf = {
"hidden_dims": [256, 128, 64],
"activation": torch.nn.Tanh,
"output_activation": torch.nn.Tanh,
"softmax": True,
"batch_norm": False,
"head": True,
@ -22,6 +23,7 @@ critic_net_conf = {
"hidden_dims": [256, 128, 64],
"output_dim": 1,
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"head": True,

Просмотреть файл

@ -90,11 +90,25 @@ class CIMEnvSampler(AbsEnvSampler):
for info in info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")
# print the average env metric
if len(info_list) > 1:
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
print(f"average env summary (episode {ep}): {avg_metric}")
# average env metric
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
print(f"average env summary (episode {ep}): {avg_metric}")
self.metrics.update(avg_metric)
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
def post_evaluate(self, info_list: list, ep: int) -> None:
self.post_collect(info_list, ep)
# print the env metric from each rollout worker
for info in info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")
# average env metric
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
print(f"average env summary (episode {ep}): {avg_metric}")
self.metrics.update({"val/" + k: v for k, v in avg_metric.items()})
def monitor_metrics(self) -> float:
return -self.metrics["val/container_shortage"]

Просмотреть файл

@ -13,7 +13,7 @@ from examples.cim.rl.env_sampler import CIMEnvSampler
# Environments
learn_env = Env(**env_conf)
test_env = learn_env
test_env = Env(**env_conf)
# Agent, policy, and trainers
num_agents = len(learn_env.agent_idx_list)

Просмотреть файл

@ -7,7 +7,7 @@ This folder contains scenarios that employ reinforcement learning. MARO's RL too
The entrance of a RL workflow is a YAML config file. For readers' convenience, we call this config file `config.yml` in the rest part of this doc. `config.yml` specifies the path of all necessary resources, definitions, and configurations to run the job. MARO provides a comprehensive template of the config file with detailed explanations (`maro/maro/rl/workflows/config/template.yml`). Meanwhile, MARO also provides several simple examples of `config.yml` under the current folder.
There are two ways to start the RL job:
- If you only need to have a quick look and try to start an out-of-box workflow, just run `python .\examples\rl\run_rl_example.py PATH_TO_CONFIG_YAML`. For example, `python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml` will run the complete example RL training workflow of CIM scenario. If you only want to run the evaluation workflow, you could start the job with `--evaluate_only`.
- If you only need to have a quick look and try to start an out-of-box workflow, just run `python .\examples\rl\run.py PATH_TO_CONFIG_YAML`. For example, `python .\examples\rl\run.py .\examples\rl\cim.yml` will run the complete example RL training workflow of CIM scenario. If you only want to run the evaluation workflow, you could start the job with `--evaluate_only`.
- (**Require install MARO from source**) You could also start the job through MARO CLI. Use the command `maro local run [-c] path/to/your/config` to run in containerized (with `-c`) or non-containerized (without `-c`) environments. Similar, you could add `--evaluate_only` if you only need to run the evaluation workflow.
## Create Your Own Scenarios

Просмотреть файл

@ -5,16 +5,17 @@
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
# Run this workflow by executing one of the following commands:
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml
# - python ./examples/rl/run.py ./examples/rl/cim.yml
# - (Requires installing MARO from source) maro local run ./examples/rl/cim.yml
job: cim_rl_workflow
scenario_path: "examples/cim/rl"
log_path: "log/rl_job/cim.txt"
log_path: "log/cim_rl/"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
eval_schedule: 5
early_stop_patience: 5
logging:
stdout: INFO
file: DEBUG
@ -27,7 +28,7 @@ training:
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/cim"
path: "log/cim_rl/checkpoints"
interval: 5
logging:
stdout: INFO

Просмотреть файл

@ -1,16 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for CIM scenario.
# Example RL config file for CIM scenario (distributed version).
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
# Run this workflow by executing one of the following commands:
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml
# - python ./examples/rl/run.py ./examples/rl/cim_distributed.yml
# - (Requires installing MARO from source) maro local run ./examples/rl/cim_distributed.yml
job: cim_rl_workflow
scenario_path: "examples/cim/rl"
log_path: "log/rl_job/cim.txt"
log_path: "log/cim_rl/"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
@ -35,7 +35,7 @@ training:
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/cim"
path: "log/cim_rl/checkpoints"
interval: 5
proxy:
host: "127.0.0.1"

Просмотреть файл

Просмотреть файл

@ -5,12 +5,12 @@
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
# Run this workflow by executing one of the following commands:
# - python .\examples\rl\run_rl_example.py .\examples\rl\vm_scheduling.yml
# - (Requires installing MARO from source) maro local run .\examples\rl\vm_scheduling.yml
# - python ./examples/rl/run.py ./examples/rl/vm_scheduling.yml
# - (Requires installing MARO from source) maro local run ./examples/rl/vm_scheduling.yml
job: vm_scheduling_rl_workflow
scenario_path: "examples/vm_scheduling/rl"
log_path: "log/rl_job/vm_scheduling.txt"
log_path: "log/vm_rl/"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
@ -27,7 +27,7 @@ training:
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/vm_scheduling"
path: "log/vm_rl/checkpoints"
interval: 5
logging:
stdout: INFO

Просмотреть файл

@ -11,6 +11,7 @@ from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
actor_net_conf = {
"hidden_dims": [64, 32, 32],
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": True,
"batch_norm": False,
"head": True,
@ -19,6 +20,7 @@ actor_net_conf = {
critic_net_conf = {
"hidden_dims": [256, 128, 64],
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": False,
"head": True,

Просмотреть файл

@ -14,6 +14,7 @@ from maro.rl.training.algorithms import DQNParams, DQNTrainer
q_net_conf = {
"hidden_dims": [64, 128, 256],
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": False,
"skip_connection": False,

Просмотреть файл

@ -2,6 +2,6 @@
# Licensed under the MIT license.
__version__ = "0.3.1a2"
__version__ = "0.3.2a1"
__data_version__ = "0.2"

Просмотреть файл

@ -8,7 +8,6 @@ import zipfile
from enum import Enum
import geopy.distance
import numpy as np
import pandas as pd
from yaml import safe_load
@ -320,7 +319,7 @@ class CitiBikePipeline(DataPipeline):
0,
index=station_info["station_index"],
columns=station_info["station_index"],
dtype=np.float,
dtype=float,
)
look_up_df = station_info[["latitude", "longitude"]]
return distance_adj.apply(
@ -617,7 +616,7 @@ class CitiBikeToyPipeline(DataPipeline):
0,
index=station_init["station_index"],
columns=station_init["station_index"],
dtype=np.float,
dtype=float,
)
look_up_df = station_init[["latitude", "longitude"]]
distance_df = distance_adj.apply(

Просмотреть файл

@ -61,7 +61,7 @@ def get_redis_conn(port=None):
# Functions executed on CLI commands
def run(conf_path: str, containerize: bool = False, evaluate_only: bool = False, **kwargs):
def run(conf_path: str, containerize: bool = False, seed: int = None, evaluate_only: bool = False, **kwargs):
# Load job configuration file
parser = ConfigParser(conf_path)
if containerize:
@ -71,13 +71,14 @@ def run(conf_path: str, containerize: bool = False, evaluate_only: bool = False,
LOCAL_MARO_ROOT,
DOCKERFILE_PATH,
DOCKER_IMAGE_NAME,
seed=seed,
evaluate_only=evaluate_only,
)
except KeyboardInterrupt:
stop_rl_job_with_docker_compose(parser.config["job"], LOCAL_MARO_ROOT)
else:
try:
start_rl_job(parser, LOCAL_MARO_ROOT, evaluate_only=evaluate_only)
start_rl_job(parser, LOCAL_MARO_ROOT, seed=seed, evaluate_only=evaluate_only)
except KeyboardInterrupt:
sys.exit(1)

Просмотреть файл

@ -4,7 +4,7 @@
import os
import subprocess
from copy import deepcopy
from typing import List
from typing import List, Optional
import docker
import yaml
@ -110,12 +110,15 @@ def exec(cmd: str, env: dict, debug: bool = False) -> subprocess.Popen:
def start_rl_job(
parser: ConfigParser,
maro_root: str,
seed: Optional[int],
evaluate_only: bool,
background: bool = False,
) -> List[subprocess.Popen]:
procs = [
exec(
f"python {script}" + ("" if not evaluate_only else " --evaluate_only"),
f"python {script}"
+ ("" if not evaluate_only else " --evaluate_only")
+ ("" if seed is None else f" --seed {seed}"),
format_env_vars({**env, "PYTHONPATH": maro_root}, mode="proc"),
debug=not background,
)
@ -169,6 +172,7 @@ def start_rl_job_with_docker_compose(
context: str,
dockerfile_path: str,
image_name: str,
seed: Optional[int],
evaluate_only: bool,
) -> None:
common_spec = {
@ -185,7 +189,9 @@ def start_rl_job_with_docker_compose(
**deepcopy(common_spec),
**{
"container_name": component,
"command": f"python3 {script}" + ("" if not evaluate_only else " --evaluate_only"),
"command": f"python3 {script}"
+ ("" if not evaluate_only else " --evaluate_only")
+ ("" if seed is None else f" --seed {seed}"),
"environment": format_env_vars(env, mode="docker-compose"),
},
}

Просмотреть файл

@ -4,7 +4,7 @@
from __future__ import annotations
from abc import ABCMeta
from typing import Any, Dict
from typing import Any, Dict, Optional
import torch.nn
from torch.optim import Optimizer
@ -18,6 +18,8 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
def __init__(self) -> None:
super(AbsNet, self).__init__()
self._device: Optional[torch.device] = None
@property
def optim(self) -> Optimizer:
optim = getattr(self, "_optim", None)
@ -119,3 +121,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
"""Unfreeze all parameters."""
for p in self.parameters():
p.requires_grad = True
def to_device(self, device: torch.device) -> None:
self._device = device
self.to(device)

Просмотреть файл

@ -43,14 +43,23 @@ class ContinuousACBasedNet(ContinuousPolicyNet, metaclass=ABCMeta):
- set_state(self, net_state: dict) -> None:
"""
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
actions, _ = self._get_actions_with_logps_impl(states, exploring)
def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
actions, _ = self._get_actions_with_logps_impl(states, exploring, **kwargs)
return actions
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_actions_with_probs_impl(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Not used in Actor-Critic or PPO
pass
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
# Not used in Actor-Critic or PPO
pass
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
# Not used in Actor-Critic or PPO
pass

Просмотреть файл

@ -25,18 +25,32 @@ class ContinuousDDPGNet(ContinuousPolicyNet, metaclass=ABCMeta):
- set_state(self, net_state: dict) -> None:
"""
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_actions_with_probs_impl(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Not used in DDPG
pass
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_actions_with_logps_impl(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Not used in DDPG
pass
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
# Not used in DDPG
pass
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
# Not used in DDPG
pass
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
# Not used in DDPG
pass

Просмотреть файл

@ -25,18 +25,23 @@ class ContinuousSACNet(ContinuousPolicyNet, metaclass=ABCMeta):
- set_state(self, net_state: dict) -> None:
"""
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
actions, _ = self._get_actions_with_logps_impl(states, exploring)
return actions
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_actions_with_probs_impl(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Not used in SAC
pass
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
# Not used in SAC
pass
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
# Not used in SAC
pass

Просмотреть файл

@ -39,7 +39,8 @@ class FullyConnected(nn.Module):
input_dim: int,
output_dim: int,
hidden_dims: List[int],
activation: Optional[Type[torch.nn.Module]] = nn.ReLU,
activation: Optional[Type[torch.nn.Module]] = None,
output_activation: Optional[Type[torch.nn.Module]] = None,
head: bool = False,
softmax: bool = False,
batch_norm: bool = False,
@ -54,7 +55,8 @@ class FullyConnected(nn.Module):
self._output_dim = output_dim
# network features
self._activation = activation() if activation else None
self._activation = activation if activation else None
self._output_activation = output_activation if output_activation else None
self._head = head
self._softmax = nn.Softmax(dim=1) if softmax else None
self._batch_norm = batch_norm
@ -70,9 +72,13 @@ class FullyConnected(nn.Module):
# build the net
dims = [self._input_dim] + self._hidden_dims
layers = [self._build_layer(in_dim, out_dim) for in_dim, out_dim in zip(dims, dims[1:])]
layers = [
self._build_layer(in_dim, out_dim, activation=self._activation) for in_dim, out_dim in zip(dims, dims[1:])
]
# top layer
layers.append(self._build_layer(dims[-1], self._output_dim, head=self._head))
layers.append(
self._build_layer(dims[-1], self._output_dim, head=self._head, activation=self._output_activation),
)
self._net = nn.Sequential(*layers)
@ -101,7 +107,13 @@ class FullyConnected(nn.Module):
def output_dim(self) -> int:
return self._output_dim
def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> nn.Module:
def _build_layer(
self,
input_dim: int,
output_dim: int,
head: bool = False,
activation: Type[torch.nn.Module] = None,
) -> nn.Module:
"""Build a basic layer.
BN -> Linear -> Activation -> Dropout
@ -110,8 +122,8 @@ class FullyConnected(nn.Module):
if self._batch_norm:
components.append(("batch_norm", nn.BatchNorm1d(input_dim)))
components.append(("linear", nn.Linear(input_dim, output_dim)))
if not head and self._activation is not None:
components.append(("activation", self._activation))
if not head and activation is not None:
components.append(("activation", activation()))
if not head and self._dropout_p:
components.append(("dropout", nn.Dropout(p=self._dropout_p)))
return nn.Sequential(OrderedDict(components))

Просмотреть файл

@ -37,7 +37,7 @@ class MultiQNet(AbsNet, metaclass=ABCMeta):
def agent_num(self) -> int:
return len(self._action_dims)
def _shape_check(self, states: torch.Tensor, actions: List[torch.Tensor] = None) -> bool:
def _shape_check(self, states: torch.Tensor, actions: List[torch.Tensor] = None, **kwargs) -> bool:
"""Check whether the states and actions have valid shapes.
Args:
@ -61,7 +61,7 @@ class MultiQNet(AbsNet, metaclass=ABCMeta):
return False
return True
def q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor:
def q_values(self, states: torch.Tensor, actions: List[torch.Tensor], **kwargs) -> torch.Tensor:
"""Get Q-values according to states and actions.
Args:
@ -71,8 +71,8 @@ class MultiQNet(AbsNet, metaclass=ABCMeta):
Returns:
q (torch.Tensor): Q-values with shape [batch_size].
"""
assert self._shape_check(states, actions)
q = self._get_q_values(states, actions)
assert self._shape_check(states, actions, **kwargs)
q = self._get_q_values(states, actions, **kwargs)
assert match_shape(
q,
(states.shape[0],),
@ -80,6 +80,6 @@ class MultiQNet(AbsNet, metaclass=ABCMeta):
return q
@abstractmethod
def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor:
def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor], **kwargs) -> torch.Tensor:
"""Implementation of `q_values`."""
raise NotImplementedError

Просмотреть файл

@ -33,93 +33,121 @@ class PolicyNet(AbsNet, metaclass=ABCMeta):
def action_dim(self) -> int:
return self._action_dim
def get_actions(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
def get_actions(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
actions = self._get_actions_impl(states, exploring)
actions = self._get_actions_impl(states, exploring, **kwargs)
assert self._shape_check(
states=states,
actions=actions,
**kwargs,
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
return actions
def get_actions_with_probs(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def get_actions_with_probs(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
actions, probs = self._get_actions_with_probs_impl(states, exploring)
actions, probs = self._get_actions_with_probs_impl(states, exploring, **kwargs)
assert self._shape_check(
states=states,
actions=actions,
**kwargs,
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0]
return actions, probs
def get_actions_with_logps(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def get_actions_with_logps(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
actions, logps = self._get_actions_with_logps_impl(states, exploring)
actions, logps = self._get_actions_with_logps_impl(states, exploring, **kwargs)
assert self._shape_check(
states=states,
actions=actions,
**kwargs,
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0]
return actions, logps
def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
probs = self._get_states_actions_probs_impl(states, actions)
probs = self._get_states_actions_probs_impl(states, actions, **kwargs)
assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0]
return probs
def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
logps = self._get_states_actions_logps_impl(states, actions)
logps = self._get_states_actions_logps_impl(states, actions, **kwargs)
assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0]
return logps
@abstractmethod
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_actions_with_probs_impl(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_actions_with_logps_impl(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None) -> bool:
def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None, **kwargs) -> bool:
"""Check whether the states and actions have valid shapes.
Args:
@ -160,7 +188,7 @@ class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta):
def action_num(self) -> int:
return self._action_num
def get_action_probs(self, states: torch.Tensor) -> torch.Tensor:
def get_action_probs(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Get the probabilities for all possible actions in the action space.
Args:
@ -171,8 +199,9 @@ class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta):
"""
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
action_probs = self._get_action_probs_impl(states)
action_probs = self._get_action_probs_impl(states, **kwargs)
assert match_shape(action_probs, (states.shape[0], self.action_num)), (
f"Action probabilities shape check failed. Expecting: {(states.shape[0], self.action_num)}, "
f"actual: {action_probs.shape}."
@ -180,16 +209,21 @@ class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta):
return action_probs
@abstractmethod
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
def _get_action_probs_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Implementation of `get_action_probs`. The core logic of a discrete policy net should be implemented here."""
raise NotImplementedError
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
actions, _ = self._get_actions_with_probs_impl(states, exploring)
def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
actions, _ = self._get_actions_with_probs_impl(states, exploring, **kwargs)
return actions
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
probs = self.get_action_probs(states)
def _get_actions_with_probs_impl(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = self.get_action_probs(states, **kwargs)
if exploring:
distribution = Categorical(probs)
actions = distribution.sample().unsqueeze(1)
@ -198,16 +232,21 @@ class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta):
probs, actions = probs.max(dim=1)
return actions.unsqueeze(1), probs
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
actions, probs = self._get_actions_with_probs_impl(states, exploring)
def _get_actions_with_logps_impl(
self,
states: torch.Tensor,
exploring: bool,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
actions, probs = self._get_actions_with_probs_impl(states, exploring, **kwargs)
return actions, torch.log(probs)
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
probs = self.get_action_probs(states)
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
probs = self.get_action_probs(states, **kwargs)
return probs.gather(1, actions).squeeze(-1)
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
probs = self._get_states_actions_probs_impl(states, actions)
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
probs = self._get_states_actions_probs_impl(states, actions, **kwargs)
return torch.log(probs)
@ -221,3 +260,18 @@ class ContinuousPolicyNet(PolicyNet, metaclass=ABCMeta):
def __init__(self, state_dim: int, action_dim: int) -> None:
super(ContinuousPolicyNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
def get_random_actions(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
actions = self._get_random_actions_impl(states, **kwargs)
assert self._shape_check(
states=states,
actions=actions,
**kwargs,
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
return actions
@abstractmethod
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError

Просмотреть файл

@ -31,7 +31,7 @@ class QNet(AbsNet, metaclass=ABCMeta):
def action_dim(self) -> int:
return self._action_dim
def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None) -> bool:
def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None, **kwargs) -> bool:
"""Check whether the states and actions have valid shapes.
Args:
@ -52,7 +52,7 @@ class QNet(AbsNet, metaclass=ABCMeta):
return False
return True
def q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def q_values(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
"""Get Q-values according to states and actions.
Args:
@ -62,12 +62,12 @@ class QNet(AbsNet, metaclass=ABCMeta):
Returns:
q (torch.Tensor): Q-values with shape [batch_size].
"""
assert self._shape_check(states=states, actions=actions), (
assert self._shape_check(states=states, actions=actions, **kwargs), (
f"States or action shape check failed. Expecting: "
f"states = {('BATCH_SIZE', self.state_dim)}, action = {('BATCH_SIZE', self.action_dim)}. "
f"Actual: states = {states.shape}, action = {actions.shape}."
)
q = self._get_q_values(states, actions)
q = self._get_q_values(states, actions, **kwargs)
assert match_shape(
q,
(states.shape[0],),
@ -75,7 +75,7 @@ class QNet(AbsNet, metaclass=ABCMeta):
return q
@abstractmethod
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
"""Implementation of `q_values`."""
raise NotImplementedError
@ -96,7 +96,7 @@ class DiscreteQNet(QNet, metaclass=ABCMeta):
def action_num(self) -> int:
return self._action_num
def q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
def q_values_for_all_actions(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Get Q-values for all actions according to states.
Args:
@ -107,20 +107,21 @@ class DiscreteQNet(QNet, metaclass=ABCMeta):
"""
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
q = self._get_q_values_for_all_actions(states)
q = self._get_q_values_for_all_actions(states, **kwargs)
assert match_shape(q, (states.shape[0], self.action_num)), (
f"Q-value matrix shape check failed. Expecting: {(states.shape[0], self.action_num)}, "
f"actual: {q.shape}."
) # [B, action_num]
return q
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
q = self.q_values_for_all_actions(states) # [B, action_num]
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
q = self.q_values_for_all_actions(states, **kwargs) # [B, action_num]
return q.gather(1, actions.long()).reshape(-1) # [B, action_num] + [B, 1] => [B]
@abstractmethod
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
def _get_q_values_for_all_actions(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Implementation of `q_values_for_all_actions`."""
raise NotImplementedError

Просмотреть файл

@ -25,7 +25,7 @@ class VNet(AbsNet, metaclass=ABCMeta):
def state_dim(self) -> int:
return self._state_dim
def _shape_check(self, states: torch.Tensor) -> bool:
def _shape_check(self, states: torch.Tensor, **kwargs) -> bool:
"""Check whether the states have valid shapes.
Args:
@ -39,7 +39,7 @@ class VNet(AbsNet, metaclass=ABCMeta):
else:
return states.shape[0] > 0 and match_shape(states, (None, self.state_dim))
def v_values(self, states: torch.Tensor) -> torch.Tensor:
def v_values(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Get V-values according to states.
Args:
@ -50,8 +50,9 @@ class VNet(AbsNet, metaclass=ABCMeta):
"""
assert self._shape_check(
states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
v = self._get_v_values(states)
v = self._get_v_values(states, **kwargs)
assert match_shape(
v,
(states.shape[0],),
@ -59,6 +60,6 @@ class VNet(AbsNet, metaclass=ABCMeta):
return v
@abstractmethod
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
def _get_v_values(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Implementation of `v_values`."""
raise NotImplementedError

Просмотреть файл

@ -27,7 +27,7 @@ class AbsPolicy(object, metaclass=ABCMeta):
self._trainable = trainable
@abstractmethod
def get_actions(self, states: Union[list, np.ndarray]) -> Any:
def get_actions(self, states: Union[list, np.ndarray], **kwargs) -> Any:
"""Get actions according to states.
Args:
@ -79,7 +79,7 @@ class DummyPolicy(AbsPolicy):
def __init__(self) -> None:
super(DummyPolicy, self).__init__(name="DUMMY_POLICY", trainable=False)
def get_actions(self, states: Union[list, np.ndarray]) -> None:
def get_actions(self, states: Union[list, np.ndarray], **kwargs) -> None:
return None
def explore(self) -> None:
@ -101,11 +101,11 @@ class RuleBasedPolicy(AbsPolicy, metaclass=ABCMeta):
def __init__(self, name: str) -> None:
super(RuleBasedPolicy, self).__init__(name=name, trainable=False)
def get_actions(self, states: list) -> list:
return self._rule(states)
def get_actions(self, states: list, **kwargs) -> list:
return self._rule(states, **kwargs)
@abstractmethod
def _rule(self, states: list) -> list:
def _rule(self, states: list, **kwargs) -> list:
raise NotImplementedError
def explore(self) -> None:
@ -129,6 +129,8 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
state_dim (int): Dimension of states.
action_dim (int): Dimension of actions.
trainable (bool, default=True): Whether this policy is trainable.
warmup (int, default=0): Number of steps for uniform-random action selection, before running real policy.
Helps exploration.
"""
def __init__(
@ -138,6 +140,7 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
action_dim: int,
is_discrete_action: bool,
trainable: bool = True,
warmup: int = 0,
) -> None:
super(RLPolicy, self).__init__(name=name, trainable=trainable)
self._state_dim = state_dim
@ -145,6 +148,8 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
self._is_exploring = False
self._device: Optional[torch.device] = None
self._warmup = warmup
self._call_count = 0
self.is_discrete_action = is_discrete_action
@ -199,94 +204,122 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
"""
raise NotImplementedError
def get_actions(self, states: np.ndarray) -> np.ndarray:
actions = self.get_actions_tensor(ndarray_to_tensor(states, device=self._device))
def get_actions(self, states: np.ndarray, **kwargs) -> np.ndarray:
self._call_count += 1
if self._call_count <= self._warmup:
actions = self.get_random_actions_tensor(ndarray_to_tensor(states, device=self._device), **kwargs)
else:
actions = self.get_actions_tensor(ndarray_to_tensor(states, device=self._device), **kwargs)
return actions.detach().cpu().numpy()
def get_actions_tensor(self, states: torch.Tensor) -> torch.Tensor:
def get_actions_tensor(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
actions = self._get_actions_impl(states)
actions = self._get_actions_impl(states, **kwargs)
assert self._shape_check(
states=states,
actions=actions,
**kwargs,
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
return actions
def get_actions_with_probs(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._shape_check(
states=states,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
actions, probs = self._get_actions_with_probs_impl(states)
def get_random_actions_tensor(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
actions = self._get_random_actions_impl(states, **kwargs)
assert self._shape_check(
states=states,
actions=actions,
**kwargs,
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
return actions
def get_actions_with_probs(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
actions, probs = self._get_actions_with_probs_impl(states, **kwargs)
assert self._shape_check(
states=states,
actions=actions,
**kwargs,
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0]
return actions, probs
def get_actions_with_logps(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def get_actions_with_logps(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
actions, logps = self._get_actions_with_logps_impl(states)
actions, logps = self._get_actions_with_logps_impl(states, **kwargs)
assert self._shape_check(
states=states,
actions=actions,
**kwargs,
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0]
return actions, logps
def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
probs = self._get_states_actions_probs_impl(states, actions)
probs = self._get_states_actions_probs_impl(states, actions, **kwargs)
assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0]
return probs
def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
logps = self._get_states_actions_logps_impl(states, actions)
logps = self._get_states_actions_logps_impl(states, actions, **kwargs)
assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0]
return logps
@abstractmethod
def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
@ -327,6 +360,7 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
self,
states: torch.Tensor,
actions: torch.Tensor = None,
**kwargs,
) -> bool:
"""Check whether the states and actions have valid shapes.
@ -352,7 +386,7 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
return True
@abstractmethod
def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool:
def _post_check(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> bool:
"""Check whether the generated action tensor is valid, i.e., has matching shape with states tensor.
Args:

Просмотреть файл

@ -42,6 +42,8 @@ class ContinuousRLPolicy(RLPolicy):
the bound for every dimension. If it is a float, it will be broadcast to all dimensions.
policy_net (ContinuousPolicyNet): The core net of this policy.
trainable (bool, default=True): Whether this policy is trainable.
warmup (int, default=0): Number of steps for uniform-random action selection, before running real policy.
Helps exploration.
"""
def __init__(
@ -50,6 +52,7 @@ class ContinuousRLPolicy(RLPolicy):
action_range: Tuple[Union[float, List[float]], Union[float, List[float]]],
policy_net: ContinuousPolicyNet,
trainable: bool = True,
warmup: int = 0,
) -> None:
assert isinstance(policy_net, ContinuousPolicyNet)
@ -59,6 +62,7 @@ class ContinuousRLPolicy(RLPolicy):
action_dim=policy_net.action_dim,
trainable=trainable,
is_discrete_action=False,
warmup=warmup,
)
self._lbounds, self._ubounds = _parse_action_range(self.action_dim, action_range)
@ -72,7 +76,7 @@ class ContinuousRLPolicy(RLPolicy):
def policy_net(self) -> ContinuousPolicyNet:
return self._policy_net
def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool:
def _post_check(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> bool:
return all(
[
(np.array(self._lbounds) <= actions.detach().cpu().numpy()).all(),
@ -80,20 +84,23 @@ class ContinuousRLPolicy(RLPolicy):
],
)
def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
return self._policy_net.get_actions(states, self._is_exploring)
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
return self._policy_net.get_actions(states, self._is_exploring, **kwargs)
def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self._policy_net.get_actions_with_probs(states, self._is_exploring)
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
return self._policy_net.get_random_actions(states, **kwargs)
def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self._policy_net.get_actions_with_logps(states, self._is_exploring)
def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return self._policy_net.get_actions_with_probs(states, self._is_exploring, **kwargs)
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self._policy_net.get_states_actions_probs(states, actions)
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return self._policy_net.get_actions_with_logps(states, self._is_exploring, **kwargs)
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self._policy_net.get_states_actions_logps(states, actions)
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
return self._policy_net.get_states_actions_probs(states, actions, **kwargs)
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
return self._policy_net.get_states_actions_logps(states, actions, **kwargs)
def train_step(self, loss: torch.Tensor) -> None:
self._policy_net.step(loss)
@ -117,14 +124,22 @@ class ContinuousRLPolicy(RLPolicy):
self._policy_net.train()
def get_state(self) -> dict:
return self._policy_net.get_state()
return {
"net": self._policy_net.get_state(),
"policy": {
"warmup": self._warmup,
"call_count": self._call_count,
},
}
def set_state(self, policy_state: dict) -> None:
self._policy_net.set_state(policy_state)
self._policy_net.set_state(policy_state["net"])
self._warmup = policy_state["policy"]["warmup"]
self._call_count = policy_state["policy"]["call_count"]
def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, ContinuousRLPolicy)
self._policy_net.soft_update(other_policy.policy_net, tau)
def _to_device_impl(self, device: torch.device) -> None:
self._policy_net.to(device)
self._policy_net.to_device(device)

Просмотреть файл

@ -23,6 +23,8 @@ class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta):
state_dim (int): Dimension of states.
action_num (int): Number of actions.
trainable (bool, default=True): Whether this policy is trainable.
warmup (int, default=0): Number of steps for uniform-random action selection, before running real policy.
Helps exploration.
"""
def __init__(
@ -31,6 +33,7 @@ class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta):
state_dim: int,
action_num: int,
trainable: bool = True,
warmup: int = 0,
) -> None:
assert action_num >= 1
@ -40,6 +43,7 @@ class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta):
action_dim=1,
trainable=trainable,
is_discrete_action=True,
warmup=warmup,
)
self._action_num = action_num
@ -48,9 +52,15 @@ class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta):
def action_num(self) -> int:
return self._action_num
def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool:
def _post_check(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> bool:
return all([0 <= action < self.action_num for action in actions.cpu().numpy().flatten()])
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
return ndarray_to_tensor(
np.random.randint(self.action_num, size=(states.shape[0], 1)),
device=self._device,
)
class ValueBasedPolicy(DiscreteRLPolicy):
"""Valued-based policy.
@ -61,7 +71,8 @@ class ValueBasedPolicy(DiscreteRLPolicy):
trainable (bool, default=True): Whether this policy is trainable.
exploration_strategy (Tuple[Callable, dict], default=(epsilon_greedy, {"epsilon": 0.1})): Exploration strategy.
exploration_scheduling_options (List[tuple], default=None): List of exploration scheduler options.
warmup (int, default=50000): Minimum number of experiences to warm up this policy.
warmup (int, default=50000): Number of steps for uniform-random action selection, before running real policy.
Helps exploration.
"""
def __init__(
@ -80,6 +91,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
state_dim=q_net.state_dim,
action_num=q_net.action_num,
trainable=trainable,
warmup=warmup,
)
self._q_net = q_net
@ -91,16 +103,13 @@ class ValueBasedPolicy(DiscreteRLPolicy):
else []
)
self._call_cnt = 0
self._warmup = warmup
self._softmax = torch.nn.Softmax(dim=1)
@property
def q_net(self) -> DiscreteQNet:
return self._q_net
def q_values_for_all_actions(self, states: np.ndarray) -> np.ndarray:
def q_values_for_all_actions(self, states: np.ndarray, **kwargs) -> np.ndarray:
"""Generate a matrix containing the Q-values for all actions for the given states.
Args:
@ -109,9 +118,16 @@ class ValueBasedPolicy(DiscreteRLPolicy):
Returns:
q_values (np.ndarray): Q-matrix.
"""
return self.q_values_for_all_actions_tensor(ndarray_to_tensor(states, device=self._device)).cpu().numpy()
return (
self.q_values_for_all_actions_tensor(
ndarray_to_tensor(states, device=self._device),
**kwargs,
)
.cpu()
.numpy()
)
def q_values_for_all_actions_tensor(self, states: torch.Tensor) -> torch.Tensor:
def q_values_for_all_actions_tensor(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Generate a matrix containing the Q-values for all actions for the given states.
Args:
@ -120,12 +136,12 @@ class ValueBasedPolicy(DiscreteRLPolicy):
Returns:
q_values (torch.Tensor): Q-matrix.
"""
assert self._shape_check(states=states)
q_values = self._q_net.q_values_for_all_actions(states)
assert self._shape_check(states=states, **kwargs)
q_values = self._q_net.q_values_for_all_actions(states, **kwargs)
assert match_shape(q_values, (states.shape[0], self.action_num)) # [B, action_num]
return q_values
def q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray:
def q_values(self, states: np.ndarray, actions: np.ndarray, **kwargs) -> np.ndarray:
"""Generate the Q values for given state-action pairs.
Args:
@ -139,12 +155,13 @@ class ValueBasedPolicy(DiscreteRLPolicy):
self.q_values_tensor(
ndarray_to_tensor(states, device=self._device),
ndarray_to_tensor(actions, device=self._device),
**kwargs,
)
.cpu()
.numpy()
)
def q_values_tensor(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
def q_values_tensor(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
"""Generate the Q values for given state-action pairs.
Args:
@ -154,50 +171,46 @@ class ValueBasedPolicy(DiscreteRLPolicy):
Returns:
q_values (torch.Tensor): Q-values.
"""
assert self._shape_check(states=states, actions=actions) # actions: [B, 1]
q_values = self._q_net.q_values(states, actions)
assert self._shape_check(states=states, actions=actions, **kwargs) # actions: [B, 1]
q_values = self._q_net.q_values(states, actions, **kwargs)
assert match_shape(q_values, (states.shape[0],)) # [B]
return q_values
def explore(self) -> None:
pass # Overwrite the base method and turn off explore mode.
def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
actions, _ = self._get_actions_with_probs_impl(states)
return actions
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
return self._get_actions_with_probs_impl(states, **kwargs)[0]
def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self._call_cnt += 1
if self._call_cnt <= self._warmup:
actions = ndarray_to_tensor(
np.random.randint(self.action_num, size=(states.shape[0], 1)),
device=self._device,
)
probs = torch.ones(states.shape[0]).float() * (1.0 / self.action_num)
return actions, probs
q_matrix = self.q_values_for_all_actions_tensor(states) # [B, action_num]
def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
q_matrix = self.q_values_for_all_actions_tensor(states, **kwargs) # [B, action_num]
q_matrix_softmax = self._softmax(q_matrix)
_, actions = q_matrix.max(dim=1) # [B], [B]
if self._is_exploring:
actions = self._exploration_func(states, actions.cpu().numpy(), self.action_num, **self._exploration_params)
actions = self._exploration_func(
states,
actions.cpu().numpy(),
self.action_num,
**self._exploration_params,
**kwargs,
)
actions = ndarray_to_tensor(actions, device=self._device)
actions = actions.unsqueeze(1)
return actions, q_matrix_softmax.gather(1, actions).squeeze(-1) # [B, 1]
def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
actions, probs = self._get_actions_with_probs_impl(states)
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
actions, probs = self._get_actions_with_probs_impl(states, **kwargs)
return actions, torch.log(probs)
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
q_matrix = self.q_values_for_all_actions_tensor(states)
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
q_matrix = self.q_values_for_all_actions_tensor(states, **kwargs)
q_matrix_softmax = self._softmax(q_matrix)
return q_matrix_softmax.gather(1, actions).squeeze(-1) # [B]
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
probs = self._get_states_actions_probs_impl(states, actions)
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
probs = self._get_states_actions_probs_impl(states, actions, **kwargs)
return torch.log(probs)
def train_step(self, loss: torch.Tensor) -> None:
@ -222,17 +235,25 @@ class ValueBasedPolicy(DiscreteRLPolicy):
self._q_net.train()
def get_state(self) -> dict:
return self._q_net.get_state()
return {
"net": self._q_net.get_state(),
"policy": {
"warmup": self._warmup,
"call_count": self._call_count,
},
}
def set_state(self, policy_state: dict) -> None:
self._q_net.set_state(policy_state)
self._warmup = policy_state["policy"]["warmup"]
self._call_count = policy_state["policy"]["call_count"]
def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, ValueBasedPolicy)
self._q_net.soft_update(other_policy.q_net, tau)
def _to_device_impl(self, device: torch.device) -> None:
self._q_net.to(device)
self._q_net.to_device(device)
class DiscretePolicyGradient(DiscreteRLPolicy):
@ -242,6 +263,8 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
name (str): Name of the policy.
policy_net (DiscretePolicyNet): The core net of this policy.
trainable (bool, default=True): Whether this policy is trainable.
warmup (int, default=50000): Number of steps for uniform-random action selection, before running real policy.
Helps exploration.
"""
def __init__(
@ -249,6 +272,7 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
name: str,
policy_net: DiscretePolicyNet,
trainable: bool = True,
warmup: int = 0,
) -> None:
assert isinstance(policy_net, DiscretePolicyNet)
@ -257,6 +281,7 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
state_dim=policy_net.state_dim,
action_num=policy_net.action_num,
trainable=trainable,
warmup=warmup,
)
self._policy_net = policy_net
@ -265,20 +290,20 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
def policy_net(self) -> DiscretePolicyNet:
return self._policy_net
def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
return self._policy_net.get_actions(states, self._is_exploring)
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
return self._policy_net.get_actions(states, self._is_exploring, **kwargs)
def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self._policy_net.get_actions_with_probs(states, self._is_exploring)
def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return self._policy_net.get_actions_with_probs(states, self._is_exploring, **kwargs)
def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self._policy_net.get_actions_with_logps(states, self._is_exploring)
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return self._policy_net.get_actions_with_logps(states, self._is_exploring, **kwargs)
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self._policy_net.get_states_actions_probs(states, actions)
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
return self._policy_net.get_states_actions_probs(states, actions, **kwargs)
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self._policy_net.get_states_actions_logps(states, actions)
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
return self._policy_net.get_states_actions_logps(states, actions, **kwargs)
def train_step(self, loss: torch.Tensor) -> None:
self._policy_net.step(loss)
@ -302,16 +327,24 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
self._policy_net.train()
def get_state(self) -> dict:
return self._policy_net.get_state()
return {
"net": self._policy_net.get_state(),
"policy": {
"warmup": self._warmup,
"call_count": self._call_count,
},
}
def set_state(self, policy_state: dict) -> None:
self._policy_net.set_state(policy_state)
self._warmup = policy_state["policy"]["warmup"]
self._call_count = policy_state["policy"]["call_count"]
def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, DiscretePolicyGradient)
self._policy_net.soft_update(other_policy.policy_net, tau)
def get_action_probs(self, states: torch.Tensor) -> torch.Tensor:
def get_action_probs(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Get the probabilities for all actions according to states.
Args:
@ -322,15 +355,16 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
"""
assert self._shape_check(
states=states,
**kwargs,
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
action_probs = self._policy_net.get_action_probs(states)
action_probs = self._policy_net.get_action_probs(states, **kwargs)
assert match_shape(action_probs, (states.shape[0], self.action_num)), (
f"Action probabilities shape check failed. Expecting: {(states.shape[0], self.action_num)}, "
f"actual: {action_probs.shape}."
)
return action_probs
def get_action_logps(self, states: torch.Tensor) -> torch.Tensor:
def get_action_logps(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
"""Get the log-probabilities for all actions according to states.
Args:
@ -339,15 +373,15 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
Returns:
action_logps (torch.Tensor): Action probabilities with shape [batch_size, action_num].
"""
return torch.log(self.get_action_probs(states))
return torch.log(self.get_action_probs(states, **kwargs))
def _get_state_action_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
action_probs = self.get_action_probs(states)
def _get_state_action_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
action_probs = self.get_action_probs(states, **kwargs)
return action_probs.gather(1, actions).squeeze(-1) # [B]
def _get_state_action_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
action_logps = self.get_action_logps(states)
def _get_state_action_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
action_logps = self.get_action_logps(states, **kwargs)
return action_logps.gather(1, actions).squeeze(-1) # [B]
def _to_device_impl(self, device: torch.device) -> None:
self._policy_net.to(device)
self._policy_net.to_device(device)

Просмотреть файл

@ -6,6 +6,7 @@ from typing import Any, Dict, List
from maro.rl.policy import AbsPolicy, RLPolicy
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
from maro.rl.workflows.callback import Callback
class RLComponentBundle:
@ -20,7 +21,7 @@ class RLComponentBundle:
If None, there will be no explicit device assignment.
policy_trainer_mapping (Dict[str, str], default=None): Policy-trainer mapping which identifying which trainer to
train each policy. If None, then a policy's trainer's name is the first segment of the policy's name,
seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1". Only policies that provided in
separated by dot. For example, "ppo_1.policy" is trained by "ppo_1". Only policies that provided in
policy-trainer mapping are considered as trainable polices. Policies that not provided in policy-trainer
mapping will not be trained.
"""
@ -33,11 +34,13 @@ class RLComponentBundle:
trainers: List[AbsTrainer],
device_mapping: Dict[str, str] = None,
policy_trainer_mapping: Dict[str, str] = None,
customized_callbacks: List[Callback] = [],
) -> None:
self.env_sampler = env_sampler
self.agent2policy = agent2policy
self.policies = policies
self.trainers = trainers
self.customized_callbacks = customized_callbacks
policy_set = set([policy.name for policy in self.policies])
not_found = [policy_name for policy_name in self.agent2policy.values() if policy_name not in policy_set]

Просмотреть файл

@ -189,8 +189,13 @@ class BatchEnvSampler:
"info": [res["info"][0] for res in results],
}
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
req = {"type": "eval", "policy_state": policy_state, "index": self._ep} # -1 signals test
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict:
req = {
"type": "eval",
"policy_state": policy_state,
"index": self._ep,
"num_eval_episodes": num_episodes,
} # -1 signals test
results = self._controller.collect(req, self._eval_parallelism)
return {
"info": [res["info"][0] for res in results],

Просмотреть файл

@ -48,6 +48,7 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
def choose_actions(
self,
state_by_agent: Dict[Any, Union[np.ndarray, list]],
**kwargs,
) -> Dict[Any, Union[np.ndarray, list]]:
"""Choose action according to the given (observable) states of all agents.
@ -61,13 +62,14 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
"""
self.switch_to_eval_mode()
with torch.no_grad():
ret = self._choose_actions_impl(state_by_agent)
ret = self._choose_actions_impl(state_by_agent, **kwargs)
return ret
@abstractmethod
def _choose_actions_impl(
self,
state_by_agent: Dict[Any, Union[np.ndarray, list]],
**kwargs,
) -> Dict[Any, Union[np.ndarray, list]]:
"""Implementation of `choose_actions`."""
raise NotImplementedError
@ -99,6 +101,7 @@ class SimpleAgentWrapper(AbsAgentWrapper):
def _choose_actions_impl(
self,
state_by_agent: Dict[Any, Union[np.ndarray, list]],
**kwargs,
) -> Dict[Any, Union[np.ndarray, list]]:
# Aggregate states by policy
states_by_policy = collections.defaultdict(list) # {str: list of np.ndarray}
@ -116,7 +119,7 @@ class SimpleAgentWrapper(AbsAgentWrapper):
states = np.vstack(states_by_policy[policy_name]) # np.ndarray
else:
states = states_by_policy[policy_name] # list
actions: Union[np.ndarray, list] = policy.get_actions(states) # np.ndarray or list
actions: Union[np.ndarray, list] = policy.get_actions(states, **kwargs) # np.ndarray or list
action_dict.update(zip(agents_by_policy[policy_name], actions))
return action_dict
@ -146,6 +149,7 @@ class ExpElement:
terminal_dict: Dict[Any, bool]
next_state: Optional[np.ndarray]
next_agent_state_dict: Dict[Any, np.ndarray]
truncated: bool
@property
def agent_names(self) -> list:
@ -171,6 +175,7 @@ class ExpElement:
}
if self.next_agent_state_dict is not None and agent_name in self.next_agent_state_dict
else {},
truncated=self.truncated,
)
return ret
@ -194,6 +199,7 @@ class ExpElement:
terminal_dict={},
next_state=self.next_state,
next_agent_state_dict=None if self.next_agent_state_dict is None else {},
truncated=self.truncated,
),
)
for agent_name, trainer_name in agent2trainer.items():
@ -225,6 +231,7 @@ class CacheElement(ExpElement):
terminal_dict=self.terminal_dict,
next_state=self.next_state,
next_agent_state_dict=self.next_agent_state_dict,
truncated=self.truncated,
)
@ -240,6 +247,8 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
agent_wrapper_cls (Type[AbsAgentWrapper], default=SimpleAgentWrapper): Specific AgentWrapper type.
reward_eval_delay (int, default=None): Number of ticks required after a decision event to evaluate the reward
for the action taken for that event. If it is None, calculate reward immediately after `step()`.
max_episode_length (int, default=None): Maximum number of steps in one episode during sampling.
When reach this limit, the environment will be truncated and reset.
"""
def __init__(
@ -251,7 +260,10 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
trainable_policies: List[str] = None,
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
reward_eval_delay: int = None,
max_episode_length: int = None,
) -> None:
assert learn_env is not test_env, "Please use different envs for training and testing."
self._learn_env = learn_env
self._test_env = test_env
@ -262,11 +274,14 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
self._state: Optional[np.ndarray] = None
self._agent_state_dict: Dict[Any, np.ndarray] = {}
self._trans_cache: List[CacheElement] = []
self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._trans_cache
self._transition_cache: List[CacheElement] = []
self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._transition_cache
self._reward_eval_delay = reward_eval_delay
self._max_episode_length = max_episode_length
self._current_episode_length = 0
self._info: dict = {}
self.metrics: dict = {}
assert self._reward_eval_delay is None or self._reward_eval_delay >= 0
@ -291,11 +306,17 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
[policy_name in self._rl_policy_dict for policy_name in self._trainable_policies],
), "All trainable policies must be RL policies!"
self._total_number_interactions = 0
@property
def env(self) -> Env:
assert self._env is not None
return self._env
def monitor_metrics(self) -> float:
"""Metrics watched by early stopping."""
return float(self._total_number_interactions)
def _switch_env(self, env: Env) -> None:
self._env = env
@ -369,7 +390,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
def _step(self, actions: Optional[list]) -> None:
_, self._event, self._end_of_episode = self.env.step(actions)
self._state, self._agent_state_dict = (
(None, {}) if self._end_of_episode else self._get_global_and_agent_state(self._event)
(None, {}) if self._end_of_episode else self._get_global_and_agent_state(self._event, self.env.tick)
)
def _calc_reward(self, cache_element: CacheElement) -> None:
@ -383,37 +404,37 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
def _append_cache_element(self, cache_element: Optional[CacheElement]) -> None:
"""`cache_element` == None means we are processing the last element in trans_cache"""
if cache_element is None:
if len(self._trans_cache) > 0:
self._trans_cache[-1].next_state = self._trans_cache[-1].state
for agent_name, i in self._agent_last_index.items():
e = self._trans_cache[i]
e = self._transition_cache[i]
e.terminal_dict[agent_name] = self._end_of_episode
e.next_agent_state_dict[agent_name] = e.agent_state_dict[agent_name]
else:
self._trans_cache.append(cache_element)
self._transition_cache.append(cache_element)
if len(self._trans_cache) > 0:
self._trans_cache[-1].next_state = cache_element.state
cur_index = len(self._trans_cache) - 1
cur_index = len(self._transition_cache) - 1
for agent_name in cache_element.agent_names:
if agent_name in self._agent_last_index:
i = self._agent_last_index[agent_name]
self._trans_cache[i].terminal_dict[agent_name] = False
self._trans_cache[i].next_agent_state_dict[agent_name] = cache_element.agent_state_dict[agent_name]
e = self._transition_cache[i]
e.terminal_dict[agent_name] = False
e.next_agent_state_dict[agent_name] = cache_element.agent_state_dict[agent_name]
self._agent_last_index[agent_name] = cur_index
def _reset(self) -> None:
self.env.reset()
self._current_episode_length = 0
self._info.clear()
self._trans_cache.clear()
self._transition_cache.clear()
self._agent_last_index.clear()
self._step(None)
def _select_trainable_agents(self, original_dict: dict) -> dict:
return {k: v for k, v in original_dict.items() if k in self._trainable_agents}
@property
def truncated(self) -> bool:
return self._max_episode_length == self._current_episode_length
def sample(
self,
policy_state: Optional[Dict[str, Dict[str, Any]]] = None,
@ -430,65 +451,88 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
Returns:
A dict that contains the collected experiences and additional information.
"""
# Init the env
self._switch_env(self._learn_env)
steps_to_go = num_steps if num_steps is not None else float("inf")
if policy_state is not None: # Update policy state if necessary
self.set_policy_state(policy_state)
self._switch_env(self._learn_env) # Init the env
self._agent_wrapper.explore() # Collect experience
if self._end_of_episode:
self._reset()
# Update policy state if necessary
if policy_state is not None:
self.set_policy_state(policy_state)
# If num_steps is None, run until the end of episode or the episode is truncated
# If num_steps is not None, run until we collect required number of steps
total_experiences = []
# Collect experience
self._agent_wrapper.explore()
steps_to_go = float("inf") if num_steps is None else num_steps
while not self._end_of_episode and steps_to_go > 0:
# Get agent actions and translate them to env actions
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
env_action_dict = self._translate_to_env_action(action_dict, self._event)
while not any(
[
num_steps is None and (self._end_of_episode or self.truncated),
num_steps is not None and steps_to_go == 0,
],
):
if self._end_of_episode or self.truncated:
self._reset()
# Store experiences in the cache
cache_element = CacheElement(
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
action_dict=self._select_trainable_agents(action_dict),
env_action_dict=self._select_trainable_agents(env_action_dict),
# The following will be generated later
reward_dict={},
terminal_dict={},
next_state=None,
next_agent_state_dict={},
)
while not any(
[
self._end_of_episode,
self.truncated,
steps_to_go == 0,
],
):
# Get agent actions and translate them to env actions
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
env_action_dict = self._translate_to_env_action(action_dict, self._event)
# Update env and get new states (global & agent)
self._step(list(env_action_dict.values()))
self._total_number_interactions += 1
self._current_episode_length += 1
steps_to_go -= 1
if self._reward_eval_delay is None:
self._calc_reward(cache_element)
self._post_step(cache_element)
self._append_cache_element(cache_element)
steps_to_go -= 1
self._append_cache_element(None)
# Store experiences in the cache
cache_element = CacheElement(
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
action_dict=self._select_trainable_agents(action_dict),
env_action_dict=self._select_trainable_agents(env_action_dict),
# The following will be generated/updated later
reward_dict={},
terminal_dict={},
next_state=None,
next_agent_state_dict={},
truncated=self.truncated,
)
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
experiences: List[ExpElement] = []
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
cache_element = self._trans_cache.pop(0)
# !: Here the reward calculation method requires the given tick is enough and must be used then.
if self._reward_eval_delay is not None:
self._calc_reward(cache_element)
self._post_step(cache_element)
experiences.append(cache_element.make_exp_element())
# Update env and get new states (global & agent)
self._step(list(env_action_dict.values()))
cache_element.next_state = self._state
self._agent_last_index = {
k: v - len(experiences) for k, v in self._agent_last_index.items() if v >= len(experiences)
}
if self._reward_eval_delay is None:
self._calc_reward(cache_element)
self._post_step(cache_element)
self._append_cache_element(cache_element)
self._append_cache_element(None)
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
experiences: List[ExpElement] = []
while len(self._transition_cache) > 0 and self._transition_cache[0].tick <= tick_bound:
cache_element = self._transition_cache.pop(0)
# !: Here the reward calculation method requires the given tick is enough and must be used then.
if self._reward_eval_delay is not None:
self._calc_reward(cache_element)
self._post_step(cache_element)
experiences.append(cache_element.make_exp_element())
self._agent_last_index = {
k: v - len(experiences) for k, v in self._agent_last_index.items() if v >= len(experiences)
}
total_experiences += experiences
return {
"end_of_episode": self._end_of_episode,
"experiences": [experiences],
"experiences": [total_experiences],
"info": [deepcopy(self._info)], # TODO: may have overhead issues. Leave to future work.
}
@ -514,50 +558,57 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
return loaded
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict:
self._switch_env(self._test_env)
self._reset()
if policy_state is not None:
self.set_policy_state(policy_state)
info_list = []
self._agent_wrapper.exploit()
while not self._end_of_episode:
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
env_action_dict = self._translate_to_env_action(action_dict, self._event)
for _ in range(num_episodes):
self._reset()
if policy_state is not None:
self.set_policy_state(policy_state)
# Store experiences in the cache
cache_element = CacheElement(
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
action_dict=self._select_trainable_agents(action_dict),
env_action_dict=self._select_trainable_agents(env_action_dict),
# The following will be generated later
reward_dict={},
terminal_dict={},
next_state=None,
next_agent_state_dict={},
)
self._agent_wrapper.exploit()
while not self._end_of_episode:
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
env_action_dict = self._translate_to_env_action(action_dict, self._event)
# Update env and get new states (global & agent)
self._step(list(env_action_dict.values()))
# Store experiences in the cache
cache_element = CacheElement(
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
action_dict=self._select_trainable_agents(action_dict),
env_action_dict=self._select_trainable_agents(env_action_dict),
# The following will be generated later
reward_dict={},
terminal_dict={},
next_state=None,
next_agent_state_dict={},
truncated=False, # No truncation in evaluation
)
if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
self._calc_reward(cache_element)
self._post_eval_step(cache_element)
# Update env and get new states (global & agent)
self._step(list(env_action_dict.values()))
cache_element.next_state = self._state
self._append_cache_element(cache_element)
self._append_cache_element(None)
if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
self._calc_reward(cache_element)
self._post_eval_step(cache_element)
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
cache_element = self._trans_cache.pop(0)
if self._reward_eval_delay is not None:
self._calc_reward(cache_element)
self._post_eval_step(cache_element)
self._append_cache_element(cache_element)
self._append_cache_element(None)
return {"info": [self._info]}
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
while len(self._transition_cache) > 0 and self._transition_cache[0].tick <= tick_bound:
cache_element = self._transition_cache.pop(0)
if self._reward_eval_delay is not None:
self._calc_reward(cache_element)
self._post_eval_step(cache_element)
info_list.append(self._info)
return {"info": info_list}
@abstractmethod
def _post_step(self, cache_element: CacheElement) -> None:

Просмотреть файл

@ -59,7 +59,7 @@ class RolloutWorker(AbsWorker):
result = (
self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"])
if req["type"] == "sample"
else self._env_sampler.eval(policy_state=req["policy_state"])
else self._env_sampler.eval(policy_state=req["policy_state"], num_episodes=req["num_eval_episodes"])
)
self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]}))
else:

Просмотреть файл

@ -90,13 +90,19 @@ class ACBasedOps(AbsTrainOps):
"""
return self._v_critic_net.get_gradients(self._get_critic_loss(batch))
def update_critic(self, batch: TransitionBatch) -> None:
def update_critic(self, batch: TransitionBatch) -> float:
"""Update the critic network using a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
loss (float): The detached loss of this batch.
"""
self._v_critic_net.step(self._get_critic_loss(batch))
self._v_critic_net.train()
loss = self._get_critic_loss(batch)
self._v_critic_net.step(loss)
return loss.detach().cpu().numpy().item()
def update_critic_with_grad(self, grad_dict: dict) -> None:
"""Update the critic network with remotely computed gradients.
@ -148,24 +154,26 @@ class ACBasedOps(AbsTrainOps):
batch (TransitionBatch): Batch.
Returns:
grad (torch.Tensor): The actor gradient of the batch.
grad_dict (Dict[str, torch.Tensor]): The actor gradient of the batch.
early_stop (bool): Early stop indicator.
"""
loss, early_stop = self._get_actor_loss(batch)
return self._policy.get_gradients(loss), early_stop
def update_actor(self, batch: TransitionBatch) -> bool:
def update_actor(self, batch: TransitionBatch) -> Tuple[float, bool]:
"""Update the actor network using a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
loss (float): The detached loss of this batch.
early_stop (bool): Early stop indicator.
"""
self._policy.train()
loss, early_stop = self._get_actor_loss(batch)
self._policy.train_step(loss)
return early_stop
return loss.detach().cpu().numpy().item(), early_stop
def update_actor_with_grad(self, grad_dict_and_early_stop: Tuple[dict, bool]) -> bool:
"""Update the actor network with remotely computed gradients.
@ -202,6 +210,9 @@ class ACBasedOps(AbsTrainOps):
# Preprocess advantages
states = ndarray_to_tensor(batch.states, device=self._device) # s
actions = ndarray_to_tensor(batch.actions, device=self._device) # a
terminals = ndarray_to_tensor(batch.terminals, device=self._device)
truncated = ndarray_to_tensor(batch.truncated, device=self._device)
next_states = ndarray_to_tensor(batch.next_states, device=self._device)
if self._is_discrete_action:
actions = actions.long()
@ -209,11 +220,34 @@ class ACBasedOps(AbsTrainOps):
self._v_critic_net.eval()
self._policy.eval()
values = self._v_critic_net.v_values(states).detach().cpu().numpy()
values = np.concatenate([values, np.zeros(1)])
rewards = np.concatenate([batch.rewards, np.zeros(1)])
deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s)
batch.returns = discount_cumsum(rewards, self._reward_discount)[:-1]
batch.advantages = discount_cumsum(deltas, self._reward_discount * self._lam)
batch.returns = np.zeros(batch.size, dtype=np.float32)
batch.advantages = np.zeros(batch.size, dtype=np.float32)
i = 0
while i < batch.size:
j = i
while j < batch.size - 1 and not (terminals[j] or truncated[j]):
j += 1
last_val = (
0.0
if terminals[j]
else self._v_critic_net.v_values(
next_states[j].unsqueeze(dim=0),
)
.detach()
.cpu()
.numpy()
.item()
)
cur_values = np.append(values[i : j + 1], last_val)
cur_rewards = np.append(batch.rewards[i : j + 1], last_val)
# delta = r + gamma * v(s') - v(s)
cur_deltas = cur_rewards[:-1] + self._reward_discount * cur_values[1:] - cur_values[:-1]
batch.returns[i : j + 1] = discount_cumsum(cur_rewards, self._reward_discount)[:-1]
batch.advantages[i : j + 1] = discount_cumsum(cur_deltas, self._reward_discount * self._lam)
i = j + 1
if self._clip_ratio is not None:
batch.old_logps = self._policy.get_states_actions_logps(states, actions).detach().cpu().numpy()
@ -229,7 +263,7 @@ class ACBasedOps(AbsTrainOps):
def to_device(self, device: str = None) -> None:
self._device = get_torch_device(device)
self._policy.to_device(self._device)
self._v_critic_net.to(self._device)
self._v_critic_net.to_device(self._device)
class ACBasedTrainer(SingleAgentTrainer):
@ -291,21 +325,25 @@ class ACBasedTrainer(SingleAgentTrainer):
assert isinstance(self._ops, ACBasedOps)
batch = self._get_batch()
for _ in range(self._params.grad_iters):
self._ops.update_critic(batch)
for _ in range(self._params.grad_iters):
early_stop = self._ops.update_actor(batch)
if early_stop:
break
for _ in range(self._params.grad_iters):
self._ops.update_critic(batch)
async def train_step_as_task(self) -> None:
assert isinstance(self._ops, RemoteOps)
batch = self._get_batch()
for _ in range(self._params.grad_iters):
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
for _ in range(self._params.grad_iters):
if self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch)): # early stop
grad_dict, early_stop = await self._ops.get_actor_grad(batch)
self._ops.update_actor_with_grad(grad_dict)
if early_stop:
break
for _ in range(self._params.grad_iters):
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, Tuple, cast
import torch
@ -27,7 +27,7 @@ class DDPGParams(BaseTrainerParams):
random_overwrite (bool, default=False): This specifies overwrite behavior when the replay memory capacity
is reached. If True, overwrite positions will be selected randomly. Otherwise, overwrites will occur
sequentially with wrap-around.
min_num_to_trigger_training (int, default=0): Minimum number required to start training.
n_start_train (int, default=0): Minimum number required to start training.
"""
get_q_critic_net_func: Callable[[], QNet]
@ -36,7 +36,7 @@ class DDPGParams(BaseTrainerParams):
q_value_loss_cls: Optional[Callable] = None
soft_update_coef: float = 1.0
random_overwrite: bool = False
min_num_to_trigger_training: int = 0
n_start_train: int = 0
class DDPGOps(AbsTrainOps):
@ -93,9 +93,9 @@ class DDPGOps(AbsTrainOps):
states=next_states, # s'
actions=self._target_policy.get_actions_tensor(next_states), # miu_targ(s')
) # Q_targ(s', miu_targ(s'))
# y(r, s', d) = r + gamma * (1 - d) * Q_targ(s', miu_targ(s'))
target_q_values = (rewards + self._reward_discount * (1.0 - terminals.float()) * next_q_values).detach()
# y(r, s', d) = r + gamma * (1 - d) * Q_targ(s', miu_targ(s'))
target_q_values = (rewards + self._reward_discount * (1 - terminals.long()) * next_q_values).detach()
q_values = self._q_critic_net.q_values(states=states, actions=actions) # Q(s, a)
return self._q_value_loss_func(q_values, target_q_values) # MSE(Q(s, a), y(r, s', d))
@ -120,16 +120,21 @@ class DDPGOps(AbsTrainOps):
self._q_critic_net.train()
self._q_critic_net.apply_gradients(grad_dict)
def update_critic(self, batch: TransitionBatch) -> None:
def update_critic(self, batch: TransitionBatch) -> float:
"""Update the critic network using a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
loss (float): The detached loss of this batch.
"""
self._q_critic_net.train()
self._q_critic_net.step(self._get_critic_loss(batch))
loss = self._get_critic_loss(batch)
self._q_critic_net.step(loss)
return loss.detach().cpu().numpy().item()
def _get_actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
def _get_actor_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, bool]:
"""Compute the actor loss of the batch.
Args:
@ -137,6 +142,7 @@ class DDPGOps(AbsTrainOps):
Returns:
loss (torch.Tensor): The actor loss of the batch.
early_stop (bool): The early stop indicator, set to False in current implementation.
"""
assert isinstance(batch, TransitionBatch)
self._policy.train()
@ -147,19 +153,23 @@ class DDPGOps(AbsTrainOps):
actions=self._policy.get_actions_tensor(states), # miu(s)
).mean() # -Q(s, miu(s))
return policy_loss
early_stop = False
return policy_loss, early_stop
@remote
def get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
def get_actor_grad(self, batch: TransitionBatch) -> Tuple[Dict[str, torch.Tensor], bool]:
"""Compute the actor network's gradients of a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
grad (torch.Tensor): The actor gradient of the batch.
grad_dict (Dict[str, torch.Tensor]): The actor gradient of the batch.
early_stop (bool): Early stop indicator.
"""
return self._policy.get_gradients(self._get_actor_loss(batch))
loss, early_stop = self._get_actor_loss(batch)
return self._policy.get_gradients(loss), early_stop
def update_actor_with_grad(self, grad_dict: dict) -> None:
"""Update the actor network with remotely computed gradients.
@ -170,14 +180,20 @@ class DDPGOps(AbsTrainOps):
self._policy.train()
self._policy.apply_gradients(grad_dict)
def update_actor(self, batch: TransitionBatch) -> None:
def update_actor(self, batch: TransitionBatch) -> Tuple[float, bool]:
"""Update the actor network using a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
loss (float): The detached loss of this batch.
early_stop (bool): Early stop indicator.
"""
self._policy.train()
self._policy.train_step(self._get_actor_loss(batch))
loss, early_stop = self._get_actor_loss(batch)
self._policy.train_step(loss)
return loss.detach().cpu().numpy().item(), early_stop
def get_non_policy_state(self) -> dict:
return {
@ -200,8 +216,8 @@ class DDPGOps(AbsTrainOps):
self._device = get_torch_device(device=device)
self._policy.to_device(self._device)
self._target_policy.to_device(self._device)
self._q_critic_net.to(self._device)
self._target_q_critic_net.to(self._device)
self._q_critic_net.to_device(self._device)
self._target_q_critic_net.to_device(self._device)
class DDPGTrainer(SingleAgentTrainer):
@ -263,10 +279,10 @@ class DDPGTrainer(SingleAgentTrainer):
def train_step(self) -> None:
assert isinstance(self._ops, DDPGOps)
if self._replay_memory.n_sample < self._params.min_num_to_trigger_training:
if self._replay_memory.n_sample < self._params.n_start_train:
print(
f"Skip this training step due to lack of experiences "
f"(current = {self._replay_memory.n_sample}, minimum = {self._params.min_num_to_trigger_training})",
f"(current = {self._replay_memory.n_sample}, minimum = {self._params.n_start_train})",
)
return
@ -280,19 +296,21 @@ class DDPGTrainer(SingleAgentTrainer):
async def train_step_as_task(self) -> None:
assert isinstance(self._ops, RemoteOps)
if self._replay_memory.n_sample < self._params.min_num_to_trigger_training:
if self._replay_memory.n_sample < self._params.n_start_train:
print(
f"Skip this training step due to lack of experiences "
f"(current = {self._replay_memory.n_sample}, minimum = {self._params.min_num_to_trigger_training})",
f"(current = {self._replay_memory.n_sample}, minimum = {self._params.n_start_train})",
)
return
for _ in range(self._params.num_epochs):
batch = self._get_batch()
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch))
grad_dict, early_stop = await self._ops.get_actor_grad(batch)
self._ops.update_actor_with_grad(grad_dict)
self._try_soft_update_target()
if early_stop:
break
def _try_soft_update_target(self) -> None:
"""Soft update the target policy and target critic."""

Просмотреть файл

@ -161,15 +161,20 @@ class DiscreteMADDPGOps(AbsTrainOps):
"""
return self._q_critic_net.get_gradients(self._get_critic_loss(batch, next_actions))
def update_critic(self, batch: MultiTransitionBatch, next_actions: List[torch.Tensor]) -> None:
def update_critic(self, batch: MultiTransitionBatch, next_actions: List[torch.Tensor]) -> float:
"""Update the critic network using a batch.
Args:
batch (MultiTransitionBatch): Batch.
next_actions (List[torch.Tensor]): List of next actions of all policies.
Returns:
loss (float): The detached loss of this batch.
"""
self._q_critic_net.train()
self._q_critic_net.step(self._get_critic_loss(batch, next_actions))
loss = self._get_critic_loss(batch, next_actions)
self._q_critic_net.step(loss)
return loss.detach().cpu().numpy().item()
def update_critic_with_grad(self, grad_dict: dict) -> None:
"""Update the critic network with remotely computed gradients.
@ -180,7 +185,7 @@ class DiscreteMADDPGOps(AbsTrainOps):
self._q_critic_net.train()
self._q_critic_net.apply_gradients(grad_dict)
def _get_actor_loss(self, batch: MultiTransitionBatch) -> torch.Tensor:
def _get_actor_loss(self, batch: MultiTransitionBatch) -> Tuple[torch.Tensor, bool]:
"""Compute the actor loss of the batch.
Args:
@ -188,11 +193,13 @@ class DiscreteMADDPGOps(AbsTrainOps):
Returns:
loss (torch.Tensor): The actor loss of the batch.
early_stop (bool): The early stop indicator, set to False in current implementation.
"""
latest_action, latest_action_logp = self.get_latest_action(batch)
states = ndarray_to_tensor(batch.states, device=self._device) # x
actions = [ndarray_to_tensor(action, device=self._device) for action in batch.actions] # a
actions[self._policy_idx] = latest_action
self._policy.train()
self._q_critic_net.freeze()
actor_loss = -(
@ -203,28 +210,39 @@ class DiscreteMADDPGOps(AbsTrainOps):
* latest_action_logp
).mean() # Q(x, a^j_1, ..., a_i, ..., a^j_N)
self._q_critic_net.unfreeze()
return actor_loss
early_stop = False
return actor_loss, early_stop
@remote
def get_actor_grad(self, batch: MultiTransitionBatch) -> Dict[str, torch.Tensor]:
def get_actor_grad(self, batch: MultiTransitionBatch) -> Tuple[Dict[str, torch.Tensor], bool]:
"""Compute the actor network's gradients of a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
grad_dict (Dict[str, torch.Tensor]): The actor gradient of the batch.
early_stop (bool): Early stop indicator.
"""
loss, early_stop = self._get_actor_loss(batch)
return self._policy.get_gradients(loss), early_stop
def update_actor(self, batch: MultiTransitionBatch) -> Tuple[float, bool]:
"""Update the actor network using a batch.
Args:
batch (MultiTransitionBatch): Batch.
Returns:
grad (torch.Tensor): The actor gradient of the batch.
"""
return self._policy.get_gradients(self._get_actor_loss(batch))
def update_actor(self, batch: MultiTransitionBatch) -> None:
"""Update the actor network using a batch.
Args:
batch (MultiTransitionBatch): Batch.
loss (float): The detached loss of this batch.
early_stop (bool): Early stop indicator.
"""
self._policy.train()
self._policy.train_step(self._get_actor_loss(batch))
loss, early_stop = self._get_actor_loss(batch)
self._policy.train_step(loss)
return loss.detach().cpu().numpy().item(), early_stop
def update_actor_with_grad(self, grad_dict: dict) -> None:
"""Update the critic network with remotely computed gradients.
@ -275,8 +293,8 @@ class DiscreteMADDPGOps(AbsTrainOps):
self._policy.to_device(self._device)
self._target_policy.to_device(self._device)
self._q_critic_net.to(self._device)
self._target_q_critic_net.to(self._device)
self._q_critic_net.to_device(self._device)
self._target_q_critic_net.to_device(self._device)
class DiscreteMADDPGTrainer(MultiAgentTrainer):
@ -378,6 +396,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
agent_states=agent_states,
next_agent_states=next_agent_states,
terminals=np.array(terminal_flags),
truncated=np.array([exp_element.truncated for exp_element in exp_elements]),
)
self._replay_memory.put(transition_batch)
@ -459,7 +478,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
ops.update_critic_with_grad(critic_grad)
# Update actors
actor_grad_list = await asyncio.gather(*[ops.get_actor_grad(batch) for ops in self._actor_ops_list])
actor_grad_list = await asyncio.gather(*[ops.get_actor_grad(batch)[0] for ops in self._actor_ops_list])
for ops, actor_grad in zip(self._actor_ops_list, actor_grad_list):
ops.update_actor_with_grad(actor_grad)

Просмотреть файл

@ -22,7 +22,7 @@ class SoftActorCriticParams(BaseTrainerParams):
num_epochs: int = 1
n_start_train: int = 0
q_value_loss_cls: Optional[Callable] = None
soft_update_coef: float = 1.0
soft_update_coef: float = 0.05
class SoftActorCriticOps(AbsTrainOps):
@ -58,6 +58,7 @@ class SoftActorCriticOps(AbsTrainOps):
def _get_critic_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, torch.Tensor]:
self._q_net1.train()
self._q_net2.train()
states = ndarray_to_tensor(batch.states, device=self._device) # s
next_states = ndarray_to_tensor(batch.next_states, device=self._device) # s'
actions = ndarray_to_tensor(batch.actions, device=self._device) # a
@ -67,11 +68,13 @@ class SoftActorCriticOps(AbsTrainOps):
assert isinstance(self._policy, ContinuousRLPolicy)
with torch.no_grad():
next_actions, next_logps = self._policy.get_actions_with_logps(states)
q1 = self._target_q_net1.q_values(next_states, next_actions)
q2 = self._target_q_net2.q_values(next_states, next_actions)
q = torch.min(q1, q2)
y = rewards + self._reward_discount * (1.0 - terminals.float()) * (q - self._entropy_coef * next_logps)
next_actions, next_logps = self._policy.get_actions_with_logps(next_states)
target_q1 = self._target_q_net1.q_values(next_states, next_actions)
target_q2 = self._target_q_net2.q_values(next_states, next_actions)
target_q = torch.min(target_q1, target_q2)
y = rewards + self._reward_discount * (1.0 - terminals.float()) * (
target_q - self._entropy_coef * next_logps
)
q1 = self._q_net1.q_values(states, actions)
q2 = self._q_net2.q_values(states, actions)
@ -92,14 +95,36 @@ class SoftActorCriticOps(AbsTrainOps):
self._q_net1.apply_gradients(grad_dicts[0])
self._q_net2.apply_gradients(grad_dicts[1])
def update_critic(self, batch: TransitionBatch) -> None:
def update_critic(self, batch: TransitionBatch) -> Tuple[float, float]:
"""Update the critic network using a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
loss_q1 (float): The detached q_net1 loss of this batch.
loss_q2 (float): The detached q_net2 loss of this batch.
"""
self._q_net1.train()
self._q_net2.train()
loss_q1, loss_q2 = self._get_critic_loss(batch)
self._q_net1.step(loss_q1)
self._q_net2.step(loss_q2)
return loss_q1.detach().cpu().numpy().item(), loss_q2.detach().cpu().numpy().item()
def _get_actor_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, bool]:
"""Compute the actor loss of the batch.
Args:
batch (TransitionBatch): Batch.
Returns:
loss (torch.Tensor): The actor loss of the batch.
early_stop (bool): The early stop indicator, set to False in current implementation.
"""
self._q_net1.freeze()
self._q_net2.freeze()
def _get_actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
self._policy.train()
states = ndarray_to_tensor(batch.states, device=self._device) # s
actions, logps = self._policy.get_actions_with_logps(states)
@ -108,19 +133,46 @@ class SoftActorCriticOps(AbsTrainOps):
q = torch.min(q1, q2)
loss = (self._entropy_coef * logps - q).mean()
return loss
self._q_net1.unfreeze()
self._q_net2.unfreeze()
early_stop = False
return loss, early_stop
@remote
def get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
return self._policy.get_gradients(self._get_actor_loss(batch))
def get_actor_grad(self, batch: TransitionBatch) -> Tuple[Dict[str, torch.Tensor], bool]:
"""Compute the actor network's gradients of a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
grad_dict (Dict[str, torch.Tensor]): The actor gradient of the batch.
early_stop (bool): Early stop indicator.
"""
loss, early_stop = self._get_actor_loss(batch)
return self._policy.get_gradients(loss), early_stop
def update_actor_with_grad(self, grad_dict: dict) -> None:
self._policy.train()
self._policy.apply_gradients(grad_dict)
def update_actor(self, batch: TransitionBatch) -> None:
def update_actor(self, batch: TransitionBatch) -> Tuple[float, bool]:
"""Update the actor network using a batch.
Args:
batch (TransitionBatch): Batch.
Returns:
loss (float): The detached loss of this batch.
early_stop (bool): Early stop indicator.
"""
self._policy.train()
self._policy.train_step(self._get_actor_loss(batch))
loss, early_stop = self._get_actor_loss(batch)
self._policy.train_step(loss)
return loss.detach().cpu().numpy().item(), early_stop
def get_non_policy_state(self) -> dict:
return {
@ -142,10 +194,13 @@ class SoftActorCriticOps(AbsTrainOps):
def to_device(self, device: str = None) -> None:
self._device = get_torch_device(device=device)
self._q_net1.to(self._device)
self._q_net2.to(self._device)
self._target_q_net1.to(self._device)
self._target_q_net2.to(self._device)
self._policy.to_device(self._device)
self._q_net1.to_device(self._device)
self._q_net2.to_device(self._device)
self._target_q_net1.to_device(self._device)
self._target_q_net2.to_device(self._device)
class SoftActorCriticTrainer(SingleAgentTrainer):
@ -211,9 +266,11 @@ class SoftActorCriticTrainer(SingleAgentTrainer):
for _ in range(self._params.num_epochs):
batch = self._get_batch()
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch))
grad_dict, early_stop = await self._ops.get_actor_grad(batch)
self._ops.update_actor_with_grad(grad_dict)
self._try_soft_update_target()
if early_stop:
break
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
return transition_batch

Просмотреть файл

@ -35,29 +35,18 @@ class AbsIndexScheduler(object, metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
"""Generate a list of indexes that can be used to retrieve items from the replay memory.
Args:
batch_size (int, default=None): The required batch size. If it is None, all indexes where an experience
item is present are returned.
forbid_last (bool, default=False): Whether the latest element is allowed to be sampled.
If this is true, the last index will always be excluded from the result.
Returns:
indexes (np.ndarray): The list of indexes.
"""
raise NotImplementedError
@abstractmethod
def get_last_index(self) -> int:
"""Get the index of the latest element in the memory.
Returns:
index (int): The index of the latest element in the memory.
"""
raise NotImplementedError
class RandomIndexScheduler(AbsIndexScheduler):
"""Index scheduler that returns random indexes when sampling.
@ -93,14 +82,11 @@ class RandomIndexScheduler(AbsIndexScheduler):
self._size = min(self._size + batch_size, self._capacity)
return indexes
def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}"
assert self._size > 0, "Cannot sample from an empty memory."
return np.random.choice(self._size, size=batch_size, replace=True)
def get_last_index(self) -> int:
raise NotImplementedError
class FIFOIndexScheduler(AbsIndexScheduler):
"""First-in-first-out index scheduler.
@ -135,19 +121,15 @@ class FIFOIndexScheduler(AbsIndexScheduler):
self._head = (self._head + overwrite) % self._capacity
return self.get_put_indexes(batch_size)
def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
tmp = self._tail if not forbid_last else (self._tail - 1) % self._capacity
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
indexes = (
np.arange(self._head, tmp)
if tmp > self._head
else np.concatenate([np.arange(self._head, self._capacity), np.arange(tmp)])
np.arange(self._head, self._tail)
if self._tail > self._head
else np.concatenate([np.arange(self._head, self._capacity), np.arange(self._tail)])
)
self._head = tmp
self._head = self._tail
return indexes
def get_last_index(self) -> int:
return (self._tail - 1) % self._capacity
class AbsReplayMemory(object, metaclass=ABCMeta):
"""Abstract replay memory class with basic interfaces.
@ -176,9 +158,9 @@ class AbsReplayMemory(object, metaclass=ABCMeta):
"""Please refer to the doc string in AbsIndexScheduler."""
return self._idx_scheduler.get_put_indexes(batch_size)
def _get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
"""Please refer to the doc string in AbsIndexScheduler."""
return self._idx_scheduler.get_sample_indexes(batch_size, forbid_last)
return self._idx_scheduler.get_sample_indexes(batch_size)
class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
@ -204,7 +186,8 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
self._states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
self._actions = np.zeros((self._capacity, self._action_dim), dtype=np.float32)
self._rewards = np.zeros(self._capacity, dtype=np.float32)
self._terminals = np.zeros(self._capacity, dtype=np.bool)
self._terminals = np.zeros(self._capacity, dtype=bool)
self._truncated = np.zeros(self._capacity, dtype=bool)
self._next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
self._returns = np.zeros(self._capacity, dtype=np.float32)
self._advantages = np.zeros(self._capacity, dtype=np.float32)
@ -233,6 +216,7 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
assert match_shape(transition_batch.actions, (batch_size, self._action_dim))
assert match_shape(transition_batch.rewards, (batch_size,))
assert match_shape(transition_batch.terminals, (batch_size,))
assert match_shape(transition_batch.truncated, (batch_size,))
assert match_shape(transition_batch.next_states, (batch_size, self._state_dim))
if transition_batch.returns is not None:
match_shape(transition_batch.returns, (batch_size,))
@ -255,6 +239,7 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
self._actions[indexes] = transition_batch.actions
self._rewards[indexes] = transition_batch.rewards
self._terminals[indexes] = transition_batch.terminals
self._truncated[indexes] = transition_batch.truncated
self._next_states[indexes] = transition_batch.next_states
if transition_batch.returns is not None:
self._returns[indexes] = transition_batch.returns
@ -273,7 +258,7 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
Returns:
batch (TransitionBatch): The sampled batch.
"""
indexes = self._get_sample_indexes(batch_size, self._get_forbid_last())
indexes = self._get_sample_indexes(batch_size)
return self.sample_by_indexes(indexes)
def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
@ -292,16 +277,13 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
actions=self._actions[indexes],
rewards=self._rewards[indexes],
terminals=self._terminals[indexes],
truncated=self._truncated[indexes],
next_states=self._next_states[indexes],
returns=self._returns[indexes],
advantages=self._advantages[indexes],
old_logps=self._old_logps[indexes],
)
@abstractmethod
def _get_forbid_last(self) -> bool:
raise NotImplementedError
class RandomReplayMemory(ReplayMemory):
def __init__(
@ -318,15 +300,11 @@ class RandomReplayMemory(ReplayMemory):
RandomIndexScheduler(capacity, random_overwrite),
)
self._random_overwrite = random_overwrite
self._scheduler = RandomIndexScheduler(capacity, random_overwrite)
@property
def random_overwrite(self) -> bool:
return self._random_overwrite
def _get_forbid_last(self) -> bool:
return False
class FIFOReplayMemory(ReplayMemory):
def __init__(
@ -342,9 +320,6 @@ class FIFOReplayMemory(ReplayMemory):
FIFOIndexScheduler(capacity),
)
def _get_forbid_last(self) -> bool:
return not self._terminals[self._idx_scheduler.get_last_index()]
class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
"""In-memory experience storage facility for a multi trainer.
@ -373,7 +348,8 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
self._actions = [np.zeros((self._capacity, action_dim), dtype=np.float32) for action_dim in self._action_dims]
self._rewards = [np.zeros(self._capacity, dtype=np.float32) for _ in range(self.agent_num)]
self._next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
self._terminals = np.zeros(self._capacity, dtype=np.bool)
self._terminals = np.zeros(self._capacity, dtype=bool)
self._truncated = np.zeros(self._capacity, dtype=bool)
assert len(agent_states_dims) == self.agent_num
self._agent_states_dims = agent_states_dims
@ -408,6 +384,7 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
assert match_shape(transition_batch.rewards[i], (batch_size,))
assert match_shape(transition_batch.terminals, (batch_size,))
assert match_shape(transition_batch.truncated, (batch_size,))
assert match_shape(transition_batch.next_states, (batch_size, self._state_dim))
assert len(transition_batch.agent_states) == self.agent_num
@ -430,6 +407,7 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
self._actions[i][indexes] = transition_batch.actions[i]
self._rewards[i][indexes] = transition_batch.rewards[i]
self._terminals[indexes] = transition_batch.terminals
self._truncated[indexes] = transition_batch.truncated
self._next_states[indexes] = transition_batch.next_states
for i in range(self.agent_num):
@ -446,7 +424,7 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
Returns:
batch (MultiTransitionBatch): The sampled batch.
"""
indexes = self._get_sample_indexes(batch_size, self._get_forbid_last())
indexes = self._get_sample_indexes(batch_size)
return self.sample_by_indexes(indexes)
def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:
@ -465,15 +443,12 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
actions=[action[indexes] for action in self._actions],
rewards=[reward[indexes] for reward in self._rewards],
terminals=self._terminals[indexes],
truncated=self._truncated[indexes],
next_states=self._next_states[indexes],
agent_states=[state[indexes] for state in self._agent_states],
next_agent_states=[state[indexes] for state in self._next_agent_states],
)
@abstractmethod
def _get_forbid_last(self) -> bool:
raise NotImplementedError
class RandomMultiReplayMemory(MultiReplayMemory):
def __init__(
@ -492,15 +467,11 @@ class RandomMultiReplayMemory(MultiReplayMemory):
agent_states_dims,
)
self._random_overwrite = random_overwrite
self._scheduler = RandomIndexScheduler(capacity, random_overwrite)
@property
def random_overwrite(self) -> bool:
return self._random_overwrite
def _get_forbid_last(self) -> bool:
return False
class FIFOMultiReplayMemory(MultiReplayMemory):
def __init__(
@ -517,6 +488,3 @@ class FIFOMultiReplayMemory(MultiReplayMemory):
FIFOIndexScheduler(capacity),
agent_states_dims,
)
def _get_forbid_last(self) -> bool:
return not self._terminals[self._idx_scheduler.get_last_index()]

Просмотреть файл

@ -254,6 +254,7 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
exp_element.action_dict[agent_name],
exp_element.reward_dict[agent_name],
exp_element.terminal_dict[agent_name],
exp_element.truncated,
exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]),
),
)
@ -264,7 +265,8 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
actions=np.vstack([exp[1] for exp in exps]),
rewards=np.array([exp[2] for exp in exps]),
terminals=np.array([exp[3] for exp in exps]),
next_states=np.vstack([exp[4] for exp in exps]),
truncated=np.array([exp[4] for exp in exps]),
next_states=np.vstack([exp[5] for exp in exps]),
)
transition_batch = self._preprocess_batch(transition_batch)
self.replay_memory.put(transition_batch)

Просмотреть файл

@ -19,6 +19,7 @@ class TransitionBatch:
rewards: np.ndarray # 1D
next_states: np.ndarray # 2D
terminals: np.ndarray # 1D
truncated: np.ndarray # 1D
returns: np.ndarray = None # 1D
advantages: np.ndarray = None # 1D
old_logps: np.ndarray = None # 1D
@ -34,6 +35,7 @@ class TransitionBatch:
assert len(self.rewards.shape) == 1 and self.rewards.shape[0] == self.states.shape[0]
assert self.next_states.shape == self.states.shape
assert len(self.terminals.shape) == 1 and self.terminals.shape[0] == self.states.shape[0]
assert len(self.truncated.shape) == 1 and self.truncated.shape[0] == self.states.shape[0]
def make_kth_sub_batch(self, i: int, k: int) -> TransitionBatch:
return TransitionBatch(
@ -42,6 +44,7 @@ class TransitionBatch:
rewards=self.rewards[i::k],
next_states=self.next_states[i::k],
terminals=self.terminals[i::k],
truncated=self.truncated[i::k],
returns=self.returns[i::k] if self.returns is not None else None,
advantages=self.advantages[i::k] if self.advantages is not None else None,
old_logps=self.old_logps[i::k] if self.old_logps is not None else None,
@ -60,7 +63,7 @@ class MultiTransitionBatch:
agent_states: List[np.ndarray] # List of 2D
next_agent_states: List[np.ndarray] # List of 2D
terminals: np.ndarray # 1D
truncated: np.ndarray # 1D
returns: Optional[List[np.ndarray]] = None # List of 1D
advantages: Optional[List[np.ndarray]] = None # List of 1D
@ -81,6 +84,7 @@ class MultiTransitionBatch:
assert self.agent_states[i].shape[0] == self.states.shape[0]
assert len(self.terminals.shape) == 1 and self.terminals.shape[0] == self.states.shape[0]
assert len(self.truncated.shape) == 1 and self.truncated.shape[0] == self.states.shape[0]
assert self.next_states.shape == self.states.shape
assert len(self.next_agent_states) == len(self.agent_states)
@ -98,6 +102,7 @@ class MultiTransitionBatch:
agent_states = [state[i::k] for state in self.agent_states]
next_agent_states = [state[i::k] for state in self.next_agent_states]
terminals = self.terminals[i::k]
truncated = self.truncated[i::k]
returns = None if self.returns is None else [r[i::k] for r in self.returns]
advantages = None if self.advantages is None else [advantage[i::k] for advantage in self.advantages]
return MultiTransitionBatch(
@ -108,6 +113,7 @@ class MultiTransitionBatch:
agent_states,
next_agent_states,
terminals,
truncated,
returns,
advantages,
)
@ -123,6 +129,7 @@ def merge_transition_batches(batch_list: List[TransitionBatch]) -> TransitionBat
rewards=np.concatenate([batch.rewards for batch in batch_list], axis=0),
next_states=np.concatenate([batch.next_states for batch in batch_list], axis=0),
terminals=np.concatenate([batch.terminals for batch in batch_list]),
truncated=np.concatenate([batch.truncated for batch in batch_list]),
returns=np.concatenate([batch.returns for batch in batch_list]),
advantages=np.concatenate([batch.advantages for batch in batch_list]),
old_logps=None

Просмотреть файл

@ -0,0 +1,182 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import copy
import os
import typing
from typing import Dict, List, Optional, Union
import pandas as pd
from maro.rl.rollout import AbsEnvSampler, BatchEnvSampler
from maro.rl.training import TrainingManager
from maro.utils import LoggerV2
if typing.TYPE_CHECKING:
from maro.rl.workflows.main import TrainingWorkflow
EnvSampler = Union[AbsEnvSampler, BatchEnvSampler]
class Callback(object):
def __init__(self) -> None:
self.workflow: Optional[TrainingWorkflow] = None
self.env_sampler: Optional[EnvSampler] = None
self.training_manager: Optional[TrainingManager] = None
self.logger: Optional[LoggerV2] = None
def on_episode_start(self, ep: int) -> None:
pass
def on_episode_end(self, ep: int) -> None:
pass
def on_training_start(self, ep: int) -> None:
pass
def on_training_end(self, ep: int) -> None:
pass
def on_validation_start(self, ep: int) -> None:
pass
def on_validation_end(self, ep: int) -> None:
pass
def on_test_start(self, ep: int) -> None:
pass
def on_test_end(self, ep: int) -> None:
pass
class EarlyStopping(Callback):
def __init__(self, patience: int) -> None:
super(EarlyStopping, self).__init__()
self._patience = patience
self._best_ep: int = -1
self._best: float = float("-inf")
def on_validation_end(self, ep: int) -> None:
cur = self.env_sampler.monitor_metrics()
if cur > self._best:
self._best_ep = ep
self._best = cur
self.logger.info(f"Current metric: {cur} @ ep {ep}. Best metric: {self._best} @ ep {self._best_ep}")
if ep - self._best_ep > self._patience:
self.workflow.early_stop = True
self.logger.info(
f"Validation metric has not been updated for {ep - self._best_ep} "
f"epochs (patience = {self._patience} epochs). Early stop.",
)
class Checkpoint(Callback):
def __init__(self, path: str, interval: int) -> None:
super(Checkpoint, self).__init__()
self._path = path
self._interval = interval
def on_training_end(self, ep: int) -> None:
if ep % self._interval == 0:
self.training_manager.save(os.path.join(self._path, str(ep)))
self.logger.info(f"[Episode {ep}] All trainer states saved under {self._path}")
class MetricsRecorder(Callback):
def __init__(self, path: str) -> None:
super(MetricsRecorder, self).__init__()
self._full_metrics: Dict[int, dict] = {}
self._valid_metrics: Dict[int, dict] = {}
self._path = path
def _dump_metric_history(self) -> None:
if len(self._full_metrics) > 0:
metric_list = [self._full_metrics[ep] for ep in sorted(self._full_metrics.keys())]
df = pd.DataFrame.from_records(metric_list)
df.to_csv(os.path.join(self._path, "metrics_full.csv"), index=True)
if len(self._valid_metrics) > 0:
metric_list = [self._valid_metrics[ep] for ep in sorted(self._valid_metrics.keys())]
df = pd.DataFrame.from_records(metric_list)
df.to_csv(os.path.join(self._path, "metrics_valid.csv"), index=True)
def on_training_end(self, ep: int) -> None:
if len(self.env_sampler.metrics) > 0:
metrics = copy.deepcopy(self.env_sampler.metrics)
metrics["ep"] = ep
if ep in self._full_metrics:
self._full_metrics[ep].update(metrics)
else:
self._full_metrics[ep] = metrics
self._dump_metric_history()
def on_validation_end(self, ep: int) -> None:
if len(self.env_sampler.metrics) > 0:
metrics = copy.deepcopy(self.env_sampler.metrics)
metrics["ep"] = ep
if ep in self._full_metrics:
self._full_metrics[ep].update(metrics)
else:
self._full_metrics[ep] = metrics
if ep in self._valid_metrics:
self._valid_metrics[ep].update(metrics)
else:
self._valid_metrics[ep] = metrics
self._dump_metric_history()
class CallbackManager(object):
def __init__(
self,
workflow: TrainingWorkflow,
callbacks: List[Callback],
env_sampler: EnvSampler,
training_manager: TrainingManager,
logger: LoggerV2,
) -> None:
super(CallbackManager, self).__init__()
self._callbacks = callbacks
for callback in self._callbacks:
callback.workflow = workflow
callback.env_sampler = env_sampler
callback.training_manager = training_manager
callback.logger = logger
def on_episode_start(self, ep: int) -> None:
for callback in self._callbacks:
callback.on_episode_start(ep)
def on_episode_end(self, ep: int) -> None:
for callback in self._callbacks:
callback.on_episode_end(ep)
def on_training_start(self, ep: int) -> None:
for callback in self._callbacks:
callback.on_training_start(ep)
def on_training_end(self, ep: int) -> None:
for callback in self._callbacks:
callback.on_training_end(ep)
def on_validation_start(self, ep: int) -> None:
for callback in self._callbacks:
callback.on_validation_start(ep)
def on_validation_end(self, ep: int) -> None:
for callback in self._callbacks:
callback.on_validation_end(ep)
def on_test_start(self, ep: int) -> None:
for callback in self._callbacks:
callback.on_test_start(ep)
def on_test_end(self, ep: int) -> None:
for callback in self._callbacks:
callback.on_test_end(ep)

Просмотреть файл

@ -76,6 +76,11 @@ class ConfigParser:
f"positive ints",
)
early_stop_patience = self._config["main"].get("early_stop_patience", None)
if early_stop_patience is not None:
if not isinstance(early_stop_patience, int) or early_stop_patience <= 0:
raise ValueError(f"Invalid early stop patience: {early_stop_patience}. Should be a positive integer.")
if "logging" in self._config["main"]:
self._validate_logging_section("main", self._config["main"]["logging"])
@ -196,9 +201,10 @@ class ConfigParser:
raise TypeError(f"{self._validation_err_pfx}: 'training.proxy.backend' must be an int")
def _validate_checkpointing_section(self, section: dict) -> None:
if "path" not in section:
raise KeyError(f"{self._validation_err_pfx}: missing field 'path' under section 'checkpointing'")
if not isinstance(section["path"], str):
ckpt_path = section.get("path", None)
if ckpt_path is None:
section["path"] = os.path.join(self._config["log_path"], "checkpoints")
elif not isinstance(section["path"], str):
raise TypeError(f"{self._validation_err_pfx}: 'training.checkpointing.path' must be a string")
if "interval" in section:
@ -231,10 +237,9 @@ class ConfigParser:
local/log/path -> "/logs"
Defaults to False.
"""
log_dir = os.path.dirname(self._config["log_path"])
path_map = {
self._config["scenario_path"]: "/scenario" if containerize else self._config["scenario_path"],
log_dir: "/logs" if containerize else log_dir,
self._config["log_path"]: "/logs" if containerize else self._config["log_path"],
}
load_path = self._config["training"].get("load_path", None)
@ -286,12 +291,16 @@ class ConfigParser:
else:
main_proc_env["EVAL_SCHEDULE"] = " ".join([str(val) for val in sorted(sch)])
main_proc_env["NUM_EVAL_EPISODES"] = str(self._config["main"].get("num_eval_episodes", 1))
if "early_stop_patience" in self._config["main"]:
main_proc_env["EARLY_STOP_PATIENCE"] = str(self._config["main"]["early_stop_patience"])
load_path = self._config["training"].get("load_path", None)
if load_path is not None:
env["main"]["LOAD_PATH"] = path_mapping[load_path]
main_proc_env["LOAD_PATH"] = path_mapping[load_path]
load_episode = self._config["training"].get("load_episode", None)
if load_episode is not None:
env["main"]["LOAD_EPISODE"] = str(load_episode)
main_proc_env["LOAD_EPISODE"] = str(load_episode)
if "checkpointing" in self._config["training"]:
conf = self._config["training"]["checkpointing"]
@ -385,9 +394,8 @@ class ConfigParser:
)
# All components write logs to the same file
log_dir, log_file = os.path.split(self._config["log_path"])
for _, vars in env.values():
vars["LOG_PATH"] = os.path.join(path_mapping[log_dir], log_file)
vars["LOG_PATH"] = path_mapping[self._config["log_path"]]
return env

Просмотреть файл

@ -24,6 +24,8 @@ main:
# A list indicates the episodes at the end of which policies are to be evaluated. Note that episode indexes are
# 1-based.
eval_schedule: 10
early_stop_patience: 10 # Number of epochs waiting for a better validation metrics. Could be `null`.
num_eval_episodes: 10 # Number of Episodes to run in evaluation.
# Minimum number of samples to start training in one epoch. The workflow will re-run experience collection
# until we have at least `min_n_sample` of experiences.
min_n_sample: 1
@ -68,8 +70,9 @@ training:
checkpointing:
# Directory to save trainer snapshots under. Snapshot files created at different episodes will be saved under
# separate folders named using episode numbers. For example, if a snapshot is created for a trainer named "dqn"
# at the end of episode 10, the file path would be "/path/to/your/checkpoint/folder/10/dqn.ckpt".
path: "/path/to/your/checkpoint/folder"
# at the end of episode 10, the file path would be "/path/to/your/checkpoint/folder/10/dqn.ckpt". If null, the
# default checkpoint folder would be created under `log_path`.
path: "/path/to/your/checkpoint/folder" # or `null`
interval: 10 # Interval at which trained policies / models are persisted to disk.
proxy: # Proxy settings. Ignored if training.mode is "simple".
host: "127.0.0.1" # Proxy service host's IP address. Ignored if run in containerized environments.

Просмотреть файл

@ -14,21 +14,21 @@ from maro.rl.training import TrainingManager
from maro.rl.utils import get_torch_device
from maro.rl.utils.common import float_or_none, get_env, int_or_none, list_or_none
from maro.rl.utils.training import get_latest_ep
from maro.rl.workflows.utils import env_str_helper
from maro.utils import LoggerV2
from maro.rl.workflows.callback import CallbackManager, Checkpoint, EarlyStopping, MetricsRecorder
from maro.utils import LoggerV2, set_seeds
class WorkflowEnvAttributes:
def __init__(self) -> None:
# Number of training episodes
self.num_episodes = int(env_str_helper(get_env("NUM_EPISODES")))
self.num_episodes = int(get_env("NUM_EPISODES"))
# Maximum number of steps in on round of sampling.
self.num_steps = int_or_none(get_env("NUM_STEPS", required=False))
# Minimum number of data samples to start a round of training. If the data samples are insufficient, re-run
# data sampling until we have at least `min_n_sample` data entries.
self.min_n_sample = int(env_str_helper(get_env("MIN_N_SAMPLE")))
self.min_n_sample = int(get_env("MIN_N_SAMPLE"))
# Path to store logs.
self.log_path = get_env("LOG_PATH")
@ -46,6 +46,8 @@ class WorkflowEnvAttributes:
# Evaluating schedule.
self.eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False))
self.early_stop_patience = int_or_none(get_env("EARLY_STOP_PATIENCE", required=False))
self.num_eval_episodes = int_or_none(get_env("NUM_EVAL_EPISODES", required=False))
# Restore configurations.
self.load_path = get_env("LOAD_PATH", required=False)
@ -58,7 +60,7 @@ class WorkflowEnvAttributes:
# Parallel sampling configurations.
self.parallel_rollout = self.env_sampling_parallelism is not None or self.env_eval_parallelism is not None
if self.parallel_rollout:
self.port = int(env_str_helper(get_env("ROLLOUT_CONTROLLER_PORT")))
self.port = int(get_env("ROLLOUT_CONTROLLER_PORT"))
self.min_env_samples = int_or_none(get_env("MIN_ENV_SAMPLES", required=False))
self.grace_factor = float_or_none(get_env("GRACE_FACTOR", required=False))
@ -67,13 +69,13 @@ class WorkflowEnvAttributes:
# Distributed training configurations.
if self.train_mode != "simple":
self.proxy_address = (
env_str_helper(get_env("TRAIN_PROXY_HOST")),
int(env_str_helper(get_env("TRAIN_PROXY_FRONTEND_PORT"))),
str(get_env("TRAIN_PROXY_HOST")),
int(get_env("TRAIN_PROXY_FRONTEND_PORT")),
)
self.logger = LoggerV2(
"MAIN",
dump_path=self.log_path,
dump_path=os.path.join(self.log_path, "log.txt"),
dump_mode="a",
stdout_level=self.log_level_stdout,
file_level=self.log_level_file,
@ -83,6 +85,7 @@ class WorkflowEnvAttributes:
def _get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="MARO RL workflow parser")
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
parser.add_argument("--seed", type=int, help="The random seed set before running this job")
return parser.parse_args()
@ -112,88 +115,112 @@ def main(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes
if args.evaluate_only:
evaluate_only_workflow(rl_component_bundle, env_attr)
else:
training_workflow(rl_component_bundle, env_attr)
TrainingWorkflow().run(rl_component_bundle, env_attr)
def training_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
env_attr.logger.info("Start training workflow.")
class TrainingWorkflow(object):
def run(self, rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
env_attr.logger.info("Start training workflow.")
env_sampler = _get_env_sampler(rl_component_bundle, env_attr)
env_sampler = _get_env_sampler(rl_component_bundle, env_attr)
# evaluation schedule
env_attr.logger.info(f"Policy will be evaluated at the end of episodes {env_attr.eval_schedule}")
eval_point_index = 0
# evaluation schedule
env_attr.logger.info(f"Policy will be evaluated at the end of episodes {env_attr.eval_schedule}")
eval_point_index = 0
training_manager = TrainingManager(
rl_component_bundle=rl_component_bundle,
explicit_assign_device=(env_attr.train_mode == "simple"),
proxy_address=None if env_attr.train_mode == "simple" else env_attr.proxy_address,
logger=env_attr.logger,
)
if env_attr.load_path:
assert isinstance(env_attr.load_path, str)
ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path)
path = os.path.join(env_attr.load_path, str(ep))
loaded = env_sampler.load_policy_state(path)
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
loaded = training_manager.load(path)
env_attr.logger.info(f"Loaded trainers {loaded} from {path}")
start_ep = ep + 1
else:
start_ep = 1
# main loop
for ep in range(start_ep, env_attr.num_episodes + 1):
collect_time = training_time = 0.0
total_experiences: List[List[ExpElement]] = []
total_info_list: List[dict] = []
n_sample = 0
while n_sample < env_attr.min_n_sample:
tc0 = time.time()
result = env_sampler.sample(
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
num_steps=env_attr.num_steps,
)
experiences: List[List[ExpElement]] = result["experiences"]
info_list: List[dict] = result["info"]
n_sample += len(experiences[0])
total_experiences.extend(experiences)
total_info_list.extend(info_list)
collect_time += time.time() - tc0
env_sampler.post_collect(total_info_list, ep)
env_attr.logger.info(f"Roll-out completed for episode {ep}. Training started...")
tu0 = time.time()
training_manager.record_experiences(total_experiences)
training_manager.train_step()
if env_attr.checkpoint_path and (not env_attr.checkpoint_interval or ep % env_attr.checkpoint_interval == 0):
assert isinstance(env_attr.checkpoint_path, str)
pth = os.path.join(env_attr.checkpoint_path, str(ep))
training_manager.save(pth)
env_attr.logger.info(f"All trainer states saved under {pth}")
training_time += time.time() - tu0
# performance details
env_attr.logger.info(
f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds",
training_manager = TrainingManager(
rl_component_bundle=rl_component_bundle,
explicit_assign_device=(env_attr.train_mode == "simple"),
proxy_address=None if env_attr.train_mode == "simple" else env_attr.proxy_address,
logger=env_attr.logger,
)
if env_attr.eval_schedule and ep == env_attr.eval_schedule[eval_point_index]:
eval_point_index += 1
result = env_sampler.eval(
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
)
env_sampler.post_evaluate(result["info"], ep)
if isinstance(env_sampler, BatchEnvSampler):
env_sampler.exit()
training_manager.exit()
callbacks = [MetricsRecorder(path=env_attr.log_path)]
if env_attr.checkpoint_path is not None:
callbacks.append(
Checkpoint(
path=env_attr.checkpoint_path,
interval=1 if env_attr.checkpoint_interval is None else env_attr.checkpoint_interval,
),
)
if env_attr.early_stop_patience is not None:
callbacks.append(EarlyStopping(patience=env_attr.early_stop_patience))
callbacks.extend(rl_component_bundle.customized_callbacks)
cbm = CallbackManager(self, callbacks, env_sampler, training_manager, env_attr.logger)
if env_attr.load_path:
assert isinstance(env_attr.load_path, str)
ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path)
path = os.path.join(env_attr.load_path, str(ep))
loaded = env_sampler.load_policy_state(path)
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
loaded = training_manager.load(path)
env_attr.logger.info(f"Loaded trainers {loaded} from {path}")
start_ep = ep + 1
else:
start_ep = 1
# main loop
self.early_stop = False
for ep in range(start_ep, env_attr.num_episodes + 1):
if self.early_stop: # Might be set in `cbm.on_validation_end()`
break
cbm.on_episode_start(ep)
collect_time = training_time = 0.0
total_experiences: List[List[ExpElement]] = []
total_info_list: List[dict] = []
n_sample = 0
while n_sample < env_attr.min_n_sample:
tc0 = time.time()
result = env_sampler.sample(
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
num_steps=env_attr.num_steps,
)
experiences: List[List[ExpElement]] = result["experiences"]
info_list: List[dict] = result["info"]
n_sample += len(experiences[0])
total_experiences.extend(experiences)
total_info_list.extend(info_list)
collect_time += time.time() - tc0
env_sampler.post_collect(total_info_list, ep)
tu0 = time.time()
env_attr.logger.info(f"Roll-out completed for episode {ep}. Training started...")
cbm.on_training_start(ep)
training_manager.record_experiences(total_experiences)
training_manager.train_step()
cbm.on_training_end(ep)
training_time += time.time() - tu0
# performance details
env_attr.logger.info(
f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds",
)
if env_attr.eval_schedule and ep == env_attr.eval_schedule[eval_point_index]:
cbm.on_validation_start(ep)
eval_point_index += 1
result = env_sampler.eval(
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
num_episodes=env_attr.num_eval_episodes,
)
env_sampler.post_evaluate(result["info"], ep)
cbm.on_validation_end(ep)
cbm.on_episode_end(ep)
if isinstance(env_sampler, BatchEnvSampler):
env_sampler.exit()
training_manager.exit()
def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
@ -210,7 +237,7 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: Wor
loaded = env_sampler.load_policy_state(path)
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
result = env_sampler.eval()
result = env_sampler.eval(num_episodes=env_attr.num_eval_episodes)
env_sampler.post_evaluate(result["info"], -1)
if isinstance(env_sampler, BatchEnvSampler):
@ -218,9 +245,13 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: Wor
if __name__ == "__main__":
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
args = _get_args()
if args.seed is not None:
set_seeds(seed=args.seed)
scenario_path = get_env("SCENARIO_PATH")
scenario_path = os.path.normpath(scenario_path)
sys.path.insert(0, os.path.dirname(scenario_path))
module = importlib.import_module(os.path.basename(scenario_path))
main(getattr(module, "rl_component_bundle"), WorkflowEnvAttributes(), args=_get_args())
main(getattr(module, "rl_component_bundle"), WorkflowEnvAttributes(), args=args)

Просмотреть файл

@ -8,21 +8,20 @@ import sys
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import RolloutWorker
from maro.rl.utils.common import get_env, int_or_none
from maro.rl.workflows.utils import env_str_helper
from maro.utils import LoggerV2
if __name__ == "__main__":
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
scenario_path = get_env("SCENARIO_PATH")
scenario_path = os.path.normpath(scenario_path)
sys.path.insert(0, os.path.dirname(scenario_path))
module = importlib.import_module(os.path.basename(scenario_path))
rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle")
worker_idx = int(env_str_helper(get_env("ID")))
worker_idx = int(get_env("ID"))
logger = LoggerV2(
f"ROLLOUT-WORKER.{worker_idx}",
dump_path=get_env("LOG_PATH"),
dump_path=os.path.join(get_env("LOG_PATH"), f"ROLLOUT-WORKER.{worker_idx}.txt"),
dump_mode="a",
stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"),
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
@ -30,7 +29,7 @@ if __name__ == "__main__":
worker = RolloutWorker(
idx=worker_idx,
rl_component_bundle=rl_component_bundle,
producer_host=env_str_helper(get_env("ROLLOUT_CONTROLLER_HOST")),
producer_host=get_env("ROLLOUT_CONTROLLER_HOST"),
producer_port=int_or_none(get_env("ROLLOUT_CONTROLLER_PORT")),
logger=logger,
)

Просмотреть файл

@ -8,11 +8,10 @@ import sys
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.training import TrainOpsWorker
from maro.rl.utils.common import get_env, int_or_none
from maro.rl.workflows.utils import env_str_helper
from maro.utils import LoggerV2
if __name__ == "__main__":
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
scenario_path = get_env("SCENARIO_PATH")
scenario_path = os.path.normpath(scenario_path)
sys.path.insert(0, os.path.dirname(scenario_path))
module = importlib.import_module(os.path.basename(scenario_path))
@ -22,15 +21,15 @@ if __name__ == "__main__":
worker_idx = int_or_none(get_env("ID"))
logger = LoggerV2(
f"TRAIN-WORKER.{worker_idx}",
dump_path=get_env("LOG_PATH"),
dump_path=os.path.join(get_env("LOG_PATH"), f"TRAIN-WORKER.{worker_idx}.txt"),
dump_mode="a",
stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"),
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
)
worker = TrainOpsWorker(
idx=int(env_str_helper(get_env("ID"))),
idx=int(get_env("ID")),
rl_component_bundle=rl_component_bundle,
producer_host=env_str_helper(get_env("TRAIN_PROXY_HOST")),
producer_host=get_env("TRAIN_PROXY_HOST"),
producer_port=int_or_none(get_env("TRAIN_PROXY_BACKEND_PORT")),
logger=logger,
)

Просмотреть файл

@ -1,9 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
def env_str_helper(string: Optional[str]) -> str:
assert string is not None
return string

Просмотреть файл

@ -61,7 +61,7 @@ def is_float_type(v_type: type):
Returns:
bool: True if an float type.
"""
return v_type is float or v_type is np.float or v_type is np.float32 or v_type is np.float64
return v_type is float or v_type is np.float16 or v_type is np.float32 or v_type is np.float64
def parse_value(value: object):

Просмотреть файл

@ -6,7 +6,7 @@ deepdiff>=5.7.0
geopy>=2.0.0
holidays>=0.10.3
kubernetes>=21.7.0
numpy>=1.19.5,<1.24.0
numpy>=1.19.5
pandas>=0.25.3
paramiko>=2.9.2
pytest>=7.1.2

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
from maro.simulator import Env
from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine
env_conf = {
"topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4
"start_tick": 0,
"durations": 100000, # Set a very large number
"options": {},
}
learn_env = Env(business_engine_cls=GymBusinessEngine, **env_conf)
test_env = Env(business_engine_cls=GymBusinessEngine, **env_conf)
num_agents = len(learn_env.agent_idx_list)
gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env
gym_action_space = gym_env.action_space
gym_state_dim = gym_env.observation_space.shape[0]
gym_action_dim = gym_action_space.shape[0]
action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high
action_limit = gym_action_space.high[0]

Просмотреть файл

@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, List, Tuple, Type, Union
import numpy as np
from maro.rl.policy.abs_policy import AbsPolicy
from maro.rl.rollout import AbsEnvSampler, CacheElement
from maro.rl.rollout.env_sampler import AbsAgentWrapper, SimpleAgentWrapper
from maro.simulator.core import Env
from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine
from tests.rl.gym_wrapper.simulator.common import Action, DecisionEvent
class GymEnvSampler(AbsEnvSampler):
def __init__(
self,
learn_env: Env,
test_env: Env,
policies: List[AbsPolicy],
agent2policy: Dict[Any, str],
trainable_policies: List[str] = None,
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
reward_eval_delay: int = None,
max_episode_length: int = None,
) -> None:
super(GymEnvSampler, self).__init__(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
trainable_policies=trainable_policies,
agent_wrapper_cls=agent_wrapper_cls,
reward_eval_delay=reward_eval_delay,
max_episode_length=max_episode_length,
)
self._sample_rewards = []
self._eval_rewards = []
def _get_global_and_agent_state_impl(
self,
event: DecisionEvent,
tick: int = None,
) -> Tuple[Union[None, np.ndarray, list], Dict[Any, Union[np.ndarray, list]]]:
return None, {0: event.state}
def _translate_to_env_action(self, action_dict: dict, event: Any) -> dict:
return {k: Action(v) for k, v in action_dict.items()}
def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]:
be = self._env.business_engine
assert isinstance(be, GymBusinessEngine)
return {0: be.get_reward_at_tick(tick)}
def _post_step(self, cache_element: CacheElement) -> None:
if not (self._end_of_episode or self.truncated):
return
rewards = list(self._env.metrics["reward_record"].values())
self._sample_rewards.append((len(rewards), np.sum(rewards)))
def _post_eval_step(self, cache_element: CacheElement) -> None:
if not (self._end_of_episode or self.truncated):
return
rewards = list(self._env.metrics["reward_record"].values())
self._eval_rewards.append((len(rewards), np.sum(rewards)))
def post_collect(self, info_list: list, ep: int) -> None:
if len(self._sample_rewards) > 0:
cur = {
"n_steps": sum([n for n, _ in self._sample_rewards]),
"n_segment": len(self._sample_rewards),
"avg_reward": np.mean([r for _, r in self._sample_rewards]),
"avg_n_steps": np.mean([n for n, _ in self._sample_rewards]),
"max_n_steps": np.max([n for n, _ in self._sample_rewards]),
"n_interactions": self._total_number_interactions,
}
self.metrics.update(cur)
# clear validation metrics
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
self._sample_rewards.clear()
else:
self.metrics = {"n_interactions": self._total_number_interactions}
def post_evaluate(self, info_list: list, ep: int) -> None:
if len(self._eval_rewards) > 0:
cur = {
"val/n_steps": sum([n for n, _ in self._eval_rewards]),
"val/n_segment": len(self._eval_rewards),
"val/avg_reward": np.mean([r for _, r in self._eval_rewards]),
"val/avg_n_steps": np.mean([n for n, _ in self._eval_rewards]),
"val/max_n_steps": np.max([n for n, _ in self._eval_rewards]),
}
self.metrics.update(cur)
self._eval_rewards.clear()
else:
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -0,0 +1,102 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Optional, cast
import gym
import numpy as np
from maro.backends.frame import FrameBase, SnapshotList
from maro.event_buffer import CascadeEvent, EventBuffer, MaroEvents
from maro.simulator.scenarios import AbsBusinessEngine
from .common import Action, DecisionEvent
class GymBusinessEngine(AbsBusinessEngine):
def __init__(
self,
event_buffer: EventBuffer,
topology: Optional[str],
start_tick: int,
max_tick: int,
snapshot_resolution: int,
max_snapshots: Optional[int],
additional_options: dict = None,
) -> None:
super(GymBusinessEngine, self).__init__(
scenario_name="gym",
event_buffer=event_buffer,
topology=topology,
start_tick=start_tick,
max_tick=max_tick,
snapshot_resolution=snapshot_resolution,
max_snapshots=max_snapshots,
additional_options=additional_options,
)
self._gym_scenario_name = topology
self._gym_env = gym.make(self._gym_scenario_name)
self.reset()
self._frame: FrameBase = FrameBase()
self._snapshots: SnapshotList = self._frame.snapshots
self._register_events()
@property
def gym_env(self) -> gym.Env:
return self._gym_env
@property
def frame(self) -> FrameBase:
return self._frame
@property
def snapshots(self) -> SnapshotList:
return self._snapshots
def _register_events(self) -> None:
self._event_buffer.register_event_handler(MaroEvents.TAKE_ACTION, self._on_action_received)
def _on_action_received(self, event: CascadeEvent) -> None:
action = cast(Action, cast(list, event.payload)[0]).action
self._last_obs, reward, self._is_done, self._truncated, info = self._gym_env.step(action)
self._reward_record[event.tick] = reward
self._info_record[event.tick] = info
def step(self, tick: int) -> None:
self._event_buffer.insert_event(self._event_buffer.gen_decision_event(tick, DecisionEvent(self._last_obs)))
@property
def configs(self) -> dict:
return {}
def get_reward_at_tick(self, tick: int) -> float:
return self._reward_record[tick]
def get_info_at_tick(self, tick: int) -> object: # TODO
return self._info_record[tick]
def reset(self, keep_seed: bool = False) -> None:
self._last_obs = self._gym_env.reset(seed=np.random.randint(low=0, high=4096))[0]
self._is_done = False
self._truncated = False
self._reward_record = {}
self._info_record = {}
def post_step(self, tick: int) -> bool:
return self._is_done or self._truncated or tick + 1 == self._max_tick
def get_agent_idx_list(self) -> List[int]:
return [0]
def get_metrics(self) -> dict:
return {
"reward_record": {k: v for k, v in self._reward_record.items()},
}
def set_seed(self, seed: int) -> None:
pass

Просмотреть файл

@ -0,0 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
from maro.common import BaseAction, BaseDecisionEvent
class Action(BaseAction):
def __init__(self, action: np.ndarray) -> None:
self.action = action
class DecisionEvent(BaseDecisionEvent):
def __init__(self, state: np.ndarray) -> None:
self.state = state

Двоичные данные
tests/rl/log/Ant_1.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 139 KiB

Двоичные данные
tests/rl/log/Ant_11.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 101 KiB

Двоичные данные
tests/rl/log/HalfCheetah_1.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 131 KiB

Двоичные данные
tests/rl/log/HalfCheetah_11.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 86 KiB

Двоичные данные
tests/rl/log/Hopper_1.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 172 KiB

Двоичные данные
tests/rl/log/Hopper_11.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 130 KiB

Двоичные данные
tests/rl/log/Swimmer_1.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 120 KiB

Двоичные данные
tests/rl/log/Swimmer_11.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 81 KiB

Двоичные данные
tests/rl/log/Walker2d_1.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 165 KiB

Двоичные данные
tests/rl/log/Walker2d_11.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 113 KiB

54
tests/rl/performance.md Normal file
Просмотреть файл

@ -0,0 +1,54 @@
# Performance for Gym Task Suite
We benchmarked the MARO RL Toolkit implementation in Gym task suite. Some are compared to the benchmarks in
[OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#). We've tried to align the
hyper-parameters for these benchmarks , but limited by the environment version difference, there may be some gaps
between the performance here and that in Spinning Up benchmarks. Generally speaking, the performance is comparable.
## Experimental Setting
The hyper-parameters are set to align with those used in
[Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#experiment-details):
**Batch Size**:
- For on-policy algorithms: 4000 steps of interaction per batch update;
- For off-policy algorithms: size 100 for each gradient descent step;
**Network**:
- For on-policy algorithms: size (64, 32) with tanh units for both policy and value function;
- For off-policy algorithms: size (256, 256) with relu units;
**Performance metric**:
- For on-policy algorithms: measured as the average trajectory return across the batch collected at each epoch;
- For off-policy algorithms: measured once every 10,000 steps by running the deterministic policy (or, in the case of SAC, the mean policy) without action noise for ten trajectories, and reporting the average return over those test trajectories;
**Total timesteps**: set to 4M for all task suites and algorithms.
More details about the parameters can be found in *tests/rl/tasks/*.
## Performance
Five environments from the MuJoCo Gym task suite are reported in Spinning Up, they are: HalfCheetah, Hopper, Walker2d,
Swimmer, and Ant. The commit id of the code used to conduct the experiments for MARO RL benchmarks is ee25ce1e97.
The commands used are:
```sh
# Step 1: Set up the MuJoCo Environment in file tests/rl/gym_wrapper/common.py
# Step 2: Use the command below to run experiment with ALGORITHM (ddpg, ppo, sac) and random seed SEED.
python tests/rl/run.py tests/rl/tasks/ALGORITHM/config.yml --seed SEED
# Step 3: Plot performance curves by environment with specific smooth window size WINDOWSIZE.
python tests/rl/plot.py --smooth WINDOWSIZE
```
| **Env** | **Spinning Up** | **MARO RL w/o Smooth** | **MARO RL w/ Smooth** |
|:---------------:|:---------------:|:----------------------:|:---------------------:|
| [**HalfCheetah**](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) | ![Hab](https://spinningup.openai.com/en/latest/_images/pytorch_halfcheetah_performance.svg) | ![Ha1](./log/HalfCheetah_1.png) | ![Ha11](./log/HalfCheetah_11.png) |
| [**Hopper**](https://gymnasium.farama.org/environments/mujoco/hopper/) | ![Hob](https://spinningup.openai.com/en/latest/_images/pytorch_hopper_performance.svg) | ![Ho1](./log/Hopper_1.png) | ![Ho11](./log/Hopper_11.png) |
| [**Walker2d**](https://gymnasium.farama.org/environments/mujoco/walker2d/) | ![Wab](https://spinningup.openai.com/en/latest/_images/pytorch_walker2d_performance.svg) | ![Wa1](./log/Walker2d_1.png) | ![Wa11](./log/Walker2d_11.png) |
| [**Swimmer**](https://gymnasium.farama.org/environments/mujoco/swimmer/) | ![Swb](https://spinningup.openai.com/en/latest/_images/pytorch_swimmer_performance.svg) | ![Sw1](./log/Swimmer_1.png) | ![Sw11](./log/Swimmer_11.png) |
| [**Ant**](https://gymnasium.farama.org/environments/mujoco/ant/) | ![Anb](https://spinningup.openai.com/en/latest/_images/pytorch_ant_performance.svg) | ![An1](./log/Ant_1.png) | ![An11](./log/Ant_11.png) |

100
tests/rl/plot.py Normal file
Просмотреть файл

@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import os
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
LOG_DIR = "tests/rl/log"
color_map = {
"ppo": "green",
"sac": "goldenrod",
"ddpg": "firebrick",
"vpg": "cornflowerblue",
"td3": "mediumpurple",
}
def smooth(data: np.ndarray, window_size: int) -> np.ndarray:
if window_size > 1:
"""
smooth data with moving window average.
that is,
smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
where the "smooth" param is width of that window (2k+1)
"""
y = np.ones(window_size)
x = np.asarray(data)
z = np.ones_like(x)
smoothed_x = np.convolve(x, y, "same") / np.convolve(z, y, "same")
return smoothed_x
else:
return data
def get_off_policy_data(log_dir: str) -> Tuple[np.ndarray, np.ndarray]:
file_path = os.path.join(log_dir, "metrics_full.csv")
df = pd.read_csv(file_path)
x, y = df["n_interactions"], df["val/avg_reward"]
mask = ~np.isnan(y)
x, y = x[mask], y[mask]
return x, y
def get_on_policy_data(log_dir: str) -> Tuple[np.ndarray, np.ndarray]:
file_path = os.path.join(log_dir, "metrics_full.csv")
df = pd.read_csv(file_path)
x, y = df["n_interactions"], df["avg_reward"]
return x, y
def plot_performance_curves(title: str, dir_names: List[str], smooth_window_size: int) -> None:
for algorithm in color_map.keys():
if algorithm in ["ddpg", "sac", "td3"]:
func = get_off_policy_data
elif algorithm in ["ppo", "vpg"]:
func = get_on_policy_data
log_dirs = [os.path.join(LOG_DIR, name) for name in dir_names if algorithm in name]
series = [func(log_dir) for log_dir in log_dirs if os.path.exists(log_dir)]
if len(series) == 0:
continue
x = series[0][0]
assert all(len(_x) == len(x) for _x, _ in series), f"Input data should share the same length!"
ys = np.array([smooth(y, smooth_window_size) for _, y in series])
y_mean = np.mean(ys, axis=0)
y_std = np.std(ys, axis=0)
plt.plot(x, y_mean, label=algorithm, color=color_map[algorithm])
plt.fill_between(x, y_mean - y_std, y_mean + y_std, color=color_map[algorithm], alpha=0.2)
plt.legend()
plt.grid()
plt.title(title)
plt.xlabel("Total Env Interactions")
plt.ylabel(f"Average Trajectory Return \n(moving average with window size = {smooth_window_size})")
plt.savefig(os.path.join(LOG_DIR, f"{title}_{smooth_window_size}.png"), bbox_inches="tight")
plt.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--smooth", "-s", type=int, default=11, help="smooth window size")
args = parser.parse_args()
for env_name in ["HalfCheetah", "Hopper", "Walker2d", "Swimmer", "Ant"]:
plot_performance_curves(
title=env_name,
dir_names=[
f"{algorithm}_{env_name.lower()}_{seed}"
for algorithm in ["ppo", "sac", "ddpg"]
for seed in [42, 729, 1024, 2023, 3500]
],
smooth_window_size=args.smooth,
)

19
tests/rl/run.py Normal file
Просмотреть файл

@ -0,0 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
from maro.cli.local.commands import run
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("conf_path", help="Path of the job deployment")
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
parser.add_argument("--seed", type=int, help="The random seed set before running this job")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
run(conf_path=args.conf_path, containerize=False, seed=args.seed, evaluate_only=args.evaluate_only)

Просмотреть файл

@ -0,0 +1,138 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Tuple
import numpy as np
import torch
from torch.distributions import Normal
from torch.optim import Adam
from maro.rl.model import ContinuousACBasedNet, VNet
from maro.rl.model.fc_block import FullyConnected
from maro.rl.policy import ContinuousRLPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
from tests.rl.gym_wrapper.common import (
action_lower_bound,
action_upper_bound,
gym_action_dim,
gym_state_dim,
learn_env,
num_agents,
test_env,
)
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
actor_net_conf = {
"hidden_dims": [64, 32],
"activation": torch.nn.Tanh,
}
critic_net_conf = {
"hidden_dims": [64, 32],
"activation": torch.nn.Tanh,
}
actor_learning_rate = 3e-4
critic_learning_rate = 1e-3
class MyContinuousACBasedNet(ContinuousACBasedNet):
def __init__(self, state_dim: int, action_dim: int) -> None:
super(MyContinuousACBasedNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
log_std = -0.5 * np.ones(action_dim, dtype=np.float32)
self._log_std = torch.nn.Parameter(torch.as_tensor(log_std))
self._mu_net = FullyConnected(
input_dim=state_dim,
hidden_dims=actor_net_conf["hidden_dims"],
output_dim=action_dim,
activation=actor_net_conf["activation"],
)
self._optim = Adam(self.parameters(), lr=actor_learning_rate)
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
distribution = self._distribution(states)
actions = distribution.sample()
logps = distribution.log_prob(actions).sum(axis=-1)
return actions, logps
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
distribution = self._distribution(states)
logps = distribution.log_prob(actions).sum(axis=-1)
return logps
def _distribution(self, states: torch.Tensor) -> Normal:
mu = self._mu_net(states.float())
std = torch.exp(self._log_std)
return Normal(mu, std)
class MyVCriticNet(VNet):
def __init__(self, state_dim: int) -> None:
super(MyVCriticNet, self).__init__(state_dim=state_dim)
self._critic = FullyConnected(
input_dim=state_dim,
output_dim=1,
hidden_dims=critic_net_conf["hidden_dims"],
activation=critic_net_conf["activation"],
)
self._optim = Adam(self._critic.parameters(), lr=critic_learning_rate)
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
return self._critic(states.float()).squeeze(-1)
def get_ac_policy(
name: str,
action_lower_bound: list,
action_upper_bound: list,
gym_state_dim: int,
gym_action_dim: int,
) -> ContinuousRLPolicy:
return ContinuousRLPolicy(
name=name,
action_range=(action_lower_bound, action_upper_bound),
policy_net=MyContinuousACBasedNet(gym_state_dim, gym_action_dim),
)
def get_ac_trainer(name: str, state_dim: int) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
reward_discount=0.99,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyVCriticNet(state_dim),
grad_iters=80,
lam=0.97,
),
)
algorithm = "ac"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [
get_ac_policy(f"{algorithm}_{i}.policy", action_lower_bound, action_upper_bound, gym_state_dim, gym_action_dim)
for i in range(num_agents)
]
trainers = [get_ac_trainer(f"{algorithm}_{i}", gym_state_dim) for i in range(num_agents)]
device_mapping = None
if torch.cuda.is_available():
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)}
rl_component_bundle = RLComponentBundle(
env_sampler=GymEnvSampler(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
),
agent2policy=agent2policy,
policies=policies,
trainers=trainers,
device_mapping=device_mapping,
)
__all__ = ["rl_component_bundle"]

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for GYM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
job: gym_rl_workflow
scenario_path: "tests/rl/tasks/ac"
log_path: "tests/rl/log/ac"
main:
num_episodes: 1000
num_steps: null
eval_schedule: 5
num_eval_episodes: 10
min_n_sample: 5000
logging:
stdout: INFO
file: DEBUG
rollout:
logging:
stdout: INFO
file: DEBUG
training:
mode: simple
load_path: null
load_episode: null
checkpointing:
path: null
interval: 5
logging:
stdout: INFO
file: DEBUG

Просмотреть файл

@ -0,0 +1,159 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from gym import spaces
from torch.optim import Adam
from maro.rl.model import QNet
from maro.rl.model.algorithm_nets.ddpg import ContinuousDDPGNet
from maro.rl.model.fc_block import FullyConnected
from maro.rl.policy import ContinuousRLPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.training.algorithms import DDPGParams, DDPGTrainer
from maro.rl.utils import ndarray_to_tensor
from tests.rl.gym_wrapper.common import (
action_limit,
action_lower_bound,
action_upper_bound,
gym_action_dim,
gym_action_space,
gym_state_dim,
learn_env,
num_agents,
test_env,
)
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
actor_net_conf = {
"hidden_dims": [256, 256],
"activation": torch.nn.ReLU,
"output_activation": torch.nn.Tanh,
}
critic_net_conf = {
"hidden_dims": [256, 256],
"activation": torch.nn.ReLU,
}
actor_learning_rate = 1e-3
critic_learning_rate = 1e-3
class MyContinuousDDPGNet(ContinuousDDPGNet):
def __init__(
self,
state_dim: int,
action_dim: int,
action_limit: float,
action_space: spaces.Space,
noise_scale: float = 0.1,
) -> None:
super(MyContinuousDDPGNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
self._net = FullyConnected(
input_dim=state_dim,
output_dim=action_dim,
hidden_dims=actor_net_conf["hidden_dims"],
activation=actor_net_conf["activation"],
output_activation=actor_net_conf["output_activation"],
)
self._optim = Adam(self._net.parameters(), lr=critic_learning_rate)
self._action_limit = action_limit
self._noise_scale = noise_scale
self._action_space = action_space
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
action = self._net(states) * self._action_limit
if exploring:
noise = torch.randn(self.action_dim) * self._noise_scale
action += noise.to(action.device)
action = torch.clamp(action, -self._action_limit, self._action_limit)
return action
def _get_random_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
return torch.stack(
[ndarray_to_tensor(self._action_space.sample(), device=self._device) for _ in range(states.shape[0])],
)
class MyQCriticNet(QNet):
def __init__(self, state_dim: int, action_dim: int) -> None:
super(MyQCriticNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
self._critic = FullyConnected(
input_dim=state_dim + action_dim,
output_dim=1,
hidden_dims=critic_net_conf["hidden_dims"],
activation=critic_net_conf["activation"],
)
self._optim = Adam(self._critic.parameters(), lr=critic_learning_rate)
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self._critic(torch.cat([states, actions], dim=1).float()).squeeze(-1)
def get_ddpg_policy(
name: str,
action_lower_bound: list,
action_upper_bound: list,
gym_state_dim: int,
gym_action_dim: int,
action_limit: float,
) -> ContinuousRLPolicy:
return ContinuousRLPolicy(
name=name,
action_range=(action_lower_bound, action_upper_bound),
policy_net=MyContinuousDDPGNet(gym_state_dim, gym_action_dim, action_limit, gym_action_space),
warmup=10000,
)
def get_ddpg_trainer(name: str, state_dim: int, action_dim: int) -> DDPGTrainer:
return DDPGTrainer(
name=name,
reward_discount=0.99,
replay_memory_capacity=1000000,
batch_size=100,
params=DDPGParams(
get_q_critic_net_func=lambda: MyQCriticNet(state_dim, action_dim),
num_epochs=50,
n_start_train=1000,
soft_update_coef=0.005,
update_target_every=1,
),
)
algorithm = "ddpg"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [
get_ddpg_policy(
f"{algorithm}_{i}.policy",
action_lower_bound,
action_upper_bound,
gym_state_dim,
gym_action_dim,
action_limit,
)
for i in range(num_agents)
]
trainers = [get_ddpg_trainer(f"{algorithm}_{i}", gym_state_dim, gym_action_dim) for i in range(num_agents)]
device_mapping = None
if torch.cuda.is_available():
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)}
rl_component_bundle = RLComponentBundle(
env_sampler=GymEnvSampler(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
),
agent2policy=agent2policy,
policies=policies,
trainers=trainers,
device_mapping=device_mapping,
)
__all__ = ["rl_component_bundle"]

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for GYM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
job: gym_rl_workflow
scenario_path: "tests/rl/tasks/ddpg"
log_path: "tests/rl/log/ddpg_walker2d"
main:
num_episodes: 80000
num_steps: 50
eval_schedule: 200
num_eval_episodes: 10
min_n_sample: 1
logging:
stdout: INFO
file: DEBUG
rollout:
logging:
stdout: INFO
file: DEBUG
training:
mode: simple
load_path: null
load_episode: null
checkpointing:
path: null
interval: 200
logging:
stdout: INFO
file: DEBUG

Просмотреть файл

@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.training.algorithms.ppo import PPOParams, PPOTrainer
from tests.rl.gym_wrapper.common import (
action_lower_bound,
action_upper_bound,
gym_action_dim,
gym_state_dim,
learn_env,
num_agents,
test_env,
)
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
from tests.rl.tasks.ac import MyVCriticNet, get_ac_policy
get_ppo_policy = get_ac_policy
def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer:
return PPOTrainer(
name=name,
reward_discount=0.99,
replay_memory_capacity=4000,
batch_size=4000,
params=PPOParams(
get_v_critic_net_func=lambda: MyVCriticNet(state_dim),
grad_iters=80,
lam=0.97,
clip_ratio=0.2,
),
)
algorithm = "ppo"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [
get_ppo_policy(f"{algorithm}_{i}.policy", action_lower_bound, action_upper_bound, gym_state_dim, gym_action_dim)
for i in range(num_agents)
]
trainers = [get_ppo_trainer(f"{algorithm}_{i}", gym_state_dim) for i in range(num_agents)]
device_mapping = None
if torch.cuda.is_available():
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)}
rl_component_bundle = RLComponentBundle(
env_sampler=GymEnvSampler(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
max_episode_length=1000,
),
agent2policy=agent2policy,
policies=policies,
trainers=trainers,
device_mapping=device_mapping,
)
__all__ = ["rl_component_bundle"]

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for GYM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
job: gym_rl_workflow
scenario_path: "tests/rl/tasks/ppo"
log_path: "tests/rl/log/ppo_walker2d"
main:
num_episodes: 1000
num_steps: 4000
eval_schedule: 5
num_eval_episodes: 10
min_n_sample: 1
logging:
stdout: INFO
file: DEBUG
rollout:
logging:
stdout: INFO
file: DEBUG
training:
mode: simple
load_path: null
load_episode: null
checkpointing:
path: null
interval: 5
logging:
stdout: INFO
file: DEBUG

Просмотреть файл

@ -0,0 +1,168 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from gym import spaces
from torch.distributions import Normal
from torch.optim import Adam
from maro.rl.model import ContinuousSACNet, QNet
from maro.rl.model.fc_block import FullyConnected
from maro.rl.policy import ContinuousRLPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.training.algorithms import SoftActorCriticParams, SoftActorCriticTrainer
from maro.rl.utils import ndarray_to_tensor
from tests.rl.gym_wrapper.common import (
action_limit,
action_lower_bound,
action_upper_bound,
gym_action_dim,
gym_action_space,
gym_state_dim,
learn_env,
num_agents,
test_env,
)
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
actor_net_conf = {
"hidden_dims": [256, 256],
"activation": torch.nn.ReLU,
}
critic_net_conf = {
"hidden_dims": [256, 256],
"activation": torch.nn.ReLU,
}
actor_learning_rate = 1e-3
critic_learning_rate = 1e-3
LOG_STD_MAX = 2
LOG_STD_MIN = -20
class MyContinuousSACNet(ContinuousSACNet):
def __init__(self, state_dim: int, action_dim: int, action_limit: float, action_space: spaces.Space) -> None:
super(MyContinuousSACNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
self._net = FullyConnected(
input_dim=state_dim,
output_dim=actor_net_conf["hidden_dims"][-1],
hidden_dims=actor_net_conf["hidden_dims"][:-1],
activation=actor_net_conf["activation"],
output_activation=actor_net_conf["activation"],
)
self._mu = torch.nn.Linear(actor_net_conf["hidden_dims"][-1], action_dim)
self._log_std = torch.nn.Linear(actor_net_conf["hidden_dims"][-1], action_dim)
self._action_limit = action_limit
self._optim = Adam(self.parameters(), lr=actor_learning_rate)
self._action_space = action_space
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
net_out = self._net(states.float())
mu = self._mu(net_out)
log_std = torch.clamp(self._log_std(net_out), LOG_STD_MIN, LOG_STD_MAX)
std = torch.exp(log_std)
pi_distribution = Normal(mu, std)
pi_action = pi_distribution.rsample() if exploring else mu
logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(axis=1)
pi_action = torch.tanh(pi_action) * self._action_limit
return pi_action, logp_pi
def _get_random_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
return torch.stack(
[ndarray_to_tensor(self._action_space.sample(), device=self._device) for _ in range(states.shape[0])],
)
class MyQCriticNet(QNet):
def __init__(self, state_dim: int, action_dim: int) -> None:
super(MyQCriticNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
self._critic = FullyConnected(
input_dim=state_dim + action_dim,
output_dim=1,
hidden_dims=critic_net_conf["hidden_dims"],
activation=critic_net_conf["activation"],
)
self._optim = Adam(self._critic.parameters(), lr=critic_learning_rate)
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
return self._critic(torch.cat([states, actions], dim=1).float()).squeeze(-1)
def get_sac_policy(
name: str,
action_lower_bound: list,
action_upper_bound: list,
gym_state_dim: int,
gym_action_dim: int,
action_limit: float,
) -> ContinuousRLPolicy:
return ContinuousRLPolicy(
name=name,
action_range=(action_lower_bound, action_upper_bound),
policy_net=MyContinuousSACNet(gym_state_dim, gym_action_dim, action_limit, action_space=gym_action_space),
warmup=10000,
)
def get_sac_trainer(name: str, state_dim: int, action_dim: int) -> SoftActorCriticTrainer:
return SoftActorCriticTrainer(
name=name,
reward_discount=0.99,
replay_memory_capacity=1000000,
batch_size=100,
params=SoftActorCriticParams(
get_q_critic_net_func=lambda: MyQCriticNet(state_dim, action_dim),
update_target_every=1,
entropy_coef=0.2,
num_epochs=50,
n_start_train=1000,
soft_update_coef=0.005,
),
)
algorithm = "sac"
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [
get_sac_policy(
f"{algorithm}_{i}.policy",
action_lower_bound,
action_upper_bound,
gym_state_dim,
gym_action_dim,
action_limit,
)
for i in range(num_agents)
]
trainers = [get_sac_trainer(f"{algorithm}_{i}", gym_state_dim, gym_action_dim) for i in range(num_agents)]
device_mapping = None
if torch.cuda.is_available():
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)}
rl_component_bundle = RLComponentBundle(
env_sampler=GymEnvSampler(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
),
agent2policy=agent2policy,
policies=policies,
trainers=trainers,
device_mapping=device_mapping,
)
__all__ = ["rl_component_bundle"]

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for GYM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
job: gym_rl_workflow
scenario_path: "tests/rl/tasks/sac"
log_path: "tests/rl/log/sac_walker2d"
main:
num_episodes: 80000
num_steps: 50
eval_schedule: 200
num_eval_episodes: 10
min_n_sample: 1
logging:
stdout: INFO
file: DEBUG
rollout:
logging:
stdout: INFO
file: DEBUG
training:
mode: simple
load_path: null
load_episode: null
checkpointing:
path: null
interval: 200
logging:
stdout: INFO
file: DEBUG

Просмотреть файл

@ -311,7 +311,7 @@ class TestFrame(unittest.TestCase):
self.assertListEqual([0.0, 0.0, 0.0, 0.0, 9.0], list(states)[0:5])
# 2 padding (NAN) in the end
self.assertTrue((states[-2:].astype(np.int) == 0).all())
self.assertTrue((states[-2:].astype(int) == 0).all())
states = static_snapshot[1::"a3"]