V0.3: Upgrade RL Workflow; Add RL Benchmarks; Update Package Version (#588)
* call policy update only for AbsCorePolicy * add limitation of AbsCorePolicy in Actor.collect() * refined actor to return only experiences for policies that received new experiences * fix MsgKey issue in rollout_manager * fix typo in learner * call exit function for parallel rollout manager * update supply chain example distributed training scripts * 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype * reformat render * fix supply chain business engine action type problem * reset supply chain example render figsize from 4 to 3 * Add render to all modes of supply chain example * fix or policy typos * 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes * refined parallel policy manager * updated rl/__init__/py * fixed lint issues and CIM local learner bugs * deleted unwanted supply_chain test files * revised default config for cim-dqn * removed test_store.py as it is no longer needed * 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms * updated figures * removed unwanted import * refactored CIM-DQN example * added MultiProcessRolloutManager and MultiProcessTrainingManager * updated doc * lint issue fix * lint issue fix * fixed import formatting * [Feature] Prioritized Experience Replay (#355) * added prioritized experience replay * deleted unwanted supply_chain test files * fixed import order * import fix * fixed lint issues * fixed import formatting * added note in docstring that rank-based PER has yet to be implemented Co-authored-by: ysqyang <v-yangqi@microsoft.com> * rm AbsDecisionGenerator * small fixes * bug fix * reorganized training folder structure * fixed lint issues * fixed lint issues * policy manager refined * lint fix * restructured CIM-dqn sync code * added policy version index and used it as a measure of experience staleness * lint issue fix * lint issue fix * switched log_dir and proxy_kwargs order * cim example refinement * eval schedule sorted only when it's a list * eval schedule sorted only when it's a list * update sc env wrapper * added docker scripts for cim-dqn * refactored example folder structure and added workflow templates * fixed lint issues * fixed lint issues * fixed template bugs * removed unused imports * refactoring sc in progress * simplified cim meta * fixed build.sh path bug * template refinement * deleted obsolete svgs * updated learner logs * minor edits * refactored templates for easy merge with async PR * added component names for rollout manager and policy manager * fixed incorrect position to add last episode to eval schedule * added max_lag option in templates * formatting edit in docker_compose_yml script * moved local learner and early stopper outside sync_tools * refactored rl toolkit folder structure * refactored rl toolkit folder structure * moved env_wrapper and agent_wrapper inside rl/learner * refined scripts * fixed typo in script * changes needed for running sc * removed unwanted imports * config change for testing sc scenario * changes for perf testing * Asynchronous Training (#364) * remote inference code draft * changed actor to rollout_worker and updated init files * removed unwanted import * updated inits * more async code * added async scripts * added async training code & scripts for CIM-dqn * changed async to async_tools to avoid conflict with python keyword * reverted unwanted change to dockerfile * added doc for policy server * addressed PR comments and fixed a bug in docker_compose_yml.py * fixed lint issue * resolved PR comment * resolved merge conflicts * added async templates * added proxy.close() for actor and policy_server * fixed incorrect position to add last episode to eval schedule * reverted unwanted changes * added missing async files * rm unwanted echo in kill.sh Co-authored-by: ysqyang <v-yangqi@microsoft.com> * renamed sync to synchronous and async to asynchronous to avoid conflict with keyword * added missing policy version increment in LocalPolicyManager * refined rollout manager recv logic * removed a debugging print * added sleep in distributed launcher to avoid hanging * updated api doc and rl toolkit doc * refined dynamic imports using importlib * 1. moved policy update triggers to policy manager; 2. added version control in policy manager * fixed a few bugs and updated cim RL example * fixed a few more bugs * added agent wrapper instantiation to workflows * added agent wrapper instantiation to workflows * removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet * fixed incorrect get_ac_policy signature for CIM * moved exploration inside core policy * added state to exploration call to support context-dependent exploration * separated non_rl_policy_index and rl_policy_index in workflows * modified sc example code according to workflow changes * modified sc example code according to workflow changes * added replay_agent_ids parameter to get_env_func for RL examples * fixed a few bugs * added maro/simulator/scenarios/supply_chain as bind mount * added post-step, post-collect, post-eval and post-update callbacks * fixed lint issues * fixed lint issues * moved instantiation of policy manager inside simple learner * fixed env_wrapper get_reward signature * minor edits * removed get_eperience kwargs from env_wrapper * 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows * added rollout exp disribution option in RL examples * removed unwanted files * 1. made logger internal in learner; 2 removed logger creation in abs classes * checked out supply chain test files from v0.2_sc * 1. added missing model.eval() to choose_action; 2.added entropy features to AC * fixed a bug in ac entropy * abbreviated coefficient to coeff * removed -dqn from job name in rl example config * added tmp patch to dev.df * renamed image name for running rl examples * added get_loss interface for core policies * added policy manager in rl_toolkit.rst * 1. env_wrapper bug fix; 2. policy manager update logic refinement * refactored policy and algorithms * policy interface redesigned * refined policy interfaces * fixed typo * fixed bugs in refactored policy interface * fixed some bugs * refactoring in progress * policy interface and policy manager redesigned * 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts * fixed bug in distributed policy manager * fixed lint issues * fixed lint issues * added scipy in setup * 1. trimmed rollout manager code; 2. added option to docker scripts * updated api doc for policy manager * 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script * 1. simplified rl example structure; 2. fixed lint issues * further rl toolkit code simplifications * more numpy-based optimization in RL toolkit * moved replay buffer inside policy * bug fixes * numpy optimization and associated refactoring * extracted shaping logic out of env_sampler * fixed bug in CIM shaping and lint issues * preliminary implemetation of parallel batch inference * fixed bug in ddpg transition recording * put get_state, get_env_actions, get_reward back in EnvSampler * simplified exploration and core model interfaces * bug fixes and doc update * added improve() interface for RLPolicy for single-thread support * fixed simple policy manager bug * updated doc, rst, notebook * updated notebook * fixed lint issues * fixed entropy bugs in ac.py * reverted to simple policy manager as default * 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit * fixed lint issues and updated rl toolkit images * removed obsolete images * added back agent2policy for general workflow use * V0.2 rl refinement dist (#377) * Support `slice` operation in ExperienceSet * Support naive distributed policy training by proxy * Dynamically allocate trainers according to number of experience * code check * code check * code check * Fix a bug in distributed trianing with no gradient * Code check * Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy * 1.call allocate_trainer() at first of update(); 2.refine according to code review * Code check * Refine code with new interface * Update docs of PolicyManger and ExperienceSet * Add images for rl_toolkit docs * Update diagram of PolicyManager * Refine with new interface * Extract allocation strategy into `allocation_strategy.py` * add `distributed_learn()` in policies for data-parallel training * Update doc of RL_toolkit * Add gradient workers for data-parallel * Refine code and update docs * Lint check * Refine by comments * Rename `trainer` to `worker` * Rename `distributed_learn` to `learn_with_data_parallel` * Refine allocator and remove redundant code in policy_manager * remove arugments in allocate_by_policy and so on * added checkpointing for simple and multi-process policy managers * 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager * added missing set_state and get_state for CIM policies * removed blank line * updated RL workflow README * Integrate `data_parallel` arguments into `worker_allocator` (#402) * 1. simplified workflow config; 2. added comments to CIM shaping * lint issue fix * 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading * 1. moved post_step callback inside env sampler; 2. updated README for rl workflows * refined READEME for CIM * VM scheduling with RL (#375) * added part of vm scheduling RL code * refined vm env_wrapper code style * added DQN * added get_experiences func for ac in vm scheduling * added post_step callback to env wrapper * moved Aiming's tracking and plotting logic into callbacks * added eval env wrapper * renamed AC config variable name for VM * vm scheduling RL code finished * updated README * fixed various bugs and hard coding for vm_scheduling * uncommented callbacks for VM scheduling * Minor revision for better code style * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * vm example refactoring * fixed bugs in vm_scheduling * removed unwanted files from cim dir * reverted to simple policy manager as default * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * resolved rebase conflicts * fixed bugs in vm_scheduling * added get_state and set_state to vm_scheduling policy models * updated README for vm_scheduling with RL Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * SC refinement (#397) * Refine test scripts & pending_order_daily logic * Refactor code for better code style: complete type hint, correct typos, remove unused items. Refactor code for better code style: complete type hint, correct typos, remove unused items. * Polish test_supply_chain.py * update import format * Modify vehicle steps logic & remove outdated test case * Optimize imports * Optimize imports * Lint error * Lint error * Lint error * Add SupplyChainAction * Lint error Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * refined workflow scripts * fixed bug in ParallelAgentWrapper * 1. fixed lint issues; 2. refined main script in workflows * lint issue fix * restored default config for rl example * Update rollout.py * refined env var processing in policy manager workflow * added hasattr check in agent wrapper * updated docker_compose_yml.py * Minor refinement * Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412) * Prepare to merge master_mirror * Lint error * Minor * Merge latest master into v0.3 (#426) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * Minor * Remove docs/source/examples/multi_agent_dqn_cim.rst * Update .gitignore * Update .gitignore Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> * Change `Env.set_seed()` logic (#456) * Change Env.set_seed() logic * Redesign CIM reset logic; fix lint issues; * Lint * Seed type assertion * Remove all SC related files (#473) * RL Toolkit V3 (#471) * added daemon=True for multi-process rollout, policy manager and inference * removed obsolete files * [REDO][PR#406]V0.2 rl refinement taskq (#408) * Add a usable task_queue * Rename some variables * 1. Add ; 2. Integrate related files; 3. Remove * merge `data_parallel` and `num_grad_workers` into `data_parallelism` * Fix bugs in docker_compose_yml.py and Simple/Multi-process mode. * Move `grad_worker` into marl/rl/workflows * 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue. * Refine code and update docs of `TaskQueue` * Support priority for tasks in `task_queue` * Update diagram of policy manager and task queue. * Add configurable `single_task_limit` and correct docstring about `data_parallelism` * Fix lint errors in `supply chain` * RL policy redesign (V2) (#405) * Drafi v2.0 for V2 * Polish models with more comments * Polish policies with more comments * Lint * Lint * Add developer doc for models. * Add developer doc for policies. * Remove policy manager V2 since it is not used and out-of-date * Lint * Lint * refined messy workflow code * merged 'scenario_dir' and 'scenario' in rl config * 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods * 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2 * merged cim and cim_v2 * lint issue fix * refined logging logic * lint issue fix * reversed unwanted changes * . . . . ReplayMemory & IndexScheduler ReplayMemory & IndexScheduler . MultiReplayMemory get_actions_with_logps EnvSampler on the road EnvSampler Minor * LearnerManager * Use batch to transfer data & add SHAPE_CHECK_FLAG * Rename learner to trainer * Add property for policy._is_exploring * CIM test scenario for V3. Manual test passed. Next step: run it, make it works. * env_sampler.py could run * env_sampler refine on the way * First runnable version done * AC could run, but the result is bad. Need to check the logic * Refine abstract method & shape check error info. * Docs * Very detailed compare. Try again. * AC done * DQN check done * Minor * DDPG, not tested * Minors * A rough draft of MAAC * Cannot use CIM as the multi-agent scenario. * Minor * MAAC refinement on the way * Remove ActionWithAux * Refine batch & memory * MAAC example works * Reproduce-able fix. Policy share between env_sampler and trainer_manager. * Detail refinement * Simplify the user configed workflow * Minor * Refine example codes * Minor polishment * Migrate rollout_manager to V3 * Error on the way * Redesign torch.device management * Rl v3 maddpg (#418) * Add MADDPG trainer * Fit independent critics and shared critic modes. * Add a new property: num_policies * Lint * Fix a bug in `sum(rewards)` * Rename `MADDPG` to `DiscreteMADDPG` and fix type hint. * Rename maddpg in examples. * Preparation for data parallel (#420) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * Rename train worker to train ops; add placeholder for abstract methods; * Lint Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> * [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * dsitributed training pipeline draft * added temporary test files for review purposes * Several code style refinements (#451) * Polish rl_v3/utils/ * Polish rl_v3/distributed/ * Polish rl_v3/policy_trainer/abs_trainer.py * fixed merge conflicts * unified sync and async interfaces * refactored rl_v3; refinement in progress * Finish the runnable pipeline under new design * Remove outdated files; refine class names; optimize imports; * Lint * Minor maddpg related refinement * Lint Co-authored-by: Default <huo53926@126.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Miner bug fix * Coroutine-related bug fix ("get_policy_state") (#452) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * deleted unwanted folder * removed unwanted changes * resolved PR452 comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Quick fix * Redesign experience recording logic (#453) * Two not important fix * Temp draft. Prepare to WFH * Done * Lint * Lint * Calculating advantages / returns (#454) * V1.0 * Complete DDPG * Rl v3 hanging issue fix (#455) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * Final test & format. Ready to merge. Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * Rl v3 parallel rollout (#457) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * load balancing dispatcher * added parallel rollout * lint * Tracker variable type issue; rename to env_sampler_creator; * Rl v3 parallel rollout follow ups (#458) * AbsWorker & AbsDispatcher * Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds. * Fix policy_creator reuse bug * Format code * Merge AbsTrainerManager & SimpleTrainerManager * AC test passed * Lint * Remove AbsTrainer.build() method. Put all initialization operations into __init__ * Redesign AC preprocess batches logic Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * MADDPG performance bug fix (#459) * Fix MARL (MADDPG) terminal recording bug; some other minor refinements; * Restore Trainer.build() method * Calculate latest action in the get_actor_grad method in MADDPG. * Share critic bug fix * Rl v3 example update (#461) * updated vm_scheduling example and cim notebook * fixed bugs in vm_scheduling * added local train method * bug fix * modified async client logic to fix hidden issue * reverted to default config * fixed PR comments and some bugs * removed hardcode Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Done (#462) * Rl v3 load save (#463) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL Toolkit data parallelism revamp & config utils (#464) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments * 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local * 1. fixed rollout exit issue; 2. refined config * removed config file from example * fixed lint issues * fixed lint issues * added main.py under examples/rl * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL doc string (#465) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * Rl config doc (#467) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * added detailed doc * lint * wording refined * resolved some PR comments * resolved more PR comments * typo fix Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL online doc (#469) * Model, policy, trainer * RL workflows and env sampler doc in RST (#468) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * 1. type-sensitive env variable getter; 2. updated READMEs for examples * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * added detailed doc * lint * wording refined * resolved some PR comments * rewriting rl toolkit rst * resolved more PR comments * typo fix * updated rst Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: Default <huo53926@126.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Finish docs/source/key_components/rl_toolkit.rst * API doc * RL online doc image fix (#470) * resolved some PR comments * fix * fixed PR comments * added numfig=True setting in conf.py for sphinx Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Resolve PR comments * Add example github link Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Rl v3 pr comment resolution (#474) * added load/save feature * 1. resolved pr comments; 2. reverted maro/cli/k8s * fixed some bugs Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * RL renaming v2 (#476) * Change all Logger in RL to LoggerV2 * TrainerManager => TrainingManager * Add Trainer suffix to all algorithms * Finish docs * Update interface names * Minor fix * Cherry pick latest RL (#498) * Cherry pick * Remove SC related files * Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509) * Cherry pick RL changes from sc_refinement (2a4869) * Limit time display precision * RL incremental refactor (#501) * Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training. AC & PPO for continuous action policy; refine AC & PPO logic. Cherry pick RL changes from GYM-DDPG Cherry pick RL changes from GYM-SAC Minor error in doc string * Add min_n_sample in template and parser * Resolve PR comments. Fix a minor issue in SAC. * RL component bundle (#513) * CIM passed * Update workers * Refine annotations * VM passed * Code formatting. * Minor import loop issue * Pass batch in PPO again * Remove Scenario * Complete docs * Minor * Remove segment * Optimize logic in RLComponentBundle * Resolve PR comments * Move 'post methods from RLComponenetBundle to EnvSampler * Add method to get mapping of available tick to frame index (#415) * add method to get mapping of available tick to frame index * fix lint issue * fix naming issue * Cherry pick from sc_refinement (#527) * Cherry pick from sc_refinement * Cherry pick from sc_refinement * Refine `terminal` / `next_agent_state` logic (#531) * Optimize RL toolkit * Fix bug in terminal/next_state generation * Rewrite terminal/next_state logic again * Minor renaming * Minor bug fix * Resolve PR comments * Merge master into v0.3 (#536) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * add branch v0.3 to github workflow * update github test workflow * Update requirements.dev.txt (#444) Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact. * Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460) Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3) --- updated-dependencies: - dependency-name: ipython dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add & sort requirements.dev.txt Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Merge master into v0.3 (#545) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * add branch v0.3 to github workflow * update github test workflow * Update requirements.dev.txt (#444) Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact. * Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460) Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3) --- updated-dependencies: - dependency-name: ipython dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * update github woorkflow config * MARO v0.3: a new design of RL Toolkit, CLI refactorization, and corresponding updates. (#539) * refined proxy coding style * updated images and refined doc * updated images * updated CIM-AC example * refined proxy retry logic * call policy update only for AbsCorePolicy * add limitation of AbsCorePolicy in Actor.collect() * refined actor to return only experiences for policies that received new experiences * fix MsgKey issue in rollout_manager * fix typo in learner * call exit function for parallel rollout manager * update supply chain example distributed training scripts * 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype * reformat render * fix supply chain business engine action type problem * reset supply chain example render figsize from 4 to 3 * Add render to all modes of supply chain example * fix or policy typos * 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes * refined parallel policy manager * updated rl/__init__/py * fixed lint issues and CIM local learner bugs * deleted unwanted supply_chain test files * revised default config for cim-dqn * removed test_store.py as it is no longer needed * 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms * updated figures * removed unwanted import * refactored CIM-DQN example * added MultiProcessRolloutManager and MultiProcessTrainingManager * updated doc * lint issue fix * lint issue fix * fixed import formatting * [Feature] Prioritized Experience Replay (#355) * added prioritized experience replay * deleted unwanted supply_chain test files * fixed import order * import fix * fixed lint issues * fixed import formatting * added note in docstring that rank-based PER has yet to be implemented Co-authored-by: ysqyang <v-yangqi@microsoft.com> * rm AbsDecisionGenerator * small fixes * bug fix * reorganized training folder structure * fixed lint issues * fixed lint issues * policy manager refined * lint fix * restructured CIM-dqn sync code * added policy version index and used it as a measure of experience staleness * lint issue fix * lint issue fix * switched log_dir and proxy_kwargs order * cim example refinement * eval schedule sorted only when it's a list * eval schedule sorted only when it's a list * update sc env wrapper * added docker scripts for cim-dqn * refactored example folder structure and added workflow templates * fixed lint issues * fixed lint issues * fixed template bugs * removed unused imports * refactoring sc in progress * simplified cim meta * fixed build.sh path bug * template refinement * deleted obsolete svgs * updated learner logs * minor edits * refactored templates for easy merge with async PR * added component names for rollout manager and policy manager * fixed incorrect position to add last episode to eval schedule * added max_lag option in templates * formatting edit in docker_compose_yml script * moved local learner and early stopper outside sync_tools * refactored rl toolkit folder structure * refactored rl toolkit folder structure * moved env_wrapper and agent_wrapper inside rl/learner * refined scripts * fixed typo in script * changes needed for running sc * removed unwanted imports * config change for testing sc scenario * changes for perf testing * Asynchronous Training (#364) * remote inference code draft * changed actor to rollout_worker and updated init files * removed unwanted import * updated inits * more async code * added async scripts * added async training code & scripts for CIM-dqn * changed async to async_tools to avoid conflict with python keyword * reverted unwanted change to dockerfile * added doc for policy server * addressed PR comments and fixed a bug in docker_compose_yml.py * fixed lint issue * resolved PR comment * resolved merge conflicts * added async templates * added proxy.close() for actor and policy_server * fixed incorrect position to add last episode to eval schedule * reverted unwanted changes * added missing async files * rm unwanted echo in kill.sh Co-authored-by: ysqyang <v-yangqi@microsoft.com> * renamed sync to synchronous and async to asynchronous to avoid conflict with keyword * added missing policy version increment in LocalPolicyManager * refined rollout manager recv logic * removed a debugging print * added sleep in distributed launcher to avoid hanging * updated api doc and rl toolkit doc * refined dynamic imports using importlib * 1. moved policy update triggers to policy manager; 2. added version control in policy manager * fixed a few bugs and updated cim RL example * fixed a few more bugs * added agent wrapper instantiation to workflows * added agent wrapper instantiation to workflows * removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet * fixed incorrect get_ac_policy signature for CIM * moved exploration inside core policy * added state to exploration call to support context-dependent exploration * separated non_rl_policy_index and rl_policy_index in workflows * modified sc example code according to workflow changes * modified sc example code according to workflow changes * added replay_agent_ids parameter to get_env_func for RL examples * fixed a few bugs * added maro/simulator/scenarios/supply_chain as bind mount * added post-step, post-collect, post-eval and post-update callbacks * fixed lint issues * fixed lint issues * moved instantiation of policy manager inside simple learner * fixed env_wrapper get_reward signature * minor edits * removed get_eperience kwargs from env_wrapper * 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows * added rollout exp disribution option in RL examples * removed unwanted files * 1. made logger internal in learner; 2 removed logger creation in abs classes * checked out supply chain test files from v0.2_sc * 1. added missing model.eval() to choose_action; 2.added entropy features to AC * fixed a bug in ac entropy * abbreviated coefficient to coeff * removed -dqn from job name in rl example config * added tmp patch to dev.df * renamed image name for running rl examples * added get_loss interface for core policies * added policy manager in rl_toolkit.rst * 1. env_wrapper bug fix; 2. policy manager update logic refinement * refactored policy and algorithms * policy interface redesigned * refined policy interfaces * fixed typo * fixed bugs in refactored policy interface * fixed some bugs * refactoring in progress * policy interface and policy manager redesigned * 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts * fixed bug in distributed policy manager * fixed lint issues * fixed lint issues * added scipy in setup * 1. trimmed rollout manager code; 2. added option to docker scripts * updated api doc for policy manager * 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script * 1. simplified rl example structure; 2. fixed lint issues * further rl toolkit code simplifications * more numpy-based optimization in RL toolkit * moved replay buffer inside policy * bug fixes * numpy optimization and associated refactoring * extracted shaping logic out of env_sampler * fixed bug in CIM shaping and lint issues * preliminary implemetation of parallel batch inference * fixed bug in ddpg transition recording * put get_state, get_env_actions, get_reward back in EnvSampler * simplified exploration and core model interfaces * bug fixes and doc update * added improve() interface for RLPolicy for single-thread support * fixed simple policy manager bug * updated doc, rst, notebook * updated notebook * fixed lint issues * fixed entropy bugs in ac.py * reverted to simple policy manager as default * 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit * fixed lint issues and updated rl toolkit images * removed obsolete images * added back agent2policy for general workflow use * V0.2 rl refinement dist (#377) * Support `slice` operation in ExperienceSet * Support naive distributed policy training by proxy * Dynamically allocate trainers according to number of experience * code check * code check * code check * Fix a bug in distributed trianing with no gradient * Code check * Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy * 1.call allocate_trainer() at first of update(); 2.refine according to code review * Code check * Refine code with new interface * Update docs of PolicyManger and ExperienceSet * Add images for rl_toolkit docs * Update diagram of PolicyManager * Refine with new interface * Extract allocation strategy into `allocation_strategy.py` * add `distributed_learn()` in policies for data-parallel training * Update doc of RL_toolkit * Add gradient workers for data-parallel * Refine code and update docs * Lint check * Refine by comments * Rename `trainer` to `worker` * Rename `distributed_learn` to `learn_with_data_parallel` * Refine allocator and remove redundant code in policy_manager * remove arugments in allocate_by_policy and so on * added checkpointing for simple and multi-process policy managers * 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager * added missing set_state and get_state for CIM policies * removed blank line * updated RL workflow README * Integrate `data_parallel` arguments into `worker_allocator` (#402) * 1. simplified workflow config; 2. added comments to CIM shaping * lint issue fix * 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading * 1. moved post_step callback inside env sampler; 2. updated README for rl workflows * refined READEME for CIM * VM scheduling with RL (#375) * added part of vm scheduling RL code * refined vm env_wrapper code style * added DQN * added get_experiences func for ac in vm scheduling * added post_step callback to env wrapper * moved Aiming's tracking and plotting logic into callbacks * added eval env wrapper * renamed AC config variable name for VM * vm scheduling RL code finished * updated README * fixed various bugs and hard coding for vm_scheduling * uncommented callbacks for VM scheduling * Minor revision for better code style * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * vm example refactoring * fixed bugs in vm_scheduling * removed unwanted files from cim dir * reverted to simple policy manager as default * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * resolved rebase conflicts * fixed bugs in vm_scheduling * added get_state and set_state to vm_scheduling policy models * updated README for vm_scheduling with RL Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * SC refinement (#397) * Refine test scripts & pending_order_daily logic * Refactor code for better code style: complete type hint, correct typos, remove unused items. Refactor code for better code style: complete type hint, correct typos, remove unused items. * Polish test_supply_chain.py * update import format * Modify vehicle steps logic & remove outdated test case * Optimize imports * Optimize imports * Lint error * Lint error * Lint error * Add SupplyChainAction * Lint error Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * refined workflow scripts * fixed bug in ParallelAgentWrapper * 1. fixed lint issues; 2. refined main script in workflows * lint issue fix * restored default config for rl example * Update rollout.py * refined env var processing in policy manager workflow * added hasattr check in agent wrapper * updated docker_compose_yml.py * Minor refinement * Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412) * Prepare to merge master_mirror * Lint error * Minor * Merge latest master into v0.3 (#426) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * Minor * Remove docs/source/examples/multi_agent_dqn_cim.rst * Update .gitignore * Update .gitignore Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> * Change `Env.set_seed()` logic (#456) * Change Env.set_seed() logic * Redesign CIM reset logic; fix lint issues; * Lint * Seed type assertion * Remove all SC related files (#473) * RL Toolkit V3 (#471) * added daemon=True for multi-process rollout, policy manager and inference * removed obsolete files * [REDO][PR#406]V0.2 rl refinement taskq (#408) * Add a usable task_queue * Rename some variables * 1. Add ; 2. Integrate related files; 3. Remove * merge `data_parallel` and `num_grad_workers` into `data_parallelism` * Fix bugs in docker_compose_yml.py and Simple/Multi-process mode. * Move `grad_worker` into marl/rl/workflows * 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue. * Refine code and update docs of `TaskQueue` * Support priority for tasks in `task_queue` * Update diagram of policy manager and task queue. * Add configurable `single_task_limit` and correct docstring about `data_parallelism` * Fix lint errors in `supply chain` * RL policy redesign (V2) (#405) * Drafi v2.0 for V2 * Polish models with more comments * Polish policies with more comments * Lint * Lint * Add developer doc for models. * Add developer doc for policies. * Remove policy manager V2 since it is not used and out-of-date * Lint * Lint * refined messy workflow code * merged 'scenario_dir' and 'scenario' in rl config * 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods * 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2 * merged cim and cim_v2 * lint issue fix * refined logging logic * lint issue fix * reversed unwanted changes * . . . . ReplayMemory & IndexScheduler ReplayMemory & IndexScheduler . MultiReplayMemory get_actions_with_logps EnvSampler on the road EnvSampler Minor * LearnerManager * Use batch to transfer data & add SHAPE_CHECK_FLAG * Rename learner to trainer * Add property for policy._is_exploring * CIM test scenario for V3. Manual test passed. Next step: run it, make it works. * env_sampler.py could run * env_sampler refine on the way * First runnable version done * AC could run, but the result is bad. Need to check the logic * Refine abstract method & shape check error info. * Docs * Very detailed compare. Try again. * AC done * DQN check done * Minor * DDPG, not tested * Minors * A rough draft of MAAC * Cannot use CIM as the multi-agent scenario. * Minor * MAAC refinement on the way * Remove ActionWithAux * Refine batch & memory * MAAC example works * Reproduce-able fix. Policy share between env_sampler and trainer_manager. * Detail refinement * Simplify the user configed workflow * Minor * Refine example codes * Minor polishment * Migrate rollout_manager to V3 * Error on the way * Redesign torch.device management * Rl v3 maddpg (#418) * Add MADDPG trainer * Fit independent critics and shared critic modes. * Add a new property: num_policies * Lint * Fix a bug in `sum(rewards)` * Rename `MADDPG` to `DiscreteMADDPG` and fix type hint. * Rename maddpg in examples. * Preparation for data parallel (#420) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * Rename train worker to train ops; add placeholder for abstract methods; * Lint Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> * [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * dsitributed training pipeline draft * added temporary test files for review purposes * Several code style refinements (#451) * Polish rl_v3/utils/ * Polish rl_v3/distributed/ * Polish rl_v3/policy_trainer/abs_trainer.py * fixed merge conflicts * unified sync and async interfaces * refactored rl_v3; refinement in progress * Finish the runnable pipeline under new design * Remove outdated files; refine class names; optimize imports; * Lint * Minor maddpg related refinement * Lint Co-authored-by: Default <huo53926@126.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Miner bug fix * Coroutine-related bug fix ("get_policy_state") (#452) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * deleted unwanted folder * removed unwanted changes * resolved PR452 comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Quick fix * Redesign experience recording logic (#453) * Two not important fix * Temp draft. Prepare to WFH * Done * Lint * Lint * Calculating advantages / returns (#454) * V1.0 * Complete DDPG * Rl v3 hanging issue fix (#455) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * Final test & format. Ready to merge. Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * Rl v3 parallel rollout (#457) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * load balancing dispatcher * added parallel rollout * lint * Tracker variable type issue; rename to env_sampler_creator; * Rl v3 parallel rollout follow ups (#458) * AbsWorker & AbsDispatcher * Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds. * Fix policy_creator reuse bug * Format code * Merge AbsTrainerManager & SimpleTrainerManager * AC test passed * Lint * Remove AbsTrainer.build() method. Put all initialization operations into __init__ * Redesign AC preprocess batches logic Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * MADDPG performance bug fix (#459) * Fix MARL (MADDPG) terminal recording bug; some other minor refinements; * Restore Trainer.build() method * Calculate latest action in the get_actor_grad method in MADDPG. * Share critic bug fix * Rl v3 example update (#461) * updated vm_scheduling example and cim notebook * fixed bugs in vm_scheduling * added local train method * bug fix * modified async client logic to fix hidden issue * reverted to default config * fixed PR comments and some bugs * removed hardcode Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Done (#462) * Rl v3 load save (#463) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL Toolkit data parallelism revamp & config utils (#464) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments * 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local * 1. fixed rollout exit issue; 2. refined config * removed config file from example * fixed lint issues * fixed lint issues * added main.py under examples/rl * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL doc string (#465) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * Rl config doc (#467) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * added detailed doc * lint * wording refined * resolved some PR comments * resolved more PR comments * typo fix Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL online doc (#469) * Model, policy, trainer * RL workflows and env sampler doc in RST (#468) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * 1. type-sensitive env variable getter; 2. updated READMEs for examples * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * added detailed doc * lint * wording refined * resolved some PR comments * rewriting rl toolkit rst * resolved more PR comments * typo fix * updated rst Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: Default <huo53926@126.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Finish docs/source/key_components/rl_toolkit.rst * API doc * RL online doc image fix (#470) * resolved some PR comments * fix * fixed PR comments * added numfig=True setting in conf.py for sphinx Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Resolve PR comments * Add example github link Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Rl v3 pr comment resolution (#474) * added load/save feature * 1. resolved pr comments; 2. reverted maro/cli/k8s * fixed some bugs Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * RL renaming v2 (#476) * Change all Logger in RL to LoggerV2 * TrainerManager => TrainingManager * Add Trainer suffix to all algorithms * Finish docs * Update interface names * Minor fix * Cherry pick latest RL (#498) * Cherry pick * Remove SC related files * Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509) * Cherry pick RL changes from sc_refinement (2a4869) * Limit time display precision * RL incremental refactor (#501) * Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training. AC & PPO for continuous action policy; refine AC & PPO logic. Cherry pick RL changes from GYM-DDPG Cherry pick RL changes from GYM-SAC Minor error in doc string * Add min_n_sample in template and parser * Resolve PR comments. Fix a minor issue in SAC. * RL component bundle (#513) * CIM passed * Update workers * Refine annotations * VM passed * Code formatting. * Minor import loop issue * Pass batch in PPO again * Remove Scenario * Complete docs * Minor * Remove segment * Optimize logic in RLComponentBundle * Resolve PR comments * Move 'post methods from RLComponenetBundle to EnvSampler * Add method to get mapping of available tick to frame index (#415) * add method to get mapping of available tick to frame index * fix lint issue * fix naming issue * Cherry pick from sc_refinement (#527) * Cherry pick from sc_refinement * Cherry pick from sc_refinement * Refine `terminal` / `next_agent_state` logic (#531) * Optimize RL toolkit * Fix bug in terminal/next_state generation * Rewrite terminal/next_state logic again * Minor renaming * Minor bug fix * Resolve PR comments * Merge master into v0.3 (#536) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * add branch v0.3 to github workflow * update github test workflow * Update requirements.dev.txt (#444) Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact. * Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460) Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3) --- updated-dependencies: - dependency-name: ipython dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add & sort requirements.dev.txt Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Remove random_config.py * Remove test_trajectory_utils.py * Pass tests * Update rl docs * Remove python 3.6 in test * Update docs Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Wang.Jinyu <jinywan@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: GQ.Chen <675865907@qq.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Logger bug hotfix (#543) * Rename param * Rename param * Quick fix in env_data_process * frame data precision issue fix (#544) * fix frame precision issue * add .xmake to .gitignore * update frame precision lost warning message * add assert to frame precision checking * typo fix * add TODO for future Long data type issue fix * Minor cleaning Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jinyu Wang <jinyu@RL4Inv.l1ea1prscrcu1p4sa0eapum5vc.bx.internal.cloudapp.net> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: GQ.Chen <675865907@qq.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> * Update requirements. (#552) * Fix several encoding issues; update requirements. * Test & minor * Remove torch in requirements.build.txt * Polish * Update README * Resolve PR comments * Keep working * Keep working * Update test requirements * Done (#554) * Update requirements in example and notebook (#553) * Update requirements in example and notebook * Remove autopep8 * Add jupyterlab packages back Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * Refine decision event logic (#559) * Add DecisionEventPayload * Change decision payload name * Refine action logic * Add doc for env.step * Restore pre-commit config * Resolve PR comments * Refactor decision event & action * Pre-commit * Resolve PR comments * Refine rl component bundle (#549) * Config files * Done * Minor bugfix * Add autoflake * Update isort exclude; add pre-commit to requirements * Check only isort * Minor * Format * Test passed * Run pre-commit * Minor bugfix in rl_component_bundle * Pass mypy * Fix a bug in RL notebook * A minor bug fix * Add upper bound for numpy version in test * Remove numpy data type (#571) * Change numpy data type; change test requirements. * Lint * RL benchmark on GYM (#575) * PPO, SAC, DDPG passed * Explore in SAC * Test GYM on server * Sync server changes * pre-commit * Ready to try on server * . * . * . * . * . * Performance OK * Move to tests * Remove old versions * PPO done * Start to test AC * Start to test SAC * SAC test passed * update for some PR comments; Add a MARKDOWN file (#576) Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com> * Use FullyConnected to replace mlp * Update action bound * Pre-commit --------- Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com> * Refine RL workflow & tune RL models under GYM (#577) * PPO, SAC, DDPG passed * Explore in SAC * Test GYM on server * Sync server changes * pre-commit * Ready to try on server * . * . * . * . * . * Performance OK * Move to tests * Remove old versions * PPO done * Start to test AC * Start to test SAC * SAC test passed * Multiple round in evaluation * Modify config.yml * Add Callbacks * [wip] SAC performance not good * [wip] still not good * update for some PR comments; Add a MARKDOWN file (#576) Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com> * Use FullyConnected to replace mlp * Update action bound * ??? * Change gym env wrapper metrics logci * Change gym env wrapper metrics logci * refine env_sampler.sample under step mode * Add DDPG. Performance not good... * Add DDPG. Performance not good... * wip * Sounds like sac works * Refactor file structure * Refactor file structure * Refactor file structure * Pre-commit * Pre commit * Minor refinement of CIM RL * Jinyu/rl workflow refine (#578) * remove useless files; add device mapping; update pdoc * add default checkpoint path; fix distributed worker log path issue; update example log path * update performance doc * remove tests/rl/algorithms folder * Resolve PR comments * Compare PPO with spinning up (#579) * [wip] compare PPO * PPO matching * Revert unnecessary changes * Minor * Minor * SAC Test parameters update (#580) * fix sac to_device issue; update sac gym test parameters * add rl test performance plot func * update sac eval interval config * update sac checkpoint interval config * fix callback issue * update plot func * update plot func * update plot func * update performance doc; upload performance images * Minor fix in callbacks; refine plot.py format. * Add n_interactions. Use n_interactions to plot curves. * pre-commit --------- Co-authored-by: Huoran Li <huo53926@126.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * Episode truncation & early stopping (#581) * Add truncated logic * (To be tested) early stop * Early stop test passed * Test passed * Random action. To be tested. * Warmup OK * Pre-commit * random seed * Revert pre-commit config --------- Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com> * DDPG parameters update (#583) * Tune params * fix conflict * remove duplicated seed setting --------- Co-authored-by: Huoran Li <huo53926@126.com> * Update RL Benchmarks (#584) * update plot func for rl tests * Refine seed setting logic * Refine metrics logic; add warmup to ddpg. * Complete ddpg config * Minor refinement of GymEnvSampler and plot.py * update rl benchmark performance results * Lint --------- Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: Huoran Li <huo53926@126.com> * Update Input Template of RL Policy to Improve Module Flexisiblity (#589) * add customized_callbacks to RLComponentBundle * add env.tick to replace the default None in AbsEnvSampler._get_global_and_agent_state() * fix rl algorithms to_device issue * add kwargs to RL models' forward funcs and _shape_check() * add kwargs to RL policies' get_action related funcs and _post_check() * add detached loss to the return value of update_critic() and update_actor() of current TrainOps; add default False early_stop to update_actor() of current TrainOps * add kwargs to choose_actions of AbsEnvSampler; remain it None in current sample() and eval() * ufix line length issue * fix line break issue --------- Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com> * update code version to 0.3.2a1 --------- Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: GQ.Chen <675865907@qq.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: Huoran Li <huo53926@126.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jinyu Wang <jinyu@RL4Inv.l1ea1prscrcu1p4sa0eapum5vc.bx.internal.cloudapp.net>
|
@ -11,6 +11,7 @@ from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
|
|||
actor_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": torch.nn.Tanh,
|
||||
"output_activation": torch.nn.Tanh,
|
||||
"softmax": True,
|
||||
"batch_norm": False,
|
||||
"head": True,
|
||||
|
@ -19,6 +20,7 @@ critic_net_conf = {
|
|||
"hidden_dims": [256, 128, 64],
|
||||
"output_dim": 1,
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"output_activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"head": True,
|
||||
|
|
|
@ -12,6 +12,7 @@ from maro.rl.training.algorithms import DQNParams, DQNTrainer
|
|||
q_net_conf = {
|
||||
"hidden_dims": [256, 128, 64, 32],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"output_activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"skip_connection": False,
|
||||
|
|
|
@ -14,6 +14,7 @@ from maro.rl.training.algorithms import DiscreteMADDPGParams, DiscreteMADDPGTrai
|
|||
actor_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": torch.nn.Tanh,
|
||||
"output_activation": torch.nn.Tanh,
|
||||
"softmax": True,
|
||||
"batch_norm": False,
|
||||
"head": True,
|
||||
|
@ -22,6 +23,7 @@ critic_net_conf = {
|
|||
"hidden_dims": [256, 128, 64],
|
||||
"output_dim": 1,
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"output_activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"head": True,
|
||||
|
|
|
@ -90,11 +90,25 @@ class CIMEnvSampler(AbsEnvSampler):
|
|||
for info in info_list:
|
||||
print(f"env summary (episode {ep}): {info['env_metric']}")
|
||||
|
||||
# print the average env metric
|
||||
if len(info_list) > 1:
|
||||
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
|
||||
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
|
||||
print(f"average env summary (episode {ep}): {avg_metric}")
|
||||
# average env metric
|
||||
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
|
||||
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
|
||||
print(f"average env summary (episode {ep}): {avg_metric}")
|
||||
|
||||
self.metrics.update(avg_metric)
|
||||
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
|
||||
|
||||
def post_evaluate(self, info_list: list, ep: int) -> None:
|
||||
self.post_collect(info_list, ep)
|
||||
# print the env metric from each rollout worker
|
||||
for info in info_list:
|
||||
print(f"env summary (episode {ep}): {info['env_metric']}")
|
||||
|
||||
# average env metric
|
||||
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
|
||||
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
|
||||
print(f"average env summary (episode {ep}): {avg_metric}")
|
||||
|
||||
self.metrics.update({"val/" + k: v for k, v in avg_metric.items()})
|
||||
|
||||
def monitor_metrics(self) -> float:
|
||||
return -self.metrics["val/container_shortage"]
|
||||
|
|
|
@ -13,7 +13,7 @@ from examples.cim.rl.env_sampler import CIMEnvSampler
|
|||
|
||||
# Environments
|
||||
learn_env = Env(**env_conf)
|
||||
test_env = learn_env
|
||||
test_env = Env(**env_conf)
|
||||
|
||||
# Agent, policy, and trainers
|
||||
num_agents = len(learn_env.agent_idx_list)
|
||||
|
|
|
@ -7,7 +7,7 @@ This folder contains scenarios that employ reinforcement learning. MARO's RL too
|
|||
The entrance of a RL workflow is a YAML config file. For readers' convenience, we call this config file `config.yml` in the rest part of this doc. `config.yml` specifies the path of all necessary resources, definitions, and configurations to run the job. MARO provides a comprehensive template of the config file with detailed explanations (`maro/maro/rl/workflows/config/template.yml`). Meanwhile, MARO also provides several simple examples of `config.yml` under the current folder.
|
||||
|
||||
There are two ways to start the RL job:
|
||||
- If you only need to have a quick look and try to start an out-of-box workflow, just run `python .\examples\rl\run_rl_example.py PATH_TO_CONFIG_YAML`. For example, `python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml` will run the complete example RL training workflow of CIM scenario. If you only want to run the evaluation workflow, you could start the job with `--evaluate_only`.
|
||||
- If you only need to have a quick look and try to start an out-of-box workflow, just run `python .\examples\rl\run.py PATH_TO_CONFIG_YAML`. For example, `python .\examples\rl\run.py .\examples\rl\cim.yml` will run the complete example RL training workflow of CIM scenario. If you only want to run the evaluation workflow, you could start the job with `--evaluate_only`.
|
||||
- (**Require install MARO from source**) You could also start the job through MARO CLI. Use the command `maro local run [-c] path/to/your/config` to run in containerized (with `-c`) or non-containerized (without `-c`) environments. Similar, you could add `--evaluate_only` if you only need to run the evaluation workflow.
|
||||
|
||||
## Create Your Own Scenarios
|
||||
|
|
|
@ -5,16 +5,17 @@
|
|||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
# Run this workflow by executing one of the following commands:
|
||||
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
|
||||
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml
|
||||
# - python ./examples/rl/run.py ./examples/rl/cim.yml
|
||||
# - (Requires installing MARO from source) maro local run ./examples/rl/cim.yml
|
||||
|
||||
job: cim_rl_workflow
|
||||
scenario_path: "examples/cim/rl"
|
||||
log_path: "log/rl_job/cim.txt"
|
||||
log_path: "log/cim_rl/"
|
||||
main:
|
||||
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
|
||||
num_steps: null
|
||||
eval_schedule: 5
|
||||
early_stop_patience: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
|
@ -27,7 +28,7 @@ training:
|
|||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: "checkpoint/rl_job/cim"
|
||||
path: "log/cim_rl/checkpoints"
|
||||
interval: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for CIM scenario.
|
||||
# Example RL config file for CIM scenario (distributed version).
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
# Run this workflow by executing one of the following commands:
|
||||
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
|
||||
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml
|
||||
# - python ./examples/rl/run.py ./examples/rl/cim_distributed.yml
|
||||
# - (Requires installing MARO from source) maro local run ./examples/rl/cim_distributed.yml
|
||||
|
||||
job: cim_rl_workflow
|
||||
scenario_path: "examples/cim/rl"
|
||||
log_path: "log/rl_job/cim.txt"
|
||||
log_path: "log/cim_rl/"
|
||||
main:
|
||||
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
|
||||
num_steps: null
|
||||
|
@ -35,7 +35,7 @@ training:
|
|||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: "checkpoint/rl_job/cim"
|
||||
path: "log/cim_rl/checkpoints"
|
||||
interval: 5
|
||||
proxy:
|
||||
host: "127.0.0.1"
|
||||
|
|
|
@ -5,12 +5,12 @@
|
|||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
# Run this workflow by executing one of the following commands:
|
||||
# - python .\examples\rl\run_rl_example.py .\examples\rl\vm_scheduling.yml
|
||||
# - (Requires installing MARO from source) maro local run .\examples\rl\vm_scheduling.yml
|
||||
# - python ./examples/rl/run.py ./examples/rl/vm_scheduling.yml
|
||||
# - (Requires installing MARO from source) maro local run ./examples/rl/vm_scheduling.yml
|
||||
|
||||
job: vm_scheduling_rl_workflow
|
||||
scenario_path: "examples/vm_scheduling/rl"
|
||||
log_path: "log/rl_job/vm_scheduling.txt"
|
||||
log_path: "log/vm_rl/"
|
||||
main:
|
||||
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
|
||||
num_steps: null
|
||||
|
@ -27,7 +27,7 @@ training:
|
|||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: "checkpoint/rl_job/vm_scheduling"
|
||||
path: "log/vm_rl/checkpoints"
|
||||
interval: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
|
|
|
@ -11,6 +11,7 @@ from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
|
|||
actor_net_conf = {
|
||||
"hidden_dims": [64, 32, 32],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"output_activation": torch.nn.LeakyReLU,
|
||||
"softmax": True,
|
||||
"batch_norm": False,
|
||||
"head": True,
|
||||
|
@ -19,6 +20,7 @@ actor_net_conf = {
|
|||
critic_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"output_activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": False,
|
||||
"head": True,
|
||||
|
|
|
@ -14,6 +14,7 @@ from maro.rl.training.algorithms import DQNParams, DQNTrainer
|
|||
q_net_conf = {
|
||||
"hidden_dims": [64, 128, 256],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"output_activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": False,
|
||||
"skip_connection": False,
|
||||
|
|
|
@ -2,6 +2,6 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
__version__ = "0.3.1a2"
|
||||
__version__ = "0.3.2a1"
|
||||
|
||||
__data_version__ = "0.2"
|
||||
|
|
|
@ -8,7 +8,6 @@ import zipfile
|
|||
from enum import Enum
|
||||
|
||||
import geopy.distance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from yaml import safe_load
|
||||
|
||||
|
@ -320,7 +319,7 @@ class CitiBikePipeline(DataPipeline):
|
|||
0,
|
||||
index=station_info["station_index"],
|
||||
columns=station_info["station_index"],
|
||||
dtype=np.float,
|
||||
dtype=float,
|
||||
)
|
||||
look_up_df = station_info[["latitude", "longitude"]]
|
||||
return distance_adj.apply(
|
||||
|
@ -617,7 +616,7 @@ class CitiBikeToyPipeline(DataPipeline):
|
|||
0,
|
||||
index=station_init["station_index"],
|
||||
columns=station_init["station_index"],
|
||||
dtype=np.float,
|
||||
dtype=float,
|
||||
)
|
||||
look_up_df = station_init[["latitude", "longitude"]]
|
||||
distance_df = distance_adj.apply(
|
||||
|
|
|
@ -61,7 +61,7 @@ def get_redis_conn(port=None):
|
|||
|
||||
|
||||
# Functions executed on CLI commands
|
||||
def run(conf_path: str, containerize: bool = False, evaluate_only: bool = False, **kwargs):
|
||||
def run(conf_path: str, containerize: bool = False, seed: int = None, evaluate_only: bool = False, **kwargs):
|
||||
# Load job configuration file
|
||||
parser = ConfigParser(conf_path)
|
||||
if containerize:
|
||||
|
@ -71,13 +71,14 @@ def run(conf_path: str, containerize: bool = False, evaluate_only: bool = False,
|
|||
LOCAL_MARO_ROOT,
|
||||
DOCKERFILE_PATH,
|
||||
DOCKER_IMAGE_NAME,
|
||||
seed=seed,
|
||||
evaluate_only=evaluate_only,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
stop_rl_job_with_docker_compose(parser.config["job"], LOCAL_MARO_ROOT)
|
||||
else:
|
||||
try:
|
||||
start_rl_job(parser, LOCAL_MARO_ROOT, evaluate_only=evaluate_only)
|
||||
start_rl_job(parser, LOCAL_MARO_ROOT, seed=seed, evaluate_only=evaluate_only)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
import os
|
||||
import subprocess
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import docker
|
||||
import yaml
|
||||
|
@ -110,12 +110,15 @@ def exec(cmd: str, env: dict, debug: bool = False) -> subprocess.Popen:
|
|||
def start_rl_job(
|
||||
parser: ConfigParser,
|
||||
maro_root: str,
|
||||
seed: Optional[int],
|
||||
evaluate_only: bool,
|
||||
background: bool = False,
|
||||
) -> List[subprocess.Popen]:
|
||||
procs = [
|
||||
exec(
|
||||
f"python {script}" + ("" if not evaluate_only else " --evaluate_only"),
|
||||
f"python {script}"
|
||||
+ ("" if not evaluate_only else " --evaluate_only")
|
||||
+ ("" if seed is None else f" --seed {seed}"),
|
||||
format_env_vars({**env, "PYTHONPATH": maro_root}, mode="proc"),
|
||||
debug=not background,
|
||||
)
|
||||
|
@ -169,6 +172,7 @@ def start_rl_job_with_docker_compose(
|
|||
context: str,
|
||||
dockerfile_path: str,
|
||||
image_name: str,
|
||||
seed: Optional[int],
|
||||
evaluate_only: bool,
|
||||
) -> None:
|
||||
common_spec = {
|
||||
|
@ -185,7 +189,9 @@ def start_rl_job_with_docker_compose(
|
|||
**deepcopy(common_spec),
|
||||
**{
|
||||
"container_name": component,
|
||||
"command": f"python3 {script}" + ("" if not evaluate_only else " --evaluate_only"),
|
||||
"command": f"python3 {script}"
|
||||
+ ("" if not evaluate_only else " --evaluate_only")
|
||||
+ ("" if seed is None else f" --seed {seed}"),
|
||||
"environment": format_env_vars(env, mode="docker-compose"),
|
||||
},
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch.nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -18,6 +18,8 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
def __init__(self) -> None:
|
||||
super(AbsNet, self).__init__()
|
||||
|
||||
self._device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def optim(self) -> Optimizer:
|
||||
optim = getattr(self, "_optim", None)
|
||||
|
@ -119,3 +121,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
"""Unfreeze all parameters."""
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
def to_device(self, device: torch.device) -> None:
|
||||
self._device = device
|
||||
self.to(device)
|
||||
|
|
|
@ -43,14 +43,23 @@ class ContinuousACBasedNet(ContinuousPolicyNet, metaclass=ABCMeta):
|
|||
- set_state(self, net_state: dict) -> None:
|
||||
"""
|
||||
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
|
||||
actions, _ = self._get_actions_with_logps_impl(states, exploring)
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
|
||||
actions, _ = self._get_actions_with_logps_impl(states, exploring, **kwargs)
|
||||
return actions
|
||||
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_actions_with_probs_impl(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Not used in Actor-Critic or PPO
|
||||
pass
|
||||
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Not used in Actor-Critic or PPO
|
||||
pass
|
||||
|
||||
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Not used in Actor-Critic or PPO
|
||||
pass
|
||||
|
|
|
@ -25,18 +25,32 @@ class ContinuousDDPGNet(ContinuousPolicyNet, metaclass=ABCMeta):
|
|||
- set_state(self, net_state: dict) -> None:
|
||||
"""
|
||||
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_actions_with_probs_impl(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Not used in DDPG
|
||||
pass
|
||||
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_actions_with_logps_impl(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Not used in DDPG
|
||||
pass
|
||||
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Not used in DDPG
|
||||
pass
|
||||
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Not used in DDPG
|
||||
pass
|
||||
|
||||
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Not used in DDPG
|
||||
pass
|
||||
|
|
|
@ -25,18 +25,23 @@ class ContinuousSACNet(ContinuousPolicyNet, metaclass=ABCMeta):
|
|||
- set_state(self, net_state: dict) -> None:
|
||||
"""
|
||||
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
|
||||
actions, _ = self._get_actions_with_logps_impl(states, exploring)
|
||||
return actions
|
||||
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_actions_with_probs_impl(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Not used in SAC
|
||||
pass
|
||||
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Not used in SAC
|
||||
pass
|
||||
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Not used in SAC
|
||||
pass
|
||||
|
|
|
@ -39,7 +39,8 @@ class FullyConnected(nn.Module):
|
|||
input_dim: int,
|
||||
output_dim: int,
|
||||
hidden_dims: List[int],
|
||||
activation: Optional[Type[torch.nn.Module]] = nn.ReLU,
|
||||
activation: Optional[Type[torch.nn.Module]] = None,
|
||||
output_activation: Optional[Type[torch.nn.Module]] = None,
|
||||
head: bool = False,
|
||||
softmax: bool = False,
|
||||
batch_norm: bool = False,
|
||||
|
@ -54,7 +55,8 @@ class FullyConnected(nn.Module):
|
|||
self._output_dim = output_dim
|
||||
|
||||
# network features
|
||||
self._activation = activation() if activation else None
|
||||
self._activation = activation if activation else None
|
||||
self._output_activation = output_activation if output_activation else None
|
||||
self._head = head
|
||||
self._softmax = nn.Softmax(dim=1) if softmax else None
|
||||
self._batch_norm = batch_norm
|
||||
|
@ -70,9 +72,13 @@ class FullyConnected(nn.Module):
|
|||
|
||||
# build the net
|
||||
dims = [self._input_dim] + self._hidden_dims
|
||||
layers = [self._build_layer(in_dim, out_dim) for in_dim, out_dim in zip(dims, dims[1:])]
|
||||
layers = [
|
||||
self._build_layer(in_dim, out_dim, activation=self._activation) for in_dim, out_dim in zip(dims, dims[1:])
|
||||
]
|
||||
# top layer
|
||||
layers.append(self._build_layer(dims[-1], self._output_dim, head=self._head))
|
||||
layers.append(
|
||||
self._build_layer(dims[-1], self._output_dim, head=self._head, activation=self._output_activation),
|
||||
)
|
||||
|
||||
self._net = nn.Sequential(*layers)
|
||||
|
||||
|
@ -101,7 +107,13 @@ class FullyConnected(nn.Module):
|
|||
def output_dim(self) -> int:
|
||||
return self._output_dim
|
||||
|
||||
def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> nn.Module:
|
||||
def _build_layer(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
head: bool = False,
|
||||
activation: Type[torch.nn.Module] = None,
|
||||
) -> nn.Module:
|
||||
"""Build a basic layer.
|
||||
|
||||
BN -> Linear -> Activation -> Dropout
|
||||
|
@ -110,8 +122,8 @@ class FullyConnected(nn.Module):
|
|||
if self._batch_norm:
|
||||
components.append(("batch_norm", nn.BatchNorm1d(input_dim)))
|
||||
components.append(("linear", nn.Linear(input_dim, output_dim)))
|
||||
if not head and self._activation is not None:
|
||||
components.append(("activation", self._activation))
|
||||
if not head and activation is not None:
|
||||
components.append(("activation", activation()))
|
||||
if not head and self._dropout_p:
|
||||
components.append(("dropout", nn.Dropout(p=self._dropout_p)))
|
||||
return nn.Sequential(OrderedDict(components))
|
||||
|
|
|
@ -37,7 +37,7 @@ class MultiQNet(AbsNet, metaclass=ABCMeta):
|
|||
def agent_num(self) -> int:
|
||||
return len(self._action_dims)
|
||||
|
||||
def _shape_check(self, states: torch.Tensor, actions: List[torch.Tensor] = None) -> bool:
|
||||
def _shape_check(self, states: torch.Tensor, actions: List[torch.Tensor] = None, **kwargs) -> bool:
|
||||
"""Check whether the states and actions have valid shapes.
|
||||
|
||||
Args:
|
||||
|
@ -61,7 +61,7 @@ class MultiQNet(AbsNet, metaclass=ABCMeta):
|
|||
return False
|
||||
return True
|
||||
|
||||
def q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor:
|
||||
def q_values(self, states: torch.Tensor, actions: List[torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
"""Get Q-values according to states and actions.
|
||||
|
||||
Args:
|
||||
|
@ -71,8 +71,8 @@ class MultiQNet(AbsNet, metaclass=ABCMeta):
|
|||
Returns:
|
||||
q (torch.Tensor): Q-values with shape [batch_size].
|
||||
"""
|
||||
assert self._shape_check(states, actions)
|
||||
q = self._get_q_values(states, actions)
|
||||
assert self._shape_check(states, actions, **kwargs)
|
||||
q = self._get_q_values(states, actions, **kwargs)
|
||||
assert match_shape(
|
||||
q,
|
||||
(states.shape[0],),
|
||||
|
@ -80,6 +80,6 @@ class MultiQNet(AbsNet, metaclass=ABCMeta):
|
|||
return q
|
||||
|
||||
@abstractmethod
|
||||
def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor:
|
||||
def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
"""Implementation of `q_values`."""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -33,93 +33,121 @@ class PolicyNet(AbsNet, metaclass=ABCMeta):
|
|||
def action_dim(self) -> int:
|
||||
return self._action_dim
|
||||
|
||||
def get_actions(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
|
||||
def get_actions(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
actions = self._get_actions_impl(states, exploring)
|
||||
actions = self._get_actions_impl(states, exploring, **kwargs)
|
||||
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
actions=actions,
|
||||
**kwargs,
|
||||
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
|
||||
|
||||
return actions
|
||||
|
||||
def get_actions_with_probs(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def get_actions_with_probs(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
actions, probs = self._get_actions_with_probs_impl(states, exploring)
|
||||
actions, probs = self._get_actions_with_probs_impl(states, exploring, **kwargs)
|
||||
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
actions=actions,
|
||||
**kwargs,
|
||||
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
|
||||
assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0]
|
||||
|
||||
return actions, probs
|
||||
|
||||
def get_actions_with_logps(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def get_actions_with_logps(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
actions, logps = self._get_actions_with_logps_impl(states, exploring)
|
||||
actions, logps = self._get_actions_with_logps_impl(states, exploring, **kwargs)
|
||||
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
actions=actions,
|
||||
**kwargs,
|
||||
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
|
||||
assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0]
|
||||
|
||||
return actions, logps
|
||||
|
||||
def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
probs = self._get_states_actions_probs_impl(states, actions)
|
||||
probs = self._get_states_actions_probs_impl(states, actions, **kwargs)
|
||||
|
||||
assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0]
|
||||
|
||||
return probs
|
||||
|
||||
def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
logps = self._get_states_actions_logps_impl(states, actions)
|
||||
logps = self._get_states_actions_logps_impl(states, actions, **kwargs)
|
||||
|
||||
assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0]
|
||||
|
||||
return logps
|
||||
|
||||
@abstractmethod
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_actions_with_probs_impl(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_actions_with_logps_impl(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None) -> bool:
|
||||
def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None, **kwargs) -> bool:
|
||||
"""Check whether the states and actions have valid shapes.
|
||||
|
||||
Args:
|
||||
|
@ -160,7 +188,7 @@ class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta):
|
|||
def action_num(self) -> int:
|
||||
return self._action_num
|
||||
|
||||
def get_action_probs(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def get_action_probs(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Get the probabilities for all possible actions in the action space.
|
||||
|
||||
Args:
|
||||
|
@ -171,8 +199,9 @@ class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta):
|
|||
"""
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
action_probs = self._get_action_probs_impl(states)
|
||||
action_probs = self._get_action_probs_impl(states, **kwargs)
|
||||
assert match_shape(action_probs, (states.shape[0], self.action_num)), (
|
||||
f"Action probabilities shape check failed. Expecting: {(states.shape[0], self.action_num)}, "
|
||||
f"actual: {action_probs.shape}."
|
||||
|
@ -180,16 +209,21 @@ class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta):
|
|||
return action_probs
|
||||
|
||||
@abstractmethod
|
||||
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def _get_action_probs_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Implementation of `get_action_probs`. The core logic of a discrete policy net should be implemented here."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
|
||||
actions, _ = self._get_actions_with_probs_impl(states, exploring)
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
|
||||
actions, _ = self._get_actions_with_probs_impl(states, exploring, **kwargs)
|
||||
return actions
|
||||
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probs = self.get_action_probs(states)
|
||||
def _get_actions_with_probs_impl(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probs = self.get_action_probs(states, **kwargs)
|
||||
if exploring:
|
||||
distribution = Categorical(probs)
|
||||
actions = distribution.sample().unsqueeze(1)
|
||||
|
@ -198,16 +232,21 @@ class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta):
|
|||
probs, actions = probs.max(dim=1)
|
||||
return actions.unsqueeze(1), probs
|
||||
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
actions, probs = self._get_actions_with_probs_impl(states, exploring)
|
||||
def _get_actions_with_logps_impl(
|
||||
self,
|
||||
states: torch.Tensor,
|
||||
exploring: bool,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
actions, probs = self._get_actions_with_probs_impl(states, exploring, **kwargs)
|
||||
return actions, torch.log(probs)
|
||||
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
probs = self.get_action_probs(states)
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
probs = self.get_action_probs(states, **kwargs)
|
||||
return probs.gather(1, actions).squeeze(-1)
|
||||
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
probs = self._get_states_actions_probs_impl(states, actions)
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
probs = self._get_states_actions_probs_impl(states, actions, **kwargs)
|
||||
return torch.log(probs)
|
||||
|
||||
|
||||
|
@ -221,3 +260,18 @@ class ContinuousPolicyNet(PolicyNet, metaclass=ABCMeta):
|
|||
|
||||
def __init__(self, state_dim: int, action_dim: int) -> None:
|
||||
super(ContinuousPolicyNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
|
||||
|
||||
def get_random_actions(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
actions = self._get_random_actions_impl(states, **kwargs)
|
||||
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
actions=actions,
|
||||
**kwargs,
|
||||
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
|
||||
|
||||
return actions
|
||||
|
||||
@abstractmethod
|
||||
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -31,7 +31,7 @@ class QNet(AbsNet, metaclass=ABCMeta):
|
|||
def action_dim(self) -> int:
|
||||
return self._action_dim
|
||||
|
||||
def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None) -> bool:
|
||||
def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None, **kwargs) -> bool:
|
||||
"""Check whether the states and actions have valid shapes.
|
||||
|
||||
Args:
|
||||
|
@ -52,7 +52,7 @@ class QNet(AbsNet, metaclass=ABCMeta):
|
|||
return False
|
||||
return True
|
||||
|
||||
def q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def q_values(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Get Q-values according to states and actions.
|
||||
|
||||
Args:
|
||||
|
@ -62,12 +62,12 @@ class QNet(AbsNet, metaclass=ABCMeta):
|
|||
Returns:
|
||||
q (torch.Tensor): Q-values with shape [batch_size].
|
||||
"""
|
||||
assert self._shape_check(states=states, actions=actions), (
|
||||
assert self._shape_check(states=states, actions=actions, **kwargs), (
|
||||
f"States or action shape check failed. Expecting: "
|
||||
f"states = {('BATCH_SIZE', self.state_dim)}, action = {('BATCH_SIZE', self.action_dim)}. "
|
||||
f"Actual: states = {states.shape}, action = {actions.shape}."
|
||||
)
|
||||
q = self._get_q_values(states, actions)
|
||||
q = self._get_q_values(states, actions, **kwargs)
|
||||
assert match_shape(
|
||||
q,
|
||||
(states.shape[0],),
|
||||
|
@ -75,7 +75,7 @@ class QNet(AbsNet, metaclass=ABCMeta):
|
|||
return q
|
||||
|
||||
@abstractmethod
|
||||
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Implementation of `q_values`."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -96,7 +96,7 @@ class DiscreteQNet(QNet, metaclass=ABCMeta):
|
|||
def action_num(self) -> int:
|
||||
return self._action_num
|
||||
|
||||
def q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def q_values_for_all_actions(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Get Q-values for all actions according to states.
|
||||
|
||||
Args:
|
||||
|
@ -107,20 +107,21 @@ class DiscreteQNet(QNet, metaclass=ABCMeta):
|
|||
"""
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
q = self._get_q_values_for_all_actions(states)
|
||||
q = self._get_q_values_for_all_actions(states, **kwargs)
|
||||
assert match_shape(q, (states.shape[0], self.action_num)), (
|
||||
f"Q-value matrix shape check failed. Expecting: {(states.shape[0], self.action_num)}, "
|
||||
f"actual: {q.shape}."
|
||||
) # [B, action_num]
|
||||
return q
|
||||
|
||||
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
q = self.q_values_for_all_actions(states) # [B, action_num]
|
||||
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
q = self.q_values_for_all_actions(states, **kwargs) # [B, action_num]
|
||||
return q.gather(1, actions.long()).reshape(-1) # [B, action_num] + [B, 1] => [B]
|
||||
|
||||
@abstractmethod
|
||||
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def _get_q_values_for_all_actions(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Implementation of `q_values_for_all_actions`."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class VNet(AbsNet, metaclass=ABCMeta):
|
|||
def state_dim(self) -> int:
|
||||
return self._state_dim
|
||||
|
||||
def _shape_check(self, states: torch.Tensor) -> bool:
|
||||
def _shape_check(self, states: torch.Tensor, **kwargs) -> bool:
|
||||
"""Check whether the states have valid shapes.
|
||||
|
||||
Args:
|
||||
|
@ -39,7 +39,7 @@ class VNet(AbsNet, metaclass=ABCMeta):
|
|||
else:
|
||||
return states.shape[0] > 0 and match_shape(states, (None, self.state_dim))
|
||||
|
||||
def v_values(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def v_values(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Get V-values according to states.
|
||||
|
||||
Args:
|
||||
|
@ -50,8 +50,9 @@ class VNet(AbsNet, metaclass=ABCMeta):
|
|||
"""
|
||||
assert self._shape_check(
|
||||
states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
v = self._get_v_values(states)
|
||||
v = self._get_v_values(states, **kwargs)
|
||||
assert match_shape(
|
||||
v,
|
||||
(states.shape[0],),
|
||||
|
@ -59,6 +60,6 @@ class VNet(AbsNet, metaclass=ABCMeta):
|
|||
return v
|
||||
|
||||
@abstractmethod
|
||||
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def _get_v_values(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Implementation of `v_values`."""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -27,7 +27,7 @@ class AbsPolicy(object, metaclass=ABCMeta):
|
|||
self._trainable = trainable
|
||||
|
||||
@abstractmethod
|
||||
def get_actions(self, states: Union[list, np.ndarray]) -> Any:
|
||||
def get_actions(self, states: Union[list, np.ndarray], **kwargs) -> Any:
|
||||
"""Get actions according to states.
|
||||
|
||||
Args:
|
||||
|
@ -79,7 +79,7 @@ class DummyPolicy(AbsPolicy):
|
|||
def __init__(self) -> None:
|
||||
super(DummyPolicy, self).__init__(name="DUMMY_POLICY", trainable=False)
|
||||
|
||||
def get_actions(self, states: Union[list, np.ndarray]) -> None:
|
||||
def get_actions(self, states: Union[list, np.ndarray], **kwargs) -> None:
|
||||
return None
|
||||
|
||||
def explore(self) -> None:
|
||||
|
@ -101,11 +101,11 @@ class RuleBasedPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
def __init__(self, name: str) -> None:
|
||||
super(RuleBasedPolicy, self).__init__(name=name, trainable=False)
|
||||
|
||||
def get_actions(self, states: list) -> list:
|
||||
return self._rule(states)
|
||||
def get_actions(self, states: list, **kwargs) -> list:
|
||||
return self._rule(states, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _rule(self, states: list) -> list:
|
||||
def _rule(self, states: list, **kwargs) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
def explore(self) -> None:
|
||||
|
@ -129,6 +129,8 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
state_dim (int): Dimension of states.
|
||||
action_dim (int): Dimension of actions.
|
||||
trainable (bool, default=True): Whether this policy is trainable.
|
||||
warmup (int, default=0): Number of steps for uniform-random action selection, before running real policy.
|
||||
Helps exploration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -138,6 +140,7 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
action_dim: int,
|
||||
is_discrete_action: bool,
|
||||
trainable: bool = True,
|
||||
warmup: int = 0,
|
||||
) -> None:
|
||||
super(RLPolicy, self).__init__(name=name, trainable=trainable)
|
||||
self._state_dim = state_dim
|
||||
|
@ -145,6 +148,8 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
self._is_exploring = False
|
||||
|
||||
self._device: Optional[torch.device] = None
|
||||
self._warmup = warmup
|
||||
self._call_count = 0
|
||||
|
||||
self.is_discrete_action = is_discrete_action
|
||||
|
||||
|
@ -199,94 +204,122 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_actions(self, states: np.ndarray) -> np.ndarray:
|
||||
actions = self.get_actions_tensor(ndarray_to_tensor(states, device=self._device))
|
||||
def get_actions(self, states: np.ndarray, **kwargs) -> np.ndarray:
|
||||
self._call_count += 1
|
||||
|
||||
if self._call_count <= self._warmup:
|
||||
actions = self.get_random_actions_tensor(ndarray_to_tensor(states, device=self._device), **kwargs)
|
||||
else:
|
||||
actions = self.get_actions_tensor(ndarray_to_tensor(states, device=self._device), **kwargs)
|
||||
return actions.detach().cpu().numpy()
|
||||
|
||||
def get_actions_tensor(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def get_actions_tensor(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
actions = self._get_actions_impl(states)
|
||||
actions = self._get_actions_impl(states, **kwargs)
|
||||
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
actions=actions,
|
||||
**kwargs,
|
||||
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
|
||||
|
||||
return actions
|
||||
|
||||
def get_actions_with_probs(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
actions, probs = self._get_actions_with_probs_impl(states)
|
||||
def get_random_actions_tensor(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
actions = self._get_random_actions_impl(states, **kwargs)
|
||||
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
actions=actions,
|
||||
**kwargs,
|
||||
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
|
||||
|
||||
return actions
|
||||
|
||||
def get_actions_with_probs(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
actions, probs = self._get_actions_with_probs_impl(states, **kwargs)
|
||||
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
actions=actions,
|
||||
**kwargs,
|
||||
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
|
||||
assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0]
|
||||
|
||||
return actions, probs
|
||||
|
||||
def get_actions_with_logps(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def get_actions_with_logps(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
actions, logps = self._get_actions_with_logps_impl(states)
|
||||
actions, logps = self._get_actions_with_logps_impl(states, **kwargs)
|
||||
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
actions=actions,
|
||||
**kwargs,
|
||||
), f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}."
|
||||
assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0]
|
||||
|
||||
return actions, logps
|
||||
|
||||
def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
probs = self._get_states_actions_probs_impl(states, actions)
|
||||
probs = self._get_states_actions_probs_impl(states, actions, **kwargs)
|
||||
|
||||
assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0]
|
||||
|
||||
return probs
|
||||
|
||||
def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
|
||||
logps = self._get_states_actions_logps_impl(states, actions)
|
||||
logps = self._get_states_actions_logps_impl(states, actions, **kwargs)
|
||||
|
||||
assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0]
|
||||
|
||||
return logps
|
||||
|
||||
@abstractmethod
|
||||
def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
@ -327,6 +360,7 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
self,
|
||||
states: torch.Tensor,
|
||||
actions: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""Check whether the states and actions have valid shapes.
|
||||
|
||||
|
@ -352,7 +386,7 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
return True
|
||||
|
||||
@abstractmethod
|
||||
def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool:
|
||||
def _post_check(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> bool:
|
||||
"""Check whether the generated action tensor is valid, i.e., has matching shape with states tensor.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -42,6 +42,8 @@ class ContinuousRLPolicy(RLPolicy):
|
|||
the bound for every dimension. If it is a float, it will be broadcast to all dimensions.
|
||||
policy_net (ContinuousPolicyNet): The core net of this policy.
|
||||
trainable (bool, default=True): Whether this policy is trainable.
|
||||
warmup (int, default=0): Number of steps for uniform-random action selection, before running real policy.
|
||||
Helps exploration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -50,6 +52,7 @@ class ContinuousRLPolicy(RLPolicy):
|
|||
action_range: Tuple[Union[float, List[float]], Union[float, List[float]]],
|
||||
policy_net: ContinuousPolicyNet,
|
||||
trainable: bool = True,
|
||||
warmup: int = 0,
|
||||
) -> None:
|
||||
assert isinstance(policy_net, ContinuousPolicyNet)
|
||||
|
||||
|
@ -59,6 +62,7 @@ class ContinuousRLPolicy(RLPolicy):
|
|||
action_dim=policy_net.action_dim,
|
||||
trainable=trainable,
|
||||
is_discrete_action=False,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
self._lbounds, self._ubounds = _parse_action_range(self.action_dim, action_range)
|
||||
|
@ -72,7 +76,7 @@ class ContinuousRLPolicy(RLPolicy):
|
|||
def policy_net(self) -> ContinuousPolicyNet:
|
||||
return self._policy_net
|
||||
|
||||
def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool:
|
||||
def _post_check(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> bool:
|
||||
return all(
|
||||
[
|
||||
(np.array(self._lbounds) <= actions.detach().cpu().numpy()).all(),
|
||||
|
@ -80,20 +84,23 @@ class ContinuousRLPolicy(RLPolicy):
|
|||
],
|
||||
)
|
||||
|
||||
def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._policy_net.get_actions(states, self._is_exploring)
|
||||
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._policy_net.get_actions(states, self._is_exploring, **kwargs)
|
||||
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._policy_net.get_actions_with_probs(states, self._is_exploring)
|
||||
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._policy_net.get_random_actions(states, **kwargs)
|
||||
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._policy_net.get_actions_with_logps(states, self._is_exploring)
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._policy_net.get_actions_with_probs(states, self._is_exploring, **kwargs)
|
||||
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
return self._policy_net.get_states_actions_probs(states, actions)
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._policy_net.get_actions_with_logps(states, self._is_exploring, **kwargs)
|
||||
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
return self._policy_net.get_states_actions_logps(states, actions)
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._policy_net.get_states_actions_probs(states, actions, **kwargs)
|
||||
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._policy_net.get_states_actions_logps(states, actions, **kwargs)
|
||||
|
||||
def train_step(self, loss: torch.Tensor) -> None:
|
||||
self._policy_net.step(loss)
|
||||
|
@ -117,14 +124,22 @@ class ContinuousRLPolicy(RLPolicy):
|
|||
self._policy_net.train()
|
||||
|
||||
def get_state(self) -> dict:
|
||||
return self._policy_net.get_state()
|
||||
return {
|
||||
"net": self._policy_net.get_state(),
|
||||
"policy": {
|
||||
"warmup": self._warmup,
|
||||
"call_count": self._call_count,
|
||||
},
|
||||
}
|
||||
|
||||
def set_state(self, policy_state: dict) -> None:
|
||||
self._policy_net.set_state(policy_state)
|
||||
self._policy_net.set_state(policy_state["net"])
|
||||
self._warmup = policy_state["policy"]["warmup"]
|
||||
self._call_count = policy_state["policy"]["call_count"]
|
||||
|
||||
def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
|
||||
assert isinstance(other_policy, ContinuousRLPolicy)
|
||||
self._policy_net.soft_update(other_policy.policy_net, tau)
|
||||
|
||||
def _to_device_impl(self, device: torch.device) -> None:
|
||||
self._policy_net.to(device)
|
||||
self._policy_net.to_device(device)
|
||||
|
|
|
@ -23,6 +23,8 @@ class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta):
|
|||
state_dim (int): Dimension of states.
|
||||
action_num (int): Number of actions.
|
||||
trainable (bool, default=True): Whether this policy is trainable.
|
||||
warmup (int, default=0): Number of steps for uniform-random action selection, before running real policy.
|
||||
Helps exploration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -31,6 +33,7 @@ class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta):
|
|||
state_dim: int,
|
||||
action_num: int,
|
||||
trainable: bool = True,
|
||||
warmup: int = 0,
|
||||
) -> None:
|
||||
assert action_num >= 1
|
||||
|
||||
|
@ -40,6 +43,7 @@ class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta):
|
|||
action_dim=1,
|
||||
trainable=trainable,
|
||||
is_discrete_action=True,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
self._action_num = action_num
|
||||
|
@ -48,9 +52,15 @@ class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta):
|
|||
def action_num(self) -> int:
|
||||
return self._action_num
|
||||
|
||||
def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool:
|
||||
def _post_check(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> bool:
|
||||
return all([0 <= action < self.action_num for action in actions.cpu().numpy().flatten()])
|
||||
|
||||
def _get_random_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return ndarray_to_tensor(
|
||||
np.random.randint(self.action_num, size=(states.shape[0], 1)),
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
|
||||
class ValueBasedPolicy(DiscreteRLPolicy):
|
||||
"""Valued-based policy.
|
||||
|
@ -61,7 +71,8 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
trainable (bool, default=True): Whether this policy is trainable.
|
||||
exploration_strategy (Tuple[Callable, dict], default=(epsilon_greedy, {"epsilon": 0.1})): Exploration strategy.
|
||||
exploration_scheduling_options (List[tuple], default=None): List of exploration scheduler options.
|
||||
warmup (int, default=50000): Minimum number of experiences to warm up this policy.
|
||||
warmup (int, default=50000): Number of steps for uniform-random action selection, before running real policy.
|
||||
Helps exploration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -80,6 +91,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
state_dim=q_net.state_dim,
|
||||
action_num=q_net.action_num,
|
||||
trainable=trainable,
|
||||
warmup=warmup,
|
||||
)
|
||||
self._q_net = q_net
|
||||
|
||||
|
@ -91,16 +103,13 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
else []
|
||||
)
|
||||
|
||||
self._call_cnt = 0
|
||||
self._warmup = warmup
|
||||
|
||||
self._softmax = torch.nn.Softmax(dim=1)
|
||||
|
||||
@property
|
||||
def q_net(self) -> DiscreteQNet:
|
||||
return self._q_net
|
||||
|
||||
def q_values_for_all_actions(self, states: np.ndarray) -> np.ndarray:
|
||||
def q_values_for_all_actions(self, states: np.ndarray, **kwargs) -> np.ndarray:
|
||||
"""Generate a matrix containing the Q-values for all actions for the given states.
|
||||
|
||||
Args:
|
||||
|
@ -109,9 +118,16 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
Returns:
|
||||
q_values (np.ndarray): Q-matrix.
|
||||
"""
|
||||
return self.q_values_for_all_actions_tensor(ndarray_to_tensor(states, device=self._device)).cpu().numpy()
|
||||
return (
|
||||
self.q_values_for_all_actions_tensor(
|
||||
ndarray_to_tensor(states, device=self._device),
|
||||
**kwargs,
|
||||
)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
def q_values_for_all_actions_tensor(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def q_values_for_all_actions_tensor(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Generate a matrix containing the Q-values for all actions for the given states.
|
||||
|
||||
Args:
|
||||
|
@ -120,12 +136,12 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
Returns:
|
||||
q_values (torch.Tensor): Q-matrix.
|
||||
"""
|
||||
assert self._shape_check(states=states)
|
||||
q_values = self._q_net.q_values_for_all_actions(states)
|
||||
assert self._shape_check(states=states, **kwargs)
|
||||
q_values = self._q_net.q_values_for_all_actions(states, **kwargs)
|
||||
assert match_shape(q_values, (states.shape[0], self.action_num)) # [B, action_num]
|
||||
return q_values
|
||||
|
||||
def q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray:
|
||||
def q_values(self, states: np.ndarray, actions: np.ndarray, **kwargs) -> np.ndarray:
|
||||
"""Generate the Q values for given state-action pairs.
|
||||
|
||||
Args:
|
||||
|
@ -139,12 +155,13 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
self.q_values_tensor(
|
||||
ndarray_to_tensor(states, device=self._device),
|
||||
ndarray_to_tensor(actions, device=self._device),
|
||||
**kwargs,
|
||||
)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
def q_values_tensor(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
def q_values_tensor(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Generate the Q values for given state-action pairs.
|
||||
|
||||
Args:
|
||||
|
@ -154,50 +171,46 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
Returns:
|
||||
q_values (torch.Tensor): Q-values.
|
||||
"""
|
||||
assert self._shape_check(states=states, actions=actions) # actions: [B, 1]
|
||||
q_values = self._q_net.q_values(states, actions)
|
||||
assert self._shape_check(states=states, actions=actions, **kwargs) # actions: [B, 1]
|
||||
q_values = self._q_net.q_values(states, actions, **kwargs)
|
||||
assert match_shape(q_values, (states.shape[0],)) # [B]
|
||||
return q_values
|
||||
|
||||
def explore(self) -> None:
|
||||
pass # Overwrite the base method and turn off explore mode.
|
||||
|
||||
def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
actions, _ = self._get_actions_with_probs_impl(states)
|
||||
return actions
|
||||
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._get_actions_with_probs_impl(states, **kwargs)[0]
|
||||
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._call_cnt += 1
|
||||
if self._call_cnt <= self._warmup:
|
||||
actions = ndarray_to_tensor(
|
||||
np.random.randint(self.action_num, size=(states.shape[0], 1)),
|
||||
device=self._device,
|
||||
)
|
||||
probs = torch.ones(states.shape[0]).float() * (1.0 / self.action_num)
|
||||
return actions, probs
|
||||
|
||||
q_matrix = self.q_values_for_all_actions_tensor(states) # [B, action_num]
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q_matrix = self.q_values_for_all_actions_tensor(states, **kwargs) # [B, action_num]
|
||||
q_matrix_softmax = self._softmax(q_matrix)
|
||||
_, actions = q_matrix.max(dim=1) # [B], [B]
|
||||
|
||||
if self._is_exploring:
|
||||
actions = self._exploration_func(states, actions.cpu().numpy(), self.action_num, **self._exploration_params)
|
||||
actions = self._exploration_func(
|
||||
states,
|
||||
actions.cpu().numpy(),
|
||||
self.action_num,
|
||||
**self._exploration_params,
|
||||
**kwargs,
|
||||
)
|
||||
actions = ndarray_to_tensor(actions, device=self._device)
|
||||
|
||||
actions = actions.unsqueeze(1)
|
||||
return actions, q_matrix_softmax.gather(1, actions).squeeze(-1) # [B, 1]
|
||||
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
actions, probs = self._get_actions_with_probs_impl(states)
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
actions, probs = self._get_actions_with_probs_impl(states, **kwargs)
|
||||
return actions, torch.log(probs)
|
||||
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
q_matrix = self.q_values_for_all_actions_tensor(states)
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
q_matrix = self.q_values_for_all_actions_tensor(states, **kwargs)
|
||||
q_matrix_softmax = self._softmax(q_matrix)
|
||||
return q_matrix_softmax.gather(1, actions).squeeze(-1) # [B]
|
||||
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
probs = self._get_states_actions_probs_impl(states, actions)
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
probs = self._get_states_actions_probs_impl(states, actions, **kwargs)
|
||||
return torch.log(probs)
|
||||
|
||||
def train_step(self, loss: torch.Tensor) -> None:
|
||||
|
@ -222,17 +235,25 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
self._q_net.train()
|
||||
|
||||
def get_state(self) -> dict:
|
||||
return self._q_net.get_state()
|
||||
return {
|
||||
"net": self._q_net.get_state(),
|
||||
"policy": {
|
||||
"warmup": self._warmup,
|
||||
"call_count": self._call_count,
|
||||
},
|
||||
}
|
||||
|
||||
def set_state(self, policy_state: dict) -> None:
|
||||
self._q_net.set_state(policy_state)
|
||||
self._warmup = policy_state["policy"]["warmup"]
|
||||
self._call_count = policy_state["policy"]["call_count"]
|
||||
|
||||
def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
|
||||
assert isinstance(other_policy, ValueBasedPolicy)
|
||||
self._q_net.soft_update(other_policy.q_net, tau)
|
||||
|
||||
def _to_device_impl(self, device: torch.device) -> None:
|
||||
self._q_net.to(device)
|
||||
self._q_net.to_device(device)
|
||||
|
||||
|
||||
class DiscretePolicyGradient(DiscreteRLPolicy):
|
||||
|
@ -242,6 +263,8 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
|
|||
name (str): Name of the policy.
|
||||
policy_net (DiscretePolicyNet): The core net of this policy.
|
||||
trainable (bool, default=True): Whether this policy is trainable.
|
||||
warmup (int, default=50000): Number of steps for uniform-random action selection, before running real policy.
|
||||
Helps exploration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -249,6 +272,7 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
|
|||
name: str,
|
||||
policy_net: DiscretePolicyNet,
|
||||
trainable: bool = True,
|
||||
warmup: int = 0,
|
||||
) -> None:
|
||||
assert isinstance(policy_net, DiscretePolicyNet)
|
||||
|
||||
|
@ -257,6 +281,7 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
|
|||
state_dim=policy_net.state_dim,
|
||||
action_num=policy_net.action_num,
|
||||
trainable=trainable,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
self._policy_net = policy_net
|
||||
|
@ -265,20 +290,20 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
|
|||
def policy_net(self) -> DiscretePolicyNet:
|
||||
return self._policy_net
|
||||
|
||||
def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._policy_net.get_actions(states, self._is_exploring)
|
||||
def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._policy_net.get_actions(states, self._is_exploring, **kwargs)
|
||||
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._policy_net.get_actions_with_probs(states, self._is_exploring)
|
||||
def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._policy_net.get_actions_with_probs(states, self._is_exploring, **kwargs)
|
||||
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._policy_net.get_actions_with_logps(states, self._is_exploring)
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._policy_net.get_actions_with_logps(states, self._is_exploring, **kwargs)
|
||||
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
return self._policy_net.get_states_actions_probs(states, actions)
|
||||
def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._policy_net.get_states_actions_probs(states, actions, **kwargs)
|
||||
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
return self._policy_net.get_states_actions_logps(states, actions)
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return self._policy_net.get_states_actions_logps(states, actions, **kwargs)
|
||||
|
||||
def train_step(self, loss: torch.Tensor) -> None:
|
||||
self._policy_net.step(loss)
|
||||
|
@ -302,16 +327,24 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
|
|||
self._policy_net.train()
|
||||
|
||||
def get_state(self) -> dict:
|
||||
return self._policy_net.get_state()
|
||||
return {
|
||||
"net": self._policy_net.get_state(),
|
||||
"policy": {
|
||||
"warmup": self._warmup,
|
||||
"call_count": self._call_count,
|
||||
},
|
||||
}
|
||||
|
||||
def set_state(self, policy_state: dict) -> None:
|
||||
self._policy_net.set_state(policy_state)
|
||||
self._warmup = policy_state["policy"]["warmup"]
|
||||
self._call_count = policy_state["policy"]["call_count"]
|
||||
|
||||
def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
|
||||
assert isinstance(other_policy, DiscretePolicyGradient)
|
||||
self._policy_net.soft_update(other_policy.policy_net, tau)
|
||||
|
||||
def get_action_probs(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def get_action_probs(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Get the probabilities for all actions according to states.
|
||||
|
||||
Args:
|
||||
|
@ -322,15 +355,16 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
|
|||
"""
|
||||
assert self._shape_check(
|
||||
states=states,
|
||||
**kwargs,
|
||||
), f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}."
|
||||
action_probs = self._policy_net.get_action_probs(states)
|
||||
action_probs = self._policy_net.get_action_probs(states, **kwargs)
|
||||
assert match_shape(action_probs, (states.shape[0], self.action_num)), (
|
||||
f"Action probabilities shape check failed. Expecting: {(states.shape[0], self.action_num)}, "
|
||||
f"actual: {action_probs.shape}."
|
||||
)
|
||||
return action_probs
|
||||
|
||||
def get_action_logps(self, states: torch.Tensor) -> torch.Tensor:
|
||||
def get_action_logps(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""Get the log-probabilities for all actions according to states.
|
||||
|
||||
Args:
|
||||
|
@ -339,15 +373,15 @@ class DiscretePolicyGradient(DiscreteRLPolicy):
|
|||
Returns:
|
||||
action_logps (torch.Tensor): Action probabilities with shape [batch_size, action_num].
|
||||
"""
|
||||
return torch.log(self.get_action_probs(states))
|
||||
return torch.log(self.get_action_probs(states, **kwargs))
|
||||
|
||||
def _get_state_action_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
action_probs = self.get_action_probs(states)
|
||||
def _get_state_action_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
action_probs = self.get_action_probs(states, **kwargs)
|
||||
return action_probs.gather(1, actions).squeeze(-1) # [B]
|
||||
|
||||
def _get_state_action_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
action_logps = self.get_action_logps(states)
|
||||
def _get_state_action_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
action_logps = self.get_action_logps(states, **kwargs)
|
||||
return action_logps.gather(1, actions).squeeze(-1) # [B]
|
||||
|
||||
def _to_device_impl(self, device: torch.device) -> None:
|
||||
self._policy_net.to(device)
|
||||
self._policy_net.to_device(device)
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
|||
from maro.rl.policy import AbsPolicy, RLPolicy
|
||||
from maro.rl.rollout import AbsEnvSampler
|
||||
from maro.rl.training import AbsTrainer
|
||||
from maro.rl.workflows.callback import Callback
|
||||
|
||||
|
||||
class RLComponentBundle:
|
||||
|
@ -20,7 +21,7 @@ class RLComponentBundle:
|
|||
If None, there will be no explicit device assignment.
|
||||
policy_trainer_mapping (Dict[str, str], default=None): Policy-trainer mapping which identifying which trainer to
|
||||
train each policy. If None, then a policy's trainer's name is the first segment of the policy's name,
|
||||
seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1". Only policies that provided in
|
||||
separated by dot. For example, "ppo_1.policy" is trained by "ppo_1". Only policies that provided in
|
||||
policy-trainer mapping are considered as trainable polices. Policies that not provided in policy-trainer
|
||||
mapping will not be trained.
|
||||
"""
|
||||
|
@ -33,11 +34,13 @@ class RLComponentBundle:
|
|||
trainers: List[AbsTrainer],
|
||||
device_mapping: Dict[str, str] = None,
|
||||
policy_trainer_mapping: Dict[str, str] = None,
|
||||
customized_callbacks: List[Callback] = [],
|
||||
) -> None:
|
||||
self.env_sampler = env_sampler
|
||||
self.agent2policy = agent2policy
|
||||
self.policies = policies
|
||||
self.trainers = trainers
|
||||
self.customized_callbacks = customized_callbacks
|
||||
|
||||
policy_set = set([policy.name for policy in self.policies])
|
||||
not_found = [policy_name for policy_name in self.agent2policy.values() if policy_name not in policy_set]
|
||||
|
|
|
@ -189,8 +189,13 @@ class BatchEnvSampler:
|
|||
"info": [res["info"][0] for res in results],
|
||||
}
|
||||
|
||||
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
|
||||
req = {"type": "eval", "policy_state": policy_state, "index": self._ep} # -1 signals test
|
||||
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict:
|
||||
req = {
|
||||
"type": "eval",
|
||||
"policy_state": policy_state,
|
||||
"index": self._ep,
|
||||
"num_eval_episodes": num_episodes,
|
||||
} # -1 signals test
|
||||
results = self._controller.collect(req, self._eval_parallelism)
|
||||
return {
|
||||
"info": [res["info"][0] for res in results],
|
||||
|
|
|
@ -48,6 +48,7 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
|
|||
def choose_actions(
|
||||
self,
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, list]],
|
||||
**kwargs,
|
||||
) -> Dict[Any, Union[np.ndarray, list]]:
|
||||
"""Choose action according to the given (observable) states of all agents.
|
||||
|
||||
|
@ -61,13 +62,14 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
|
|||
"""
|
||||
self.switch_to_eval_mode()
|
||||
with torch.no_grad():
|
||||
ret = self._choose_actions_impl(state_by_agent)
|
||||
ret = self._choose_actions_impl(state_by_agent, **kwargs)
|
||||
return ret
|
||||
|
||||
@abstractmethod
|
||||
def _choose_actions_impl(
|
||||
self,
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, list]],
|
||||
**kwargs,
|
||||
) -> Dict[Any, Union[np.ndarray, list]]:
|
||||
"""Implementation of `choose_actions`."""
|
||||
raise NotImplementedError
|
||||
|
@ -99,6 +101,7 @@ class SimpleAgentWrapper(AbsAgentWrapper):
|
|||
def _choose_actions_impl(
|
||||
self,
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, list]],
|
||||
**kwargs,
|
||||
) -> Dict[Any, Union[np.ndarray, list]]:
|
||||
# Aggregate states by policy
|
||||
states_by_policy = collections.defaultdict(list) # {str: list of np.ndarray}
|
||||
|
@ -116,7 +119,7 @@ class SimpleAgentWrapper(AbsAgentWrapper):
|
|||
states = np.vstack(states_by_policy[policy_name]) # np.ndarray
|
||||
else:
|
||||
states = states_by_policy[policy_name] # list
|
||||
actions: Union[np.ndarray, list] = policy.get_actions(states) # np.ndarray or list
|
||||
actions: Union[np.ndarray, list] = policy.get_actions(states, **kwargs) # np.ndarray or list
|
||||
action_dict.update(zip(agents_by_policy[policy_name], actions))
|
||||
|
||||
return action_dict
|
||||
|
@ -146,6 +149,7 @@ class ExpElement:
|
|||
terminal_dict: Dict[Any, bool]
|
||||
next_state: Optional[np.ndarray]
|
||||
next_agent_state_dict: Dict[Any, np.ndarray]
|
||||
truncated: bool
|
||||
|
||||
@property
|
||||
def agent_names(self) -> list:
|
||||
|
@ -171,6 +175,7 @@ class ExpElement:
|
|||
}
|
||||
if self.next_agent_state_dict is not None and agent_name in self.next_agent_state_dict
|
||||
else {},
|
||||
truncated=self.truncated,
|
||||
)
|
||||
return ret
|
||||
|
||||
|
@ -194,6 +199,7 @@ class ExpElement:
|
|||
terminal_dict={},
|
||||
next_state=self.next_state,
|
||||
next_agent_state_dict=None if self.next_agent_state_dict is None else {},
|
||||
truncated=self.truncated,
|
||||
),
|
||||
)
|
||||
for agent_name, trainer_name in agent2trainer.items():
|
||||
|
@ -225,6 +231,7 @@ class CacheElement(ExpElement):
|
|||
terminal_dict=self.terminal_dict,
|
||||
next_state=self.next_state,
|
||||
next_agent_state_dict=self.next_agent_state_dict,
|
||||
truncated=self.truncated,
|
||||
)
|
||||
|
||||
|
||||
|
@ -240,6 +247,8 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
agent_wrapper_cls (Type[AbsAgentWrapper], default=SimpleAgentWrapper): Specific AgentWrapper type.
|
||||
reward_eval_delay (int, default=None): Number of ticks required after a decision event to evaluate the reward
|
||||
for the action taken for that event. If it is None, calculate reward immediately after `step()`.
|
||||
max_episode_length (int, default=None): Maximum number of steps in one episode during sampling.
|
||||
When reach this limit, the environment will be truncated and reset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -251,7 +260,10 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
trainable_policies: List[str] = None,
|
||||
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
|
||||
reward_eval_delay: int = None,
|
||||
max_episode_length: int = None,
|
||||
) -> None:
|
||||
assert learn_env is not test_env, "Please use different envs for training and testing."
|
||||
|
||||
self._learn_env = learn_env
|
||||
self._test_env = test_env
|
||||
|
||||
|
@ -262,11 +274,14 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
self._state: Optional[np.ndarray] = None
|
||||
self._agent_state_dict: Dict[Any, np.ndarray] = {}
|
||||
|
||||
self._trans_cache: List[CacheElement] = []
|
||||
self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._trans_cache
|
||||
self._transition_cache: List[CacheElement] = []
|
||||
self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._transition_cache
|
||||
self._reward_eval_delay = reward_eval_delay
|
||||
self._max_episode_length = max_episode_length
|
||||
self._current_episode_length = 0
|
||||
|
||||
self._info: dict = {}
|
||||
self.metrics: dict = {}
|
||||
|
||||
assert self._reward_eval_delay is None or self._reward_eval_delay >= 0
|
||||
|
||||
|
@ -291,11 +306,17 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
[policy_name in self._rl_policy_dict for policy_name in self._trainable_policies],
|
||||
), "All trainable policies must be RL policies!"
|
||||
|
||||
self._total_number_interactions = 0
|
||||
|
||||
@property
|
||||
def env(self) -> Env:
|
||||
assert self._env is not None
|
||||
return self._env
|
||||
|
||||
def monitor_metrics(self) -> float:
|
||||
"""Metrics watched by early stopping."""
|
||||
return float(self._total_number_interactions)
|
||||
|
||||
def _switch_env(self, env: Env) -> None:
|
||||
self._env = env
|
||||
|
||||
|
@ -369,7 +390,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
def _step(self, actions: Optional[list]) -> None:
|
||||
_, self._event, self._end_of_episode = self.env.step(actions)
|
||||
self._state, self._agent_state_dict = (
|
||||
(None, {}) if self._end_of_episode else self._get_global_and_agent_state(self._event)
|
||||
(None, {}) if self._end_of_episode else self._get_global_and_agent_state(self._event, self.env.tick)
|
||||
)
|
||||
|
||||
def _calc_reward(self, cache_element: CacheElement) -> None:
|
||||
|
@ -383,37 +404,37 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
def _append_cache_element(self, cache_element: Optional[CacheElement]) -> None:
|
||||
"""`cache_element` == None means we are processing the last element in trans_cache"""
|
||||
if cache_element is None:
|
||||
if len(self._trans_cache) > 0:
|
||||
self._trans_cache[-1].next_state = self._trans_cache[-1].state
|
||||
|
||||
for agent_name, i in self._agent_last_index.items():
|
||||
e = self._trans_cache[i]
|
||||
e = self._transition_cache[i]
|
||||
e.terminal_dict[agent_name] = self._end_of_episode
|
||||
e.next_agent_state_dict[agent_name] = e.agent_state_dict[agent_name]
|
||||
else:
|
||||
self._trans_cache.append(cache_element)
|
||||
self._transition_cache.append(cache_element)
|
||||
|
||||
if len(self._trans_cache) > 0:
|
||||
self._trans_cache[-1].next_state = cache_element.state
|
||||
|
||||
cur_index = len(self._trans_cache) - 1
|
||||
cur_index = len(self._transition_cache) - 1
|
||||
for agent_name in cache_element.agent_names:
|
||||
if agent_name in self._agent_last_index:
|
||||
i = self._agent_last_index[agent_name]
|
||||
self._trans_cache[i].terminal_dict[agent_name] = False
|
||||
self._trans_cache[i].next_agent_state_dict[agent_name] = cache_element.agent_state_dict[agent_name]
|
||||
e = self._transition_cache[i]
|
||||
e.terminal_dict[agent_name] = False
|
||||
e.next_agent_state_dict[agent_name] = cache_element.agent_state_dict[agent_name]
|
||||
self._agent_last_index[agent_name] = cur_index
|
||||
|
||||
def _reset(self) -> None:
|
||||
self.env.reset()
|
||||
self._current_episode_length = 0
|
||||
self._info.clear()
|
||||
self._trans_cache.clear()
|
||||
self._transition_cache.clear()
|
||||
self._agent_last_index.clear()
|
||||
self._step(None)
|
||||
|
||||
def _select_trainable_agents(self, original_dict: dict) -> dict:
|
||||
return {k: v for k, v in original_dict.items() if k in self._trainable_agents}
|
||||
|
||||
@property
|
||||
def truncated(self) -> bool:
|
||||
return self._max_episode_length == self._current_episode_length
|
||||
|
||||
def sample(
|
||||
self,
|
||||
policy_state: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
|
@ -430,65 +451,88 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
Returns:
|
||||
A dict that contains the collected experiences and additional information.
|
||||
"""
|
||||
# Init the env
|
||||
self._switch_env(self._learn_env)
|
||||
steps_to_go = num_steps if num_steps is not None else float("inf")
|
||||
if policy_state is not None: # Update policy state if necessary
|
||||
self.set_policy_state(policy_state)
|
||||
self._switch_env(self._learn_env) # Init the env
|
||||
self._agent_wrapper.explore() # Collect experience
|
||||
|
||||
if self._end_of_episode:
|
||||
self._reset()
|
||||
|
||||
# Update policy state if necessary
|
||||
if policy_state is not None:
|
||||
self.set_policy_state(policy_state)
|
||||
# If num_steps is None, run until the end of episode or the episode is truncated
|
||||
# If num_steps is not None, run until we collect required number of steps
|
||||
total_experiences = []
|
||||
|
||||
# Collect experience
|
||||
self._agent_wrapper.explore()
|
||||
steps_to_go = float("inf") if num_steps is None else num_steps
|
||||
while not self._end_of_episode and steps_to_go > 0:
|
||||
# Get agent actions and translate them to env actions
|
||||
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
|
||||
env_action_dict = self._translate_to_env_action(action_dict, self._event)
|
||||
while not any(
|
||||
[
|
||||
num_steps is None and (self._end_of_episode or self.truncated),
|
||||
num_steps is not None and steps_to_go == 0,
|
||||
],
|
||||
):
|
||||
if self._end_of_episode or self.truncated:
|
||||
self._reset()
|
||||
|
||||
# Store experiences in the cache
|
||||
cache_element = CacheElement(
|
||||
tick=self.env.tick,
|
||||
event=self._event,
|
||||
state=self._state,
|
||||
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
|
||||
action_dict=self._select_trainable_agents(action_dict),
|
||||
env_action_dict=self._select_trainable_agents(env_action_dict),
|
||||
# The following will be generated later
|
||||
reward_dict={},
|
||||
terminal_dict={},
|
||||
next_state=None,
|
||||
next_agent_state_dict={},
|
||||
)
|
||||
while not any(
|
||||
[
|
||||
self._end_of_episode,
|
||||
self.truncated,
|
||||
steps_to_go == 0,
|
||||
],
|
||||
):
|
||||
# Get agent actions and translate them to env actions
|
||||
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
|
||||
env_action_dict = self._translate_to_env_action(action_dict, self._event)
|
||||
|
||||
# Update env and get new states (global & agent)
|
||||
self._step(list(env_action_dict.values()))
|
||||
self._total_number_interactions += 1
|
||||
self._current_episode_length += 1
|
||||
steps_to_go -= 1
|
||||
|
||||
if self._reward_eval_delay is None:
|
||||
self._calc_reward(cache_element)
|
||||
self._post_step(cache_element)
|
||||
self._append_cache_element(cache_element)
|
||||
steps_to_go -= 1
|
||||
self._append_cache_element(None)
|
||||
# Store experiences in the cache
|
||||
cache_element = CacheElement(
|
||||
tick=self.env.tick,
|
||||
event=self._event,
|
||||
state=self._state,
|
||||
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
|
||||
action_dict=self._select_trainable_agents(action_dict),
|
||||
env_action_dict=self._select_trainable_agents(env_action_dict),
|
||||
# The following will be generated/updated later
|
||||
reward_dict={},
|
||||
terminal_dict={},
|
||||
next_state=None,
|
||||
next_agent_state_dict={},
|
||||
truncated=self.truncated,
|
||||
)
|
||||
|
||||
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
|
||||
experiences: List[ExpElement] = []
|
||||
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
|
||||
cache_element = self._trans_cache.pop(0)
|
||||
# !: Here the reward calculation method requires the given tick is enough and must be used then.
|
||||
if self._reward_eval_delay is not None:
|
||||
self._calc_reward(cache_element)
|
||||
self._post_step(cache_element)
|
||||
experiences.append(cache_element.make_exp_element())
|
||||
# Update env and get new states (global & agent)
|
||||
self._step(list(env_action_dict.values()))
|
||||
cache_element.next_state = self._state
|
||||
|
||||
self._agent_last_index = {
|
||||
k: v - len(experiences) for k, v in self._agent_last_index.items() if v >= len(experiences)
|
||||
}
|
||||
if self._reward_eval_delay is None:
|
||||
self._calc_reward(cache_element)
|
||||
self._post_step(cache_element)
|
||||
self._append_cache_element(cache_element)
|
||||
|
||||
self._append_cache_element(None)
|
||||
|
||||
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
|
||||
experiences: List[ExpElement] = []
|
||||
while len(self._transition_cache) > 0 and self._transition_cache[0].tick <= tick_bound:
|
||||
cache_element = self._transition_cache.pop(0)
|
||||
# !: Here the reward calculation method requires the given tick is enough and must be used then.
|
||||
if self._reward_eval_delay is not None:
|
||||
self._calc_reward(cache_element)
|
||||
self._post_step(cache_element)
|
||||
experiences.append(cache_element.make_exp_element())
|
||||
|
||||
self._agent_last_index = {
|
||||
k: v - len(experiences) for k, v in self._agent_last_index.items() if v >= len(experiences)
|
||||
}
|
||||
|
||||
total_experiences += experiences
|
||||
|
||||
return {
|
||||
"end_of_episode": self._end_of_episode,
|
||||
"experiences": [experiences],
|
||||
"experiences": [total_experiences],
|
||||
"info": [deepcopy(self._info)], # TODO: may have overhead issues. Leave to future work.
|
||||
}
|
||||
|
||||
|
@ -514,50 +558,57 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
|
||||
return loaded
|
||||
|
||||
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
|
||||
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict:
|
||||
self._switch_env(self._test_env)
|
||||
self._reset()
|
||||
if policy_state is not None:
|
||||
self.set_policy_state(policy_state)
|
||||
info_list = []
|
||||
|
||||
self._agent_wrapper.exploit()
|
||||
while not self._end_of_episode:
|
||||
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
|
||||
env_action_dict = self._translate_to_env_action(action_dict, self._event)
|
||||
for _ in range(num_episodes):
|
||||
self._reset()
|
||||
if policy_state is not None:
|
||||
self.set_policy_state(policy_state)
|
||||
|
||||
# Store experiences in the cache
|
||||
cache_element = CacheElement(
|
||||
tick=self.env.tick,
|
||||
event=self._event,
|
||||
state=self._state,
|
||||
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
|
||||
action_dict=self._select_trainable_agents(action_dict),
|
||||
env_action_dict=self._select_trainable_agents(env_action_dict),
|
||||
# The following will be generated later
|
||||
reward_dict={},
|
||||
terminal_dict={},
|
||||
next_state=None,
|
||||
next_agent_state_dict={},
|
||||
)
|
||||
self._agent_wrapper.exploit()
|
||||
while not self._end_of_episode:
|
||||
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
|
||||
env_action_dict = self._translate_to_env_action(action_dict, self._event)
|
||||
|
||||
# Update env and get new states (global & agent)
|
||||
self._step(list(env_action_dict.values()))
|
||||
# Store experiences in the cache
|
||||
cache_element = CacheElement(
|
||||
tick=self.env.tick,
|
||||
event=self._event,
|
||||
state=self._state,
|
||||
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
|
||||
action_dict=self._select_trainable_agents(action_dict),
|
||||
env_action_dict=self._select_trainable_agents(env_action_dict),
|
||||
# The following will be generated later
|
||||
reward_dict={},
|
||||
terminal_dict={},
|
||||
next_state=None,
|
||||
next_agent_state_dict={},
|
||||
truncated=False, # No truncation in evaluation
|
||||
)
|
||||
|
||||
if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
|
||||
self._calc_reward(cache_element)
|
||||
self._post_eval_step(cache_element)
|
||||
# Update env and get new states (global & agent)
|
||||
self._step(list(env_action_dict.values()))
|
||||
cache_element.next_state = self._state
|
||||
|
||||
self._append_cache_element(cache_element)
|
||||
self._append_cache_element(None)
|
||||
if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
|
||||
self._calc_reward(cache_element)
|
||||
self._post_eval_step(cache_element)
|
||||
|
||||
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
|
||||
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
|
||||
cache_element = self._trans_cache.pop(0)
|
||||
if self._reward_eval_delay is not None:
|
||||
self._calc_reward(cache_element)
|
||||
self._post_eval_step(cache_element)
|
||||
self._append_cache_element(cache_element)
|
||||
self._append_cache_element(None)
|
||||
|
||||
return {"info": [self._info]}
|
||||
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
|
||||
while len(self._transition_cache) > 0 and self._transition_cache[0].tick <= tick_bound:
|
||||
cache_element = self._transition_cache.pop(0)
|
||||
if self._reward_eval_delay is not None:
|
||||
self._calc_reward(cache_element)
|
||||
self._post_eval_step(cache_element)
|
||||
|
||||
info_list.append(self._info)
|
||||
|
||||
return {"info": info_list}
|
||||
|
||||
@abstractmethod
|
||||
def _post_step(self, cache_element: CacheElement) -> None:
|
||||
|
|
|
@ -59,7 +59,7 @@ class RolloutWorker(AbsWorker):
|
|||
result = (
|
||||
self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"])
|
||||
if req["type"] == "sample"
|
||||
else self._env_sampler.eval(policy_state=req["policy_state"])
|
||||
else self._env_sampler.eval(policy_state=req["policy_state"], num_episodes=req["num_eval_episodes"])
|
||||
)
|
||||
self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]}))
|
||||
else:
|
||||
|
|
|
@ -90,13 +90,19 @@ class ACBasedOps(AbsTrainOps):
|
|||
"""
|
||||
return self._v_critic_net.get_gradients(self._get_critic_loss(batch))
|
||||
|
||||
def update_critic(self, batch: TransitionBatch) -> None:
|
||||
def update_critic(self, batch: TransitionBatch) -> float:
|
||||
"""Update the critic network using a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
loss (float): The detached loss of this batch.
|
||||
"""
|
||||
self._v_critic_net.step(self._get_critic_loss(batch))
|
||||
self._v_critic_net.train()
|
||||
loss = self._get_critic_loss(batch)
|
||||
self._v_critic_net.step(loss)
|
||||
return loss.detach().cpu().numpy().item()
|
||||
|
||||
def update_critic_with_grad(self, grad_dict: dict) -> None:
|
||||
"""Update the critic network with remotely computed gradients.
|
||||
|
@ -148,24 +154,26 @@ class ACBasedOps(AbsTrainOps):
|
|||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
grad (torch.Tensor): The actor gradient of the batch.
|
||||
grad_dict (Dict[str, torch.Tensor]): The actor gradient of the batch.
|
||||
early_stop (bool): Early stop indicator.
|
||||
"""
|
||||
loss, early_stop = self._get_actor_loss(batch)
|
||||
return self._policy.get_gradients(loss), early_stop
|
||||
|
||||
def update_actor(self, batch: TransitionBatch) -> bool:
|
||||
def update_actor(self, batch: TransitionBatch) -> Tuple[float, bool]:
|
||||
"""Update the actor network using a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
loss (float): The detached loss of this batch.
|
||||
early_stop (bool): Early stop indicator.
|
||||
"""
|
||||
self._policy.train()
|
||||
loss, early_stop = self._get_actor_loss(batch)
|
||||
self._policy.train_step(loss)
|
||||
return early_stop
|
||||
return loss.detach().cpu().numpy().item(), early_stop
|
||||
|
||||
def update_actor_with_grad(self, grad_dict_and_early_stop: Tuple[dict, bool]) -> bool:
|
||||
"""Update the actor network with remotely computed gradients.
|
||||
|
@ -202,6 +210,9 @@ class ACBasedOps(AbsTrainOps):
|
|||
# Preprocess advantages
|
||||
states = ndarray_to_tensor(batch.states, device=self._device) # s
|
||||
actions = ndarray_to_tensor(batch.actions, device=self._device) # a
|
||||
terminals = ndarray_to_tensor(batch.terminals, device=self._device)
|
||||
truncated = ndarray_to_tensor(batch.truncated, device=self._device)
|
||||
next_states = ndarray_to_tensor(batch.next_states, device=self._device)
|
||||
if self._is_discrete_action:
|
||||
actions = actions.long()
|
||||
|
||||
|
@ -209,11 +220,34 @@ class ACBasedOps(AbsTrainOps):
|
|||
self._v_critic_net.eval()
|
||||
self._policy.eval()
|
||||
values = self._v_critic_net.v_values(states).detach().cpu().numpy()
|
||||
values = np.concatenate([values, np.zeros(1)])
|
||||
rewards = np.concatenate([batch.rewards, np.zeros(1)])
|
||||
deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s)
|
||||
batch.returns = discount_cumsum(rewards, self._reward_discount)[:-1]
|
||||
batch.advantages = discount_cumsum(deltas, self._reward_discount * self._lam)
|
||||
|
||||
batch.returns = np.zeros(batch.size, dtype=np.float32)
|
||||
batch.advantages = np.zeros(batch.size, dtype=np.float32)
|
||||
i = 0
|
||||
while i < batch.size:
|
||||
j = i
|
||||
while j < batch.size - 1 and not (terminals[j] or truncated[j]):
|
||||
j += 1
|
||||
last_val = (
|
||||
0.0
|
||||
if terminals[j]
|
||||
else self._v_critic_net.v_values(
|
||||
next_states[j].unsqueeze(dim=0),
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
.item()
|
||||
)
|
||||
|
||||
cur_values = np.append(values[i : j + 1], last_val)
|
||||
cur_rewards = np.append(batch.rewards[i : j + 1], last_val)
|
||||
# delta = r + gamma * v(s') - v(s)
|
||||
cur_deltas = cur_rewards[:-1] + self._reward_discount * cur_values[1:] - cur_values[:-1]
|
||||
batch.returns[i : j + 1] = discount_cumsum(cur_rewards, self._reward_discount)[:-1]
|
||||
batch.advantages[i : j + 1] = discount_cumsum(cur_deltas, self._reward_discount * self._lam)
|
||||
|
||||
i = j + 1
|
||||
|
||||
if self._clip_ratio is not None:
|
||||
batch.old_logps = self._policy.get_states_actions_logps(states, actions).detach().cpu().numpy()
|
||||
|
@ -229,7 +263,7 @@ class ACBasedOps(AbsTrainOps):
|
|||
def to_device(self, device: str = None) -> None:
|
||||
self._device = get_torch_device(device)
|
||||
self._policy.to_device(self._device)
|
||||
self._v_critic_net.to(self._device)
|
||||
self._v_critic_net.to_device(self._device)
|
||||
|
||||
|
||||
class ACBasedTrainer(SingleAgentTrainer):
|
||||
|
@ -291,21 +325,25 @@ class ACBasedTrainer(SingleAgentTrainer):
|
|||
assert isinstance(self._ops, ACBasedOps)
|
||||
|
||||
batch = self._get_batch()
|
||||
for _ in range(self._params.grad_iters):
|
||||
self._ops.update_critic(batch)
|
||||
|
||||
for _ in range(self._params.grad_iters):
|
||||
early_stop = self._ops.update_actor(batch)
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
for _ in range(self._params.grad_iters):
|
||||
self._ops.update_critic(batch)
|
||||
|
||||
async def train_step_as_task(self) -> None:
|
||||
assert isinstance(self._ops, RemoteOps)
|
||||
|
||||
batch = self._get_batch()
|
||||
for _ in range(self._params.grad_iters):
|
||||
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
|
||||
|
||||
for _ in range(self._params.grad_iters):
|
||||
if self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch)): # early stop
|
||||
grad_dict, early_stop = await self._ops.get_actor_grad(batch)
|
||||
self._ops.update_actor_with_grad(grad_dict)
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
for _ in range(self._params.grad_iters):
|
||||
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
from typing import Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -27,7 +27,7 @@ class DDPGParams(BaseTrainerParams):
|
|||
random_overwrite (bool, default=False): This specifies overwrite behavior when the replay memory capacity
|
||||
is reached. If True, overwrite positions will be selected randomly. Otherwise, overwrites will occur
|
||||
sequentially with wrap-around.
|
||||
min_num_to_trigger_training (int, default=0): Minimum number required to start training.
|
||||
n_start_train (int, default=0): Minimum number required to start training.
|
||||
"""
|
||||
|
||||
get_q_critic_net_func: Callable[[], QNet]
|
||||
|
@ -36,7 +36,7 @@ class DDPGParams(BaseTrainerParams):
|
|||
q_value_loss_cls: Optional[Callable] = None
|
||||
soft_update_coef: float = 1.0
|
||||
random_overwrite: bool = False
|
||||
min_num_to_trigger_training: int = 0
|
||||
n_start_train: int = 0
|
||||
|
||||
|
||||
class DDPGOps(AbsTrainOps):
|
||||
|
@ -93,9 +93,9 @@ class DDPGOps(AbsTrainOps):
|
|||
states=next_states, # s'
|
||||
actions=self._target_policy.get_actions_tensor(next_states), # miu_targ(s')
|
||||
) # Q_targ(s', miu_targ(s'))
|
||||
# y(r, s', d) = r + gamma * (1 - d) * Q_targ(s', miu_targ(s'))
|
||||
target_q_values = (rewards + self._reward_discount * (1.0 - terminals.float()) * next_q_values).detach()
|
||||
|
||||
# y(r, s', d) = r + gamma * (1 - d) * Q_targ(s', miu_targ(s'))
|
||||
target_q_values = (rewards + self._reward_discount * (1 - terminals.long()) * next_q_values).detach()
|
||||
q_values = self._q_critic_net.q_values(states=states, actions=actions) # Q(s, a)
|
||||
return self._q_value_loss_func(q_values, target_q_values) # MSE(Q(s, a), y(r, s', d))
|
||||
|
||||
|
@ -120,16 +120,21 @@ class DDPGOps(AbsTrainOps):
|
|||
self._q_critic_net.train()
|
||||
self._q_critic_net.apply_gradients(grad_dict)
|
||||
|
||||
def update_critic(self, batch: TransitionBatch) -> None:
|
||||
def update_critic(self, batch: TransitionBatch) -> float:
|
||||
"""Update the critic network using a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
loss (float): The detached loss of this batch.
|
||||
"""
|
||||
self._q_critic_net.train()
|
||||
self._q_critic_net.step(self._get_critic_loss(batch))
|
||||
loss = self._get_critic_loss(batch)
|
||||
self._q_critic_net.step(loss)
|
||||
return loss.detach().cpu().numpy().item()
|
||||
|
||||
def _get_actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
|
||||
def _get_actor_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, bool]:
|
||||
"""Compute the actor loss of the batch.
|
||||
|
||||
Args:
|
||||
|
@ -137,6 +142,7 @@ class DDPGOps(AbsTrainOps):
|
|||
|
||||
Returns:
|
||||
loss (torch.Tensor): The actor loss of the batch.
|
||||
early_stop (bool): The early stop indicator, set to False in current implementation.
|
||||
"""
|
||||
assert isinstance(batch, TransitionBatch)
|
||||
self._policy.train()
|
||||
|
@ -147,19 +153,23 @@ class DDPGOps(AbsTrainOps):
|
|||
actions=self._policy.get_actions_tensor(states), # miu(s)
|
||||
).mean() # -Q(s, miu(s))
|
||||
|
||||
return policy_loss
|
||||
early_stop = False
|
||||
|
||||
return policy_loss, early_stop
|
||||
|
||||
@remote
|
||||
def get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
|
||||
def get_actor_grad(self, batch: TransitionBatch) -> Tuple[Dict[str, torch.Tensor], bool]:
|
||||
"""Compute the actor network's gradients of a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
grad (torch.Tensor): The actor gradient of the batch.
|
||||
grad_dict (Dict[str, torch.Tensor]): The actor gradient of the batch.
|
||||
early_stop (bool): Early stop indicator.
|
||||
"""
|
||||
return self._policy.get_gradients(self._get_actor_loss(batch))
|
||||
loss, early_stop = self._get_actor_loss(batch)
|
||||
return self._policy.get_gradients(loss), early_stop
|
||||
|
||||
def update_actor_with_grad(self, grad_dict: dict) -> None:
|
||||
"""Update the actor network with remotely computed gradients.
|
||||
|
@ -170,14 +180,20 @@ class DDPGOps(AbsTrainOps):
|
|||
self._policy.train()
|
||||
self._policy.apply_gradients(grad_dict)
|
||||
|
||||
def update_actor(self, batch: TransitionBatch) -> None:
|
||||
def update_actor(self, batch: TransitionBatch) -> Tuple[float, bool]:
|
||||
"""Update the actor network using a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
loss (float): The detached loss of this batch.
|
||||
early_stop (bool): Early stop indicator.
|
||||
"""
|
||||
self._policy.train()
|
||||
self._policy.train_step(self._get_actor_loss(batch))
|
||||
loss, early_stop = self._get_actor_loss(batch)
|
||||
self._policy.train_step(loss)
|
||||
return loss.detach().cpu().numpy().item(), early_stop
|
||||
|
||||
def get_non_policy_state(self) -> dict:
|
||||
return {
|
||||
|
@ -200,8 +216,8 @@ class DDPGOps(AbsTrainOps):
|
|||
self._device = get_torch_device(device=device)
|
||||
self._policy.to_device(self._device)
|
||||
self._target_policy.to_device(self._device)
|
||||
self._q_critic_net.to(self._device)
|
||||
self._target_q_critic_net.to(self._device)
|
||||
self._q_critic_net.to_device(self._device)
|
||||
self._target_q_critic_net.to_device(self._device)
|
||||
|
||||
|
||||
class DDPGTrainer(SingleAgentTrainer):
|
||||
|
@ -263,10 +279,10 @@ class DDPGTrainer(SingleAgentTrainer):
|
|||
def train_step(self) -> None:
|
||||
assert isinstance(self._ops, DDPGOps)
|
||||
|
||||
if self._replay_memory.n_sample < self._params.min_num_to_trigger_training:
|
||||
if self._replay_memory.n_sample < self._params.n_start_train:
|
||||
print(
|
||||
f"Skip this training step due to lack of experiences "
|
||||
f"(current = {self._replay_memory.n_sample}, minimum = {self._params.min_num_to_trigger_training})",
|
||||
f"(current = {self._replay_memory.n_sample}, minimum = {self._params.n_start_train})",
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -280,19 +296,21 @@ class DDPGTrainer(SingleAgentTrainer):
|
|||
async def train_step_as_task(self) -> None:
|
||||
assert isinstance(self._ops, RemoteOps)
|
||||
|
||||
if self._replay_memory.n_sample < self._params.min_num_to_trigger_training:
|
||||
if self._replay_memory.n_sample < self._params.n_start_train:
|
||||
print(
|
||||
f"Skip this training step due to lack of experiences "
|
||||
f"(current = {self._replay_memory.n_sample}, minimum = {self._params.min_num_to_trigger_training})",
|
||||
f"(current = {self._replay_memory.n_sample}, minimum = {self._params.n_start_train})",
|
||||
)
|
||||
return
|
||||
|
||||
for _ in range(self._params.num_epochs):
|
||||
batch = self._get_batch()
|
||||
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
|
||||
self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch))
|
||||
|
||||
grad_dict, early_stop = await self._ops.get_actor_grad(batch)
|
||||
self._ops.update_actor_with_grad(grad_dict)
|
||||
self._try_soft_update_target()
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
def _try_soft_update_target(self) -> None:
|
||||
"""Soft update the target policy and target critic."""
|
||||
|
|
|
@ -161,15 +161,20 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
"""
|
||||
return self._q_critic_net.get_gradients(self._get_critic_loss(batch, next_actions))
|
||||
|
||||
def update_critic(self, batch: MultiTransitionBatch, next_actions: List[torch.Tensor]) -> None:
|
||||
def update_critic(self, batch: MultiTransitionBatch, next_actions: List[torch.Tensor]) -> float:
|
||||
"""Update the critic network using a batch.
|
||||
|
||||
Args:
|
||||
batch (MultiTransitionBatch): Batch.
|
||||
next_actions (List[torch.Tensor]): List of next actions of all policies.
|
||||
|
||||
Returns:
|
||||
loss (float): The detached loss of this batch.
|
||||
"""
|
||||
self._q_critic_net.train()
|
||||
self._q_critic_net.step(self._get_critic_loss(batch, next_actions))
|
||||
loss = self._get_critic_loss(batch, next_actions)
|
||||
self._q_critic_net.step(loss)
|
||||
return loss.detach().cpu().numpy().item()
|
||||
|
||||
def update_critic_with_grad(self, grad_dict: dict) -> None:
|
||||
"""Update the critic network with remotely computed gradients.
|
||||
|
@ -180,7 +185,7 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
self._q_critic_net.train()
|
||||
self._q_critic_net.apply_gradients(grad_dict)
|
||||
|
||||
def _get_actor_loss(self, batch: MultiTransitionBatch) -> torch.Tensor:
|
||||
def _get_actor_loss(self, batch: MultiTransitionBatch) -> Tuple[torch.Tensor, bool]:
|
||||
"""Compute the actor loss of the batch.
|
||||
|
||||
Args:
|
||||
|
@ -188,11 +193,13 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
|
||||
Returns:
|
||||
loss (torch.Tensor): The actor loss of the batch.
|
||||
early_stop (bool): The early stop indicator, set to False in current implementation.
|
||||
"""
|
||||
latest_action, latest_action_logp = self.get_latest_action(batch)
|
||||
states = ndarray_to_tensor(batch.states, device=self._device) # x
|
||||
actions = [ndarray_to_tensor(action, device=self._device) for action in batch.actions] # a
|
||||
actions[self._policy_idx] = latest_action
|
||||
|
||||
self._policy.train()
|
||||
self._q_critic_net.freeze()
|
||||
actor_loss = -(
|
||||
|
@ -203,28 +210,39 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
* latest_action_logp
|
||||
).mean() # Q(x, a^j_1, ..., a_i, ..., a^j_N)
|
||||
self._q_critic_net.unfreeze()
|
||||
return actor_loss
|
||||
|
||||
early_stop = False
|
||||
|
||||
return actor_loss, early_stop
|
||||
|
||||
@remote
|
||||
def get_actor_grad(self, batch: MultiTransitionBatch) -> Dict[str, torch.Tensor]:
|
||||
def get_actor_grad(self, batch: MultiTransitionBatch) -> Tuple[Dict[str, torch.Tensor], bool]:
|
||||
"""Compute the actor network's gradients of a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
grad_dict (Dict[str, torch.Tensor]): The actor gradient of the batch.
|
||||
early_stop (bool): Early stop indicator.
|
||||
"""
|
||||
loss, early_stop = self._get_actor_loss(batch)
|
||||
return self._policy.get_gradients(loss), early_stop
|
||||
|
||||
def update_actor(self, batch: MultiTransitionBatch) -> Tuple[float, bool]:
|
||||
"""Update the actor network using a batch.
|
||||
|
||||
Args:
|
||||
batch (MultiTransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
grad (torch.Tensor): The actor gradient of the batch.
|
||||
"""
|
||||
return self._policy.get_gradients(self._get_actor_loss(batch))
|
||||
|
||||
def update_actor(self, batch: MultiTransitionBatch) -> None:
|
||||
"""Update the actor network using a batch.
|
||||
|
||||
Args:
|
||||
batch (MultiTransitionBatch): Batch.
|
||||
loss (float): The detached loss of this batch.
|
||||
early_stop (bool): Early stop indicator.
|
||||
"""
|
||||
self._policy.train()
|
||||
self._policy.train_step(self._get_actor_loss(batch))
|
||||
loss, early_stop = self._get_actor_loss(batch)
|
||||
self._policy.train_step(loss)
|
||||
return loss.detach().cpu().numpy().item(), early_stop
|
||||
|
||||
def update_actor_with_grad(self, grad_dict: dict) -> None:
|
||||
"""Update the critic network with remotely computed gradients.
|
||||
|
@ -275,8 +293,8 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
self._policy.to_device(self._device)
|
||||
self._target_policy.to_device(self._device)
|
||||
|
||||
self._q_critic_net.to(self._device)
|
||||
self._target_q_critic_net.to(self._device)
|
||||
self._q_critic_net.to_device(self._device)
|
||||
self._target_q_critic_net.to_device(self._device)
|
||||
|
||||
|
||||
class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
||||
|
@ -378,6 +396,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
agent_states=agent_states,
|
||||
next_agent_states=next_agent_states,
|
||||
terminals=np.array(terminal_flags),
|
||||
truncated=np.array([exp_element.truncated for exp_element in exp_elements]),
|
||||
)
|
||||
self._replay_memory.put(transition_batch)
|
||||
|
||||
|
@ -459,7 +478,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
ops.update_critic_with_grad(critic_grad)
|
||||
|
||||
# Update actors
|
||||
actor_grad_list = await asyncio.gather(*[ops.get_actor_grad(batch) for ops in self._actor_ops_list])
|
||||
actor_grad_list = await asyncio.gather(*[ops.get_actor_grad(batch)[0] for ops in self._actor_ops_list])
|
||||
for ops, actor_grad in zip(self._actor_ops_list, actor_grad_list):
|
||||
ops.update_actor_with_grad(actor_grad)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ class SoftActorCriticParams(BaseTrainerParams):
|
|||
num_epochs: int = 1
|
||||
n_start_train: int = 0
|
||||
q_value_loss_cls: Optional[Callable] = None
|
||||
soft_update_coef: float = 1.0
|
||||
soft_update_coef: float = 0.05
|
||||
|
||||
|
||||
class SoftActorCriticOps(AbsTrainOps):
|
||||
|
@ -58,6 +58,7 @@ class SoftActorCriticOps(AbsTrainOps):
|
|||
|
||||
def _get_critic_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._q_net1.train()
|
||||
self._q_net2.train()
|
||||
states = ndarray_to_tensor(batch.states, device=self._device) # s
|
||||
next_states = ndarray_to_tensor(batch.next_states, device=self._device) # s'
|
||||
actions = ndarray_to_tensor(batch.actions, device=self._device) # a
|
||||
|
@ -67,11 +68,13 @@ class SoftActorCriticOps(AbsTrainOps):
|
|||
assert isinstance(self._policy, ContinuousRLPolicy)
|
||||
|
||||
with torch.no_grad():
|
||||
next_actions, next_logps = self._policy.get_actions_with_logps(states)
|
||||
q1 = self._target_q_net1.q_values(next_states, next_actions)
|
||||
q2 = self._target_q_net2.q_values(next_states, next_actions)
|
||||
q = torch.min(q1, q2)
|
||||
y = rewards + self._reward_discount * (1.0 - terminals.float()) * (q - self._entropy_coef * next_logps)
|
||||
next_actions, next_logps = self._policy.get_actions_with_logps(next_states)
|
||||
target_q1 = self._target_q_net1.q_values(next_states, next_actions)
|
||||
target_q2 = self._target_q_net2.q_values(next_states, next_actions)
|
||||
target_q = torch.min(target_q1, target_q2)
|
||||
y = rewards + self._reward_discount * (1.0 - terminals.float()) * (
|
||||
target_q - self._entropy_coef * next_logps
|
||||
)
|
||||
|
||||
q1 = self._q_net1.q_values(states, actions)
|
||||
q2 = self._q_net2.q_values(states, actions)
|
||||
|
@ -92,14 +95,36 @@ class SoftActorCriticOps(AbsTrainOps):
|
|||
self._q_net1.apply_gradients(grad_dicts[0])
|
||||
self._q_net2.apply_gradients(grad_dicts[1])
|
||||
|
||||
def update_critic(self, batch: TransitionBatch) -> None:
|
||||
def update_critic(self, batch: TransitionBatch) -> Tuple[float, float]:
|
||||
"""Update the critic network using a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
loss_q1 (float): The detached q_net1 loss of this batch.
|
||||
loss_q2 (float): The detached q_net2 loss of this batch.
|
||||
"""
|
||||
self._q_net1.train()
|
||||
self._q_net2.train()
|
||||
loss_q1, loss_q2 = self._get_critic_loss(batch)
|
||||
self._q_net1.step(loss_q1)
|
||||
self._q_net2.step(loss_q2)
|
||||
return loss_q1.detach().cpu().numpy().item(), loss_q2.detach().cpu().numpy().item()
|
||||
|
||||
def _get_actor_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, bool]:
|
||||
"""Compute the actor loss of the batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor): The actor loss of the batch.
|
||||
early_stop (bool): The early stop indicator, set to False in current implementation.
|
||||
"""
|
||||
self._q_net1.freeze()
|
||||
self._q_net2.freeze()
|
||||
|
||||
def _get_actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
|
||||
self._policy.train()
|
||||
states = ndarray_to_tensor(batch.states, device=self._device) # s
|
||||
actions, logps = self._policy.get_actions_with_logps(states)
|
||||
|
@ -108,19 +133,46 @@ class SoftActorCriticOps(AbsTrainOps):
|
|||
q = torch.min(q1, q2)
|
||||
|
||||
loss = (self._entropy_coef * logps - q).mean()
|
||||
return loss
|
||||
|
||||
self._q_net1.unfreeze()
|
||||
self._q_net2.unfreeze()
|
||||
|
||||
early_stop = False
|
||||
|
||||
return loss, early_stop
|
||||
|
||||
@remote
|
||||
def get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
|
||||
return self._policy.get_gradients(self._get_actor_loss(batch))
|
||||
def get_actor_grad(self, batch: TransitionBatch) -> Tuple[Dict[str, torch.Tensor], bool]:
|
||||
"""Compute the actor network's gradients of a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
grad_dict (Dict[str, torch.Tensor]): The actor gradient of the batch.
|
||||
early_stop (bool): Early stop indicator.
|
||||
"""
|
||||
loss, early_stop = self._get_actor_loss(batch)
|
||||
return self._policy.get_gradients(loss), early_stop
|
||||
|
||||
def update_actor_with_grad(self, grad_dict: dict) -> None:
|
||||
self._policy.train()
|
||||
self._policy.apply_gradients(grad_dict)
|
||||
|
||||
def update_actor(self, batch: TransitionBatch) -> None:
|
||||
def update_actor(self, batch: TransitionBatch) -> Tuple[float, bool]:
|
||||
"""Update the actor network using a batch.
|
||||
|
||||
Args:
|
||||
batch (TransitionBatch): Batch.
|
||||
|
||||
Returns:
|
||||
loss (float): The detached loss of this batch.
|
||||
early_stop (bool): Early stop indicator.
|
||||
"""
|
||||
self._policy.train()
|
||||
self._policy.train_step(self._get_actor_loss(batch))
|
||||
loss, early_stop = self._get_actor_loss(batch)
|
||||
self._policy.train_step(loss)
|
||||
return loss.detach().cpu().numpy().item(), early_stop
|
||||
|
||||
def get_non_policy_state(self) -> dict:
|
||||
return {
|
||||
|
@ -142,10 +194,13 @@ class SoftActorCriticOps(AbsTrainOps):
|
|||
|
||||
def to_device(self, device: str = None) -> None:
|
||||
self._device = get_torch_device(device=device)
|
||||
self._q_net1.to(self._device)
|
||||
self._q_net2.to(self._device)
|
||||
self._target_q_net1.to(self._device)
|
||||
self._target_q_net2.to(self._device)
|
||||
|
||||
self._policy.to_device(self._device)
|
||||
|
||||
self._q_net1.to_device(self._device)
|
||||
self._q_net2.to_device(self._device)
|
||||
self._target_q_net1.to_device(self._device)
|
||||
self._target_q_net2.to_device(self._device)
|
||||
|
||||
|
||||
class SoftActorCriticTrainer(SingleAgentTrainer):
|
||||
|
@ -211,9 +266,11 @@ class SoftActorCriticTrainer(SingleAgentTrainer):
|
|||
for _ in range(self._params.num_epochs):
|
||||
batch = self._get_batch()
|
||||
self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch))
|
||||
self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch))
|
||||
|
||||
grad_dict, early_stop = await self._ops.get_actor_grad(batch)
|
||||
self._ops.update_actor_with_grad(grad_dict)
|
||||
self._try_soft_update_target()
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
|
||||
return transition_batch
|
||||
|
|
|
@ -35,29 +35,18 @@ class AbsIndexScheduler(object, metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
|
||||
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
|
||||
"""Generate a list of indexes that can be used to retrieve items from the replay memory.
|
||||
|
||||
Args:
|
||||
batch_size (int, default=None): The required batch size. If it is None, all indexes where an experience
|
||||
item is present are returned.
|
||||
forbid_last (bool, default=False): Whether the latest element is allowed to be sampled.
|
||||
If this is true, the last index will always be excluded from the result.
|
||||
|
||||
Returns:
|
||||
indexes (np.ndarray): The list of indexes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_last_index(self) -> int:
|
||||
"""Get the index of the latest element in the memory.
|
||||
|
||||
Returns:
|
||||
index (int): The index of the latest element in the memory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomIndexScheduler(AbsIndexScheduler):
|
||||
"""Index scheduler that returns random indexes when sampling.
|
||||
|
@ -93,14 +82,11 @@ class RandomIndexScheduler(AbsIndexScheduler):
|
|||
self._size = min(self._size + batch_size, self._capacity)
|
||||
return indexes
|
||||
|
||||
def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
|
||||
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
|
||||
assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}"
|
||||
assert self._size > 0, "Cannot sample from an empty memory."
|
||||
return np.random.choice(self._size, size=batch_size, replace=True)
|
||||
|
||||
def get_last_index(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FIFOIndexScheduler(AbsIndexScheduler):
|
||||
"""First-in-first-out index scheduler.
|
||||
|
@ -135,19 +121,15 @@ class FIFOIndexScheduler(AbsIndexScheduler):
|
|||
self._head = (self._head + overwrite) % self._capacity
|
||||
return self.get_put_indexes(batch_size)
|
||||
|
||||
def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
|
||||
tmp = self._tail if not forbid_last else (self._tail - 1) % self._capacity
|
||||
def get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
|
||||
indexes = (
|
||||
np.arange(self._head, tmp)
|
||||
if tmp > self._head
|
||||
else np.concatenate([np.arange(self._head, self._capacity), np.arange(tmp)])
|
||||
np.arange(self._head, self._tail)
|
||||
if self._tail > self._head
|
||||
else np.concatenate([np.arange(self._head, self._capacity), np.arange(self._tail)])
|
||||
)
|
||||
self._head = tmp
|
||||
self._head = self._tail
|
||||
return indexes
|
||||
|
||||
def get_last_index(self) -> int:
|
||||
return (self._tail - 1) % self._capacity
|
||||
|
||||
|
||||
class AbsReplayMemory(object, metaclass=ABCMeta):
|
||||
"""Abstract replay memory class with basic interfaces.
|
||||
|
@ -176,9 +158,9 @@ class AbsReplayMemory(object, metaclass=ABCMeta):
|
|||
"""Please refer to the doc string in AbsIndexScheduler."""
|
||||
return self._idx_scheduler.get_put_indexes(batch_size)
|
||||
|
||||
def _get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray:
|
||||
def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray:
|
||||
"""Please refer to the doc string in AbsIndexScheduler."""
|
||||
return self._idx_scheduler.get_sample_indexes(batch_size, forbid_last)
|
||||
return self._idx_scheduler.get_sample_indexes(batch_size)
|
||||
|
||||
|
||||
class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
||||
|
@ -204,7 +186,8 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
self._states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
|
||||
self._actions = np.zeros((self._capacity, self._action_dim), dtype=np.float32)
|
||||
self._rewards = np.zeros(self._capacity, dtype=np.float32)
|
||||
self._terminals = np.zeros(self._capacity, dtype=np.bool)
|
||||
self._terminals = np.zeros(self._capacity, dtype=bool)
|
||||
self._truncated = np.zeros(self._capacity, dtype=bool)
|
||||
self._next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
|
||||
self._returns = np.zeros(self._capacity, dtype=np.float32)
|
||||
self._advantages = np.zeros(self._capacity, dtype=np.float32)
|
||||
|
@ -233,6 +216,7 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
assert match_shape(transition_batch.actions, (batch_size, self._action_dim))
|
||||
assert match_shape(transition_batch.rewards, (batch_size,))
|
||||
assert match_shape(transition_batch.terminals, (batch_size,))
|
||||
assert match_shape(transition_batch.truncated, (batch_size,))
|
||||
assert match_shape(transition_batch.next_states, (batch_size, self._state_dim))
|
||||
if transition_batch.returns is not None:
|
||||
match_shape(transition_batch.returns, (batch_size,))
|
||||
|
@ -255,6 +239,7 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
self._actions[indexes] = transition_batch.actions
|
||||
self._rewards[indexes] = transition_batch.rewards
|
||||
self._terminals[indexes] = transition_batch.terminals
|
||||
self._truncated[indexes] = transition_batch.truncated
|
||||
self._next_states[indexes] = transition_batch.next_states
|
||||
if transition_batch.returns is not None:
|
||||
self._returns[indexes] = transition_batch.returns
|
||||
|
@ -273,7 +258,7 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
Returns:
|
||||
batch (TransitionBatch): The sampled batch.
|
||||
"""
|
||||
indexes = self._get_sample_indexes(batch_size, self._get_forbid_last())
|
||||
indexes = self._get_sample_indexes(batch_size)
|
||||
return self.sample_by_indexes(indexes)
|
||||
|
||||
def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
|
||||
|
@ -292,16 +277,13 @@ class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
actions=self._actions[indexes],
|
||||
rewards=self._rewards[indexes],
|
||||
terminals=self._terminals[indexes],
|
||||
truncated=self._truncated[indexes],
|
||||
next_states=self._next_states[indexes],
|
||||
returns=self._returns[indexes],
|
||||
advantages=self._advantages[indexes],
|
||||
old_logps=self._old_logps[indexes],
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_forbid_last(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomReplayMemory(ReplayMemory):
|
||||
def __init__(
|
||||
|
@ -318,15 +300,11 @@ class RandomReplayMemory(ReplayMemory):
|
|||
RandomIndexScheduler(capacity, random_overwrite),
|
||||
)
|
||||
self._random_overwrite = random_overwrite
|
||||
self._scheduler = RandomIndexScheduler(capacity, random_overwrite)
|
||||
|
||||
@property
|
||||
def random_overwrite(self) -> bool:
|
||||
return self._random_overwrite
|
||||
|
||||
def _get_forbid_last(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FIFOReplayMemory(ReplayMemory):
|
||||
def __init__(
|
||||
|
@ -342,9 +320,6 @@ class FIFOReplayMemory(ReplayMemory):
|
|||
FIFOIndexScheduler(capacity),
|
||||
)
|
||||
|
||||
def _get_forbid_last(self) -> bool:
|
||||
return not self._terminals[self._idx_scheduler.get_last_index()]
|
||||
|
||||
|
||||
class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
||||
"""In-memory experience storage facility for a multi trainer.
|
||||
|
@ -373,7 +348,8 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
self._actions = [np.zeros((self._capacity, action_dim), dtype=np.float32) for action_dim in self._action_dims]
|
||||
self._rewards = [np.zeros(self._capacity, dtype=np.float32) for _ in range(self.agent_num)]
|
||||
self._next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
|
||||
self._terminals = np.zeros(self._capacity, dtype=np.bool)
|
||||
self._terminals = np.zeros(self._capacity, dtype=bool)
|
||||
self._truncated = np.zeros(self._capacity, dtype=bool)
|
||||
|
||||
assert len(agent_states_dims) == self.agent_num
|
||||
self._agent_states_dims = agent_states_dims
|
||||
|
@ -408,6 +384,7 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
assert match_shape(transition_batch.rewards[i], (batch_size,))
|
||||
|
||||
assert match_shape(transition_batch.terminals, (batch_size,))
|
||||
assert match_shape(transition_batch.truncated, (batch_size,))
|
||||
assert match_shape(transition_batch.next_states, (batch_size, self._state_dim))
|
||||
|
||||
assert len(transition_batch.agent_states) == self.agent_num
|
||||
|
@ -430,6 +407,7 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
self._actions[i][indexes] = transition_batch.actions[i]
|
||||
self._rewards[i][indexes] = transition_batch.rewards[i]
|
||||
self._terminals[indexes] = transition_batch.terminals
|
||||
self._truncated[indexes] = transition_batch.truncated
|
||||
|
||||
self._next_states[indexes] = transition_batch.next_states
|
||||
for i in range(self.agent_num):
|
||||
|
@ -446,7 +424,7 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
Returns:
|
||||
batch (MultiTransitionBatch): The sampled batch.
|
||||
"""
|
||||
indexes = self._get_sample_indexes(batch_size, self._get_forbid_last())
|
||||
indexes = self._get_sample_indexes(batch_size)
|
||||
return self.sample_by_indexes(indexes)
|
||||
|
||||
def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:
|
||||
|
@ -465,15 +443,12 @@ class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta):
|
|||
actions=[action[indexes] for action in self._actions],
|
||||
rewards=[reward[indexes] for reward in self._rewards],
|
||||
terminals=self._terminals[indexes],
|
||||
truncated=self._truncated[indexes],
|
||||
next_states=self._next_states[indexes],
|
||||
agent_states=[state[indexes] for state in self._agent_states],
|
||||
next_agent_states=[state[indexes] for state in self._next_agent_states],
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_forbid_last(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomMultiReplayMemory(MultiReplayMemory):
|
||||
def __init__(
|
||||
|
@ -492,15 +467,11 @@ class RandomMultiReplayMemory(MultiReplayMemory):
|
|||
agent_states_dims,
|
||||
)
|
||||
self._random_overwrite = random_overwrite
|
||||
self._scheduler = RandomIndexScheduler(capacity, random_overwrite)
|
||||
|
||||
@property
|
||||
def random_overwrite(self) -> bool:
|
||||
return self._random_overwrite
|
||||
|
||||
def _get_forbid_last(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FIFOMultiReplayMemory(MultiReplayMemory):
|
||||
def __init__(
|
||||
|
@ -517,6 +488,3 @@ class FIFOMultiReplayMemory(MultiReplayMemory):
|
|||
FIFOIndexScheduler(capacity),
|
||||
agent_states_dims,
|
||||
)
|
||||
|
||||
def _get_forbid_last(self) -> bool:
|
||||
return not self._terminals[self._idx_scheduler.get_last_index()]
|
||||
|
|
|
@ -254,6 +254,7 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
|||
exp_element.action_dict[agent_name],
|
||||
exp_element.reward_dict[agent_name],
|
||||
exp_element.terminal_dict[agent_name],
|
||||
exp_element.truncated,
|
||||
exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]),
|
||||
),
|
||||
)
|
||||
|
@ -264,7 +265,8 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
|||
actions=np.vstack([exp[1] for exp in exps]),
|
||||
rewards=np.array([exp[2] for exp in exps]),
|
||||
terminals=np.array([exp[3] for exp in exps]),
|
||||
next_states=np.vstack([exp[4] for exp in exps]),
|
||||
truncated=np.array([exp[4] for exp in exps]),
|
||||
next_states=np.vstack([exp[5] for exp in exps]),
|
||||
)
|
||||
transition_batch = self._preprocess_batch(transition_batch)
|
||||
self.replay_memory.put(transition_batch)
|
||||
|
|
|
@ -19,6 +19,7 @@ class TransitionBatch:
|
|||
rewards: np.ndarray # 1D
|
||||
next_states: np.ndarray # 2D
|
||||
terminals: np.ndarray # 1D
|
||||
truncated: np.ndarray # 1D
|
||||
returns: np.ndarray = None # 1D
|
||||
advantages: np.ndarray = None # 1D
|
||||
old_logps: np.ndarray = None # 1D
|
||||
|
@ -34,6 +35,7 @@ class TransitionBatch:
|
|||
assert len(self.rewards.shape) == 1 and self.rewards.shape[0] == self.states.shape[0]
|
||||
assert self.next_states.shape == self.states.shape
|
||||
assert len(self.terminals.shape) == 1 and self.terminals.shape[0] == self.states.shape[0]
|
||||
assert len(self.truncated.shape) == 1 and self.truncated.shape[0] == self.states.shape[0]
|
||||
|
||||
def make_kth_sub_batch(self, i: int, k: int) -> TransitionBatch:
|
||||
return TransitionBatch(
|
||||
|
@ -42,6 +44,7 @@ class TransitionBatch:
|
|||
rewards=self.rewards[i::k],
|
||||
next_states=self.next_states[i::k],
|
||||
terminals=self.terminals[i::k],
|
||||
truncated=self.truncated[i::k],
|
||||
returns=self.returns[i::k] if self.returns is not None else None,
|
||||
advantages=self.advantages[i::k] if self.advantages is not None else None,
|
||||
old_logps=self.old_logps[i::k] if self.old_logps is not None else None,
|
||||
|
@ -60,7 +63,7 @@ class MultiTransitionBatch:
|
|||
agent_states: List[np.ndarray] # List of 2D
|
||||
next_agent_states: List[np.ndarray] # List of 2D
|
||||
terminals: np.ndarray # 1D
|
||||
|
||||
truncated: np.ndarray # 1D
|
||||
returns: Optional[List[np.ndarray]] = None # List of 1D
|
||||
advantages: Optional[List[np.ndarray]] = None # List of 1D
|
||||
|
||||
|
@ -81,6 +84,7 @@ class MultiTransitionBatch:
|
|||
assert self.agent_states[i].shape[0] == self.states.shape[0]
|
||||
|
||||
assert len(self.terminals.shape) == 1 and self.terminals.shape[0] == self.states.shape[0]
|
||||
assert len(self.truncated.shape) == 1 and self.truncated.shape[0] == self.states.shape[0]
|
||||
assert self.next_states.shape == self.states.shape
|
||||
|
||||
assert len(self.next_agent_states) == len(self.agent_states)
|
||||
|
@ -98,6 +102,7 @@ class MultiTransitionBatch:
|
|||
agent_states = [state[i::k] for state in self.agent_states]
|
||||
next_agent_states = [state[i::k] for state in self.next_agent_states]
|
||||
terminals = self.terminals[i::k]
|
||||
truncated = self.truncated[i::k]
|
||||
returns = None if self.returns is None else [r[i::k] for r in self.returns]
|
||||
advantages = None if self.advantages is None else [advantage[i::k] for advantage in self.advantages]
|
||||
return MultiTransitionBatch(
|
||||
|
@ -108,6 +113,7 @@ class MultiTransitionBatch:
|
|||
agent_states,
|
||||
next_agent_states,
|
||||
terminals,
|
||||
truncated,
|
||||
returns,
|
||||
advantages,
|
||||
)
|
||||
|
@ -123,6 +129,7 @@ def merge_transition_batches(batch_list: List[TransitionBatch]) -> TransitionBat
|
|||
rewards=np.concatenate([batch.rewards for batch in batch_list], axis=0),
|
||||
next_states=np.concatenate([batch.next_states for batch in batch_list], axis=0),
|
||||
terminals=np.concatenate([batch.terminals for batch in batch_list]),
|
||||
truncated=np.concatenate([batch.truncated for batch in batch_list]),
|
||||
returns=np.concatenate([batch.returns for batch in batch_list]),
|
||||
advantages=np.concatenate([batch.advantages for batch in batch_list]),
|
||||
old_logps=None
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
import typing
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from maro.rl.rollout import AbsEnvSampler, BatchEnvSampler
|
||||
from maro.rl.training import TrainingManager
|
||||
from maro.utils import LoggerV2
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from maro.rl.workflows.main import TrainingWorkflow
|
||||
|
||||
EnvSampler = Union[AbsEnvSampler, BatchEnvSampler]
|
||||
|
||||
|
||||
class Callback(object):
|
||||
def __init__(self) -> None:
|
||||
self.workflow: Optional[TrainingWorkflow] = None
|
||||
self.env_sampler: Optional[EnvSampler] = None
|
||||
self.training_manager: Optional[TrainingManager] = None
|
||||
self.logger: Optional[LoggerV2] = None
|
||||
|
||||
def on_episode_start(self, ep: int) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_end(self, ep: int) -> None:
|
||||
pass
|
||||
|
||||
def on_training_start(self, ep: int) -> None:
|
||||
pass
|
||||
|
||||
def on_training_end(self, ep: int) -> None:
|
||||
pass
|
||||
|
||||
def on_validation_start(self, ep: int) -> None:
|
||||
pass
|
||||
|
||||
def on_validation_end(self, ep: int) -> None:
|
||||
pass
|
||||
|
||||
def on_test_start(self, ep: int) -> None:
|
||||
pass
|
||||
|
||||
def on_test_end(self, ep: int) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
def __init__(self, patience: int) -> None:
|
||||
super(EarlyStopping, self).__init__()
|
||||
|
||||
self._patience = patience
|
||||
self._best_ep: int = -1
|
||||
self._best: float = float("-inf")
|
||||
|
||||
def on_validation_end(self, ep: int) -> None:
|
||||
cur = self.env_sampler.monitor_metrics()
|
||||
if cur > self._best:
|
||||
self._best_ep = ep
|
||||
self._best = cur
|
||||
self.logger.info(f"Current metric: {cur} @ ep {ep}. Best metric: {self._best} @ ep {self._best_ep}")
|
||||
|
||||
if ep - self._best_ep > self._patience:
|
||||
self.workflow.early_stop = True
|
||||
self.logger.info(
|
||||
f"Validation metric has not been updated for {ep - self._best_ep} "
|
||||
f"epochs (patience = {self._patience} epochs). Early stop.",
|
||||
)
|
||||
|
||||
|
||||
class Checkpoint(Callback):
|
||||
def __init__(self, path: str, interval: int) -> None:
|
||||
super(Checkpoint, self).__init__()
|
||||
|
||||
self._path = path
|
||||
self._interval = interval
|
||||
|
||||
def on_training_end(self, ep: int) -> None:
|
||||
if ep % self._interval == 0:
|
||||
self.training_manager.save(os.path.join(self._path, str(ep)))
|
||||
self.logger.info(f"[Episode {ep}] All trainer states saved under {self._path}")
|
||||
|
||||
|
||||
class MetricsRecorder(Callback):
|
||||
def __init__(self, path: str) -> None:
|
||||
super(MetricsRecorder, self).__init__()
|
||||
|
||||
self._full_metrics: Dict[int, dict] = {}
|
||||
self._valid_metrics: Dict[int, dict] = {}
|
||||
self._path = path
|
||||
|
||||
def _dump_metric_history(self) -> None:
|
||||
if len(self._full_metrics) > 0:
|
||||
metric_list = [self._full_metrics[ep] for ep in sorted(self._full_metrics.keys())]
|
||||
df = pd.DataFrame.from_records(metric_list)
|
||||
df.to_csv(os.path.join(self._path, "metrics_full.csv"), index=True)
|
||||
if len(self._valid_metrics) > 0:
|
||||
metric_list = [self._valid_metrics[ep] for ep in sorted(self._valid_metrics.keys())]
|
||||
df = pd.DataFrame.from_records(metric_list)
|
||||
df.to_csv(os.path.join(self._path, "metrics_valid.csv"), index=True)
|
||||
|
||||
def on_training_end(self, ep: int) -> None:
|
||||
if len(self.env_sampler.metrics) > 0:
|
||||
metrics = copy.deepcopy(self.env_sampler.metrics)
|
||||
metrics["ep"] = ep
|
||||
if ep in self._full_metrics:
|
||||
self._full_metrics[ep].update(metrics)
|
||||
else:
|
||||
self._full_metrics[ep] = metrics
|
||||
self._dump_metric_history()
|
||||
|
||||
def on_validation_end(self, ep: int) -> None:
|
||||
if len(self.env_sampler.metrics) > 0:
|
||||
metrics = copy.deepcopy(self.env_sampler.metrics)
|
||||
metrics["ep"] = ep
|
||||
if ep in self._full_metrics:
|
||||
self._full_metrics[ep].update(metrics)
|
||||
else:
|
||||
self._full_metrics[ep] = metrics
|
||||
if ep in self._valid_metrics:
|
||||
self._valid_metrics[ep].update(metrics)
|
||||
else:
|
||||
self._valid_metrics[ep] = metrics
|
||||
self._dump_metric_history()
|
||||
|
||||
|
||||
class CallbackManager(object):
|
||||
def __init__(
|
||||
self,
|
||||
workflow: TrainingWorkflow,
|
||||
callbacks: List[Callback],
|
||||
env_sampler: EnvSampler,
|
||||
training_manager: TrainingManager,
|
||||
logger: LoggerV2,
|
||||
) -> None:
|
||||
super(CallbackManager, self).__init__()
|
||||
|
||||
self._callbacks = callbacks
|
||||
for callback in self._callbacks:
|
||||
callback.workflow = workflow
|
||||
callback.env_sampler = env_sampler
|
||||
callback.training_manager = training_manager
|
||||
callback.logger = logger
|
||||
|
||||
def on_episode_start(self, ep: int) -> None:
|
||||
for callback in self._callbacks:
|
||||
callback.on_episode_start(ep)
|
||||
|
||||
def on_episode_end(self, ep: int) -> None:
|
||||
for callback in self._callbacks:
|
||||
callback.on_episode_end(ep)
|
||||
|
||||
def on_training_start(self, ep: int) -> None:
|
||||
for callback in self._callbacks:
|
||||
callback.on_training_start(ep)
|
||||
|
||||
def on_training_end(self, ep: int) -> None:
|
||||
for callback in self._callbacks:
|
||||
callback.on_training_end(ep)
|
||||
|
||||
def on_validation_start(self, ep: int) -> None:
|
||||
for callback in self._callbacks:
|
||||
callback.on_validation_start(ep)
|
||||
|
||||
def on_validation_end(self, ep: int) -> None:
|
||||
for callback in self._callbacks:
|
||||
callback.on_validation_end(ep)
|
||||
|
||||
def on_test_start(self, ep: int) -> None:
|
||||
for callback in self._callbacks:
|
||||
callback.on_test_start(ep)
|
||||
|
||||
def on_test_end(self, ep: int) -> None:
|
||||
for callback in self._callbacks:
|
||||
callback.on_test_end(ep)
|
|
@ -76,6 +76,11 @@ class ConfigParser:
|
|||
f"positive ints",
|
||||
)
|
||||
|
||||
early_stop_patience = self._config["main"].get("early_stop_patience", None)
|
||||
if early_stop_patience is not None:
|
||||
if not isinstance(early_stop_patience, int) or early_stop_patience <= 0:
|
||||
raise ValueError(f"Invalid early stop patience: {early_stop_patience}. Should be a positive integer.")
|
||||
|
||||
if "logging" in self._config["main"]:
|
||||
self._validate_logging_section("main", self._config["main"]["logging"])
|
||||
|
||||
|
@ -196,9 +201,10 @@ class ConfigParser:
|
|||
raise TypeError(f"{self._validation_err_pfx}: 'training.proxy.backend' must be an int")
|
||||
|
||||
def _validate_checkpointing_section(self, section: dict) -> None:
|
||||
if "path" not in section:
|
||||
raise KeyError(f"{self._validation_err_pfx}: missing field 'path' under section 'checkpointing'")
|
||||
if not isinstance(section["path"], str):
|
||||
ckpt_path = section.get("path", None)
|
||||
if ckpt_path is None:
|
||||
section["path"] = os.path.join(self._config["log_path"], "checkpoints")
|
||||
elif not isinstance(section["path"], str):
|
||||
raise TypeError(f"{self._validation_err_pfx}: 'training.checkpointing.path' must be a string")
|
||||
|
||||
if "interval" in section:
|
||||
|
@ -231,10 +237,9 @@ class ConfigParser:
|
|||
local/log/path -> "/logs"
|
||||
Defaults to False.
|
||||
"""
|
||||
log_dir = os.path.dirname(self._config["log_path"])
|
||||
path_map = {
|
||||
self._config["scenario_path"]: "/scenario" if containerize else self._config["scenario_path"],
|
||||
log_dir: "/logs" if containerize else log_dir,
|
||||
self._config["log_path"]: "/logs" if containerize else self._config["log_path"],
|
||||
}
|
||||
|
||||
load_path = self._config["training"].get("load_path", None)
|
||||
|
@ -286,12 +291,16 @@ class ConfigParser:
|
|||
else:
|
||||
main_proc_env["EVAL_SCHEDULE"] = " ".join([str(val) for val in sorted(sch)])
|
||||
|
||||
main_proc_env["NUM_EVAL_EPISODES"] = str(self._config["main"].get("num_eval_episodes", 1))
|
||||
if "early_stop_patience" in self._config["main"]:
|
||||
main_proc_env["EARLY_STOP_PATIENCE"] = str(self._config["main"]["early_stop_patience"])
|
||||
|
||||
load_path = self._config["training"].get("load_path", None)
|
||||
if load_path is not None:
|
||||
env["main"]["LOAD_PATH"] = path_mapping[load_path]
|
||||
main_proc_env["LOAD_PATH"] = path_mapping[load_path]
|
||||
load_episode = self._config["training"].get("load_episode", None)
|
||||
if load_episode is not None:
|
||||
env["main"]["LOAD_EPISODE"] = str(load_episode)
|
||||
main_proc_env["LOAD_EPISODE"] = str(load_episode)
|
||||
|
||||
if "checkpointing" in self._config["training"]:
|
||||
conf = self._config["training"]["checkpointing"]
|
||||
|
@ -385,9 +394,8 @@ class ConfigParser:
|
|||
)
|
||||
|
||||
# All components write logs to the same file
|
||||
log_dir, log_file = os.path.split(self._config["log_path"])
|
||||
for _, vars in env.values():
|
||||
vars["LOG_PATH"] = os.path.join(path_mapping[log_dir], log_file)
|
||||
vars["LOG_PATH"] = path_mapping[self._config["log_path"]]
|
||||
|
||||
return env
|
||||
|
||||
|
|
|
@ -24,6 +24,8 @@ main:
|
|||
# A list indicates the episodes at the end of which policies are to be evaluated. Note that episode indexes are
|
||||
# 1-based.
|
||||
eval_schedule: 10
|
||||
early_stop_patience: 10 # Number of epochs waiting for a better validation metrics. Could be `null`.
|
||||
num_eval_episodes: 10 # Number of Episodes to run in evaluation.
|
||||
# Minimum number of samples to start training in one epoch. The workflow will re-run experience collection
|
||||
# until we have at least `min_n_sample` of experiences.
|
||||
min_n_sample: 1
|
||||
|
@ -68,8 +70,9 @@ training:
|
|||
checkpointing:
|
||||
# Directory to save trainer snapshots under. Snapshot files created at different episodes will be saved under
|
||||
# separate folders named using episode numbers. For example, if a snapshot is created for a trainer named "dqn"
|
||||
# at the end of episode 10, the file path would be "/path/to/your/checkpoint/folder/10/dqn.ckpt".
|
||||
path: "/path/to/your/checkpoint/folder"
|
||||
# at the end of episode 10, the file path would be "/path/to/your/checkpoint/folder/10/dqn.ckpt". If null, the
|
||||
# default checkpoint folder would be created under `log_path`.
|
||||
path: "/path/to/your/checkpoint/folder" # or `null`
|
||||
interval: 10 # Interval at which trained policies / models are persisted to disk.
|
||||
proxy: # Proxy settings. Ignored if training.mode is "simple".
|
||||
host: "127.0.0.1" # Proxy service host's IP address. Ignored if run in containerized environments.
|
||||
|
|
|
@ -14,21 +14,21 @@ from maro.rl.training import TrainingManager
|
|||
from maro.rl.utils import get_torch_device
|
||||
from maro.rl.utils.common import float_or_none, get_env, int_or_none, list_or_none
|
||||
from maro.rl.utils.training import get_latest_ep
|
||||
from maro.rl.workflows.utils import env_str_helper
|
||||
from maro.utils import LoggerV2
|
||||
from maro.rl.workflows.callback import CallbackManager, Checkpoint, EarlyStopping, MetricsRecorder
|
||||
from maro.utils import LoggerV2, set_seeds
|
||||
|
||||
|
||||
class WorkflowEnvAttributes:
|
||||
def __init__(self) -> None:
|
||||
# Number of training episodes
|
||||
self.num_episodes = int(env_str_helper(get_env("NUM_EPISODES")))
|
||||
self.num_episodes = int(get_env("NUM_EPISODES"))
|
||||
|
||||
# Maximum number of steps in on round of sampling.
|
||||
self.num_steps = int_or_none(get_env("NUM_STEPS", required=False))
|
||||
|
||||
# Minimum number of data samples to start a round of training. If the data samples are insufficient, re-run
|
||||
# data sampling until we have at least `min_n_sample` data entries.
|
||||
self.min_n_sample = int(env_str_helper(get_env("MIN_N_SAMPLE")))
|
||||
self.min_n_sample = int(get_env("MIN_N_SAMPLE"))
|
||||
|
||||
# Path to store logs.
|
||||
self.log_path = get_env("LOG_PATH")
|
||||
|
@ -46,6 +46,8 @@ class WorkflowEnvAttributes:
|
|||
|
||||
# Evaluating schedule.
|
||||
self.eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False))
|
||||
self.early_stop_patience = int_or_none(get_env("EARLY_STOP_PATIENCE", required=False))
|
||||
self.num_eval_episodes = int_or_none(get_env("NUM_EVAL_EPISODES", required=False))
|
||||
|
||||
# Restore configurations.
|
||||
self.load_path = get_env("LOAD_PATH", required=False)
|
||||
|
@ -58,7 +60,7 @@ class WorkflowEnvAttributes:
|
|||
# Parallel sampling configurations.
|
||||
self.parallel_rollout = self.env_sampling_parallelism is not None or self.env_eval_parallelism is not None
|
||||
if self.parallel_rollout:
|
||||
self.port = int(env_str_helper(get_env("ROLLOUT_CONTROLLER_PORT")))
|
||||
self.port = int(get_env("ROLLOUT_CONTROLLER_PORT"))
|
||||
self.min_env_samples = int_or_none(get_env("MIN_ENV_SAMPLES", required=False))
|
||||
self.grace_factor = float_or_none(get_env("GRACE_FACTOR", required=False))
|
||||
|
||||
|
@ -67,13 +69,13 @@ class WorkflowEnvAttributes:
|
|||
# Distributed training configurations.
|
||||
if self.train_mode != "simple":
|
||||
self.proxy_address = (
|
||||
env_str_helper(get_env("TRAIN_PROXY_HOST")),
|
||||
int(env_str_helper(get_env("TRAIN_PROXY_FRONTEND_PORT"))),
|
||||
str(get_env("TRAIN_PROXY_HOST")),
|
||||
int(get_env("TRAIN_PROXY_FRONTEND_PORT")),
|
||||
)
|
||||
|
||||
self.logger = LoggerV2(
|
||||
"MAIN",
|
||||
dump_path=self.log_path,
|
||||
dump_path=os.path.join(self.log_path, "log.txt"),
|
||||
dump_mode="a",
|
||||
stdout_level=self.log_level_stdout,
|
||||
file_level=self.log_level_file,
|
||||
|
@ -83,6 +85,7 @@ class WorkflowEnvAttributes:
|
|||
def _get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="MARO RL workflow parser")
|
||||
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
|
||||
parser.add_argument("--seed", type=int, help="The random seed set before running this job")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -112,88 +115,112 @@ def main(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes
|
|||
if args.evaluate_only:
|
||||
evaluate_only_workflow(rl_component_bundle, env_attr)
|
||||
else:
|
||||
training_workflow(rl_component_bundle, env_attr)
|
||||
TrainingWorkflow().run(rl_component_bundle, env_attr)
|
||||
|
||||
|
||||
def training_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
|
||||
env_attr.logger.info("Start training workflow.")
|
||||
class TrainingWorkflow(object):
|
||||
def run(self, rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
|
||||
env_attr.logger.info("Start training workflow.")
|
||||
|
||||
env_sampler = _get_env_sampler(rl_component_bundle, env_attr)
|
||||
env_sampler = _get_env_sampler(rl_component_bundle, env_attr)
|
||||
|
||||
# evaluation schedule
|
||||
env_attr.logger.info(f"Policy will be evaluated at the end of episodes {env_attr.eval_schedule}")
|
||||
eval_point_index = 0
|
||||
# evaluation schedule
|
||||
env_attr.logger.info(f"Policy will be evaluated at the end of episodes {env_attr.eval_schedule}")
|
||||
eval_point_index = 0
|
||||
|
||||
training_manager = TrainingManager(
|
||||
rl_component_bundle=rl_component_bundle,
|
||||
explicit_assign_device=(env_attr.train_mode == "simple"),
|
||||
proxy_address=None if env_attr.train_mode == "simple" else env_attr.proxy_address,
|
||||
logger=env_attr.logger,
|
||||
)
|
||||
|
||||
if env_attr.load_path:
|
||||
assert isinstance(env_attr.load_path, str)
|
||||
|
||||
ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path)
|
||||
path = os.path.join(env_attr.load_path, str(ep))
|
||||
|
||||
loaded = env_sampler.load_policy_state(path)
|
||||
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
|
||||
|
||||
loaded = training_manager.load(path)
|
||||
env_attr.logger.info(f"Loaded trainers {loaded} from {path}")
|
||||
start_ep = ep + 1
|
||||
else:
|
||||
start_ep = 1
|
||||
|
||||
# main loop
|
||||
for ep in range(start_ep, env_attr.num_episodes + 1):
|
||||
collect_time = training_time = 0.0
|
||||
total_experiences: List[List[ExpElement]] = []
|
||||
total_info_list: List[dict] = []
|
||||
n_sample = 0
|
||||
while n_sample < env_attr.min_n_sample:
|
||||
tc0 = time.time()
|
||||
result = env_sampler.sample(
|
||||
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
|
||||
num_steps=env_attr.num_steps,
|
||||
)
|
||||
experiences: List[List[ExpElement]] = result["experiences"]
|
||||
info_list: List[dict] = result["info"]
|
||||
|
||||
n_sample += len(experiences[0])
|
||||
total_experiences.extend(experiences)
|
||||
total_info_list.extend(info_list)
|
||||
|
||||
collect_time += time.time() - tc0
|
||||
|
||||
env_sampler.post_collect(total_info_list, ep)
|
||||
|
||||
env_attr.logger.info(f"Roll-out completed for episode {ep}. Training started...")
|
||||
tu0 = time.time()
|
||||
training_manager.record_experiences(total_experiences)
|
||||
training_manager.train_step()
|
||||
if env_attr.checkpoint_path and (not env_attr.checkpoint_interval or ep % env_attr.checkpoint_interval == 0):
|
||||
assert isinstance(env_attr.checkpoint_path, str)
|
||||
pth = os.path.join(env_attr.checkpoint_path, str(ep))
|
||||
training_manager.save(pth)
|
||||
env_attr.logger.info(f"All trainer states saved under {pth}")
|
||||
training_time += time.time() - tu0
|
||||
|
||||
# performance details
|
||||
env_attr.logger.info(
|
||||
f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds",
|
||||
training_manager = TrainingManager(
|
||||
rl_component_bundle=rl_component_bundle,
|
||||
explicit_assign_device=(env_attr.train_mode == "simple"),
|
||||
proxy_address=None if env_attr.train_mode == "simple" else env_attr.proxy_address,
|
||||
logger=env_attr.logger,
|
||||
)
|
||||
if env_attr.eval_schedule and ep == env_attr.eval_schedule[eval_point_index]:
|
||||
eval_point_index += 1
|
||||
result = env_sampler.eval(
|
||||
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
|
||||
)
|
||||
env_sampler.post_evaluate(result["info"], ep)
|
||||
|
||||
if isinstance(env_sampler, BatchEnvSampler):
|
||||
env_sampler.exit()
|
||||
training_manager.exit()
|
||||
callbacks = [MetricsRecorder(path=env_attr.log_path)]
|
||||
if env_attr.checkpoint_path is not None:
|
||||
callbacks.append(
|
||||
Checkpoint(
|
||||
path=env_attr.checkpoint_path,
|
||||
interval=1 if env_attr.checkpoint_interval is None else env_attr.checkpoint_interval,
|
||||
),
|
||||
)
|
||||
if env_attr.early_stop_patience is not None:
|
||||
callbacks.append(EarlyStopping(patience=env_attr.early_stop_patience))
|
||||
callbacks.extend(rl_component_bundle.customized_callbacks)
|
||||
cbm = CallbackManager(self, callbacks, env_sampler, training_manager, env_attr.logger)
|
||||
|
||||
if env_attr.load_path:
|
||||
assert isinstance(env_attr.load_path, str)
|
||||
|
||||
ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path)
|
||||
path = os.path.join(env_attr.load_path, str(ep))
|
||||
|
||||
loaded = env_sampler.load_policy_state(path)
|
||||
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
|
||||
|
||||
loaded = training_manager.load(path)
|
||||
env_attr.logger.info(f"Loaded trainers {loaded} from {path}")
|
||||
start_ep = ep + 1
|
||||
else:
|
||||
start_ep = 1
|
||||
|
||||
# main loop
|
||||
self.early_stop = False
|
||||
for ep in range(start_ep, env_attr.num_episodes + 1):
|
||||
if self.early_stop: # Might be set in `cbm.on_validation_end()`
|
||||
break
|
||||
|
||||
cbm.on_episode_start(ep)
|
||||
|
||||
collect_time = training_time = 0.0
|
||||
total_experiences: List[List[ExpElement]] = []
|
||||
total_info_list: List[dict] = []
|
||||
n_sample = 0
|
||||
while n_sample < env_attr.min_n_sample:
|
||||
tc0 = time.time()
|
||||
result = env_sampler.sample(
|
||||
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
|
||||
num_steps=env_attr.num_steps,
|
||||
)
|
||||
experiences: List[List[ExpElement]] = result["experiences"]
|
||||
info_list: List[dict] = result["info"]
|
||||
|
||||
n_sample += len(experiences[0])
|
||||
total_experiences.extend(experiences)
|
||||
total_info_list.extend(info_list)
|
||||
|
||||
collect_time += time.time() - tc0
|
||||
|
||||
env_sampler.post_collect(total_info_list, ep)
|
||||
|
||||
tu0 = time.time()
|
||||
env_attr.logger.info(f"Roll-out completed for episode {ep}. Training started...")
|
||||
cbm.on_training_start(ep)
|
||||
training_manager.record_experiences(total_experiences)
|
||||
training_manager.train_step()
|
||||
cbm.on_training_end(ep)
|
||||
training_time += time.time() - tu0
|
||||
|
||||
# performance details
|
||||
env_attr.logger.info(
|
||||
f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds",
|
||||
)
|
||||
if env_attr.eval_schedule and ep == env_attr.eval_schedule[eval_point_index]:
|
||||
cbm.on_validation_start(ep)
|
||||
|
||||
eval_point_index += 1
|
||||
result = env_sampler.eval(
|
||||
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
|
||||
num_episodes=env_attr.num_eval_episodes,
|
||||
)
|
||||
env_sampler.post_evaluate(result["info"], ep)
|
||||
|
||||
cbm.on_validation_end(ep)
|
||||
|
||||
cbm.on_episode_end(ep)
|
||||
|
||||
if isinstance(env_sampler, BatchEnvSampler):
|
||||
env_sampler.exit()
|
||||
training_manager.exit()
|
||||
|
||||
|
||||
def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
|
||||
|
@ -210,7 +237,7 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: Wor
|
|||
loaded = env_sampler.load_policy_state(path)
|
||||
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
|
||||
|
||||
result = env_sampler.eval()
|
||||
result = env_sampler.eval(num_episodes=env_attr.num_eval_episodes)
|
||||
env_sampler.post_evaluate(result["info"], -1)
|
||||
|
||||
if isinstance(env_sampler, BatchEnvSampler):
|
||||
|
@ -218,9 +245,13 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: Wor
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
|
||||
args = _get_args()
|
||||
if args.seed is not None:
|
||||
set_seeds(seed=args.seed)
|
||||
|
||||
scenario_path = get_env("SCENARIO_PATH")
|
||||
scenario_path = os.path.normpath(scenario_path)
|
||||
sys.path.insert(0, os.path.dirname(scenario_path))
|
||||
module = importlib.import_module(os.path.basename(scenario_path))
|
||||
|
||||
main(getattr(module, "rl_component_bundle"), WorkflowEnvAttributes(), args=_get_args())
|
||||
main(getattr(module, "rl_component_bundle"), WorkflowEnvAttributes(), args=args)
|
||||
|
|
|
@ -8,21 +8,20 @@ import sys
|
|||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import RolloutWorker
|
||||
from maro.rl.utils.common import get_env, int_or_none
|
||||
from maro.rl.workflows.utils import env_str_helper
|
||||
from maro.utils import LoggerV2
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
|
||||
scenario_path = get_env("SCENARIO_PATH")
|
||||
scenario_path = os.path.normpath(scenario_path)
|
||||
sys.path.insert(0, os.path.dirname(scenario_path))
|
||||
module = importlib.import_module(os.path.basename(scenario_path))
|
||||
|
||||
rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle")
|
||||
|
||||
worker_idx = int(env_str_helper(get_env("ID")))
|
||||
worker_idx = int(get_env("ID"))
|
||||
logger = LoggerV2(
|
||||
f"ROLLOUT-WORKER.{worker_idx}",
|
||||
dump_path=get_env("LOG_PATH"),
|
||||
dump_path=os.path.join(get_env("LOG_PATH"), f"ROLLOUT-WORKER.{worker_idx}.txt"),
|
||||
dump_mode="a",
|
||||
stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"),
|
||||
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
|
||||
|
@ -30,7 +29,7 @@ if __name__ == "__main__":
|
|||
worker = RolloutWorker(
|
||||
idx=worker_idx,
|
||||
rl_component_bundle=rl_component_bundle,
|
||||
producer_host=env_str_helper(get_env("ROLLOUT_CONTROLLER_HOST")),
|
||||
producer_host=get_env("ROLLOUT_CONTROLLER_HOST"),
|
||||
producer_port=int_or_none(get_env("ROLLOUT_CONTROLLER_PORT")),
|
||||
logger=logger,
|
||||
)
|
||||
|
|
|
@ -8,11 +8,10 @@ import sys
|
|||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.training import TrainOpsWorker
|
||||
from maro.rl.utils.common import get_env, int_or_none
|
||||
from maro.rl.workflows.utils import env_str_helper
|
||||
from maro.utils import LoggerV2
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
|
||||
scenario_path = get_env("SCENARIO_PATH")
|
||||
scenario_path = os.path.normpath(scenario_path)
|
||||
sys.path.insert(0, os.path.dirname(scenario_path))
|
||||
module = importlib.import_module(os.path.basename(scenario_path))
|
||||
|
@ -22,15 +21,15 @@ if __name__ == "__main__":
|
|||
worker_idx = int_or_none(get_env("ID"))
|
||||
logger = LoggerV2(
|
||||
f"TRAIN-WORKER.{worker_idx}",
|
||||
dump_path=get_env("LOG_PATH"),
|
||||
dump_path=os.path.join(get_env("LOG_PATH"), f"TRAIN-WORKER.{worker_idx}.txt"),
|
||||
dump_mode="a",
|
||||
stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"),
|
||||
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
|
||||
)
|
||||
worker = TrainOpsWorker(
|
||||
idx=int(env_str_helper(get_env("ID"))),
|
||||
idx=int(get_env("ID")),
|
||||
rl_component_bundle=rl_component_bundle,
|
||||
producer_host=env_str_helper(get_env("TRAIN_PROXY_HOST")),
|
||||
producer_host=get_env("TRAIN_PROXY_HOST"),
|
||||
producer_port=int_or_none(get_env("TRAIN_PROXY_BACKEND_PORT")),
|
||||
logger=logger,
|
||||
)
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def env_str_helper(string: Optional[str]) -> str:
|
||||
assert string is not None
|
||||
return string
|
|
@ -61,7 +61,7 @@ def is_float_type(v_type: type):
|
|||
Returns:
|
||||
bool: True if an float type.
|
||||
"""
|
||||
return v_type is float or v_type is np.float or v_type is np.float32 or v_type is np.float64
|
||||
return v_type is float or v_type is np.float16 or v_type is np.float32 or v_type is np.float64
|
||||
|
||||
|
||||
def parse_value(value: object):
|
||||
|
|
|
@ -6,7 +6,7 @@ deepdiff>=5.7.0
|
|||
geopy>=2.0.0
|
||||
holidays>=0.10.3
|
||||
kubernetes>=21.7.0
|
||||
numpy>=1.19.5,<1.24.0
|
||||
numpy>=1.19.5
|
||||
pandas>=0.25.3
|
||||
paramiko>=2.9.2
|
||||
pytest>=7.1.2
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import cast
|
||||
|
||||
from maro.simulator import Env
|
||||
|
||||
from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine
|
||||
|
||||
env_conf = {
|
||||
"topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4
|
||||
"start_tick": 0,
|
||||
"durations": 100000, # Set a very large number
|
||||
"options": {},
|
||||
}
|
||||
|
||||
learn_env = Env(business_engine_cls=GymBusinessEngine, **env_conf)
|
||||
test_env = Env(business_engine_cls=GymBusinessEngine, **env_conf)
|
||||
num_agents = len(learn_env.agent_idx_list)
|
||||
|
||||
gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env
|
||||
gym_action_space = gym_env.action_space
|
||||
gym_state_dim = gym_env.observation_space.shape[0]
|
||||
gym_action_dim = gym_action_space.shape[0]
|
||||
action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high
|
||||
action_limit = gym_action_space.high[0]
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.rl.policy.abs_policy import AbsPolicy
|
||||
from maro.rl.rollout import AbsEnvSampler, CacheElement
|
||||
from maro.rl.rollout.env_sampler import AbsAgentWrapper, SimpleAgentWrapper
|
||||
from maro.simulator.core import Env
|
||||
|
||||
from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine
|
||||
from tests.rl.gym_wrapper.simulator.common import Action, DecisionEvent
|
||||
|
||||
|
||||
class GymEnvSampler(AbsEnvSampler):
|
||||
def __init__(
|
||||
self,
|
||||
learn_env: Env,
|
||||
test_env: Env,
|
||||
policies: List[AbsPolicy],
|
||||
agent2policy: Dict[Any, str],
|
||||
trainable_policies: List[str] = None,
|
||||
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
|
||||
reward_eval_delay: int = None,
|
||||
max_episode_length: int = None,
|
||||
) -> None:
|
||||
super(GymEnvSampler, self).__init__(
|
||||
learn_env=learn_env,
|
||||
test_env=test_env,
|
||||
policies=policies,
|
||||
agent2policy=agent2policy,
|
||||
trainable_policies=trainable_policies,
|
||||
agent_wrapper_cls=agent_wrapper_cls,
|
||||
reward_eval_delay=reward_eval_delay,
|
||||
max_episode_length=max_episode_length,
|
||||
)
|
||||
|
||||
self._sample_rewards = []
|
||||
self._eval_rewards = []
|
||||
|
||||
def _get_global_and_agent_state_impl(
|
||||
self,
|
||||
event: DecisionEvent,
|
||||
tick: int = None,
|
||||
) -> Tuple[Union[None, np.ndarray, list], Dict[Any, Union[np.ndarray, list]]]:
|
||||
return None, {0: event.state}
|
||||
|
||||
def _translate_to_env_action(self, action_dict: dict, event: Any) -> dict:
|
||||
return {k: Action(v) for k, v in action_dict.items()}
|
||||
|
||||
def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]:
|
||||
be = self._env.business_engine
|
||||
assert isinstance(be, GymBusinessEngine)
|
||||
return {0: be.get_reward_at_tick(tick)}
|
||||
|
||||
def _post_step(self, cache_element: CacheElement) -> None:
|
||||
if not (self._end_of_episode or self.truncated):
|
||||
return
|
||||
rewards = list(self._env.metrics["reward_record"].values())
|
||||
self._sample_rewards.append((len(rewards), np.sum(rewards)))
|
||||
|
||||
def _post_eval_step(self, cache_element: CacheElement) -> None:
|
||||
if not (self._end_of_episode or self.truncated):
|
||||
return
|
||||
rewards = list(self._env.metrics["reward_record"].values())
|
||||
self._eval_rewards.append((len(rewards), np.sum(rewards)))
|
||||
|
||||
def post_collect(self, info_list: list, ep: int) -> None:
|
||||
if len(self._sample_rewards) > 0:
|
||||
cur = {
|
||||
"n_steps": sum([n for n, _ in self._sample_rewards]),
|
||||
"n_segment": len(self._sample_rewards),
|
||||
"avg_reward": np.mean([r for _, r in self._sample_rewards]),
|
||||
"avg_n_steps": np.mean([n for n, _ in self._sample_rewards]),
|
||||
"max_n_steps": np.max([n for n, _ in self._sample_rewards]),
|
||||
"n_interactions": self._total_number_interactions,
|
||||
}
|
||||
self.metrics.update(cur)
|
||||
# clear validation metrics
|
||||
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
|
||||
self._sample_rewards.clear()
|
||||
else:
|
||||
self.metrics = {"n_interactions": self._total_number_interactions}
|
||||
|
||||
def post_evaluate(self, info_list: list, ep: int) -> None:
|
||||
if len(self._eval_rewards) > 0:
|
||||
cur = {
|
||||
"val/n_steps": sum([n for n, _ in self._eval_rewards]),
|
||||
"val/n_segment": len(self._eval_rewards),
|
||||
"val/avg_reward": np.mean([r for _, r in self._eval_rewards]),
|
||||
"val/avg_n_steps": np.mean([n for n, _ in self._eval_rewards]),
|
||||
"val/max_n_steps": np.max([n for n, _ in self._eval_rewards]),
|
||||
}
|
||||
self.metrics.update(cur)
|
||||
self._eval_rewards.clear()
|
||||
else:
|
||||
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from maro.backends.frame import FrameBase, SnapshotList
|
||||
from maro.event_buffer import CascadeEvent, EventBuffer, MaroEvents
|
||||
from maro.simulator.scenarios import AbsBusinessEngine
|
||||
|
||||
from .common import Action, DecisionEvent
|
||||
|
||||
|
||||
class GymBusinessEngine(AbsBusinessEngine):
|
||||
def __init__(
|
||||
self,
|
||||
event_buffer: EventBuffer,
|
||||
topology: Optional[str],
|
||||
start_tick: int,
|
||||
max_tick: int,
|
||||
snapshot_resolution: int,
|
||||
max_snapshots: Optional[int],
|
||||
additional_options: dict = None,
|
||||
) -> None:
|
||||
super(GymBusinessEngine, self).__init__(
|
||||
scenario_name="gym",
|
||||
event_buffer=event_buffer,
|
||||
topology=topology,
|
||||
start_tick=start_tick,
|
||||
max_tick=max_tick,
|
||||
snapshot_resolution=snapshot_resolution,
|
||||
max_snapshots=max_snapshots,
|
||||
additional_options=additional_options,
|
||||
)
|
||||
|
||||
self._gym_scenario_name = topology
|
||||
self._gym_env = gym.make(self._gym_scenario_name)
|
||||
|
||||
self.reset()
|
||||
|
||||
self._frame: FrameBase = FrameBase()
|
||||
self._snapshots: SnapshotList = self._frame.snapshots
|
||||
|
||||
self._register_events()
|
||||
|
||||
@property
|
||||
def gym_env(self) -> gym.Env:
|
||||
return self._gym_env
|
||||
|
||||
@property
|
||||
def frame(self) -> FrameBase:
|
||||
return self._frame
|
||||
|
||||
@property
|
||||
def snapshots(self) -> SnapshotList:
|
||||
return self._snapshots
|
||||
|
||||
def _register_events(self) -> None:
|
||||
self._event_buffer.register_event_handler(MaroEvents.TAKE_ACTION, self._on_action_received)
|
||||
|
||||
def _on_action_received(self, event: CascadeEvent) -> None:
|
||||
action = cast(Action, cast(list, event.payload)[0]).action
|
||||
|
||||
self._last_obs, reward, self._is_done, self._truncated, info = self._gym_env.step(action)
|
||||
self._reward_record[event.tick] = reward
|
||||
self._info_record[event.tick] = info
|
||||
|
||||
def step(self, tick: int) -> None:
|
||||
self._event_buffer.insert_event(self._event_buffer.gen_decision_event(tick, DecisionEvent(self._last_obs)))
|
||||
|
||||
@property
|
||||
def configs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def get_reward_at_tick(self, tick: int) -> float:
|
||||
return self._reward_record[tick]
|
||||
|
||||
def get_info_at_tick(self, tick: int) -> object: # TODO
|
||||
return self._info_record[tick]
|
||||
|
||||
def reset(self, keep_seed: bool = False) -> None:
|
||||
self._last_obs = self._gym_env.reset(seed=np.random.randint(low=0, high=4096))[0]
|
||||
self._is_done = False
|
||||
self._truncated = False
|
||||
self._reward_record = {}
|
||||
self._info_record = {}
|
||||
|
||||
def post_step(self, tick: int) -> bool:
|
||||
return self._is_done or self._truncated or tick + 1 == self._max_tick
|
||||
|
||||
def get_agent_idx_list(self) -> List[int]:
|
||||
return [0]
|
||||
|
||||
def get_metrics(self) -> dict:
|
||||
return {
|
||||
"reward_record": {k: v for k, v in self._reward_record.items()},
|
||||
}
|
||||
|
||||
def set_seed(self, seed: int) -> None:
|
||||
pass
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.common import BaseAction, BaseDecisionEvent
|
||||
|
||||
|
||||
class Action(BaseAction):
|
||||
def __init__(self, action: np.ndarray) -> None:
|
||||
self.action = action
|
||||
|
||||
|
||||
class DecisionEvent(BaseDecisionEvent):
|
||||
def __init__(self, state: np.ndarray) -> None:
|
||||
self.state = state
|
После Ширина: | Высота: | Размер: 139 KiB |
После Ширина: | Высота: | Размер: 101 KiB |
После Ширина: | Высота: | Размер: 131 KiB |
После Ширина: | Высота: | Размер: 86 KiB |
После Ширина: | Высота: | Размер: 172 KiB |
После Ширина: | Высота: | Размер: 130 KiB |
После Ширина: | Высота: | Размер: 120 KiB |
После Ширина: | Высота: | Размер: 81 KiB |
После Ширина: | Высота: | Размер: 165 KiB |
После Ширина: | Высота: | Размер: 113 KiB |
|
@ -0,0 +1,54 @@
|
|||
# Performance for Gym Task Suite
|
||||
|
||||
We benchmarked the MARO RL Toolkit implementation in Gym task suite. Some are compared to the benchmarks in
|
||||
[OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#). We've tried to align the
|
||||
hyper-parameters for these benchmarks , but limited by the environment version difference, there may be some gaps
|
||||
between the performance here and that in Spinning Up benchmarks. Generally speaking, the performance is comparable.
|
||||
|
||||
## Experimental Setting
|
||||
|
||||
The hyper-parameters are set to align with those used in
|
||||
[Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#experiment-details):
|
||||
|
||||
**Batch Size**:
|
||||
|
||||
- For on-policy algorithms: 4000 steps of interaction per batch update;
|
||||
- For off-policy algorithms: size 100 for each gradient descent step;
|
||||
|
||||
**Network**:
|
||||
|
||||
- For on-policy algorithms: size (64, 32) with tanh units for both policy and value function;
|
||||
- For off-policy algorithms: size (256, 256) with relu units;
|
||||
|
||||
**Performance metric**:
|
||||
|
||||
- For on-policy algorithms: measured as the average trajectory return across the batch collected at each epoch;
|
||||
- For off-policy algorithms: measured once every 10,000 steps by running the deterministic policy (or, in the case of SAC, the mean policy) without action noise for ten trajectories, and reporting the average return over those test trajectories;
|
||||
|
||||
**Total timesteps**: set to 4M for all task suites and algorithms.
|
||||
|
||||
More details about the parameters can be found in *tests/rl/tasks/*.
|
||||
|
||||
## Performance
|
||||
|
||||
Five environments from the MuJoCo Gym task suite are reported in Spinning Up, they are: HalfCheetah, Hopper, Walker2d,
|
||||
Swimmer, and Ant. The commit id of the code used to conduct the experiments for MARO RL benchmarks is ee25ce1e97.
|
||||
The commands used are:
|
||||
|
||||
```sh
|
||||
# Step 1: Set up the MuJoCo Environment in file tests/rl/gym_wrapper/common.py
|
||||
|
||||
# Step 2: Use the command below to run experiment with ALGORITHM (ddpg, ppo, sac) and random seed SEED.
|
||||
python tests/rl/run.py tests/rl/tasks/ALGORITHM/config.yml --seed SEED
|
||||
|
||||
# Step 3: Plot performance curves by environment with specific smooth window size WINDOWSIZE.
|
||||
python tests/rl/plot.py --smooth WINDOWSIZE
|
||||
```
|
||||
|
||||
| **Env** | **Spinning Up** | **MARO RL w/o Smooth** | **MARO RL w/ Smooth** |
|
||||
|:---------------:|:---------------:|:----------------------:|:---------------------:|
|
||||
| [**HalfCheetah**](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) | ![Hab](https://spinningup.openai.com/en/latest/_images/pytorch_halfcheetah_performance.svg) | ![Ha1](./log/HalfCheetah_1.png) | ![Ha11](./log/HalfCheetah_11.png) |
|
||||
| [**Hopper**](https://gymnasium.farama.org/environments/mujoco/hopper/) | ![Hob](https://spinningup.openai.com/en/latest/_images/pytorch_hopper_performance.svg) | ![Ho1](./log/Hopper_1.png) | ![Ho11](./log/Hopper_11.png) |
|
||||
| [**Walker2d**](https://gymnasium.farama.org/environments/mujoco/walker2d/) | ![Wab](https://spinningup.openai.com/en/latest/_images/pytorch_walker2d_performance.svg) | ![Wa1](./log/Walker2d_1.png) | ![Wa11](./log/Walker2d_11.png) |
|
||||
| [**Swimmer**](https://gymnasium.farama.org/environments/mujoco/swimmer/) | ![Swb](https://spinningup.openai.com/en/latest/_images/pytorch_swimmer_performance.svg) | ![Sw1](./log/Swimmer_1.png) | ![Sw11](./log/Swimmer_11.png) |
|
||||
| [**Ant**](https://gymnasium.farama.org/environments/mujoco/ant/) | ![Anb](https://spinningup.openai.com/en/latest/_images/pytorch_ant_performance.svg) | ![An1](./log/Ant_1.png) | ![An11](./log/Ant_11.png) |
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
LOG_DIR = "tests/rl/log"
|
||||
|
||||
color_map = {
|
||||
"ppo": "green",
|
||||
"sac": "goldenrod",
|
||||
"ddpg": "firebrick",
|
||||
"vpg": "cornflowerblue",
|
||||
"td3": "mediumpurple",
|
||||
}
|
||||
|
||||
|
||||
def smooth(data: np.ndarray, window_size: int) -> np.ndarray:
|
||||
if window_size > 1:
|
||||
"""
|
||||
smooth data with moving window average.
|
||||
that is,
|
||||
smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
|
||||
where the "smooth" param is width of that window (2k+1)
|
||||
"""
|
||||
y = np.ones(window_size)
|
||||
x = np.asarray(data)
|
||||
z = np.ones_like(x)
|
||||
smoothed_x = np.convolve(x, y, "same") / np.convolve(z, y, "same")
|
||||
return smoothed_x
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def get_off_policy_data(log_dir: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
file_path = os.path.join(log_dir, "metrics_full.csv")
|
||||
df = pd.read_csv(file_path)
|
||||
x, y = df["n_interactions"], df["val/avg_reward"]
|
||||
mask = ~np.isnan(y)
|
||||
x, y = x[mask], y[mask]
|
||||
return x, y
|
||||
|
||||
|
||||
def get_on_policy_data(log_dir: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
file_path = os.path.join(log_dir, "metrics_full.csv")
|
||||
df = pd.read_csv(file_path)
|
||||
x, y = df["n_interactions"], df["avg_reward"]
|
||||
return x, y
|
||||
|
||||
|
||||
def plot_performance_curves(title: str, dir_names: List[str], smooth_window_size: int) -> None:
|
||||
for algorithm in color_map.keys():
|
||||
if algorithm in ["ddpg", "sac", "td3"]:
|
||||
func = get_off_policy_data
|
||||
elif algorithm in ["ppo", "vpg"]:
|
||||
func = get_on_policy_data
|
||||
|
||||
log_dirs = [os.path.join(LOG_DIR, name) for name in dir_names if algorithm in name]
|
||||
series = [func(log_dir) for log_dir in log_dirs if os.path.exists(log_dir)]
|
||||
if len(series) == 0:
|
||||
continue
|
||||
|
||||
x = series[0][0]
|
||||
assert all(len(_x) == len(x) for _x, _ in series), f"Input data should share the same length!"
|
||||
ys = np.array([smooth(y, smooth_window_size) for _, y in series])
|
||||
y_mean = np.mean(ys, axis=0)
|
||||
y_std = np.std(ys, axis=0)
|
||||
|
||||
plt.plot(x, y_mean, label=algorithm, color=color_map[algorithm])
|
||||
plt.fill_between(x, y_mean - y_std, y_mean + y_std, color=color_map[algorithm], alpha=0.2)
|
||||
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.title(title)
|
||||
plt.xlabel("Total Env Interactions")
|
||||
plt.ylabel(f"Average Trajectory Return \n(moving average with window size = {smooth_window_size})")
|
||||
plt.savefig(os.path.join(LOG_DIR, f"{title}_{smooth_window_size}.png"), bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--smooth", "-s", type=int, default=11, help="smooth window size")
|
||||
args = parser.parse_args()
|
||||
|
||||
for env_name in ["HalfCheetah", "Hopper", "Walker2d", "Swimmer", "Ant"]:
|
||||
plot_performance_curves(
|
||||
title=env_name,
|
||||
dir_names=[
|
||||
f"{algorithm}_{env_name.lower()}_{seed}"
|
||||
for algorithm in ["ppo", "sac", "ddpg"]
|
||||
for seed in [42, 729, 1024, 2023, 3500]
|
||||
],
|
||||
smooth_window_size=args.smooth,
|
||||
)
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
from maro.cli.local.commands import run
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("conf_path", help="Path of the job deployment")
|
||||
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
|
||||
parser.add_argument("--seed", type=int, help="The random seed set before running this job")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
run(conf_path=args.conf_path, containerize=False, seed=args.seed, evaluate_only=args.evaluate_only)
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Normal
|
||||
from torch.optim import Adam
|
||||
|
||||
from maro.rl.model import ContinuousACBasedNet, VNet
|
||||
from maro.rl.model.fc_block import FullyConnected
|
||||
from maro.rl.policy import ContinuousRLPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
|
||||
|
||||
from tests.rl.gym_wrapper.common import (
|
||||
action_lower_bound,
|
||||
action_upper_bound,
|
||||
gym_action_dim,
|
||||
gym_state_dim,
|
||||
learn_env,
|
||||
num_agents,
|
||||
test_env,
|
||||
)
|
||||
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [64, 32],
|
||||
"activation": torch.nn.Tanh,
|
||||
}
|
||||
critic_net_conf = {
|
||||
"hidden_dims": [64, 32],
|
||||
"activation": torch.nn.Tanh,
|
||||
}
|
||||
actor_learning_rate = 3e-4
|
||||
critic_learning_rate = 1e-3
|
||||
|
||||
|
||||
class MyContinuousACBasedNet(ContinuousACBasedNet):
|
||||
def __init__(self, state_dim: int, action_dim: int) -> None:
|
||||
super(MyContinuousACBasedNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
|
||||
|
||||
log_std = -0.5 * np.ones(action_dim, dtype=np.float32)
|
||||
self._log_std = torch.nn.Parameter(torch.as_tensor(log_std))
|
||||
self._mu_net = FullyConnected(
|
||||
input_dim=state_dim,
|
||||
hidden_dims=actor_net_conf["hidden_dims"],
|
||||
output_dim=action_dim,
|
||||
activation=actor_net_conf["activation"],
|
||||
)
|
||||
self._optim = Adam(self.parameters(), lr=actor_learning_rate)
|
||||
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
distribution = self._distribution(states)
|
||||
actions = distribution.sample()
|
||||
logps = distribution.log_prob(actions).sum(axis=-1)
|
||||
return actions, logps
|
||||
|
||||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
distribution = self._distribution(states)
|
||||
logps = distribution.log_prob(actions).sum(axis=-1)
|
||||
return logps
|
||||
|
||||
def _distribution(self, states: torch.Tensor) -> Normal:
|
||||
mu = self._mu_net(states.float())
|
||||
std = torch.exp(self._log_std)
|
||||
return Normal(mu, std)
|
||||
|
||||
|
||||
class MyVCriticNet(VNet):
|
||||
def __init__(self, state_dim: int) -> None:
|
||||
super(MyVCriticNet, self).__init__(state_dim=state_dim)
|
||||
self._critic = FullyConnected(
|
||||
input_dim=state_dim,
|
||||
output_dim=1,
|
||||
hidden_dims=critic_net_conf["hidden_dims"],
|
||||
activation=critic_net_conf["activation"],
|
||||
)
|
||||
self._optim = Adam(self._critic.parameters(), lr=critic_learning_rate)
|
||||
|
||||
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._critic(states.float()).squeeze(-1)
|
||||
|
||||
|
||||
def get_ac_policy(
|
||||
name: str,
|
||||
action_lower_bound: list,
|
||||
action_upper_bound: list,
|
||||
gym_state_dim: int,
|
||||
gym_action_dim: int,
|
||||
) -> ContinuousRLPolicy:
|
||||
return ContinuousRLPolicy(
|
||||
name=name,
|
||||
action_range=(action_lower_bound, action_upper_bound),
|
||||
policy_net=MyContinuousACBasedNet(gym_state_dim, gym_action_dim),
|
||||
)
|
||||
|
||||
|
||||
def get_ac_trainer(name: str, state_dim: int) -> ActorCriticTrainer:
|
||||
return ActorCriticTrainer(
|
||||
name=name,
|
||||
reward_discount=0.99,
|
||||
params=ActorCriticParams(
|
||||
get_v_critic_net_func=lambda: MyVCriticNet(state_dim),
|
||||
grad_iters=80,
|
||||
lam=0.97,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
algorithm = "ac"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
get_ac_policy(f"{algorithm}_{i}.policy", action_lower_bound, action_upper_bound, gym_state_dim, gym_action_dim)
|
||||
for i in range(num_agents)
|
||||
]
|
||||
trainers = [get_ac_trainer(f"{algorithm}_{i}", gym_state_dim) for i in range(num_agents)]
|
||||
|
||||
device_mapping = None
|
||||
if torch.cuda.is_available():
|
||||
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)}
|
||||
|
||||
|
||||
rl_component_bundle = RLComponentBundle(
|
||||
env_sampler=GymEnvSampler(
|
||||
learn_env=learn_env,
|
||||
test_env=test_env,
|
||||
policies=policies,
|
||||
agent2policy=agent2policy,
|
||||
),
|
||||
agent2policy=agent2policy,
|
||||
policies=policies,
|
||||
trainers=trainers,
|
||||
device_mapping=device_mapping,
|
||||
)
|
||||
|
||||
__all__ = ["rl_component_bundle"]
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for GYM scenario.
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
job: gym_rl_workflow
|
||||
scenario_path: "tests/rl/tasks/ac"
|
||||
log_path: "tests/rl/log/ac"
|
||||
main:
|
||||
num_episodes: 1000
|
||||
num_steps: null
|
||||
eval_schedule: 5
|
||||
num_eval_episodes: 10
|
||||
min_n_sample: 5000
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
rollout:
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
training:
|
||||
mode: simple
|
||||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: null
|
||||
interval: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from gym import spaces
|
||||
from torch.optim import Adam
|
||||
|
||||
from maro.rl.model import QNet
|
||||
from maro.rl.model.algorithm_nets.ddpg import ContinuousDDPGNet
|
||||
from maro.rl.model.fc_block import FullyConnected
|
||||
from maro.rl.policy import ContinuousRLPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.training.algorithms import DDPGParams, DDPGTrainer
|
||||
from maro.rl.utils import ndarray_to_tensor
|
||||
|
||||
from tests.rl.gym_wrapper.common import (
|
||||
action_limit,
|
||||
action_lower_bound,
|
||||
action_upper_bound,
|
||||
gym_action_dim,
|
||||
gym_action_space,
|
||||
gym_state_dim,
|
||||
learn_env,
|
||||
num_agents,
|
||||
test_env,
|
||||
)
|
||||
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activation": torch.nn.ReLU,
|
||||
"output_activation": torch.nn.Tanh,
|
||||
}
|
||||
critic_net_conf = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activation": torch.nn.ReLU,
|
||||
}
|
||||
actor_learning_rate = 1e-3
|
||||
critic_learning_rate = 1e-3
|
||||
|
||||
|
||||
class MyContinuousDDPGNet(ContinuousDDPGNet):
|
||||
def __init__(
|
||||
self,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
action_limit: float,
|
||||
action_space: spaces.Space,
|
||||
noise_scale: float = 0.1,
|
||||
) -> None:
|
||||
super(MyContinuousDDPGNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
|
||||
|
||||
self._net = FullyConnected(
|
||||
input_dim=state_dim,
|
||||
output_dim=action_dim,
|
||||
hidden_dims=actor_net_conf["hidden_dims"],
|
||||
activation=actor_net_conf["activation"],
|
||||
output_activation=actor_net_conf["output_activation"],
|
||||
)
|
||||
self._optim = Adam(self._net.parameters(), lr=critic_learning_rate)
|
||||
self._action_limit = action_limit
|
||||
self._noise_scale = noise_scale
|
||||
self._action_space = action_space
|
||||
|
||||
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
|
||||
action = self._net(states) * self._action_limit
|
||||
if exploring:
|
||||
noise = torch.randn(self.action_dim) * self._noise_scale
|
||||
action += noise.to(action.device)
|
||||
action = torch.clamp(action, -self._action_limit, self._action_limit)
|
||||
return action
|
||||
|
||||
def _get_random_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[ndarray_to_tensor(self._action_space.sample(), device=self._device) for _ in range(states.shape[0])],
|
||||
)
|
||||
|
||||
|
||||
class MyQCriticNet(QNet):
|
||||
def __init__(self, state_dim: int, action_dim: int) -> None:
|
||||
super(MyQCriticNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
|
||||
self._critic = FullyConnected(
|
||||
input_dim=state_dim + action_dim,
|
||||
output_dim=1,
|
||||
hidden_dims=critic_net_conf["hidden_dims"],
|
||||
activation=critic_net_conf["activation"],
|
||||
)
|
||||
self._optim = Adam(self._critic.parameters(), lr=critic_learning_rate)
|
||||
|
||||
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
return self._critic(torch.cat([states, actions], dim=1).float()).squeeze(-1)
|
||||
|
||||
|
||||
def get_ddpg_policy(
|
||||
name: str,
|
||||
action_lower_bound: list,
|
||||
action_upper_bound: list,
|
||||
gym_state_dim: int,
|
||||
gym_action_dim: int,
|
||||
action_limit: float,
|
||||
) -> ContinuousRLPolicy:
|
||||
return ContinuousRLPolicy(
|
||||
name=name,
|
||||
action_range=(action_lower_bound, action_upper_bound),
|
||||
policy_net=MyContinuousDDPGNet(gym_state_dim, gym_action_dim, action_limit, gym_action_space),
|
||||
warmup=10000,
|
||||
)
|
||||
|
||||
|
||||
def get_ddpg_trainer(name: str, state_dim: int, action_dim: int) -> DDPGTrainer:
|
||||
return DDPGTrainer(
|
||||
name=name,
|
||||
reward_discount=0.99,
|
||||
replay_memory_capacity=1000000,
|
||||
batch_size=100,
|
||||
params=DDPGParams(
|
||||
get_q_critic_net_func=lambda: MyQCriticNet(state_dim, action_dim),
|
||||
num_epochs=50,
|
||||
n_start_train=1000,
|
||||
soft_update_coef=0.005,
|
||||
update_target_every=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
algorithm = "ddpg"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
get_ddpg_policy(
|
||||
f"{algorithm}_{i}.policy",
|
||||
action_lower_bound,
|
||||
action_upper_bound,
|
||||
gym_state_dim,
|
||||
gym_action_dim,
|
||||
action_limit,
|
||||
)
|
||||
for i in range(num_agents)
|
||||
]
|
||||
trainers = [get_ddpg_trainer(f"{algorithm}_{i}", gym_state_dim, gym_action_dim) for i in range(num_agents)]
|
||||
|
||||
device_mapping = None
|
||||
if torch.cuda.is_available():
|
||||
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)}
|
||||
|
||||
|
||||
rl_component_bundle = RLComponentBundle(
|
||||
env_sampler=GymEnvSampler(
|
||||
learn_env=learn_env,
|
||||
test_env=test_env,
|
||||
policies=policies,
|
||||
agent2policy=agent2policy,
|
||||
),
|
||||
agent2policy=agent2policy,
|
||||
policies=policies,
|
||||
trainers=trainers,
|
||||
device_mapping=device_mapping,
|
||||
)
|
||||
|
||||
__all__ = ["rl_component_bundle"]
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for GYM scenario.
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
job: gym_rl_workflow
|
||||
scenario_path: "tests/rl/tasks/ddpg"
|
||||
log_path: "tests/rl/log/ddpg_walker2d"
|
||||
main:
|
||||
num_episodes: 80000
|
||||
num_steps: 50
|
||||
eval_schedule: 200
|
||||
num_eval_episodes: 10
|
||||
min_n_sample: 1
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
rollout:
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
training:
|
||||
mode: simple
|
||||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: null
|
||||
interval: 200
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.training.algorithms.ppo import PPOParams, PPOTrainer
|
||||
|
||||
from tests.rl.gym_wrapper.common import (
|
||||
action_lower_bound,
|
||||
action_upper_bound,
|
||||
gym_action_dim,
|
||||
gym_state_dim,
|
||||
learn_env,
|
||||
num_agents,
|
||||
test_env,
|
||||
)
|
||||
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
|
||||
from tests.rl.tasks.ac import MyVCriticNet, get_ac_policy
|
||||
|
||||
get_ppo_policy = get_ac_policy
|
||||
|
||||
|
||||
def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer:
|
||||
return PPOTrainer(
|
||||
name=name,
|
||||
reward_discount=0.99,
|
||||
replay_memory_capacity=4000,
|
||||
batch_size=4000,
|
||||
params=PPOParams(
|
||||
get_v_critic_net_func=lambda: MyVCriticNet(state_dim),
|
||||
grad_iters=80,
|
||||
lam=0.97,
|
||||
clip_ratio=0.2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
algorithm = "ppo"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
get_ppo_policy(f"{algorithm}_{i}.policy", action_lower_bound, action_upper_bound, gym_state_dim, gym_action_dim)
|
||||
for i in range(num_agents)
|
||||
]
|
||||
trainers = [get_ppo_trainer(f"{algorithm}_{i}", gym_state_dim) for i in range(num_agents)]
|
||||
|
||||
device_mapping = None
|
||||
if torch.cuda.is_available():
|
||||
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)}
|
||||
|
||||
rl_component_bundle = RLComponentBundle(
|
||||
env_sampler=GymEnvSampler(
|
||||
learn_env=learn_env,
|
||||
test_env=test_env,
|
||||
policies=policies,
|
||||
agent2policy=agent2policy,
|
||||
max_episode_length=1000,
|
||||
),
|
||||
agent2policy=agent2policy,
|
||||
policies=policies,
|
||||
trainers=trainers,
|
||||
device_mapping=device_mapping,
|
||||
)
|
||||
|
||||
__all__ = ["rl_component_bundle"]
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for GYM scenario.
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
job: gym_rl_workflow
|
||||
scenario_path: "tests/rl/tasks/ppo"
|
||||
log_path: "tests/rl/log/ppo_walker2d"
|
||||
main:
|
||||
num_episodes: 1000
|
||||
num_steps: 4000
|
||||
eval_schedule: 5
|
||||
num_eval_episodes: 10
|
||||
min_n_sample: 1
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
rollout:
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
training:
|
||||
mode: simple
|
||||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: null
|
||||
interval: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from gym import spaces
|
||||
from torch.distributions import Normal
|
||||
from torch.optim import Adam
|
||||
|
||||
from maro.rl.model import ContinuousSACNet, QNet
|
||||
from maro.rl.model.fc_block import FullyConnected
|
||||
from maro.rl.policy import ContinuousRLPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.training.algorithms import SoftActorCriticParams, SoftActorCriticTrainer
|
||||
from maro.rl.utils import ndarray_to_tensor
|
||||
|
||||
from tests.rl.gym_wrapper.common import (
|
||||
action_limit,
|
||||
action_lower_bound,
|
||||
action_upper_bound,
|
||||
gym_action_dim,
|
||||
gym_action_space,
|
||||
gym_state_dim,
|
||||
learn_env,
|
||||
num_agents,
|
||||
test_env,
|
||||
)
|
||||
from tests.rl.gym_wrapper.env_sampler import GymEnvSampler
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activation": torch.nn.ReLU,
|
||||
}
|
||||
critic_net_conf = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activation": torch.nn.ReLU,
|
||||
}
|
||||
actor_learning_rate = 1e-3
|
||||
critic_learning_rate = 1e-3
|
||||
|
||||
LOG_STD_MAX = 2
|
||||
LOG_STD_MIN = -20
|
||||
|
||||
|
||||
class MyContinuousSACNet(ContinuousSACNet):
|
||||
def __init__(self, state_dim: int, action_dim: int, action_limit: float, action_space: spaces.Space) -> None:
|
||||
super(MyContinuousSACNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
|
||||
|
||||
self._net = FullyConnected(
|
||||
input_dim=state_dim,
|
||||
output_dim=actor_net_conf["hidden_dims"][-1],
|
||||
hidden_dims=actor_net_conf["hidden_dims"][:-1],
|
||||
activation=actor_net_conf["activation"],
|
||||
output_activation=actor_net_conf["activation"],
|
||||
)
|
||||
self._mu = torch.nn.Linear(actor_net_conf["hidden_dims"][-1], action_dim)
|
||||
self._log_std = torch.nn.Linear(actor_net_conf["hidden_dims"][-1], action_dim)
|
||||
self._action_limit = action_limit
|
||||
self._optim = Adam(self.parameters(), lr=actor_learning_rate)
|
||||
|
||||
self._action_space = action_space
|
||||
|
||||
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
net_out = self._net(states.float())
|
||||
mu = self._mu(net_out)
|
||||
log_std = torch.clamp(self._log_std(net_out), LOG_STD_MIN, LOG_STD_MAX)
|
||||
std = torch.exp(log_std)
|
||||
|
||||
pi_distribution = Normal(mu, std)
|
||||
pi_action = pi_distribution.rsample() if exploring else mu
|
||||
|
||||
logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
|
||||
logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(axis=1)
|
||||
|
||||
pi_action = torch.tanh(pi_action) * self._action_limit
|
||||
|
||||
return pi_action, logp_pi
|
||||
|
||||
def _get_random_actions_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[ndarray_to_tensor(self._action_space.sample(), device=self._device) for _ in range(states.shape[0])],
|
||||
)
|
||||
|
||||
|
||||
class MyQCriticNet(QNet):
|
||||
def __init__(self, state_dim: int, action_dim: int) -> None:
|
||||
super(MyQCriticNet, self).__init__(state_dim=state_dim, action_dim=action_dim)
|
||||
self._critic = FullyConnected(
|
||||
input_dim=state_dim + action_dim,
|
||||
output_dim=1,
|
||||
hidden_dims=critic_net_conf["hidden_dims"],
|
||||
activation=critic_net_conf["activation"],
|
||||
)
|
||||
self._optim = Adam(self._critic.parameters(), lr=critic_learning_rate)
|
||||
|
||||
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
return self._critic(torch.cat([states, actions], dim=1).float()).squeeze(-1)
|
||||
|
||||
|
||||
def get_sac_policy(
|
||||
name: str,
|
||||
action_lower_bound: list,
|
||||
action_upper_bound: list,
|
||||
gym_state_dim: int,
|
||||
gym_action_dim: int,
|
||||
action_limit: float,
|
||||
) -> ContinuousRLPolicy:
|
||||
return ContinuousRLPolicy(
|
||||
name=name,
|
||||
action_range=(action_lower_bound, action_upper_bound),
|
||||
policy_net=MyContinuousSACNet(gym_state_dim, gym_action_dim, action_limit, action_space=gym_action_space),
|
||||
warmup=10000,
|
||||
)
|
||||
|
||||
|
||||
def get_sac_trainer(name: str, state_dim: int, action_dim: int) -> SoftActorCriticTrainer:
|
||||
return SoftActorCriticTrainer(
|
||||
name=name,
|
||||
reward_discount=0.99,
|
||||
replay_memory_capacity=1000000,
|
||||
batch_size=100,
|
||||
params=SoftActorCriticParams(
|
||||
get_q_critic_net_func=lambda: MyQCriticNet(state_dim, action_dim),
|
||||
update_target_every=1,
|
||||
entropy_coef=0.2,
|
||||
num_epochs=50,
|
||||
n_start_train=1000,
|
||||
soft_update_coef=0.005,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
algorithm = "sac"
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
policies = [
|
||||
get_sac_policy(
|
||||
f"{algorithm}_{i}.policy",
|
||||
action_lower_bound,
|
||||
action_upper_bound,
|
||||
gym_state_dim,
|
||||
gym_action_dim,
|
||||
action_limit,
|
||||
)
|
||||
for i in range(num_agents)
|
||||
]
|
||||
trainers = [get_sac_trainer(f"{algorithm}_{i}", gym_state_dim, gym_action_dim) for i in range(num_agents)]
|
||||
|
||||
device_mapping = None
|
||||
if torch.cuda.is_available():
|
||||
device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)}
|
||||
|
||||
rl_component_bundle = RLComponentBundle(
|
||||
env_sampler=GymEnvSampler(
|
||||
learn_env=learn_env,
|
||||
test_env=test_env,
|
||||
policies=policies,
|
||||
agent2policy=agent2policy,
|
||||
),
|
||||
agent2policy=agent2policy,
|
||||
policies=policies,
|
||||
trainers=trainers,
|
||||
device_mapping=device_mapping,
|
||||
)
|
||||
|
||||
__all__ = ["rl_component_bundle"]
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for GYM scenario.
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
job: gym_rl_workflow
|
||||
scenario_path: "tests/rl/tasks/sac"
|
||||
log_path: "tests/rl/log/sac_walker2d"
|
||||
main:
|
||||
num_episodes: 80000
|
||||
num_steps: 50
|
||||
eval_schedule: 200
|
||||
num_eval_episodes: 10
|
||||
min_n_sample: 1
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
rollout:
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
training:
|
||||
mode: simple
|
||||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: null
|
||||
interval: 200
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
|
@ -311,7 +311,7 @@ class TestFrame(unittest.TestCase):
|
|||
self.assertListEqual([0.0, 0.0, 0.0, 0.0, 9.0], list(states)[0:5])
|
||||
|
||||
# 2 padding (NAN) in the end
|
||||
self.assertTrue((states[-2:].astype(np.int) == 0).all())
|
||||
self.assertTrue((states[-2:].astype(int) == 0).all())
|
||||
|
||||
states = static_snapshot[1::"a3"]
|
||||
|
||||
|
|