Merge V0.3 into master: update decision event logic & rl component bundle (#569)
* updated images and refined doc * updated images * updated CIM-AC example * refined proxy retry logic * call policy update only for AbsCorePolicy * add limitation of AbsCorePolicy in Actor.collect() * refined actor to return only experiences for policies that received new experiences * fix MsgKey issue in rollout_manager * fix typo in learner * call exit function for parallel rollout manager * update supply chain example distributed training scripts * 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype * reformat render * fix supply chain business engine action type problem * reset supply chain example render figsize from 4 to 3 * Add render to all modes of supply chain example * fix or policy typos * 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes * refined parallel policy manager * updated rl/__init__/py * fixed lint issues and CIM local learner bugs * deleted unwanted supply_chain test files * revised default config for cim-dqn * removed test_store.py as it is no longer needed * 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms * updated figures * removed unwanted import * refactored CIM-DQN example * added MultiProcessRolloutManager and MultiProcessTrainingManager * updated doc * lint issue fix * lint issue fix * fixed import formatting * [Feature] Prioritized Experience Replay (#355) * added prioritized experience replay * deleted unwanted supply_chain test files * fixed import order * import fix * fixed lint issues * fixed import formatting * added note in docstring that rank-based PER has yet to be implemented Co-authored-by: ysqyang <v-yangqi@microsoft.com> * rm AbsDecisionGenerator * small fixes * bug fix * reorganized training folder structure * fixed lint issues * fixed lint issues * policy manager refined * lint fix * restructured CIM-dqn sync code * added policy version index and used it as a measure of experience staleness * lint issue fix * lint issue fix * switched log_dir and proxy_kwargs order * cim example refinement * eval schedule sorted only when it's a list * eval schedule sorted only when it's a list * update sc env wrapper * added docker scripts for cim-dqn * refactored example folder structure and added workflow templates * fixed lint issues * fixed lint issues * fixed template bugs * removed unused imports * refactoring sc in progress * simplified cim meta * fixed build.sh path bug * template refinement * deleted obsolete svgs * updated learner logs * minor edits * refactored templates for easy merge with async PR * added component names for rollout manager and policy manager * fixed incorrect position to add last episode to eval schedule * added max_lag option in templates * formatting edit in docker_compose_yml script * moved local learner and early stopper outside sync_tools * refactored rl toolkit folder structure * refactored rl toolkit folder structure * moved env_wrapper and agent_wrapper inside rl/learner * refined scripts * fixed typo in script * changes needed for running sc * removed unwanted imports * config change for testing sc scenario * changes for perf testing * Asynchronous Training (#364) * remote inference code draft * changed actor to rollout_worker and updated init files * removed unwanted import * updated inits * more async code * added async scripts * added async training code & scripts for CIM-dqn * changed async to async_tools to avoid conflict with python keyword * reverted unwanted change to dockerfile * added doc for policy server * addressed PR comments and fixed a bug in docker_compose_yml.py * fixed lint issue * resolved PR comment * resolved merge conflicts * added async templates * added proxy.close() for actor and policy_server * fixed incorrect position to add last episode to eval schedule * reverted unwanted changes * added missing async files * rm unwanted echo in kill.sh Co-authored-by: ysqyang <v-yangqi@microsoft.com> * renamed sync to synchronous and async to asynchronous to avoid conflict with keyword * added missing policy version increment in LocalPolicyManager * refined rollout manager recv logic * removed a debugging print * added sleep in distributed launcher to avoid hanging * updated api doc and rl toolkit doc * refined dynamic imports using importlib * 1. moved policy update triggers to policy manager; 2. added version control in policy manager * fixed a few bugs and updated cim RL example * fixed a few more bugs * added agent wrapper instantiation to workflows * added agent wrapper instantiation to workflows * removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet * fixed incorrect get_ac_policy signature for CIM * moved exploration inside core policy * added state to exploration call to support context-dependent exploration * separated non_rl_policy_index and rl_policy_index in workflows * modified sc example code according to workflow changes * modified sc example code according to workflow changes * added replay_agent_ids parameter to get_env_func for RL examples * fixed a few bugs * added maro/simulator/scenarios/supply_chain as bind mount * added post-step, post-collect, post-eval and post-update callbacks * fixed lint issues * fixed lint issues * moved instantiation of policy manager inside simple learner * fixed env_wrapper get_reward signature * minor edits * removed get_eperience kwargs from env_wrapper * 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows * added rollout exp disribution option in RL examples * removed unwanted files * 1. made logger internal in learner; 2 removed logger creation in abs classes * checked out supply chain test files from v0.2_sc * 1. added missing model.eval() to choose_action; 2.added entropy features to AC * fixed a bug in ac entropy * abbreviated coefficient to coeff * removed -dqn from job name in rl example config * added tmp patch to dev.df * renamed image name for running rl examples * added get_loss interface for core policies * added policy manager in rl_toolkit.rst * 1. env_wrapper bug fix; 2. policy manager update logic refinement * refactored policy and algorithms * policy interface redesigned * refined policy interfaces * fixed typo * fixed bugs in refactored policy interface * fixed some bugs * refactoring in progress * policy interface and policy manager redesigned * 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts * fixed bug in distributed policy manager * fixed lint issues * fixed lint issues * added scipy in setup * 1. trimmed rollout manager code; 2. added option to docker scripts * updated api doc for policy manager * 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script * 1. simplified rl example structure; 2. fixed lint issues * further rl toolkit code simplifications * more numpy-based optimization in RL toolkit * moved replay buffer inside policy * bug fixes * numpy optimization and associated refactoring * extracted shaping logic out of env_sampler * fixed bug in CIM shaping and lint issues * preliminary implemetation of parallel batch inference * fixed bug in ddpg transition recording * put get_state, get_env_actions, get_reward back in EnvSampler * simplified exploration and core model interfaces * bug fixes and doc update * added improve() interface for RLPolicy for single-thread support * fixed simple policy manager bug * updated doc, rst, notebook * updated notebook * fixed lint issues * fixed entropy bugs in ac.py * reverted to simple policy manager as default * 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit * fixed lint issues and updated rl toolkit images * removed obsolete images * added back agent2policy for general workflow use * V0.2 rl refinement dist (#377) * Support `slice` operation in ExperienceSet * Support naive distributed policy training by proxy * Dynamically allocate trainers according to number of experience * code check * code check * code check * Fix a bug in distributed trianing with no gradient * Code check * Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy * 1.call allocate_trainer() at first of update(); 2.refine according to code review * Code check * Refine code with new interface * Update docs of PolicyManger and ExperienceSet * Add images for rl_toolkit docs * Update diagram of PolicyManager * Refine with new interface * Extract allocation strategy into `allocation_strategy.py` * add `distributed_learn()` in policies for data-parallel training * Update doc of RL_toolkit * Add gradient workers for data-parallel * Refine code and update docs * Lint check * Refine by comments * Rename `trainer` to `worker` * Rename `distributed_learn` to `learn_with_data_parallel` * Refine allocator and remove redundant code in policy_manager * remove arugments in allocate_by_policy and so on * added checkpointing for simple and multi-process policy managers * 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager * added missing set_state and get_state for CIM policies * removed blank line * updated RL workflow README * Integrate `data_parallel` arguments into `worker_allocator` (#402) * 1. simplified workflow config; 2. added comments to CIM shaping * lint issue fix * 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading * 1. moved post_step callback inside env sampler; 2. updated README for rl workflows * refined READEME for CIM * VM scheduling with RL (#375) * added part of vm scheduling RL code * refined vm env_wrapper code style * added DQN * added get_experiences func for ac in vm scheduling * added post_step callback to env wrapper * moved Aiming's tracking and plotting logic into callbacks * added eval env wrapper * renamed AC config variable name for VM * vm scheduling RL code finished * updated README * fixed various bugs and hard coding for vm_scheduling * uncommented callbacks for VM scheduling * Minor revision for better code style * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * vm example refactoring * fixed bugs in vm_scheduling * removed unwanted files from cim dir * reverted to simple policy manager as default * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * resolved rebase conflicts * fixed bugs in vm_scheduling * added get_state and set_state to vm_scheduling policy models * updated README for vm_scheduling with RL Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * SC refinement (#397) * Refine test scripts & pending_order_daily logic * Refactor code for better code style: complete type hint, correct typos, remove unused items. Refactor code for better code style: complete type hint, correct typos, remove unused items. * Polish test_supply_chain.py * update import format * Modify vehicle steps logic & remove outdated test case * Optimize imports * Optimize imports * Lint error * Lint error * Lint error * Add SupplyChainAction * Lint error Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * refined workflow scripts * fixed bug in ParallelAgentWrapper * 1. fixed lint issues; 2. refined main script in workflows * lint issue fix * restored default config for rl example * Update rollout.py * refined env var processing in policy manager workflow * added hasattr check in agent wrapper * updated docker_compose_yml.py * Minor refinement * Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412) * Prepare to merge master_mirror * Lint error * Minor * Merge latest master into v0.3 (#426) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * Minor * Remove docs/source/examples/multi_agent_dqn_cim.rst * Update .gitignore * Update .gitignore Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> * Change `Env.set_seed()` logic (#456) * Change Env.set_seed() logic * Redesign CIM reset logic; fix lint issues; * Lint * Seed type assertion * Remove all SC related files (#473) * RL Toolkit V3 (#471) * added daemon=True for multi-process rollout, policy manager and inference * removed obsolete files * [REDO][PR#406]V0.2 rl refinement taskq (#408) * Add a usable task_queue * Rename some variables * 1. Add ; 2. Integrate related files; 3. Remove * merge `data_parallel` and `num_grad_workers` into `data_parallelism` * Fix bugs in docker_compose_yml.py and Simple/Multi-process mode. * Move `grad_worker` into marl/rl/workflows * 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue. * Refine code and update docs of `TaskQueue` * Support priority for tasks in `task_queue` * Update diagram of policy manager and task queue. * Add configurable `single_task_limit` and correct docstring about `data_parallelism` * Fix lint errors in `supply chain` * RL policy redesign (V2) (#405) * Drafi v2.0 for V2 * Polish models with more comments * Polish policies with more comments * Lint * Lint * Add developer doc for models. * Add developer doc for policies. * Remove policy manager V2 since it is not used and out-of-date * Lint * Lint * refined messy workflow code * merged 'scenario_dir' and 'scenario' in rl config * 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods * 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2 * merged cim and cim_v2 * lint issue fix * refined logging logic * lint issue fix * reversed unwanted changes * . . . . ReplayMemory & IndexScheduler ReplayMemory & IndexScheduler . MultiReplayMemory get_actions_with_logps EnvSampler on the road EnvSampler Minor * LearnerManager * Use batch to transfer data & add SHAPE_CHECK_FLAG * Rename learner to trainer * Add property for policy._is_exploring * CIM test scenario for V3. Manual test passed. Next step: run it, make it works. * env_sampler.py could run * env_sampler refine on the way * First runnable version done * AC could run, but the result is bad. Need to check the logic * Refine abstract method & shape check error info. * Docs * Very detailed compare. Try again. * AC done * DQN check done * Minor * DDPG, not tested * Minors * A rough draft of MAAC * Cannot use CIM as the multi-agent scenario. * Minor * MAAC refinement on the way * Remove ActionWithAux * Refine batch & memory * MAAC example works * Reproduce-able fix. Policy share between env_sampler and trainer_manager. * Detail refinement * Simplify the user configed workflow * Minor * Refine example codes * Minor polishment * Migrate rollout_manager to V3 * Error on the way * Redesign torch.device management * Rl v3 maddpg (#418) * Add MADDPG trainer * Fit independent critics and shared critic modes. * Add a new property: num_policies * Lint * Fix a bug in `sum(rewards)` * Rename `MADDPG` to `DiscreteMADDPG` and fix type hint. * Rename maddpg in examples. * Preparation for data parallel (#420) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * Rename train worker to train ops; add placeholder for abstract methods; * Lint Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> * [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * dsitributed training pipeline draft * added temporary test files for review purposes * Several code style refinements (#451) * Polish rl_v3/utils/ * Polish rl_v3/distributed/ * Polish rl_v3/policy_trainer/abs_trainer.py * fixed merge conflicts * unified sync and async interfaces * refactored rl_v3; refinement in progress * Finish the runnable pipeline under new design * Remove outdated files; refine class names; optimize imports; * Lint * Minor maddpg related refinement * Lint Co-authored-by: Default <huo53926@126.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Miner bug fix * Coroutine-related bug fix ("get_policy_state") (#452) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * deleted unwanted folder * removed unwanted changes * resolved PR452 comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Quick fix * Redesign experience recording logic (#453) * Two not important fix * Temp draft. Prepare to WFH * Done * Lint * Lint * Calculating advantages / returns (#454) * V1.0 * Complete DDPG * Rl v3 hanging issue fix (#455) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * Final test & format. Ready to merge. Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * Rl v3 parallel rollout (#457) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * load balancing dispatcher * added parallel rollout * lint * Tracker variable type issue; rename to env_sampler_creator; * Rl v3 parallel rollout follow ups (#458) * AbsWorker & AbsDispatcher * Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds. * Fix policy_creator reuse bug * Format code * Merge AbsTrainerManager & SimpleTrainerManager * AC test passed * Lint * Remove AbsTrainer.build() method. Put all initialization operations into __init__ * Redesign AC preprocess batches logic Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * MADDPG performance bug fix (#459) * Fix MARL (MADDPG) terminal recording bug; some other minor refinements; * Restore Trainer.build() method * Calculate latest action in the get_actor_grad method in MADDPG. * Share critic bug fix * Rl v3 example update (#461) * updated vm_scheduling example and cim notebook * fixed bugs in vm_scheduling * added local train method * bug fix * modified async client logic to fix hidden issue * reverted to default config * fixed PR comments and some bugs * removed hardcode Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Done (#462) * Rl v3 load save (#463) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL Toolkit data parallelism revamp & config utils (#464) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments * 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local * 1. fixed rollout exit issue; 2. refined config * removed config file from example * fixed lint issues * fixed lint issues * added main.py under examples/rl * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL doc string (#465) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * Rl config doc (#467) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * added detailed doc * lint * wording refined * resolved some PR comments * resolved more PR comments * typo fix Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL online doc (#469) * Model, policy, trainer * RL workflows and env sampler doc in RST (#468) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * 1. type-sensitive env variable getter; 2. updated READMEs for examples * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * added detailed doc * lint * wording refined * resolved some PR comments * rewriting rl toolkit rst * resolved more PR comments * typo fix * updated rst Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: Default <huo53926@126.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Finish docs/source/key_components/rl_toolkit.rst * API doc * RL online doc image fix (#470) * resolved some PR comments * fix * fixed PR comments * added numfig=True setting in conf.py for sphinx Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Resolve PR comments * Add example github link Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Rl v3 pr comment resolution (#474) * added load/save feature * 1. resolved pr comments; 2. reverted maro/cli/k8s * fixed some bugs Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * RL renaming v2 (#476) * Change all Logger in RL to LoggerV2 * TrainerManager => TrainingManager * Add Trainer suffix to all algorithms * Finish docs * Update interface names * Minor fix * Cherry pick latest RL (#498) * Cherry pick * Remove SC related files * Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509) * Cherry pick RL changes from sc_refinement (2a4869) * Limit time display precision * RL incremental refactor (#501) * Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training. AC & PPO for continuous action policy; refine AC & PPO logic. Cherry pick RL changes from GYM-DDPG Cherry pick RL changes from GYM-SAC Minor error in doc string * Add min_n_sample in template and parser * Resolve PR comments. Fix a minor issue in SAC. * RL component bundle (#513) * CIM passed * Update workers * Refine annotations * VM passed * Code formatting. * Minor import loop issue * Pass batch in PPO again * Remove Scenario * Complete docs * Minor * Remove segment * Optimize logic in RLComponentBundle * Resolve PR comments * Move 'post methods from RLComponenetBundle to EnvSampler * Add method to get mapping of available tick to frame index (#415) * add method to get mapping of available tick to frame index * fix lint issue * fix naming issue * Cherry pick from sc_refinement (#527) * Cherry pick from sc_refinement * Cherry pick from sc_refinement * Refine `terminal` / `next_agent_state` logic (#531) * Optimize RL toolkit * Fix bug in terminal/next_state generation * Rewrite terminal/next_state logic again * Minor renaming * Minor bug fix * Resolve PR comments * Merge master into v0.3 (#536) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * add branch v0.3 to github workflow * update github test workflow * Update requirements.dev.txt (#444) Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact. * Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460) Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3) --- updated-dependencies: - dependency-name: ipython dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add & sort requirements.dev.txt Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Merge master into v0.3 (#545) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * add branch v0.3 to github workflow * update github test workflow * Update requirements.dev.txt (#444) Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact. * Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460) Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3) --- updated-dependencies: - dependency-name: ipython dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * update github woorkflow config * MARO v0.3: a new design of RL Toolkit, CLI refactorization, and corresponding updates. (#539) * refined proxy coding style * updated images and refined doc * updated images * updated CIM-AC example * refined proxy retry logic * call policy update only for AbsCorePolicy * add limitation of AbsCorePolicy in Actor.collect() * refined actor to return only experiences for policies that received new experiences * fix MsgKey issue in rollout_manager * fix typo in learner * call exit function for parallel rollout manager * update supply chain example distributed training scripts * 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype * reformat render * fix supply chain business engine action type problem * reset supply chain example render figsize from 4 to 3 * Add render to all modes of supply chain example * fix or policy typos * 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes * refined parallel policy manager * updated rl/__init__/py * fixed lint issues and CIM local learner bugs * deleted unwanted supply_chain test files * revised default config for cim-dqn * removed test_store.py as it is no longer needed * 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms * updated figures * removed unwanted import * refactored CIM-DQN example * added MultiProcessRolloutManager and MultiProcessTrainingManager * updated doc * lint issue fix * lint issue fix * fixed import formatting * [Feature] Prioritized Experience Replay (#355) * added prioritized experience replay * deleted unwanted supply_chain test files * fixed import order * import fix * fixed lint issues * fixed import formatting * added note in docstring that rank-based PER has yet to be implemented Co-authored-by: ysqyang <v-yangqi@microsoft.com> * rm AbsDecisionGenerator * small fixes * bug fix * reorganized training folder structure * fixed lint issues * fixed lint issues * policy manager refined * lint fix * restructured CIM-dqn sync code * added policy version index and used it as a measure of experience staleness * lint issue fix * lint issue fix * switched log_dir and proxy_kwargs order * cim example refinement * eval schedule sorted only when it's a list * eval schedule sorted only when it's a list * update sc env wrapper * added docker scripts for cim-dqn * refactored example folder structure and added workflow templates * fixed lint issues * fixed lint issues * fixed template bugs * removed unused imports * refactoring sc in progress * simplified cim meta * fixed build.sh path bug * template refinement * deleted obsolete svgs * updated learner logs * minor edits * refactored templates for easy merge with async PR * added component names for rollout manager and policy manager * fixed incorrect position to add last episode to eval schedule * added max_lag option in templates * formatting edit in docker_compose_yml script * moved local learner and early stopper outside sync_tools * refactored rl toolkit folder structure * refactored rl toolkit folder structure * moved env_wrapper and agent_wrapper inside rl/learner * refined scripts * fixed typo in script * changes needed for running sc * removed unwanted imports * config change for testing sc scenario * changes for perf testing * Asynchronous Training (#364) * remote inference code draft * changed actor to rollout_worker and updated init files * removed unwanted import * updated inits * more async code * added async scripts * added async training code & scripts for CIM-dqn * changed async to async_tools to avoid conflict with python keyword * reverted unwanted change to dockerfile * added doc for policy server * addressed PR comments and fixed a bug in docker_compose_yml.py * fixed lint issue * resolved PR comment * resolved merge conflicts * added async templates * added proxy.close() for actor and policy_server * fixed incorrect position to add last episode to eval schedule * reverted unwanted changes * added missing async files * rm unwanted echo in kill.sh Co-authored-by: ysqyang <v-yangqi@microsoft.com> * renamed sync to synchronous and async to asynchronous to avoid conflict with keyword * added missing policy version increment in LocalPolicyManager * refined rollout manager recv logic * removed a debugging print * added sleep in distributed launcher to avoid hanging * updated api doc and rl toolkit doc * refined dynamic imports using importlib * 1. moved policy update triggers to policy manager; 2. added version control in policy manager * fixed a few bugs and updated cim RL example * fixed a few more bugs * added agent wrapper instantiation to workflows * added agent wrapper instantiation to workflows * removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet * fixed incorrect get_ac_policy signature for CIM * moved exploration inside core policy * added state to exploration call to support context-dependent exploration * separated non_rl_policy_index and rl_policy_index in workflows * modified sc example code according to workflow changes * modified sc example code according to workflow changes * added replay_agent_ids parameter to get_env_func for RL examples * fixed a few bugs * added maro/simulator/scenarios/supply_chain as bind mount * added post-step, post-collect, post-eval and post-update callbacks * fixed lint issues * fixed lint issues * moved instantiation of policy manager inside simple learner * fixed env_wrapper get_reward signature * minor edits * removed get_eperience kwargs from env_wrapper * 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows * added rollout exp disribution option in RL examples * removed unwanted files * 1. made logger internal in learner; 2 removed logger creation in abs classes * checked out supply chain test files from v0.2_sc * 1. added missing model.eval() to choose_action; 2.added entropy features to AC * fixed a bug in ac entropy * abbreviated coefficient to coeff * removed -dqn from job name in rl example config * added tmp patch to dev.df * renamed image name for running rl examples * added get_loss interface for core policies * added policy manager in rl_toolkit.rst * 1. env_wrapper bug fix; 2. policy manager update logic refinement * refactored policy and algorithms * policy interface redesigned * refined policy interfaces * fixed typo * fixed bugs in refactored policy interface * fixed some bugs * refactoring in progress * policy interface and policy manager redesigned * 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts * fixed bug in distributed policy manager * fixed lint issues * fixed lint issues * added scipy in setup * 1. trimmed rollout manager code; 2. added option to docker scripts * updated api doc for policy manager * 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script * 1. simplified rl example structure; 2. fixed lint issues * further rl toolkit code simplifications * more numpy-based optimization in RL toolkit * moved replay buffer inside policy * bug fixes * numpy optimization and associated refactoring * extracted shaping logic out of env_sampler * fixed bug in CIM shaping and lint issues * preliminary implemetation of parallel batch inference * fixed bug in ddpg transition recording * put get_state, get_env_actions, get_reward back in EnvSampler * simplified exploration and core model interfaces * bug fixes and doc update * added improve() interface for RLPolicy for single-thread support * fixed simple policy manager bug * updated doc, rst, notebook * updated notebook * fixed lint issues * fixed entropy bugs in ac.py * reverted to simple policy manager as default * 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit * fixed lint issues and updated rl toolkit images * removed obsolete images * added back agent2policy for general workflow use * V0.2 rl refinement dist (#377) * Support `slice` operation in ExperienceSet * Support naive distributed policy training by proxy * Dynamically allocate trainers according to number of experience * code check * code check * code check * Fix a bug in distributed trianing with no gradient * Code check * Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy * 1.call allocate_trainer() at first of update(); 2.refine according to code review * Code check * Refine code with new interface * Update docs of PolicyManger and ExperienceSet * Add images for rl_toolkit docs * Update diagram of PolicyManager * Refine with new interface * Extract allocation strategy into `allocation_strategy.py` * add `distributed_learn()` in policies for data-parallel training * Update doc of RL_toolkit * Add gradient workers for data-parallel * Refine code and update docs * Lint check * Refine by comments * Rename `trainer` to `worker` * Rename `distributed_learn` to `learn_with_data_parallel` * Refine allocator and remove redundant code in policy_manager * remove arugments in allocate_by_policy and so on * added checkpointing for simple and multi-process policy managers * 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager * added missing set_state and get_state for CIM policies * removed blank line * updated RL workflow README * Integrate `data_parallel` arguments into `worker_allocator` (#402) * 1. simplified workflow config; 2. added comments to CIM shaping * lint issue fix * 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading * 1. moved post_step callback inside env sampler; 2. updated README for rl workflows * refined READEME for CIM * VM scheduling with RL (#375) * added part of vm scheduling RL code * refined vm env_wrapper code style * added DQN * added get_experiences func for ac in vm scheduling * added post_step callback to env wrapper * moved Aiming's tracking and plotting logic into callbacks * added eval env wrapper * renamed AC config variable name for VM * vm scheduling RL code finished * updated README * fixed various bugs and hard coding for vm_scheduling * uncommented callbacks for VM scheduling * Minor revision for better code style * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * vm example refactoring * fixed bugs in vm_scheduling * removed unwanted files from cim dir * reverted to simple policy manager as default * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * resolved rebase conflicts * fixed bugs in vm_scheduling * added get_state and set_state to vm_scheduling policy models * updated README for vm_scheduling with RL Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * SC refinement (#397) * Refine test scripts & pending_order_daily logic * Refactor code for better code style: complete type hint, correct typos, remove unused items. Refactor code for better code style: complete type hint, correct typos, remove unused items. * Polish test_supply_chain.py * update import format * Modify vehicle steps logic & remove outdated test case * Optimize imports * Optimize imports * Lint error * Lint error * Lint error * Add SupplyChainAction * Lint error Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * refined workflow scripts * fixed bug in ParallelAgentWrapper * 1. fixed lint issues; 2. refined main script in workflows * lint issue fix * restored default config for rl example * Update rollout.py * refined env var processing in policy manager workflow * added hasattr check in agent wrapper * updated docker_compose_yml.py * Minor refinement * Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412) * Prepare to merge master_mirror * Lint error * Minor * Merge latest master into v0.3 (#426) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * Minor * Remove docs/source/examples/multi_agent_dqn_cim.rst * Update .gitignore * Update .gitignore Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> * Change `Env.set_seed()` logic (#456) * Change Env.set_seed() logic * Redesign CIM reset logic; fix lint issues; * Lint * Seed type assertion * Remove all SC related files (#473) * RL Toolkit V3 (#471) * added daemon=True for multi-process rollout, policy manager and inference * removed obsolete files * [REDO][PR#406]V0.2 rl refinement taskq (#408) * Add a usable task_queue * Rename some variables * 1. Add ; 2. Integrate related files; 3. Remove * merge `data_parallel` and `num_grad_workers` into `data_parallelism` * Fix bugs in docker_compose_yml.py and Simple/Multi-process mode. * Move `grad_worker` into marl/rl/workflows * 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue. * Refine code and update docs of `TaskQueue` * Support priority for tasks in `task_queue` * Update diagram of policy manager and task queue. * Add configurable `single_task_limit` and correct docstring about `data_parallelism` * Fix lint errors in `supply chain` * RL policy redesign (V2) (#405) * Drafi v2.0 for V2 * Polish models with more comments * Polish policies with more comments * Lint * Lint * Add developer doc for models. * Add developer doc for policies. * Remove policy manager V2 since it is not used and out-of-date * Lint * Lint * refined messy workflow code * merged 'scenario_dir' and 'scenario' in rl config * 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods * 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2 * merged cim and cim_v2 * lint issue fix * refined logging logic * lint issue fix * reversed unwanted changes * . . . . ReplayMemory & IndexScheduler ReplayMemory & IndexScheduler . MultiReplayMemory get_actions_with_logps EnvSampler on the road EnvSampler Minor * LearnerManager * Use batch to transfer data & add SHAPE_CHECK_FLAG * Rename learner to trainer * Add property for policy._is_exploring * CIM test scenario for V3. Manual test passed. Next step: run it, make it works. * env_sampler.py could run * env_sampler refine on the way * First runnable version done * AC could run, but the result is bad. Need to check the logic * Refine abstract method & shape check error info. * Docs * Very detailed compare. Try again. * AC done * DQN check done * Minor * DDPG, not tested * Minors * A rough draft of MAAC * Cannot use CIM as the multi-agent scenario. * Minor * MAAC refinement on the way * Remove ActionWithAux * Refine batch & memory * MAAC example works * Reproduce-able fix. Policy share between env_sampler and trainer_manager. * Detail refinement * Simplify the user configed workflow * Minor * Refine example codes * Minor polishment * Migrate rollout_manager to V3 * Error on the way * Redesign torch.device management * Rl v3 maddpg (#418) * Add MADDPG trainer * Fit independent critics and shared critic modes. * Add a new property: num_policies * Lint * Fix a bug in `sum(rewards)` * Rename `MADDPG` to `DiscreteMADDPG` and fix type hint. * Rename maddpg in examples. * Preparation for data parallel (#420) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * Rename train worker to train ops; add placeholder for abstract methods; * Lint Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> * [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * dsitributed training pipeline draft * added temporary test files for review purposes * Several code style refinements (#451) * Polish rl_v3/utils/ * Polish rl_v3/distributed/ * Polish rl_v3/policy_trainer/abs_trainer.py * fixed merge conflicts * unified sync and async interfaces * refactored rl_v3; refinement in progress * Finish the runnable pipeline under new design * Remove outdated files; refine class names; optimize imports; * Lint * Minor maddpg related refinement * Lint Co-authored-by: Default <huo53926@126.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Miner bug fix * Coroutine-related bug fix ("get_policy_state") (#452) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * deleted unwanted folder * removed unwanted changes * resolved PR452 comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Quick fix * Redesign experience recording logic (#453) * Two not important fix * Temp draft. Prepare to WFH * Done * Lint * Lint * Calculating advantages / returns (#454) * V1.0 * Complete DDPG * Rl v3 hanging issue fix (#455) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * Final test & format. Ready to merge. Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * Rl v3 parallel rollout (#457) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * load balancing dispatcher * added parallel rollout * lint * Tracker variable type issue; rename to env_sampler_creator; * Rl v3 parallel rollout follow ups (#458) * AbsWorker & AbsDispatcher * Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds. * Fix policy_creator reuse bug * Format code * Merge AbsTrainerManager & SimpleTrainerManager * AC test passed * Lint * Remove AbsTrainer.build() method. Put all initialization operations into __init__ * Redesign AC preprocess batches logic Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * MADDPG performance bug fix (#459) * Fix MARL (MADDPG) terminal recording bug; some other minor refinements; * Restore Trainer.build() method * Calculate latest action in the get_actor_grad method in MADDPG. * Share critic bug fix * Rl v3 example update (#461) * updated vm_scheduling example and cim notebook * fixed bugs in vm_scheduling * added local train method * bug fix * modified async client logic to fix hidden issue * reverted to default config * fixed PR comments and some bugs * removed hardcode Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Done (#462) * Rl v3 load save (#463) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL Toolkit data parallelism revamp & config utils (#464) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments * 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local * 1. fixed rollout exit issue; 2. refined config * removed config file from example * fixed lint issues * fixed lint issues * added main.py under examples/rl * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL doc string (#465) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * Rl config doc (#467) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * added detailed doc * lint * wording refined * resolved some PR comments * resolved more PR comments * typo fix Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL online doc (#469) * Model, policy, trainer * RL workflows and env sampler doc in RST (#468) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * 1. type-sensitive env variable getter; 2. updated READMEs for examples * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * added detailed doc * lint * wording refined * resolved some PR comments * rewriting rl toolkit rst * resolved more PR comments * typo fix * updated rst Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: Default <huo53926@126.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Finish docs/source/key_components/rl_toolkit.rst * API doc * RL online doc image fix (#470) * resolved some PR comments * fix * fixed PR comments * added numfig=True setting in conf.py for sphinx Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Resolve PR comments * Add example github link Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Rl v3 pr comment resolution (#474) * added load/save feature * 1. resolved pr comments; 2. reverted maro/cli/k8s * fixed some bugs Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * RL renaming v2 (#476) * Change all Logger in RL to LoggerV2 * TrainerManager => TrainingManager * Add Trainer suffix to all algorithms * Finish docs * Update interface names * Minor fix * Cherry pick latest RL (#498) * Cherry pick * Remove SC related files * Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509) * Cherry pick RL changes from sc_refinement (2a4869) * Limit time display precision * RL incremental refactor (#501) * Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training. AC & PPO for continuous action policy; refine AC & PPO logic. Cherry pick RL changes from GYM-DDPG Cherry pick RL changes from GYM-SAC Minor error in doc string * Add min_n_sample in template and parser * Resolve PR comments. Fix a minor issue in SAC. * RL component bundle (#513) * CIM passed * Update workers * Refine annotations * VM passed * Code formatting. * Minor import loop issue * Pass batch in PPO again * Remove Scenario * Complete docs * Minor * Remove segment * Optimize logic in RLComponentBundle * Resolve PR comments * Move 'post methods from RLComponenetBundle to EnvSampler * Add method to get mapping of available tick to frame index (#415) * add method to get mapping of available tick to frame index * fix lint issue * fix naming issue * Cherry pick from sc_refinement (#527) * Cherry pick from sc_refinement * Cherry pick from sc_refinement * Refine `terminal` / `next_agent_state` logic (#531) * Optimize RL toolkit * Fix bug in terminal/next_state generation * Rewrite terminal/next_state logic again * Minor renaming * Minor bug fix * Resolve PR comments * Merge master into v0.3 (#536) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * add branch v0.3 to github workflow * update github test workflow * Update requirements.dev.txt (#444) Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact. * Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460) Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3) --- updated-dependencies: - dependency-name: ipython dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add & sort requirements.dev.txt Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Remove random_config.py * Remove test_trajectory_utils.py * Pass tests * Update rl docs * Remove python 3.6 in test * Update docs Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Wang.Jinyu <jinywan@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: GQ.Chen <675865907@qq.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Logger bug hotfix (#543) * Rename param * Rename param * Quick fix in env_data_process * frame data precision issue fix (#544) * fix frame precision issue * add .xmake to .gitignore * update frame precision lost warning message * add assert to frame precision checking * typo fix * add TODO for future Long data type issue fix * Minor cleaning Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jinyu Wang <jinyu@RL4Inv.l1ea1prscrcu1p4sa0eapum5vc.bx.internal.cloudapp.net> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: GQ.Chen <675865907@qq.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> * Update requirements. (#552) * Fix several encoding issues; update requirements. * Test & minor * Remove torch in requirements.build.txt * Polish * Update README * Resolve PR comments * Keep working * Keep working * Update test requirements * Done (#554) * Update requirements in example and notebook (#553) * Update requirements in example and notebook * Remove autopep8 * Add jupyterlab packages back Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * Refine decision event logic (#559) * Add DecisionEventPayload * Change decision payload name * Refine action logic * Add doc for env.step * Restore pre-commit config * Resolve PR comments * Refactor decision event & action * Pre-commit * Resolve PR comments * Refine rl component bundle (#549) * Config files * Done * Minor bugfix * Add autoflake * Update isort exclude; add pre-commit to requirements * Check only isort * Minor * Format * Test passed * Run pre-commit * Minor bugfix in rl_component_bundle * Pass mypy * Fix a bug in RL notebook * A minor bug fix * Add upper bound for numpy version in test Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: GQ.Chen <675865907@qq.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: Huoran Li <huo53926@126.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jinyu Wang <jinyu@RL4Inv.l1ea1prscrcu1p4sa0eapum5vc.bx.internal.cloudapp.net>
This commit is contained in:
Родитель
38eb389df1
Коммит
6512879608
|
@ -121,14 +121,16 @@ of user-defined functions for message auto-handling, cluster provision, and job
|
|||
|
||||
```sh
|
||||
# Install MARO from source.
|
||||
bash scripts/install_maro.sh
|
||||
bash scripts/install_maro.sh;
|
||||
pip install -r ./requirements.dev.txt;
|
||||
```
|
||||
|
||||
- Windows
|
||||
|
||||
```powershell
|
||||
# Install MARO from source.
|
||||
.\scripts\install_maro.bat
|
||||
.\scripts\install_maro.bat;
|
||||
pip install -r ./requirements.dev.txt;
|
||||
```
|
||||
|
||||
- *Notes: If your package is not found, remember to set your PYTHONPATH*
|
||||
|
|
|
@ -326,236 +326,84 @@ In CIM scenario, there are 3 node types:
|
|||
port
|
||||
++++
|
||||
|
||||
capacity
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
The capacity of port for stocking containers.
|
||||
|
||||
empty
|
||||
*****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Empty container volume on the port.
|
||||
|
||||
full
|
||||
****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Laden container volume on the port.
|
||||
|
||||
on_shipper
|
||||
**********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Empty containers, which are released to the shipper.
|
||||
|
||||
on_consignee
|
||||
************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Laden containers, which are delivered to the consignee.
|
||||
|
||||
shortage
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick state. Shortage of empty container at current tick.
|
||||
|
||||
acc_storage
|
||||
***********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Accumulated shortage number to the current tick.
|
||||
|
||||
booking
|
||||
*******
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick state. Order booking number of a port at the current tick.
|
||||
|
||||
acc_booking
|
||||
***********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Accumulated order booking number of a port to the current tick.
|
||||
|
||||
fulfillment
|
||||
***********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Fulfilled order number of a port at the current tick.
|
||||
|
||||
acc_fulfillment
|
||||
***************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Accumulated fulfilled order number of a port to the current tick.
|
||||
|
||||
transfer_cost
|
||||
*************
|
||||
|
||||
type: float
|
||||
slots: 1
|
||||
|
||||
Cost of transferring container, which also covers loading and discharging cost.
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+==================+=======+========+==================================================================================+
|
||||
| capacity | int | 1 | The capacity of port for stocking containers. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| empty | int | 1 | Empty container volume on the port. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| full | int | 1 | Laden container volume on the port. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| on_shipper | int | 1 | Empty containers, which are released to the shipper. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| on_consignee | int | 1 | Laden containers, which are delivered to the consignee. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| shortage | int | 1 | Per tick state. Shortage of empty container at current tick. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| acc_storage | int | 1 | Accumulated shortage number to the current tick. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| booking | int | 1 | Per tick state. Order booking number of a port at the current tick. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| acc_booking | int | 1 | Accumulated order booking number of a port to the current tick. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| fulfillment | int | 1 | Fulfilled order number of a port at the current tick. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| acc_fulfillment | int | 1 | Accumulated fulfilled order number of a port to the current tick. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
| transfer_cost | float | 1 | Cost of transferring container, which also covers loading and discharging cost. |
|
||||
+------------------+-------+--------+----------------------------------------------------------------------------------+
|
||||
|
||||
vessel
|
||||
++++++
|
||||
|
||||
capacity
|
||||
********
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+========================+========+==========+=================================================================================================================================================================================================================================+
|
||||
| capacity | int | 1 | The capacity of vessel for transferring containers. NOTE: This attribute is ignored in current implementation. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| empty | int | 1 | Empty container volume on the vessel. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| full | int | 1 | Laden container volume on the vessel. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| remaining_space | int | 1 | Remaining space of the vessel. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| early_discharge | int | 1 | Discharged empty container number for loading laden containers. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| is_parking | short | 1 | Is parking or not |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| loc_port_idx | int | 1 | The port index the vessel is parking at. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| route_idx | int | 1 | Which route current vessel belongs to. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| last_loc_idx | int | 1 | Last stop port index in route, it is used to identify where is current vessel. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| next_loc_idx | int | 1 | Next stop port index in route, it is used to identify where is current vessel. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| past_stop_list | int | dynamic | NOTE: This and following attribute are special, that its slot number is determined by configuration, but different with a list attribute, its slot number is fixed at runtime. Stop indices that we have stopped in the past. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| past_stop_tick_list | int | dynamic | Ticks that we stopped at the port in the past. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| future_stop_list | int | dynamic | Stop indices that we will stop in the future. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| future_stop_tick_list | int | dynamic | Ticks that we will stop in the future. |
|
||||
+------------------------+--------+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
The capacity of vessel for transferring containers.
|
||||
|
||||
NOTE:
|
||||
This attribute is ignored in current implementation.
|
||||
|
||||
empty
|
||||
*****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Empty container volume on the vessel.
|
||||
|
||||
full
|
||||
****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Laden container volume on the vessel.
|
||||
|
||||
remaining_space
|
||||
***************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Remaining space of the vessel.
|
||||
|
||||
early_discharge
|
||||
***************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Discharged empty container number for loading laden containers.
|
||||
|
||||
route_idx
|
||||
*********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Which route current vessel belongs to.
|
||||
|
||||
last_loc_idx
|
||||
************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Last stop port index in route, it is used to identify where is current vessel.
|
||||
|
||||
next_loc_idx
|
||||
************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Next stop port index in route, it is used to identify where is current vessel.
|
||||
|
||||
past_stop_list
|
||||
**************
|
||||
|
||||
type: int
|
||||
slots: dynamic
|
||||
|
||||
NOTE:
|
||||
This and following attribute are special, that its slot number is determined by configuration,
|
||||
but different with a list attribute, its slot number is fixed at runtime.
|
||||
|
||||
Stop indices that we have stopped in the past.
|
||||
|
||||
past_stop_tick_list
|
||||
*******************
|
||||
|
||||
type: int
|
||||
slots: dynamic
|
||||
|
||||
Ticks that we stopped at the port in the past.
|
||||
|
||||
future_stop_list
|
||||
****************
|
||||
|
||||
type: int
|
||||
slots: dynamic
|
||||
|
||||
Stop indices that we will stop in the future.
|
||||
|
||||
future_stop_tick_list
|
||||
*********************
|
||||
|
||||
type: int
|
||||
slots: dynamic
|
||||
|
||||
Ticks that we will stop in the future.
|
||||
|
||||
matrices
|
||||
++++++++
|
||||
|
||||
Matrices node is used to store big matrix for ports, vessels and containers.
|
||||
|
||||
full_on_ports
|
||||
*************
|
||||
|
||||
type: int
|
||||
slots: port number * port number
|
||||
|
||||
Distribution of full from port to port.
|
||||
|
||||
full_on_vessels
|
||||
***************
|
||||
|
||||
type: int
|
||||
slots: vessel number * port number
|
||||
|
||||
Distribution of full from vessel to port.
|
||||
|
||||
vessel_plans
|
||||
************
|
||||
|
||||
type: int
|
||||
slots: vessel number * port number
|
||||
|
||||
Planed route info for vessels.
|
||||
+------------------+-------+------------------------------+---------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+==================+=======+==============================+=============================================+
|
||||
| full_on_ports | int | port number * port number | Distribution of full from port to port. |
|
||||
+------------------+-------+------------------------------+---------------------------------------------+
|
||||
| full_on_vessels | int | vessel number * port number | Distribution of full from vessel to port. |
|
||||
+------------------+-------+------------------------------+---------------------------------------------+
|
||||
| vessel_plans | int | vessel number * port number | Planed route info for vessels. |
|
||||
+------------------+-------+------------------------------+---------------------------------------------+
|
||||
|
||||
How to
|
||||
~~~~~~
|
||||
|
@ -597,133 +445,47 @@ Nodes and attributes in scenario
|
|||
station
|
||||
+++++++
|
||||
|
||||
bikes
|
||||
*****
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+===================+=======+========+===========================================================================================================+
|
||||
| bikes | int | 1 | How many bikes avaiable in current station. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| shortage | int | 1 | Per tick state. Lack number of bikes in current station. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| trip_requirement | int | 1 | Per tick states. How many requirements in current station. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| fulfillment | int | 1 | How many requirement is fit in current station. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| capacity | int | 1 | Max number of bikes this station can take. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| id | int | 1 | Id of current station. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| weekday | short | 1 | Weekday at current tick. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| temperature | short | 1 | Temperature at current tick. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| weather | short | 1 | Weather at current tick. (0: sunny, 1: rainy, 2: snowy, 3: sleet) |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| holiday | short | 1 | If it is holidy at current tick. (0: holiday, 1: not holiday) |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| extra_cost | int | 1 | Cost after we reach the capacity after executing action, we have to move extra bikes to other stations. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| transfer_cost | int | 1 | Cost to execute action to transfer bikes to other station. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| failed_return | int | 1 | Per tick state. How many bikes failed to return to current station. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
| min_bikes | int | 1 | Min bikes number in a frame. |
|
||||
+-------------------+-------+--------+-----------------------------------------------------------------------------------------------------------+
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
How many bikes avaiable in current station.
|
||||
|
||||
shortage
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick state. Lack number of bikes in current station.
|
||||
|
||||
trip_requirement
|
||||
****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick states. How many requirements in current station.
|
||||
|
||||
fulfillment
|
||||
***********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
How many requirement is fit in current station.
|
||||
|
||||
capacity
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Max number of bikes this station can take.
|
||||
|
||||
id
|
||||
+++
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Id of current station.
|
||||
|
||||
weekday
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Weekday at current tick.
|
||||
|
||||
temperature
|
||||
***********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Temperature at current tick.
|
||||
|
||||
weather
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Weather at current tick.
|
||||
|
||||
0: sunny, 1: rainy, 2: snowy, 3: sleet.
|
||||
|
||||
holiday
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
If it is holidy at current tick.
|
||||
|
||||
0: holiday, 1: not holiday
|
||||
|
||||
extra_cost
|
||||
**********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Cost after we reach the capacity after executing action, we have to move extra bikes
|
||||
to other stations.
|
||||
|
||||
transfer_cost
|
||||
*************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Cost to execute action to transfer bikes to other station.
|
||||
|
||||
failed_return
|
||||
*************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick state. How many bikes failed to return to current station.
|
||||
|
||||
min_bikes
|
||||
*********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Min bikes number in a frame.
|
||||
|
||||
matrices
|
||||
++++++++
|
||||
|
||||
trips_adj
|
||||
*********
|
||||
|
||||
type: int
|
||||
slots: station number * station number
|
||||
|
||||
Used to store trip requirement number between 2 stations.
|
||||
+------------+-------+----------------------------------+------------------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+============+=======+==================================+============================================================+
|
||||
| trips_adj | int | station number * station number | Used to store trip requirement number between 2 stations. |
|
||||
+------------+-------+----------------------------------+------------------------------------------------------------+
|
||||
|
||||
|
||||
VM-scheduling
|
||||
|
@ -743,315 +505,121 @@ Nodes and attributes in scenario
|
|||
Cluster
|
||||
+++++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Id of the cluster.
|
||||
|
||||
region_id
|
||||
*********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Region is of current cluster.
|
||||
|
||||
data_center_id
|
||||
**************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Data center id of current cluster.
|
||||
|
||||
total_machine_num
|
||||
******************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machines in the cluster.
|
||||
|
||||
empty_machine_num
|
||||
******************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
The number of empty machines in this cluster. A empty machine means that its allocated CPU cores are 0.
|
||||
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+====================+=======+========+==========================================================================================================+
|
||||
| id | short | 1 | Id of the cluster. |
|
||||
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
|
||||
| region_id | short | 1 | Region id of current cluster. |
|
||||
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
|
||||
| zond_id | short | 1 | Zone id of current cluster. |
|
||||
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
|
||||
| data_center_id | short | 1 | Data center id of current cluster. |
|
||||
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
|
||||
| total_machine_num | int | 1 | Total number of machines in the cluster. |
|
||||
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
|
||||
| empty_machine_num | int | 1 | The number of empty machines in this cluster. A empty machine means that its allocated CPU cores are 0. |
|
||||
+--------------------+-------+--------+----------------------------------------------------------------------------------------------------------+
|
||||
|
||||
data_centers
|
||||
++++++++++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Id of current data center.
|
||||
|
||||
region_id
|
||||
*********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Region id of current data center.
|
||||
|
||||
zone_id
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Zone id of current data center.
|
||||
|
||||
total_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machine in current data center.
|
||||
|
||||
empty_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
The number of empty machines in current data center.
|
||||
+--------------------+-------+--------+-------------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+====================+=======+========+=======================================================+
|
||||
| id | short | 1 | Id of current data center. |
|
||||
+--------------------+-------+--------+-------------------------------------------------------+
|
||||
| region_id | short | 1 | Region id of current data center. |
|
||||
+--------------------+-------+--------+-------------------------------------------------------+
|
||||
| zone_id | short | 1 | Zone id of current data center. |
|
||||
+--------------------+-------+--------+-------------------------------------------------------+
|
||||
| total_machine_num | int | 1 | Total number of machine in current data center. |
|
||||
+--------------------+-------+--------+-------------------------------------------------------+
|
||||
| empty_machine_num | int | 1 | The number of empty machines in current data center. |
|
||||
+--------------------+-------+--------+-------------------------------------------------------+
|
||||
|
||||
pms
|
||||
+++
|
||||
|
||||
Physical machine node.
|
||||
|
||||
id
|
||||
***
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+=====================+=======+========+=================================================================================+
|
||||
| id | int | 1 | Id of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| cpu_cores_capacity | short | 1 | Max number of cpu core can be used for current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| memory_capacity | short | 1 | Max number of memory can be used for current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| pm_type | short | 1 | Type of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| cpu_cores_allocated | short | 1 | How many cpu core is allocated. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| memory_allocated | short | 1 | How many memory is allocated. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| cpu_utilization | float | 1 | CPU utilization of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| energy_consumption | float | 1 | Energy consumption of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| oversubscribable | short | 1 | Physical machine non-oversubscribable is -1, empty: 0, oversubscribable is 1. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| region_id | short | 1 | Region id of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| zone_id | short | 1 | Zone id of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| data_center_id | short | 1 | Data center id of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| cluster_id | short | 1 | Cluster id of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
| rack_id | short | 1 | Rack id of current machine. |
|
||||
+---------------------+-------+--------+---------------------------------------------------------------------------------+
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Id of current machine.
|
||||
|
||||
cpu_cores_capacity
|
||||
******************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Max number of cpu core can be used for current machine.
|
||||
|
||||
memory_capacity
|
||||
***************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Max number of memory can be used for current machine.
|
||||
|
||||
pm_type
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Type of current machine.
|
||||
|
||||
cpu_cores_allocated
|
||||
*******************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
How many cpu core is allocated.
|
||||
|
||||
memory_allocated
|
||||
****************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
How many memory is allocated.
|
||||
|
||||
cpu_utilization
|
||||
***************
|
||||
|
||||
type: float
|
||||
slots: 1
|
||||
|
||||
CPU utilization of current machine.
|
||||
|
||||
energy_consumption
|
||||
******************
|
||||
|
||||
type: float
|
||||
slots: 1
|
||||
|
||||
Energy consumption of current machine.
|
||||
|
||||
oversubscribable
|
||||
****************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Physical machine type: non-oversubscribable is -1, empty: 0, oversubscribable is 1.
|
||||
|
||||
region_id
|
||||
*********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Region id of current machine.
|
||||
|
||||
zone_id
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Zone id of current machine.
|
||||
|
||||
data_center_id
|
||||
**************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Data center id of current machine.
|
||||
|
||||
cluster_id
|
||||
**********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Cluster id of current machine.
|
||||
|
||||
rack_id
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Rack id of current machine.
|
||||
|
||||
Rack
|
||||
rack
|
||||
++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Id of current rack.
|
||||
|
||||
region_id
|
||||
*********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Region id of current rack.
|
||||
|
||||
zone_id
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Zone id of current rack.
|
||||
|
||||
data_center_id
|
||||
**************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Data center id of current rack.
|
||||
|
||||
cluster_id
|
||||
**********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Cluster id of current rack.
|
||||
|
||||
total_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machines on this rack.
|
||||
|
||||
empty_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Number of machines that not in use on this rack.
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+====================+=======+========+===================================================+
|
||||
| id | int | 1 | Id of current rack. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| region_id | short | 1 | Region id of current rack. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| zone_id | short | 1 | Zone id of current rack. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| data_center_id | short | 1 | Data center id of current rack. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| cluster_id | short | 1 | Cluster id of current rack. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| total_machine_num | int | 1 | Total number of machines on this rack. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| empty_machine_num | int | 1 | Number of machines that not in use on this rack. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
|
||||
regions
|
||||
+++++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Id of curent region.
|
||||
|
||||
total_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machines in this region.
|
||||
|
||||
empty_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Number of machines that not in use in this region.
|
||||
+--------------------+-------+--------+------------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+====================+=======+========+======================================================+
|
||||
| id | short | 1 | Id of current region. |
|
||||
+--------------------+-------+--------+------------------------------------------------------+
|
||||
| total_machine_num | int | 1 | Total number of machines in this region. |
|
||||
+--------------------+-------+--------+------------------------------------------------------+
|
||||
| empty_machine_num | int | 1 | Number of machines that not in use in this region. |
|
||||
+--------------------+-------+--------+------------------------------------------------------+
|
||||
|
||||
zones
|
||||
+++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Id of this zone.
|
||||
|
||||
total_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machines in this zone.
|
||||
|
||||
empty_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Number of machines that not in use in this zone.
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| Field | Type | Slots | Description |
|
||||
+====================+=======+========+===================================================+
|
||||
| id | short | 1 | Id of this zone. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| region_id | short | 1 | Region id of current rack. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| total_machine_num | int | 1 | Total number of machines in this zone. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
| empty_machine_num | int | 1 | Number of machines that not in use in this zone. |
|
||||
+--------------------+-------+--------+---------------------------------------------------+
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .rl_component_bundle import CIMBundle as rl_component_bundle_cls
|
||||
from .rl_component_bundle import rl_component_bundle
|
||||
|
||||
__all__ = [
|
||||
"rl_component_bundle_cls",
|
||||
"rl_component_bundle",
|
||||
]
|
||||
|
|
|
@ -54,9 +54,9 @@ def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyG
|
|||
def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
|
||||
return ActorCriticTrainer(
|
||||
name=name,
|
||||
reward_discount=0.0,
|
||||
params=ActorCriticParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
|
||||
reward_discount=0.0,
|
||||
grad_iters=10,
|
||||
critic_loss_cls=torch.nn.SmoothL1Loss,
|
||||
min_logp=None,
|
||||
|
|
|
@ -55,14 +55,14 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli
|
|||
def get_dqn(name: str) -> DQNTrainer:
|
||||
return DQNTrainer(
|
||||
name=name,
|
||||
reward_discount=0.0,
|
||||
replay_memory_capacity=10000,
|
||||
batch_size=32,
|
||||
params=DQNParams(
|
||||
reward_discount=0.0,
|
||||
update_target_every=5,
|
||||
num_epochs=10,
|
||||
soft_update_coef=0.1,
|
||||
double=False,
|
||||
replay_memory_capacity=10000,
|
||||
random_overwrite=False,
|
||||
batch_size=32,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -62,8 +62,8 @@ def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePol
|
|||
def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer:
|
||||
return DiscreteMADDPGTrainer(
|
||||
name=name,
|
||||
reward_discount=0.0,
|
||||
params=DiscreteMADDPGParams(
|
||||
reward_discount=0.0,
|
||||
num_epoch=10,
|
||||
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
|
||||
shared_critic=False,
|
||||
|
|
|
@ -16,12 +16,11 @@ def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicy
|
|||
def get_ppo(state_dim: int, name: str) -> PPOTrainer:
|
||||
return PPOTrainer(
|
||||
name=name,
|
||||
reward_discount=0.0,
|
||||
params=PPOParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
|
||||
reward_discount=0.0,
|
||||
grad_iters=10,
|
||||
critic_loss_cls=torch.nn.SmoothL1Loss,
|
||||
min_logp=None,
|
||||
lam=0.0,
|
||||
clip_ratio=0.1,
|
||||
),
|
||||
|
|
|
@ -7,11 +7,6 @@ env_conf = {
|
|||
"durations": 560,
|
||||
}
|
||||
|
||||
if env_conf["topology"].startswith("toy"):
|
||||
num_agents = int(env_conf["topology"].split(".")[1][0])
|
||||
else:
|
||||
num_agents = int(env_conf["topology"].split(".")[1][:2])
|
||||
|
||||
port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
|
||||
vessel_attributes = ["empty", "full", "remaining_space"]
|
||||
|
||||
|
|
|
@ -1,77 +1,48 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from maro.rl.policy import AbsPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import AbsEnvSampler
|
||||
from maro.rl.training import AbsTrainer
|
||||
from maro.simulator import Env
|
||||
|
||||
from .algorithms.ac import get_ac, get_ac_policy
|
||||
from .algorithms.dqn import get_dqn, get_dqn_policy
|
||||
from .algorithms.maddpg import get_maddpg, get_maddpg_policy
|
||||
from .algorithms.ppo import get_ppo, get_ppo_policy
|
||||
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
|
||||
from examples.cim.rl.config import action_num, algorithm, env_conf, reward_shaping_conf, state_dim
|
||||
from examples.cim.rl.env_sampler import CIMEnvSampler
|
||||
|
||||
# Environments
|
||||
learn_env = Env(**env_conf)
|
||||
test_env = learn_env
|
||||
|
||||
class CIMBundle(RLComponentBundle):
|
||||
def get_env_config(self) -> dict:
|
||||
return env_conf
|
||||
# Agent, policy, and trainers
|
||||
num_agents = len(learn_env.agent_idx_list)
|
||||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
|
||||
if algorithm == "ac":
|
||||
policies = [get_ac_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
|
||||
trainers = [get_ac(state_dim, f"{algorithm}_{i}") for i in range(num_agents)]
|
||||
elif algorithm == "ppo":
|
||||
policies = [get_ppo_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
|
||||
trainers = [get_ppo(state_dim, f"{algorithm}_{i}") for i in range(num_agents)]
|
||||
elif algorithm == "dqn":
|
||||
policies = [get_dqn_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
|
||||
trainers = [get_dqn(f"{algorithm}_{i}") for i in range(num_agents)]
|
||||
elif algorithm == "discrete_maddpg":
|
||||
policies = [get_maddpg_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
|
||||
trainers = [get_maddpg(state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)]
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
def get_test_env_config(self) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
def get_env_sampler(self) -> AbsEnvSampler:
|
||||
return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"])
|
||||
|
||||
def get_agent2policy(self) -> Dict[Any, str]:
|
||||
return {agent: f"{algorithm}_{agent}.policy" for agent in self.env.agent_idx_list}
|
||||
|
||||
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
|
||||
if algorithm == "ac":
|
||||
policy_creator = {
|
||||
f"{algorithm}_{i}.policy": partial(get_ac_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "ppo":
|
||||
policy_creator = {
|
||||
f"{algorithm}_{i}.policy": partial(get_ppo_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "dqn":
|
||||
policy_creator = {
|
||||
f"{algorithm}_{i}.policy": partial(get_dqn_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "discrete_maddpg":
|
||||
policy_creator = {
|
||||
f"{algorithm}_{i}.policy": partial(get_maddpg_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return policy_creator
|
||||
|
||||
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
|
||||
if algorithm == "ac":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "ppo":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "dqn":
|
||||
trainer_creator = {f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}") for i in range(num_agents)}
|
||||
elif algorithm == "discrete_maddpg":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return trainer_creator
|
||||
# Build RLComponentBundle
|
||||
rl_component_bundle = RLComponentBundle(
|
||||
env_sampler=CIMEnvSampler(
|
||||
learn_env=learn_env,
|
||||
test_env=test_env,
|
||||
policies=policies,
|
||||
agent2policy=agent2policy,
|
||||
reward_eval_delay=reward_shaping_conf["time_window"],
|
||||
),
|
||||
agent2policy=agent2policy,
|
||||
policies=policies,
|
||||
trainers=trainers,
|
||||
)
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
PuLP==2.1
|
||||
matplotlib>=3.1.2
|
||||
pulp>=2.1.0
|
||||
tweepy>=4.10.0
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for CIM scenario.
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
# Run this workflow by executing one of the following commands:
|
||||
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
|
||||
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml
|
||||
|
||||
job: cim_rl_workflow
|
||||
scenario_path: "examples/cim/rl"
|
||||
log_path: "log/rl_job/cim.txt"
|
||||
main:
|
||||
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
|
||||
num_steps: null
|
||||
eval_schedule: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
rollout:
|
||||
parallelism:
|
||||
sampling: 3
|
||||
eval: null
|
||||
min_env_samples: 3
|
||||
grace_factor: 0.2
|
||||
controller:
|
||||
host: "127.0.0.1"
|
||||
port: 20000
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
training:
|
||||
mode: parallel
|
||||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: "checkpoint/rl_job/cim"
|
||||
interval: 5
|
||||
proxy:
|
||||
host: "127.0.0.1"
|
||||
frontend: 10000
|
||||
backend: 10001
|
||||
num_workers: 2
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
|
@ -12,7 +12,7 @@ import yaml
|
|||
from ilp_agent import IlpAgent
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import DecisionPayload
|
||||
from maro.simulator.scenarios.vm_scheduling import DecisionEvent
|
||||
from maro.simulator.scenarios.vm_scheduling.common import Action
|
||||
from maro.utils import LogFormat, Logger, convert_dottable
|
||||
|
||||
|
@ -46,7 +46,7 @@ if __name__ == "__main__":
|
|||
env.set_seed(config.env.seed)
|
||||
|
||||
metrics: object = None
|
||||
decision_event: DecisionPayload = None
|
||||
decision_event: DecisionEvent = None
|
||||
is_done: bool = False
|
||||
action: Action = None
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .rl_component_bundle import VMBundle as rl_component_bundle_cls
|
||||
from .rl_component_bundle import rl_component_bundle
|
||||
|
||||
__all__ = [
|
||||
"rl_component_bundle_cls",
|
||||
"rl_component_bundle",
|
||||
]
|
||||
|
|
|
@ -61,9 +61,9 @@ def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str)
|
|||
def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer:
|
||||
return ActorCriticTrainer(
|
||||
name=name,
|
||||
reward_discount=0.9,
|
||||
params=ActorCriticParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features),
|
||||
reward_discount=0.9,
|
||||
grad_iters=100,
|
||||
critic_loss_cls=torch.nn.MSELoss,
|
||||
min_logp=-20,
|
||||
|
|
|
@ -77,15 +77,15 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
|
|||
def get_dqn(name: str) -> DQNTrainer:
|
||||
return DQNTrainer(
|
||||
name=name,
|
||||
reward_discount=0.9,
|
||||
replay_memory_capacity=10000,
|
||||
batch_size=32,
|
||||
data_parallelism=2,
|
||||
params=DQNParams(
|
||||
reward_discount=0.9,
|
||||
update_target_every=5,
|
||||
num_epochs=100,
|
||||
soft_update_coef=0.1,
|
||||
double=False,
|
||||
replay_memory_capacity=10000,
|
||||
random_overwrite=False,
|
||||
batch_size=32,
|
||||
data_parallelism=2,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -5,14 +5,15 @@ import time
|
|||
from collections import defaultdict
|
||||
from os import makedirs
|
||||
from os.path import dirname, join, realpath
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from maro.rl.rollout import AbsEnvSampler, CacheElement
|
||||
from maro.rl.policy import AbsPolicy
|
||||
from maro.rl.rollout import AbsAgentWrapper, AbsEnvSampler, CacheElement, SimpleAgentWrapper
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent, PostponeAction
|
||||
|
||||
from .config import (
|
||||
num_features,
|
||||
|
@ -30,8 +31,25 @@ makedirs(plt_path, exist_ok=True)
|
|||
|
||||
|
||||
class VMEnvSampler(AbsEnvSampler):
|
||||
def __init__(self, learn_env: Env, test_env: Env) -> None:
|
||||
super(VMEnvSampler, self).__init__(learn_env, test_env)
|
||||
def __init__(
|
||||
self,
|
||||
learn_env: Env,
|
||||
test_env: Env,
|
||||
policies: List[AbsPolicy],
|
||||
agent2policy: Dict[Any, str],
|
||||
trainable_policies: List[str] = None,
|
||||
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
|
||||
reward_eval_delay: int = None,
|
||||
) -> None:
|
||||
super(VMEnvSampler, self).__init__(
|
||||
learn_env,
|
||||
test_env,
|
||||
policies,
|
||||
agent2policy,
|
||||
trainable_policies,
|
||||
agent_wrapper_cls,
|
||||
reward_eval_delay,
|
||||
)
|
||||
|
||||
self._learn_env.set_seed(seed)
|
||||
self._test_env.set_seed(test_seed)
|
||||
|
@ -44,7 +62,7 @@ class VMEnvSampler(AbsEnvSampler):
|
|||
|
||||
def _get_global_and_agent_state_impl(
|
||||
self,
|
||||
event: DecisionPayload,
|
||||
event: DecisionEvent,
|
||||
tick: int = None,
|
||||
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
|
||||
pm_state, vm_state = self._get_pm_state(), self._get_vm_state(event)
|
||||
|
@ -71,14 +89,14 @@ class VMEnvSampler(AbsEnvSampler):
|
|||
def _translate_to_env_action(
|
||||
self,
|
||||
action_dict: Dict[Any, Union[np.ndarray, List[object]]],
|
||||
event: DecisionPayload,
|
||||
event: DecisionEvent,
|
||||
) -> Dict[Any, object]:
|
||||
if action_dict["AGENT"] == self.num_pms:
|
||||
return {"AGENT": PostponeAction(vm_id=event.vm_id, postpone_step=1)}
|
||||
else:
|
||||
return {"AGENT": AllocateAction(vm_id=event.vm_id, pm_id=action_dict["AGENT"][0])}
|
||||
|
||||
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionPayload, tick: int) -> Dict[Any, float]:
|
||||
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:
|
||||
action = env_action_dict["AGENT"]
|
||||
conf = reward_shaping_conf if self._env == self._learn_env else test_reward_shaping_conf
|
||||
if isinstance(action, PostponeAction): # postponement
|
||||
|
@ -121,7 +139,7 @@ class VMEnvSampler(AbsEnvSampler):
|
|||
],
|
||||
)
|
||||
|
||||
def _get_allocation_reward(self, event: DecisionPayload, alpha: float, beta: float):
|
||||
def _get_allocation_reward(self, event: DecisionEvent, alpha: float, beta: float):
|
||||
vm_unit_price = self._env.business_engine._get_unit_price(
|
||||
event.vm_cpu_cores_requirement,
|
||||
event.vm_memory_requirement,
|
||||
|
|
|
@ -1,67 +1,39 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from maro.rl.policy import AbsPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import AbsEnvSampler
|
||||
from maro.rl.training import AbsTrainer
|
||||
from maro.simulator import Env
|
||||
|
||||
from examples.vm_scheduling.rl.algorithms.ac import get_ac, get_ac_policy
|
||||
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn, get_dqn_policy
|
||||
from .algorithms.ac import get_ac, get_ac_policy
|
||||
from .algorithms.dqn import get_dqn, get_dqn_policy
|
||||
from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf
|
||||
from examples.vm_scheduling.rl.env_sampler import VMEnvSampler
|
||||
|
||||
# Environments
|
||||
learn_env = Env(**env_conf)
|
||||
test_env = Env(**test_env_conf)
|
||||
|
||||
class VMBundle(RLComponentBundle):
|
||||
def get_env_config(self) -> dict:
|
||||
return env_conf
|
||||
# Agent, policy, and trainers
|
||||
action_num = num_pms + 1
|
||||
agent2policy = {"AGENT": f"{algorithm}.policy"}
|
||||
if algorithm == "ac":
|
||||
policies = [get_ac_policy(state_dim, action_num, num_features, f"{algorithm}.policy")]
|
||||
trainers = [get_ac(state_dim, num_features, algorithm)]
|
||||
elif algorithm == "dqn":
|
||||
policies = [get_dqn_policy(state_dim, action_num, num_features, f"{algorithm}.policy")]
|
||||
trainers = [get_dqn(algorithm)]
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
def get_test_env_config(self) -> Optional[dict]:
|
||||
return test_env_conf
|
||||
|
||||
def get_env_sampler(self) -> AbsEnvSampler:
|
||||
return VMEnvSampler(self.env, self.test_env)
|
||||
|
||||
def get_agent2policy(self) -> Dict[Any, str]:
|
||||
return {"AGENT": f"{algorithm}.policy"}
|
||||
|
||||
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
|
||||
action_num = num_pms + 1 # action could be any PM or postponement, hence the plus 1
|
||||
|
||||
if algorithm == "ac":
|
||||
policy_creator = {
|
||||
f"{algorithm}.policy": partial(
|
||||
get_ac_policy,
|
||||
state_dim,
|
||||
action_num,
|
||||
num_features,
|
||||
f"{algorithm}.policy",
|
||||
),
|
||||
}
|
||||
elif algorithm == "dqn":
|
||||
policy_creator = {
|
||||
f"{algorithm}.policy": partial(
|
||||
get_dqn_policy,
|
||||
state_dim,
|
||||
action_num,
|
||||
num_features,
|
||||
f"{algorithm}.policy",
|
||||
),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return policy_creator
|
||||
|
||||
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
|
||||
if algorithm == "ac":
|
||||
trainer_creator = {algorithm: partial(get_ac, state_dim, num_features, algorithm)}
|
||||
elif algorithm == "dqn":
|
||||
trainer_creator = {algorithm: partial(get_dqn, algorithm)}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return trainer_creator
|
||||
# Build RLComponentBundle
|
||||
rl_component_bundle = RLComponentBundle(
|
||||
env_sampler=VMEnvSampler(
|
||||
learn_env=learn_env,
|
||||
test_env=test_env,
|
||||
policies=policies,
|
||||
agent2policy=agent2policy,
|
||||
),
|
||||
agent2policy=agent2policy,
|
||||
policies=policies,
|
||||
trainers=trainers,
|
||||
)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent, PostponeAction
|
||||
from maro.simulator.scenarios.vm_scheduling.common import Action
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ class VMSchedulingAgent(object):
|
|||
def __init__(self, algorithm):
|
||||
self._algorithm = algorithm
|
||||
|
||||
def choose_action(self, decision_event: DecisionPayload, env: Env) -> Action:
|
||||
def choose_action(self, decision_event: DecisionEvent, env: Env) -> Action:
|
||||
"""This method will determine whether to postpone the current VM or allocate a PM to the current VM."""
|
||||
valid_pm_num: int = len(decision_event.valid_pms)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
|
||||
|
||||
|
||||
class BestFit(RuleBasedAlgorithm):
|
||||
|
@ -13,7 +13,7 @@ class BestFit(RuleBasedAlgorithm):
|
|||
super().__init__()
|
||||
self._metric_type: str = kwargs["metric_type"]
|
||||
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
|
||||
# Use a rule to choose a valid PM.
|
||||
chosen_idx: int = self._pick_pm_func(decision_event, env)
|
||||
# Take action to allocate on the chose PM.
|
||||
|
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
|
||||
|
||||
|
||||
class BinPacking(RuleBasedAlgorithm):
|
||||
|
@ -24,7 +24,7 @@ class BinPacking(RuleBasedAlgorithm):
|
|||
self._bins = [[] for _ in range(self._pm_cpu_core_num + 1)]
|
||||
self._bin_size = [0] * (self._pm_cpu_core_num + 1)
|
||||
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
|
||||
# Initialize the bin.
|
||||
self._init_bin()
|
||||
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
|
||||
|
||||
|
||||
class FirstFit(RuleBasedAlgorithm):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
|
||||
# Use a valid PM based on its order.
|
||||
chosen_idx: int = decision_event.valid_pms[0]
|
||||
# Take action to allocate on the chose PM.
|
||||
|
|
|
@ -6,14 +6,14 @@ import random
|
|||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
|
||||
|
||||
|
||||
class RandomPick(RuleBasedAlgorithm):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
|
||||
valid_pm_num: int = len(decision_event.valid_pms)
|
||||
# Random choose a valid PM.
|
||||
chosen_idx: int = random.randint(0, valid_pm_num - 1)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
|
||||
|
||||
|
||||
class RoundRobin(RuleBasedAlgorithm):
|
||||
|
@ -15,7 +15,7 @@ class RoundRobin(RuleBasedAlgorithm):
|
|||
kwargs["env"].snapshot_list["pms"][kwargs["env"].frame_index :: ["cpu_cores_capacity"]].shape[0]
|
||||
)
|
||||
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
|
||||
# Choose the valid PM which index is next to the previous chose PM's index
|
||||
chosen_idx: int = (self._prev_idx + 1) % self._pm_num
|
||||
while chosen_idx not in decision_event.valid_pms:
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
import abc
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent
|
||||
|
||||
|
||||
class RuleBasedAlgorithm(object):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
def allocate_vm(self, decision_event: DecisionEvent, env: Env) -> AllocateAction:
|
||||
"""This method will determine allocate which PM to the current VM."""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -25,7 +25,7 @@ def start_cim_dashboard(source_path: str, epoch_num: int, prefix: str):
|
|||
--ports.csv: Record ports' attributes in this file.
|
||||
--vessel.csv: Record vessels' attributes in this file.
|
||||
--matrices.csv: Record transfer volume information in this file.
|
||||
………………
|
||||
......
|
||||
--epoch_{epoch_num-1}
|
||||
--manifest.yml: Record basic info like scenario name, name of index_name_mapping file.
|
||||
--config.yml: Record the relationship between ports' index and name.
|
||||
|
|
|
@ -24,7 +24,7 @@ def start_citi_bike_dashboard(source_path: str, epoch_num: int, prefix: str):
|
|||
--stations.csv: Record stations' attributes in this file.
|
||||
--matrices.csv: Record transfer volume information in this file.
|
||||
--stations_summary.csv: Record the summary data of current epoch.
|
||||
………………
|
||||
......
|
||||
--epoch_{epoch_num-1}
|
||||
--manifest.yml: Record basic info like scenario name, name of index_name_mapping file.
|
||||
--full_stations.json: Record the relationship between ports' index and name.
|
||||
|
|
|
@ -28,7 +28,7 @@ def start_vis(source_path: str, force: str, **kwargs: dict):
|
|||
-input_file_folder_path
|
||||
--epoch_0 : Data of current epoch.
|
||||
--holder_info.csv: Attributes of current epoch.
|
||||
………………
|
||||
......
|
||||
--epoch_{epoch_num-1}
|
||||
--manifest.yml: Record basic info like scenario name, name of index_name_mapping file.
|
||||
--index_name_mapping file: Record the relationship between an index and its name.
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
class BaseDecisionEvent:
|
||||
"""Base class for all decision events.
|
||||
|
||||
We made this design for the convenience of users. As a summary, there are two types of events in MARO:
|
||||
- CascadeEvent & AtomEvent: used to drive the MARO Env / business engine.
|
||||
- DecisionEvent: exposed to users as a means of communication.
|
||||
|
||||
The latter one serves as the `payload` of the former ones inside of MARO Env.
|
||||
|
||||
Therefore, the related namings might be a little bit tricky.
|
||||
- Inside MARO Env: `decision_event` is actually a CascadeEvent. DecisionEvent is the payload of them.
|
||||
- Outside MARO Env (for users): `decision_event` is a DecisionEvent.
|
||||
"""
|
||||
|
||||
|
||||
class BaseAction:
|
||||
"""Base class for all action payloads"""
|
|
@ -4,8 +4,9 @@
|
|||
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, cast
|
||||
|
||||
from ..common import BaseAction, BaseDecisionEvent
|
||||
from .event import ActualEvent, AtomEvent, CascadeEvent
|
||||
from .event_linked_list import EventLinkedList
|
||||
from .event_pool import EventPool
|
||||
|
@ -122,9 +123,7 @@ class EventBuffer:
|
|||
Returns:
|
||||
AtomEvent: Atom event object
|
||||
"""
|
||||
event = self._event_pool.gen(tick, event_type, payload, False)
|
||||
assert isinstance(event, AtomEvent)
|
||||
return event
|
||||
return cast(AtomEvent, self._event_pool.gen(tick, event_type, payload, is_cascade=False))
|
||||
|
||||
def gen_cascade_event(self, tick: int, event_type: object, payload: object) -> CascadeEvent:
|
||||
"""Generate an cascade event that used to hold immediate events that
|
||||
|
@ -138,31 +137,32 @@ class EventBuffer:
|
|||
Returns:
|
||||
CascadeEvent: Cascade event object.
|
||||
"""
|
||||
event = self._event_pool.gen(tick, event_type, payload, True)
|
||||
assert isinstance(event, CascadeEvent)
|
||||
return event
|
||||
return cast(CascadeEvent, self._event_pool.gen(tick, event_type, payload, is_cascade=True))
|
||||
|
||||
def gen_decision_event(self, tick: int, payload: object) -> CascadeEvent:
|
||||
def gen_decision_event(self, tick: int, payload: BaseDecisionEvent) -> CascadeEvent:
|
||||
"""Generate a decision event that will stop current simulation, and ask agent for action.
|
||||
|
||||
Args:
|
||||
tick (int): Tick that the event will be processed.
|
||||
payload (object): Payload of event, used to pass data to handlers.
|
||||
payload (BaseDecisionEvent): Payload of event, used to pass data to handlers.
|
||||
Returns:
|
||||
CascadeEvent: Event object
|
||||
"""
|
||||
assert isinstance(payload, BaseDecisionEvent)
|
||||
return self.gen_cascade_event(tick, MaroEvents.PENDING_DECISION, payload)
|
||||
|
||||
def gen_action_event(self, tick: int, payload: object) -> CascadeEvent:
|
||||
def gen_action_event(self, tick: int, payloads: List[BaseAction]) -> CascadeEvent:
|
||||
"""Generate an event that used to dispatch action to business engine.
|
||||
|
||||
Args:
|
||||
tick (int): Tick that the event will be processed.
|
||||
payload (object): Payload of event, used to pass data to handlers.
|
||||
payloads (List[BaseAction]): Payloads of event, used to pass data to handlers.
|
||||
Returns:
|
||||
CascadeEvent: Event object
|
||||
"""
|
||||
return self.gen_cascade_event(tick, MaroEvents.TAKE_ACTION, payload)
|
||||
assert isinstance(payloads, list)
|
||||
assert all(isinstance(p, BaseAction) for p in payloads)
|
||||
return self.gen_cascade_event(tick, MaroEvents.TAKE_ACTION, payloads)
|
||||
|
||||
def register_event_handler(self, event_type: object, handler: Callable) -> None:
|
||||
"""Register an event with handler, when there is an event need to be processed,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
pyjwt
|
||||
numpy<1.20.0
|
||||
Cython>=0.29.14
|
||||
PyJWT>=2.4.0
|
||||
numpy>=1.19.0
|
||||
cython>=0.29.14
|
||||
altair>=4.1.0
|
||||
streamlit>=0.69.1
|
||||
tqdm>=4.51.0
|
||||
|
|
|
@ -3,8 +3,12 @@
|
|||
|
||||
from .abs_proxy import AbsProxy
|
||||
from .abs_worker import AbsWorker
|
||||
from .port_config import DEFAULT_ROLLOUT_PRODUCER_PORT, DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT
|
||||
|
||||
__all__ = [
|
||||
"AbsProxy",
|
||||
"AbsWorker",
|
||||
"DEFAULT_ROLLOUT_PRODUCER_PORT",
|
||||
"DEFAULT_TRAINING_FRONTEND_PORT",
|
||||
"DEFAULT_TRAINING_BACKEND_PORT",
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import zmq
|
||||
from tornado.ioloop import IOLoop
|
||||
|
@ -33,7 +34,7 @@ class AbsWorker(object):
|
|||
super(AbsWorker, self).__init__()
|
||||
|
||||
self._id = f"worker.{idx}"
|
||||
self._logger = logger if logger else DummyLogger()
|
||||
self._logger: Union[LoggerV2, DummyLogger] = logger if logger else DummyLogger()
|
||||
|
||||
# ZMQ sockets and streams
|
||||
self._context = Context.instance()
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
DEFAULT_ROLLOUT_PRODUCER_PORT = 20000
|
||||
DEFAULT_TRAINING_FRONTEND_PORT = 10000
|
||||
DEFAULT_TRAINING_BACKEND_PORT = 10001
|
|
@ -98,14 +98,15 @@ class MultiLinearExplorationScheduler(AbsExplorationScheduler):
|
|||
start_ep: int = 1,
|
||||
initial_value: float = None,
|
||||
) -> None:
|
||||
super().__init__(exploration_params, param_name, initial_value=initial_value)
|
||||
|
||||
# validate splits
|
||||
splits = [(start_ep, initial_value)] + splits + [(last_ep, final_value)]
|
||||
splits = [(start_ep, self._exploration_params[self.param_name])] + splits + [(last_ep, final_value)]
|
||||
splits.sort()
|
||||
for (ep, _), (ep2, _) in zip(splits, splits[1:]):
|
||||
if ep == ep2:
|
||||
raise ValueError("The zeroth element of split points must be unique")
|
||||
|
||||
super().__init__(exploration_params, param_name, initial_value=initial_value)
|
||||
self.final_value = final_value
|
||||
self._splits = splits
|
||||
self._ep = start_ep
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch.nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -18,7 +18,11 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
def __init__(self) -> None:
|
||||
super(AbsNet, self).__init__()
|
||||
|
||||
self._optim: Optional[Optimizer] = None
|
||||
@property
|
||||
def optim(self) -> Optimizer:
|
||||
optim = getattr(self, "_optim", None)
|
||||
assert isinstance(optim, Optimizer), "Each AbsNet must have an optimizer"
|
||||
return optim
|
||||
|
||||
def step(self, loss: torch.Tensor) -> None:
|
||||
"""Run a training step to update the net's parameters according to the given loss.
|
||||
|
@ -26,9 +30,9 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
Args:
|
||||
loss (torch.tensor): Loss used to update the model.
|
||||
"""
|
||||
self._optim.zero_grad()
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
self._optim.step()
|
||||
self.optim.step()
|
||||
|
||||
def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
"""Get the gradients with respect to all parameters according to the given loss.
|
||||
|
@ -39,7 +43,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
Returns:
|
||||
Gradients (Dict[str, torch.Tensor]): A dict that contains gradients for all parameters.
|
||||
"""
|
||||
self._optim.zero_grad()
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
return {name: param.grad for name, param in self.named_parameters()}
|
||||
|
||||
|
@ -51,7 +55,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
"""
|
||||
for name, param in self.named_parameters():
|
||||
param.grad = grad[name]
|
||||
self._optim.step()
|
||||
self.optim.step()
|
||||
|
||||
def _forward_unimplemented(self, *input: Any) -> None:
|
||||
pass
|
||||
|
@ -64,7 +68,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
"""
|
||||
return {
|
||||
"network": self.state_dict(),
|
||||
"optim": self._optim.state_dict(),
|
||||
"optim": self.optim.state_dict(),
|
||||
}
|
||||
|
||||
def set_state(self, net_state: dict) -> None:
|
||||
|
@ -74,7 +78,7 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
net_state (dict): A dict that contains the net's state.
|
||||
"""
|
||||
self.load_state_dict(net_state["network"])
|
||||
self._optim.load_state_dict(net_state["optim"])
|
||||
self.optim.load_state_dict(net_state["optim"])
|
||||
|
||||
def soft_update(self, other_model: AbsNet, tau: float) -> None:
|
||||
"""Soft update the net's parameters according to another net, i.e.,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -46,7 +46,7 @@ class FullyConnected(nn.Module):
|
|||
skip_connection: bool = False,
|
||||
dropout_p: float = None,
|
||||
gradient_threshold: float = None,
|
||||
name: str = None,
|
||||
name: str = "NONAME",
|
||||
) -> None:
|
||||
super(FullyConnected, self).__init__()
|
||||
self._input_dim = input_dim
|
||||
|
@ -84,7 +84,7 @@ class FullyConnected(nn.Module):
|
|||
self._name = name
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self._net(x)
|
||||
out = self._net(x.float())
|
||||
if self._skip_connection:
|
||||
out += x
|
||||
return self._softmax(out) if self._softmax else out
|
||||
|
@ -101,12 +101,12 @@ class FullyConnected(nn.Module):
|
|||
def output_dim(self) -> int:
|
||||
return self._output_dim
|
||||
|
||||
def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> torch.nn.Module:
|
||||
def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> nn.Module:
|
||||
"""Build a basic layer.
|
||||
|
||||
BN -> Linear -> Activation -> Dropout
|
||||
"""
|
||||
components = []
|
||||
components: List[Tuple[str, nn.Module]] = []
|
||||
if self._batch_norm:
|
||||
components.append(("batch_norm", nn.BatchNorm1d(input_dim)))
|
||||
components.append(("linear", nn.Linear(input_dim, output_dim)))
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -27,14 +27,14 @@ class AbsPolicy(object, metaclass=ABCMeta):
|
|||
self._trainable = trainable
|
||||
|
||||
@abstractmethod
|
||||
def get_actions(self, states: object) -> object:
|
||||
def get_actions(self, states: Union[list, np.ndarray]) -> Any:
|
||||
"""Get actions according to states.
|
||||
|
||||
Args:
|
||||
states (object): States.
|
||||
states (Union[list, np.ndarray]): States.
|
||||
|
||||
Returns:
|
||||
actions (object): Actions.
|
||||
actions (Any): Actions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -79,7 +79,7 @@ class DummyPolicy(AbsPolicy):
|
|||
def __init__(self) -> None:
|
||||
super(DummyPolicy, self).__init__(name="DUMMY_POLICY", trainable=False)
|
||||
|
||||
def get_actions(self, states: object) -> None:
|
||||
def get_actions(self, states: Union[list, np.ndarray]) -> None:
|
||||
return None
|
||||
|
||||
def explore(self) -> None:
|
||||
|
@ -101,11 +101,11 @@ class RuleBasedPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
def __init__(self, name: str) -> None:
|
||||
super(RuleBasedPolicy, self).__init__(name=name, trainable=False)
|
||||
|
||||
def get_actions(self, states: List[object]) -> List[object]:
|
||||
def get_actions(self, states: list) -> list:
|
||||
return self._rule(states)
|
||||
|
||||
@abstractmethod
|
||||
def _rule(self, states: List[object]) -> List[object]:
|
||||
def _rule(self, states: list) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
def explore(self) -> None:
|
||||
|
@ -304,7 +304,7 @@ class RLPolicy(AbsPolicy, metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> object:
|
||||
def get_state(self) -> dict:
|
||||
"""Get the state of the policy."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -62,12 +62,10 @@ class ContinuousRLPolicy(RLPolicy):
|
|||
)
|
||||
|
||||
self._lbounds, self._ubounds = _parse_action_range(self.action_dim, action_range)
|
||||
assert self._lbounds is not None and self._ubounds is not None
|
||||
|
||||
self._policy_net = policy_net
|
||||
|
||||
@property
|
||||
def action_bounds(self) -> Tuple[List[float], List[float]]:
|
||||
def action_bounds(self) -> Tuple[Optional[List[float]], Optional[List[float]]]:
|
||||
return self._lbounds, self._ubounds
|
||||
|
||||
@property
|
||||
|
@ -118,7 +116,7 @@ class ContinuousRLPolicy(RLPolicy):
|
|||
def train(self) -> None:
|
||||
self._policy_net.train()
|
||||
|
||||
def get_state(self) -> object:
|
||||
def get_state(self) -> dict:
|
||||
return self._policy_net.get_state()
|
||||
|
||||
def set_state(self, policy_state: dict) -> None:
|
||||
|
|
|
@ -85,9 +85,11 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
|
||||
self._exploration_func = exploration_strategy[0]
|
||||
self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing
|
||||
self._exploration_schedulers = [
|
||||
opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options
|
||||
]
|
||||
self._exploration_schedulers = (
|
||||
[opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options]
|
||||
if exploration_scheduling_options is not None
|
||||
else []
|
||||
)
|
||||
|
||||
self._call_cnt = 0
|
||||
self._warmup = warmup
|
||||
|
@ -219,7 +221,7 @@ class ValueBasedPolicy(DiscreteRLPolicy):
|
|||
def train(self) -> None:
|
||||
self._q_net.train()
|
||||
|
||||
def get_state(self) -> object:
|
||||
def get_state(self) -> dict:
|
||||
return self._q_net.get_state()
|
||||
|
||||
def set_state(self, policy_state: dict) -> None:
|
||||
|
|
|
@ -1,194 +1,103 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from maro.rl.policy import AbsPolicy
|
||||
from maro.rl.policy import AbsPolicy, RLPolicy
|
||||
from maro.rl.rollout import AbsEnvSampler
|
||||
from maro.rl.training import AbsTrainer
|
||||
from maro.simulator import Env
|
||||
|
||||
|
||||
class RLComponentBundle(object):
|
||||
class RLComponentBundle:
|
||||
"""Bundle of all necessary components to run a RL job in MARO.
|
||||
|
||||
Users should create their own subclass of `RLComponentBundle` and implement following methods:
|
||||
- get_env_config()
|
||||
- get_test_env_config()
|
||||
- get_env_sampler()
|
||||
- get_agent2policy()
|
||||
- get_policy_creator()
|
||||
- get_trainer_creator()
|
||||
|
||||
Following methods could be overwritten when necessary:
|
||||
- get_device_mapping()
|
||||
|
||||
Please refer to the doc string of each method for detailed explanations.
|
||||
env_sampler (AbsEnvSampler): Environment sampler of the scenario.
|
||||
agent2policy (Dict[Any, str]): Agent name to policy name mapping of the RL job. For example:
|
||||
{agent1: policy1, agent2: policy1, agent3: policy2}.
|
||||
policies (List[AbsPolicy]): Policies.
|
||||
trainers (List[AbsTrainer]): Trainers.
|
||||
device_mapping (Dict[str, str], default=None): Device mapping that identifying which device to put each policy.
|
||||
If None, there will be no explicit device assignment.
|
||||
policy_trainer_mapping (Dict[str, str], default=None): Policy-trainer mapping which identifying which trainer to
|
||||
train each policy. If None, then a policy's trainer's name is the first segment of the policy's name,
|
||||
seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1". Only policies that provided in
|
||||
policy-trainer mapping are considered as trainable polices. Policies that not provided in policy-trainer
|
||||
mapping will not be trained.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super(RLComponentBundle, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
env_sampler: AbsEnvSampler,
|
||||
agent2policy: Dict[Any, str],
|
||||
policies: List[AbsPolicy],
|
||||
trainers: List[AbsTrainer],
|
||||
device_mapping: Dict[str, str] = None,
|
||||
policy_trainer_mapping: Dict[str, str] = None,
|
||||
) -> None:
|
||||
self.env_sampler = env_sampler
|
||||
self.agent2policy = agent2policy
|
||||
self.policies = policies
|
||||
self.trainers = trainers
|
||||
|
||||
self.trainer_creator: Optional[Dict[str, Callable[[], AbsTrainer]]] = None
|
||||
policy_set = set([policy.name for policy in self.policies])
|
||||
not_found = [policy_name for policy_name in self.agent2policy.values() if policy_name not in policy_set]
|
||||
if len(not_found) > 0:
|
||||
raise ValueError(f"The following policies are required but cannot be found: [{', '.join(not_found)}]")
|
||||
|
||||
self.agent2policy: Optional[Dict[Any, str]] = None
|
||||
self.trainable_agent2policy: Optional[Dict[Any, str]] = None
|
||||
self.policy_creator: Optional[Dict[str, Callable[[], AbsPolicy]]] = None
|
||||
self.policy_names: Optional[List[str]] = None
|
||||
self.trainable_policy_creator: Optional[Dict[str, Callable[[], AbsPolicy]]] = None
|
||||
self.trainable_policy_names: Optional[List[str]] = None
|
||||
# Remove unused policies
|
||||
kept_policies = []
|
||||
for policy in self.policies:
|
||||
if policy.name not in self.agent2policy.values():
|
||||
raise Warning(f"Policy {policy.name} is removed since it is not used by any agent.")
|
||||
else:
|
||||
kept_policies.append(policy)
|
||||
self.policies = kept_policies
|
||||
policy_set = set([policy.name for policy in self.policies])
|
||||
|
||||
self.device_mapping: Optional[Dict[str, str]] = None
|
||||
self.policy_trainer_mapping: Optional[Dict[str, str]] = None
|
||||
self.device_mapping = (
|
||||
{k: v for k, v in device_mapping.items() if k in policy_set} if device_mapping is not None else {}
|
||||
)
|
||||
self.policy_trainer_mapping = (
|
||||
policy_trainer_mapping
|
||||
if policy_trainer_mapping is not None
|
||||
else {policy_name: policy_name.split(".")[0] for policy_name in policy_set}
|
||||
)
|
||||
|
||||
self._policy_cache: Optional[Dict[str, AbsPolicy]] = None
|
||||
|
||||
# Will be created when `env_sampler()` is first called
|
||||
self._env_sampler: Optional[AbsEnvSampler] = None
|
||||
|
||||
self._complete_resources()
|
||||
|
||||
########################################################################################
|
||||
# Users MUST implement the following methods #
|
||||
########################################################################################
|
||||
@abstractmethod
|
||||
def get_env_config(self) -> dict:
|
||||
"""Return the environment configuration to build the MARO Env for training.
|
||||
|
||||
Returns:
|
||||
Environment configuration.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_test_env_config(self) -> Optional[dict]:
|
||||
"""Return the environment configuration to build the MARO Env for testing. If returns `None`, the training
|
||||
environment will be reused as testing environment.
|
||||
|
||||
Returns:
|
||||
Environment configuration or `None`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_env_sampler(self) -> AbsEnvSampler:
|
||||
"""Return the environment sampler of the scenario.
|
||||
|
||||
Returns:
|
||||
The environment sampler of the scenario.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_agent2policy(self) -> Dict[Any, str]:
|
||||
"""Return agent name to policy name mapping of the RL job. This mapping identifies which policy should
|
||||
the agents use. For example: {agent1: policy1, agent2: policy1, agent3: policy2}.
|
||||
|
||||
Returns:
|
||||
Agent name to policy name mapping.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
|
||||
"""Return policy creator. Policy creator is a dictionary that contains a group of functions that generate
|
||||
policy instances. The key of this dictionary is the policy name, and the value is the function that generate
|
||||
the corresponding policy instance. Note that the creation function should not take any parameters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
|
||||
"""Return trainer creator. Trainer creator is similar to policy creator, but is used to creator trainers."""
|
||||
raise NotImplementedError
|
||||
|
||||
########################################################################################
|
||||
# Users could overwrite the following methods #
|
||||
########################################################################################
|
||||
def get_device_mapping(self) -> Dict[str, str]:
|
||||
"""Return the device mapping that identifying which device to put each policy.
|
||||
|
||||
If user does not overwrite this method, then all policies will be put on CPU by default.
|
||||
"""
|
||||
return {policy_name: "cpu" for policy_name in self.get_policy_creator()}
|
||||
|
||||
def get_policy_trainer_mapping(self) -> Dict[str, str]:
|
||||
"""Return the policy-trainer mapping which identifying which trainer to train each policy.
|
||||
|
||||
If user does not overwrite this method, then a policy's trainer's name is the first segment of the policy's
|
||||
name, seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1".
|
||||
|
||||
Only policies that provided in policy-trainer mapping are considered as trainable polices. Policies that
|
||||
not provided in policy-trainer mapping will not be trained since we do not assign a trainer to it.
|
||||
"""
|
||||
return {policy_name: policy_name.split(".")[0] for policy_name in self.policy_creator}
|
||||
|
||||
########################################################################################
|
||||
# Methods invisible to users #
|
||||
########################################################################################
|
||||
@property
|
||||
def env_sampler(self) -> AbsEnvSampler:
|
||||
if self._env_sampler is None:
|
||||
self._env_sampler = self.get_env_sampler()
|
||||
self._env_sampler.build(self)
|
||||
return self._env_sampler
|
||||
|
||||
def _complete_resources(self) -> None:
|
||||
"""Generate all attributes by calling user-defined logics. Do necessary checking and transformations."""
|
||||
env_config = self.get_env_config()
|
||||
test_env_config = self.get_test_env_config()
|
||||
self.env = Env(**env_config)
|
||||
self.test_env = self.env if test_env_config is None else Env(**test_env_config)
|
||||
|
||||
self.trainer_creator = self.get_trainer_creator()
|
||||
self.device_mapping = self.get_device_mapping()
|
||||
|
||||
self.policy_creator = self.get_policy_creator()
|
||||
self.agent2policy = self.get_agent2policy()
|
||||
|
||||
self.policy_trainer_mapping = self.get_policy_trainer_mapping()
|
||||
|
||||
required_policies = set(self.agent2policy.values())
|
||||
self.policy_creator = {name: self.policy_creator[name] for name in required_policies}
|
||||
# Check missing trainers
|
||||
self.policy_trainer_mapping = {
|
||||
name: self.policy_trainer_mapping[name] for name in required_policies if name in self.policy_trainer_mapping
|
||||
policy_name: trainer_name
|
||||
for policy_name, trainer_name in self.policy_trainer_mapping.items()
|
||||
if policy_name in policy_set
|
||||
}
|
||||
self.policy_names = list(required_policies)
|
||||
assert len(required_policies) == len(self.policy_creator) # Should have same size after filter
|
||||
trainer_set = set([trainer.name for trainer in self.trainers])
|
||||
not_found = [
|
||||
trainer_name for trainer_name in self.policy_trainer_mapping.values() if trainer_name not in trainer_set
|
||||
]
|
||||
if len(not_found) > 0:
|
||||
raise ValueError(f"The following trainers are required but cannot be found: [{', '.join(not_found)}]")
|
||||
|
||||
required_trainers = set(self.policy_trainer_mapping.values())
|
||||
self.trainer_creator = {name: self.trainer_creator[name] for name in required_trainers}
|
||||
assert len(required_trainers) == len(self.trainer_creator) # Should have same size after filter
|
||||
# Remove unused trainers
|
||||
kept_trainers = []
|
||||
for trainer in self.trainers:
|
||||
if trainer.name not in self.policy_trainer_mapping.values():
|
||||
raise Warning(f"Trainer {trainer.name} if removed since no policy is trained by it.")
|
||||
else:
|
||||
kept_trainers.append(trainer)
|
||||
self.trainers = kept_trainers
|
||||
|
||||
self.trainable_policy_names = list(self.policy_trainer_mapping.keys())
|
||||
self.trainable_policy_creator = {
|
||||
policy_name: self.policy_creator[policy_name] for policy_name in self.trainable_policy_names
|
||||
}
|
||||
self.trainable_agent2policy = {
|
||||
@property
|
||||
def trainable_agent2policy(self) -> Dict[Any, str]:
|
||||
return {
|
||||
agent_name: policy_name
|
||||
for agent_name, policy_name in self.agent2policy.items()
|
||||
if policy_name in self.trainable_policy_names
|
||||
if policy_name in self.policy_trainer_mapping
|
||||
}
|
||||
|
||||
def pre_create_policy_instances(self) -> None:
|
||||
"""Pre-create policy instances, and return the pre-created policy instances when the external callers
|
||||
want to create new policies. This will ensure that each policy will have at most one reusable duplicate.
|
||||
Under specific scenarios (for example, simple training & rollout), this will reduce unnecessary overheads.
|
||||
"""
|
||||
old_policy_creator = self.policy_creator
|
||||
self._policy_cache: Dict[str, AbsPolicy] = {}
|
||||
for policy_name in self.policy_names:
|
||||
self._policy_cache[policy_name] = old_policy_creator[policy_name]()
|
||||
|
||||
def _get_policy_instance(policy_name: str) -> AbsPolicy:
|
||||
return self._policy_cache[policy_name]
|
||||
|
||||
self.policy_creator = {
|
||||
policy_name: partial(_get_policy_instance, policy_name) for policy_name in self.policy_names
|
||||
}
|
||||
|
||||
self.trainable_policy_creator = {
|
||||
policy_name: self.policy_creator[policy_name] for policy_name in self.trainable_policy_names
|
||||
}
|
||||
@property
|
||||
def trainable_policies(self) -> List[RLPolicy]:
|
||||
policies = []
|
||||
for policy in self.policies:
|
||||
if policy.name in self.policy_trainer_mapping:
|
||||
assert isinstance(policy, RLPolicy)
|
||||
policies.append(policy)
|
||||
return policies
|
||||
|
|
|
@ -4,12 +4,13 @@
|
|||
import os
|
||||
import time
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from zmq import Context, Poller
|
||||
|
||||
from maro.rl.distributed import DEFAULT_ROLLOUT_PRODUCER_PORT
|
||||
from maro.rl.utils.common import bytes_to_pyobj, get_own_ip_address, pyobj_to_bytes
|
||||
from maro.rl.utils.objects import FILE_SUFFIX
|
||||
from maro.utils import DummyLogger, LoggerV2
|
||||
|
@ -37,19 +38,19 @@ class ParallelTaskController(object):
|
|||
self._poller = Poller()
|
||||
self._poller.register(self._task_endpoint, zmq.POLLIN)
|
||||
|
||||
self._workers = set()
|
||||
self._logger = logger
|
||||
self._workers: set = set()
|
||||
self._logger: Union[DummyLogger, LoggerV2] = logger if logger is not None else DummyLogger()
|
||||
|
||||
def _wait_for_workers_ready(self, k: int) -> None:
|
||||
while len(self._workers) < k:
|
||||
self._workers.add(self._task_endpoint.recv_multipart()[0])
|
||||
|
||||
def _recv_result_for_target_index(self, index: int) -> object:
|
||||
def _recv_result_for_target_index(self, index: int) -> Any:
|
||||
rep = bytes_to_pyobj(self._task_endpoint.recv_multipart()[-1])
|
||||
assert isinstance(rep, dict)
|
||||
return rep["result"] if rep["index"] == index else None
|
||||
|
||||
def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: int = None) -> List[dict]:
|
||||
def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: float = None) -> List[dict]:
|
||||
"""Send a task request to a set of remote workers and collect the results.
|
||||
|
||||
Args:
|
||||
|
@ -70,7 +71,7 @@ class ParallelTaskController(object):
|
|||
min_replies = parallelism
|
||||
|
||||
start_time = time.time()
|
||||
results = []
|
||||
results: list = []
|
||||
for worker_id in list(self._workers)[:parallelism]:
|
||||
self._task_endpoint.send_multipart([worker_id, pyobj_to_bytes(req)])
|
||||
self._logger.debug(f"Sent {parallelism} roll-out requests...")
|
||||
|
@ -81,7 +82,7 @@ class ParallelTaskController(object):
|
|||
results.append(result)
|
||||
|
||||
if grace_factor is not None:
|
||||
countdown = int((time.time() - start_time) * grace_factor) * 1000 # milliseconds
|
||||
countdown = int((time.time() - start_time) * grace_factor) * 1000.0 # milliseconds
|
||||
self._logger.debug(f"allowing {countdown / 1000} seconds for remaining results")
|
||||
while len(results) < parallelism and countdown > 0:
|
||||
start = time.time()
|
||||
|
@ -125,15 +126,18 @@ class BatchEnvSampler:
|
|||
def __init__(
|
||||
self,
|
||||
sampling_parallelism: int,
|
||||
port: int = 20000,
|
||||
port: int = None,
|
||||
min_env_samples: int = None,
|
||||
grace_factor: float = None,
|
||||
eval_parallelism: int = None,
|
||||
logger: LoggerV2 = None,
|
||||
) -> None:
|
||||
super(BatchEnvSampler, self).__init__()
|
||||
self._logger = logger if logger else DummyLogger()
|
||||
self._controller = ParallelTaskController(port=port, logger=logger)
|
||||
self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger()
|
||||
self._controller = ParallelTaskController(
|
||||
port=port if port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
self._sampling_parallelism = 1 if sampling_parallelism is None else sampling_parallelism
|
||||
self._min_env_samples = min_env_samples if min_env_samples is not None else self._sampling_parallelism
|
||||
|
@ -143,11 +147,15 @@ class BatchEnvSampler:
|
|||
self._ep = 0
|
||||
self._end_of_episode = True
|
||||
|
||||
def sample(self, policy_state: Optional[Dict[str, object]] = None, num_steps: Optional[int] = None) -> dict:
|
||||
def sample(
|
||||
self,
|
||||
policy_state: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Collect experiences from a set of remote roll-out workers.
|
||||
|
||||
Args:
|
||||
policy_state (Dict[str, object]): Policy state dict. If it is not None, then we need to update all
|
||||
policy_state (Dict[str, Any]): Policy state dict. If it is not None, then we need to update all
|
||||
policies according to the latest policy states, then start the experience collection.
|
||||
num_steps (Optional[int], default=None): Number of environment steps to collect experiences for. If
|
||||
it is None, interactions with the (remote) environments will continue until the terminal state is
|
||||
|
@ -181,7 +189,7 @@ class BatchEnvSampler:
|
|||
"info": [res["info"][0] for res in results],
|
||||
}
|
||||
|
||||
def eval(self, policy_state: Dict[str, object] = None) -> dict:
|
||||
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
|
||||
req = {"type": "eval", "policy_state": policy_state, "index": self._ep} # -1 signals test
|
||||
results = self._controller.collect(req, self._eval_parallelism)
|
||||
return {
|
||||
|
@ -209,3 +217,11 @@ class BatchEnvSampler:
|
|||
|
||||
def exit(self) -> None:
|
||||
self._controller.exit()
|
||||
|
||||
def post_collect(self, info_list: list, ep: int) -> None:
|
||||
req = {"type": "post_collect", "info_list": info_list, "index": ep}
|
||||
self._controller.collect(req, 1)
|
||||
|
||||
def post_evaluate(self, info_list: list, ep: int) -> None:
|
||||
req = {"type": "post_evaluate", "info_list": info_list, "index": ep}
|
||||
self._controller.collect(req, 1)
|
||||
|
|
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
|||
|
||||
import collections
|
||||
import os
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
@ -18,9 +17,6 @@ from maro.rl.policy import AbsPolicy, RLPolicy
|
|||
from maro.rl.utils.objects import FILE_SUFFIX
|
||||
from maro.simulator import Env
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
|
||||
|
||||
class AbsAgentWrapper(object, metaclass=ABCMeta):
|
||||
"""Agent wrapper. Used to manager agents & policies during experience collection.
|
||||
|
@ -51,16 +47,16 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
|
|||
|
||||
def choose_actions(
|
||||
self,
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, List[object]]],
|
||||
) -> Dict[Any, Union[np.ndarray, List[object]]]:
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, list]],
|
||||
) -> Dict[Any, Union[np.ndarray, list]]:
|
||||
"""Choose action according to the given (observable) states of all agents.
|
||||
|
||||
Args:
|
||||
state_by_agent (Dict[Any, Union[np.ndarray, List[object]]]): Dictionary containing each agent's states.
|
||||
state_by_agent (Dict[Any, Union[np.ndarray, list]]): Dictionary containing each agent's states.
|
||||
If the policy is a `RLPolicy`, its state is a Numpy array. Otherwise, its state is a list of objects.
|
||||
|
||||
Returns:
|
||||
actions (Dict[Any, Union[np.ndarray, List[object]]]): Dict that contains the action for all agents.
|
||||
actions (Dict[Any, Union[np.ndarray, list]]): Dict that contains the action for all agents.
|
||||
If the policy is a `RLPolicy`, its action is a Numpy array. Otherwise, its action is a list of objects.
|
||||
"""
|
||||
self.switch_to_eval_mode()
|
||||
|
@ -71,8 +67,8 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def _choose_actions_impl(
|
||||
self,
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, List[object]]],
|
||||
) -> Dict[Any, Union[np.ndarray, List[object]]]:
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, list]],
|
||||
) -> Dict[Any, Union[np.ndarray, list]]:
|
||||
"""Implementation of `choose_actions`."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -95,15 +91,15 @@ class AbsAgentWrapper(object, metaclass=ABCMeta):
|
|||
class SimpleAgentWrapper(AbsAgentWrapper):
|
||||
def __init__(
|
||||
self,
|
||||
policy_dict: Dict[str, RLPolicy], # {policy_name: RLPolicy}
|
||||
policy_dict: Dict[str, AbsPolicy], # {policy_name: AbsPolicy}
|
||||
agent2policy: Dict[Any, str], # {agent_name: policy_name}
|
||||
) -> None:
|
||||
super(SimpleAgentWrapper, self).__init__(policy_dict=policy_dict, agent2policy=agent2policy)
|
||||
|
||||
def _choose_actions_impl(
|
||||
self,
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, List[object]]],
|
||||
) -> Dict[Any, Union[np.ndarray, List[object]]]:
|
||||
state_by_agent: Dict[Any, Union[np.ndarray, list]],
|
||||
) -> Dict[Any, Union[np.ndarray, list]]:
|
||||
# Aggregate states by policy
|
||||
states_by_policy = collections.defaultdict(list) # {str: list of np.ndarray}
|
||||
agents_by_policy = collections.defaultdict(list) # {str: list of str}
|
||||
|
@ -112,15 +108,15 @@ class SimpleAgentWrapper(AbsAgentWrapper):
|
|||
states_by_policy[policy_name].append(state)
|
||||
agents_by_policy[policy_name].append(agent_name)
|
||||
|
||||
action_dict = {}
|
||||
action_dict: dict = {}
|
||||
for policy_name in agents_by_policy:
|
||||
policy = self._policy_dict[policy_name]
|
||||
|
||||
if isinstance(policy, RLPolicy):
|
||||
states = np.vstack(states_by_policy[policy_name]) # np.ndarray
|
||||
else:
|
||||
states = states_by_policy[policy_name] # List[object]
|
||||
actions = policy.get_actions(states) # np.ndarray or List[object]
|
||||
states = states_by_policy[policy_name] # list
|
||||
actions: Union[np.ndarray, list] = policy.get_actions(states) # np.ndarray or list
|
||||
action_dict.update(zip(agents_by_policy[policy_name], actions))
|
||||
|
||||
return action_dict
|
||||
|
@ -188,7 +184,7 @@ class ExpElement:
|
|||
Contents (Dict[str, ExpElement]): A dict that contains the ExpElements of all trainers. The key of this
|
||||
dict is the trainer name.
|
||||
"""
|
||||
ret = collections.defaultdict(
|
||||
ret: Dict[str, ExpElement] = collections.defaultdict(
|
||||
lambda: ExpElement(
|
||||
tick=self.tick,
|
||||
state=self.state,
|
||||
|
@ -213,7 +209,7 @@ class ExpElement:
|
|||
|
||||
@dataclass
|
||||
class CacheElement(ExpElement):
|
||||
event: object
|
||||
event: Any
|
||||
env_action_dict: Dict[Any, np.ndarray]
|
||||
|
||||
def make_exp_element(self) -> ExpElement:
|
||||
|
@ -238,6 +234,9 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
Args:
|
||||
learn_env (Env): Environment used for training.
|
||||
test_env (Env): Environment used for testing.
|
||||
policies (List[AbsPolicy]): List of policies.
|
||||
agent2policy (Dict[Any, str]): Agent name to policy name mapping of the RL job.
|
||||
trainable_policies (List[str]): Name of trainable policies.
|
||||
agent_wrapper_cls (Type[AbsAgentWrapper], default=SimpleAgentWrapper): Specific AgentWrapper type.
|
||||
reward_eval_delay (int, default=None): Number of ticks required after a decision event to evaluate the reward
|
||||
for the action taken for that event. If it is None, calculate reward immediately after `step()`.
|
||||
|
@ -247,6 +246,9 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
self,
|
||||
learn_env: Env,
|
||||
test_env: Env,
|
||||
policies: List[AbsPolicy],
|
||||
agent2policy: Dict[Any, str],
|
||||
trainable_policies: List[str] = None,
|
||||
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
|
||||
reward_eval_delay: int = None,
|
||||
) -> None:
|
||||
|
@ -255,7 +257,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
|
||||
self._agent_wrapper_cls = agent_wrapper_cls
|
||||
|
||||
self._event = None
|
||||
self._event: Optional[list] = None
|
||||
self._end_of_episode = True
|
||||
self._state: Optional[np.ndarray] = None
|
||||
self._agent_state_dict: Dict[Any, np.ndarray] = {}
|
||||
|
@ -264,31 +266,23 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._trans_cache
|
||||
self._reward_eval_delay = reward_eval_delay
|
||||
|
||||
self._info = {}
|
||||
self._info: dict = {}
|
||||
|
||||
assert self._reward_eval_delay is None or self._reward_eval_delay >= 0
|
||||
|
||||
def build(
|
||||
self,
|
||||
rl_component_bundle: RLComponentBundle,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
rl_component_bundle (RLComponentBundle): The RL component bundle of the job.
|
||||
"""
|
||||
#
|
||||
self._env: Optional[Env] = None
|
||||
|
||||
self._policy_dict = {
|
||||
policy_name: rl_component_bundle.policy_creator[policy_name]()
|
||||
for policy_name in rl_component_bundle.policy_names
|
||||
}
|
||||
|
||||
self._policy_dict: Dict[str, AbsPolicy] = {policy.name: policy for policy in policies}
|
||||
self._rl_policy_dict: Dict[str, RLPolicy] = {
|
||||
name: policy for name, policy in self._policy_dict.items() if isinstance(policy, RLPolicy)
|
||||
policy.name: policy for policy in policies if isinstance(policy, RLPolicy)
|
||||
}
|
||||
self._agent2policy = rl_component_bundle.agent2policy
|
||||
self._agent2policy = agent2policy
|
||||
self._agent_wrapper = self._agent_wrapper_cls(self._policy_dict, self._agent2policy)
|
||||
self._trainable_policies = set(rl_component_bundle.trainable_policy_names)
|
||||
|
||||
if trainable_policies is not None:
|
||||
self._trainable_policies = trainable_policies
|
||||
else:
|
||||
self._trainable_policies = list(self._policy_dict.keys()) # Default: all policies are trainable
|
||||
self._trainable_agents = {
|
||||
agent_id for agent_id, policy_name in self._agent2policy.items() if policy_name in self._trainable_policies
|
||||
}
|
||||
|
@ -297,23 +291,31 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
[policy_name in self._rl_policy_dict for policy_name in self._trainable_policies],
|
||||
), "All trainable policies must be RL policies!"
|
||||
|
||||
@property
|
||||
def env(self) -> Env:
|
||||
assert self._env is not None
|
||||
return self._env
|
||||
|
||||
def _switch_env(self, env: Env) -> None:
|
||||
self._env = env
|
||||
|
||||
def assign_policy_to_device(self, policy_name: str, device: torch.device) -> None:
|
||||
self._rl_policy_dict[policy_name].to_device(device)
|
||||
|
||||
def _get_global_and_agent_state(
|
||||
self,
|
||||
event: object,
|
||||
event: Any,
|
||||
tick: int = None,
|
||||
) -> Tuple[Optional[object], Dict[Any, Union[np.ndarray, List[object]]]]:
|
||||
) -> Tuple[Optional[Any], Dict[Any, Union[np.ndarray, list]]]:
|
||||
"""Get the global and individual agents' states.
|
||||
|
||||
Args:
|
||||
event (object): Event.
|
||||
event (Any): Event.
|
||||
tick (int, default=None): Current tick.
|
||||
|
||||
Returns:
|
||||
Global state (Optional[object])
|
||||
Dict of agent states (Dict[Any, Union[np.ndarray, List[object]]]). If the policy is a `RLPolicy`,
|
||||
Global state (Optional[Any])
|
||||
Dict of agent states (Dict[Any, Union[np.ndarray, list]]). If the policy is a `RLPolicy`,
|
||||
its state is a Numpy array. Otherwise, its state is a list of objects.
|
||||
"""
|
||||
global_state, agent_state_dict = self._get_global_and_agent_state_impl(event, tick)
|
||||
|
@ -327,23 +329,23 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def _get_global_and_agent_state_impl(
|
||||
self,
|
||||
event: object,
|
||||
event: Any,
|
||||
tick: int = None,
|
||||
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
|
||||
) -> Tuple[Union[None, np.ndarray, list], Dict[Any, Union[np.ndarray, list]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _translate_to_env_action(
|
||||
self,
|
||||
action_dict: Dict[Any, Union[np.ndarray, List[object]]],
|
||||
event: object,
|
||||
) -> Dict[Any, object]:
|
||||
action_dict: Dict[Any, Union[np.ndarray, list]],
|
||||
event: Any,
|
||||
) -> dict:
|
||||
"""Translate model-generated actions into an object that can be executed by the env.
|
||||
|
||||
Args:
|
||||
action_dict (Dict[Any, Union[np.ndarray, List[object]]]): Action for all agents. If the policy is a
|
||||
action_dict (Dict[Any, Union[np.ndarray, list]]): Action for all agents. If the policy is a
|
||||
`RLPolicy`, its (input) action is a Numpy array. Otherwise, its (input) action is a list of objects.
|
||||
event (object): Decision event.
|
||||
event (Any): Decision event.
|
||||
|
||||
Returns:
|
||||
A dict that contains env actions for all agents.
|
||||
|
@ -351,12 +353,12 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_reward(self, env_action_dict: Dict[Any, object], event: object, tick: int) -> Dict[Any, float]:
|
||||
def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]:
|
||||
"""Get rewards according to the env actions.
|
||||
|
||||
Args:
|
||||
env_action_dict (Dict[Any, object]): Dict that contains env actions for all agents.
|
||||
event (object): Decision event.
|
||||
env_action_dict (dict): Dict that contains env actions for all agents.
|
||||
event (Any): Decision event.
|
||||
tick (int): Current tick.
|
||||
|
||||
Returns:
|
||||
|
@ -365,7 +367,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
def _step(self, actions: Optional[list]) -> None:
|
||||
_, self._event, self._end_of_episode = self._env.step(actions)
|
||||
_, self._event, self._end_of_episode = self.env.step(actions)
|
||||
self._state, self._agent_state_dict = (
|
||||
(None, {}) if self._end_of_episode else self._get_global_and_agent_state(self._event)
|
||||
)
|
||||
|
@ -403,7 +405,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
self._agent_last_index[agent_name] = cur_index
|
||||
|
||||
def _reset(self) -> None:
|
||||
self._env.reset()
|
||||
self.env.reset()
|
||||
self._info.clear()
|
||||
self._trans_cache.clear()
|
||||
self._agent_last_index.clear()
|
||||
|
@ -412,7 +414,11 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
def _select_trainable_agents(self, original_dict: dict) -> dict:
|
||||
return {k: v for k, v in original_dict.items() if k in self._trainable_agents}
|
||||
|
||||
def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Optional[int] = None) -> dict:
|
||||
def sample(
|
||||
self,
|
||||
policy_state: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Sample experiences.
|
||||
|
||||
Args:
|
||||
|
@ -425,7 +431,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
A dict that contains the collected experiences and additional information.
|
||||
"""
|
||||
# Init the env
|
||||
self._env = self._learn_env
|
||||
self._switch_env(self._learn_env)
|
||||
if self._end_of_episode:
|
||||
self._reset()
|
||||
|
||||
|
@ -443,7 +449,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
|
||||
# Store experiences in the cache
|
||||
cache_element = CacheElement(
|
||||
tick=self._env.tick,
|
||||
tick=self.env.tick,
|
||||
event=self._event,
|
||||
state=self._state,
|
||||
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
|
||||
|
@ -466,7 +472,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
steps_to_go -= 1
|
||||
self._append_cache_element(None)
|
||||
|
||||
tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
|
||||
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
|
||||
experiences: List[ExpElement] = []
|
||||
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
|
||||
cache_element = self._trans_cache.pop(0)
|
||||
|
@ -508,8 +514,8 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
|
||||
return loaded
|
||||
|
||||
def eval(self, policy_state: Dict[str, dict] = None) -> dict:
|
||||
self._env = self._test_env
|
||||
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
|
||||
self._switch_env(self._test_env)
|
||||
self._reset()
|
||||
if policy_state is not None:
|
||||
self.set_policy_state(policy_state)
|
||||
|
@ -521,7 +527,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
|
||||
# Store experiences in the cache
|
||||
cache_element = CacheElement(
|
||||
tick=self._env.tick,
|
||||
tick=self.env.tick,
|
||||
event=self._event,
|
||||
state=self._state,
|
||||
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
|
||||
|
@ -544,7 +550,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
|
|||
self._append_cache_element(cache_element)
|
||||
self._append_cache_element(None)
|
||||
|
||||
tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
|
||||
tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
|
||||
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
|
||||
cache_element = self._trans_cache.pop(0)
|
||||
if self._reward_eval_delay is not None:
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
|
||||
import typing
|
||||
|
||||
from maro.rl.distributed import AbsWorker
|
||||
from maro.rl.distributed import DEFAULT_ROLLOUT_PRODUCER_PORT, AbsWorker
|
||||
from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes
|
||||
from maro.utils import LoggerV2
|
||||
|
||||
|
@ -19,7 +19,7 @@ class RolloutWorker(AbsWorker):
|
|||
Args:
|
||||
idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}",
|
||||
so that the parallel roll-out controller can keep track of its connection status.
|
||||
rl_component_bundle (RLComponentBundle): The RL component bundle of the job.
|
||||
rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow.
|
||||
producer_host (str): IP address of the parallel task controller host to connect to.
|
||||
producer_port (int, default=20000): Port of the parallel task controller host to connect to.
|
||||
logger (LoggerV2, default=None): The logger of the workflow.
|
||||
|
@ -30,13 +30,13 @@ class RolloutWorker(AbsWorker):
|
|||
idx: int,
|
||||
rl_component_bundle: RLComponentBundle,
|
||||
producer_host: str,
|
||||
producer_port: int = 20000,
|
||||
producer_port: int = None,
|
||||
logger: LoggerV2 = None,
|
||||
) -> None:
|
||||
super(RolloutWorker, self).__init__(
|
||||
idx=idx,
|
||||
producer_host=producer_host,
|
||||
producer_port=producer_port,
|
||||
producer_port=producer_port if producer_port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT,
|
||||
logger=logger,
|
||||
)
|
||||
self._env_sampler = rl_component_bundle.env_sampler
|
||||
|
@ -53,13 +53,20 @@ class RolloutWorker(AbsWorker):
|
|||
else:
|
||||
req = bytes_to_pyobj(msg[-1])
|
||||
assert isinstance(req, dict)
|
||||
assert req["type"] in {"sample", "eval", "set_policy_state"}
|
||||
if req["type"] == "sample":
|
||||
result = self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"])
|
||||
elif req["type"] == "eval":
|
||||
result = self._env_sampler.eval(policy_state=req["policy_state"])
|
||||
else:
|
||||
self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"])
|
||||
result = True
|
||||
assert req["type"] in {"sample", "eval", "set_policy_state", "post_collect", "post_evaluate"}
|
||||
|
||||
self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]}))
|
||||
if req["type"] in ("sample", "eval"):
|
||||
result = (
|
||||
self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"])
|
||||
if req["type"] == "sample"
|
||||
else self._env_sampler.eval(policy_state=req["policy_state"])
|
||||
)
|
||||
self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]}))
|
||||
else:
|
||||
if req["type"] == "set_policy_state":
|
||||
self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"])
|
||||
elif req["type"] == "post_collect":
|
||||
self._env_sampler.post_collect(info_list=req["info_list"], ep=req["index"])
|
||||
else:
|
||||
self._env_sampler.post_evaluate(info_list=req["info_list"], ep=req["index"])
|
||||
self._stream.send(pyobj_to_bytes({"result": True, "index": req["index"]}))
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from .proxy import TrainingProxy
|
||||
from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory
|
||||
from .train_ops import AbsTrainOps, RemoteOps, remote
|
||||
from .trainer import AbsTrainer, MultiAgentTrainer, SingleAgentTrainer, TrainerParams
|
||||
from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer
|
||||
from .training_manager import TrainingManager
|
||||
from .worker import TrainOpsWorker
|
||||
|
||||
|
@ -18,9 +18,9 @@ __all__ = [
|
|||
"RemoteOps",
|
||||
"remote",
|
||||
"AbsTrainer",
|
||||
"BaseTrainerParams",
|
||||
"MultiAgentTrainer",
|
||||
"SingleAgentTrainer",
|
||||
"TrainerParams",
|
||||
"TrainingManager",
|
||||
"TrainOpsWorker",
|
||||
]
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
from maro.rl.training.algorithms.base import ACBasedParams, ACBasedTrainer
|
||||
|
||||
|
@ -13,18 +12,8 @@ class ActorCriticParams(ACBasedParams):
|
|||
for detailed information.
|
||||
"""
|
||||
|
||||
def extract_ops_params(self) -> Dict[str, object]:
|
||||
return {
|
||||
"get_v_critic_net_func": self.get_v_critic_net_func,
|
||||
"reward_discount": self.reward_discount,
|
||||
"critic_loss_cls": self.critic_loss_cls,
|
||||
"lam": self.lam,
|
||||
"min_logp": self.min_logp,
|
||||
"is_discrete_action": self.is_discrete_action,
|
||||
}
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.get_v_critic_net_func is not None
|
||||
assert self.clip_ratio is None
|
||||
|
||||
|
||||
class ActorCriticTrainer(ACBasedTrainer):
|
||||
|
@ -34,5 +23,20 @@ class ActorCriticTrainer(ACBasedTrainer):
|
|||
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/vpg
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, params: ActorCriticParams) -> None:
|
||||
super(ActorCriticTrainer, self).__init__(name, params)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
params: ActorCriticParams,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(ActorCriticTrainer, self).__init__(
|
||||
name,
|
||||
params,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
|
|
|
@ -3,19 +3,19 @@
|
|||
|
||||
from abc import ABCMeta
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
from typing import Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from maro.rl.model import VNet
|
||||
from maro.rl.policy import ContinuousRLPolicy, DiscretePolicyGradient, RLPolicy
|
||||
from maro.rl.training import AbsTrainOps, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote
|
||||
from maro.rl.training import AbsTrainOps, BaseTrainerParams, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, remote
|
||||
from maro.rl.utils import TransitionBatch, discount_cumsum, get_torch_device, ndarray_to_tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACBasedParams(TrainerParams, metaclass=ABCMeta):
|
||||
class ACBasedParams(BaseTrainerParams, metaclass=ABCMeta):
|
||||
"""
|
||||
Parameter bundle for Actor-Critic based algorithms (Actor-Critic & PPO)
|
||||
|
||||
|
@ -23,18 +23,16 @@ class ACBasedParams(TrainerParams, metaclass=ABCMeta):
|
|||
grad_iters (int, default=1): Number of iterations to calculate gradients.
|
||||
critic_loss_cls (Callable, default=None): Critic loss function. If it is None, use MSE.
|
||||
lam (float, default=0.9): Lambda value for generalized advantage estimation (TD-Lambda).
|
||||
min_logp (float, default=None): Lower bound for clamping logP values during learning.
|
||||
min_logp (float, default=float("-inf")): Lower bound for clamping logP values during learning.
|
||||
This is to prevent logP from becoming very large in magnitude and causing stability issues.
|
||||
If it is None, it means no lower bound.
|
||||
is_discrete_action (bool, default=True): Indicator of continuous or discrete action policy.
|
||||
"""
|
||||
|
||||
get_v_critic_net_func: Callable[[], VNet] = None
|
||||
get_v_critic_net_func: Callable[[], VNet]
|
||||
grad_iters: int = 1
|
||||
critic_loss_cls: Callable = None
|
||||
critic_loss_cls: Optional[Callable] = None
|
||||
lam: float = 0.9
|
||||
min_logp: Optional[float] = None
|
||||
is_discrete_action: bool = True
|
||||
min_logp: float = float("-inf")
|
||||
clip_ratio: Optional[float] = None
|
||||
|
||||
|
||||
class ACBasedOps(AbsTrainOps):
|
||||
|
@ -43,33 +41,26 @@ class ACBasedOps(AbsTrainOps):
|
|||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
policy_creator: Callable[[], RLPolicy],
|
||||
get_v_critic_net_func: Callable[[], VNet],
|
||||
parallelism: int = 1,
|
||||
policy: RLPolicy,
|
||||
params: ACBasedParams,
|
||||
reward_discount: float = 0.9,
|
||||
critic_loss_cls: Callable = None,
|
||||
clip_ratio: float = None,
|
||||
lam: float = 0.9,
|
||||
min_logp: float = None,
|
||||
is_discrete_action: bool = True,
|
||||
parallelism: int = 1,
|
||||
) -> None:
|
||||
super(ACBasedOps, self).__init__(
|
||||
name=name,
|
||||
policy_creator=policy_creator,
|
||||
policy=policy,
|
||||
parallelism=parallelism,
|
||||
)
|
||||
|
||||
assert isinstance(self._policy, DiscretePolicyGradient) or isinstance(self._policy, ContinuousRLPolicy)
|
||||
assert isinstance(self._policy, (ContinuousRLPolicy, DiscretePolicyGradient))
|
||||
|
||||
self._reward_discount = reward_discount
|
||||
self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss()
|
||||
self._clip_ratio = clip_ratio
|
||||
self._lam = lam
|
||||
self._min_logp = min_logp
|
||||
self._v_critic_net = get_v_critic_net_func()
|
||||
self._is_discrete_action = is_discrete_action
|
||||
|
||||
self._device = None
|
||||
self._critic_loss_func = params.critic_loss_cls() if params.critic_loss_cls is not None else torch.nn.MSELoss()
|
||||
self._clip_ratio = params.clip_ratio
|
||||
self._lam = params.lam
|
||||
self._min_logp = params.min_logp
|
||||
self._v_critic_net = params.get_v_critic_net_func()
|
||||
self._is_discrete_action = isinstance(self._policy, DiscretePolicyGradient)
|
||||
|
||||
def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
|
||||
"""Compute the critic loss of the batch.
|
||||
|
@ -249,14 +240,32 @@ class ACBasedTrainer(SingleAgentTrainer):
|
|||
https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, params: ACBasedParams) -> None:
|
||||
super(ACBasedTrainer, self).__init__(name, params)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
params: ACBasedParams,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(ACBasedTrainer, self).__init__(
|
||||
name,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
self._params = params
|
||||
|
||||
def _register_policy(self, policy: RLPolicy) -> None:
|
||||
assert isinstance(policy, (ContinuousRLPolicy, DiscretePolicyGradient))
|
||||
self._policy = policy
|
||||
|
||||
def build(self) -> None:
|
||||
self._ops = self.get_ops()
|
||||
self._ops = cast(ACBasedOps, self.get_ops())
|
||||
self._replay_memory = FIFOReplayMemory(
|
||||
capacity=self._params.replay_memory_capacity,
|
||||
capacity=self._replay_memory_capacity,
|
||||
state_dim=self._ops.policy_state_dim,
|
||||
action_dim=self._ops.policy_action_dim,
|
||||
)
|
||||
|
@ -266,10 +275,11 @@ class ACBasedTrainer(SingleAgentTrainer):
|
|||
|
||||
def get_local_ops(self) -> AbsTrainOps:
|
||||
return ACBasedOps(
|
||||
name=self._policy_name,
|
||||
policy_creator=self._policy_creator,
|
||||
parallelism=self._params.data_parallelism,
|
||||
**self._params.extract_ops_params(),
|
||||
name=self._policy.name,
|
||||
policy=self._policy,
|
||||
parallelism=self._data_parallelism,
|
||||
reward_discount=self._reward_discount,
|
||||
params=self._params,
|
||||
)
|
||||
|
||||
def _get_batch(self) -> TransitionBatch:
|
||||
|
|
|
@ -2,19 +2,19 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
from maro.rl.model import QNet
|
||||
from maro.rl.policy import ContinuousRLPolicy, RLPolicy
|
||||
from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote
|
||||
from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote
|
||||
from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor
|
||||
from maro.utils import clone
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPGParams(TrainerParams):
|
||||
class DDPGParams(BaseTrainerParams):
|
||||
"""
|
||||
get_q_critic_net_func (Callable[[], QNet]): Function to get Q critic net.
|
||||
num_epochs (int, default=1): Number of training epochs per call to ``learn``.
|
||||
|
@ -30,25 +30,14 @@ class DDPGParams(TrainerParams):
|
|||
min_num_to_trigger_training (int, default=0): Minimum number required to start training.
|
||||
"""
|
||||
|
||||
get_q_critic_net_func: Callable[[], QNet] = None
|
||||
get_q_critic_net_func: Callable[[], QNet]
|
||||
num_epochs: int = 1
|
||||
update_target_every: int = 5
|
||||
q_value_loss_cls: Callable = None
|
||||
q_value_loss_cls: Optional[Callable] = None
|
||||
soft_update_coef: float = 1.0
|
||||
random_overwrite: bool = False
|
||||
min_num_to_trigger_training: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.get_q_critic_net_func is not None
|
||||
|
||||
def extract_ops_params(self) -> Dict[str, object]:
|
||||
return {
|
||||
"get_q_critic_net_func": self.get_q_critic_net_func,
|
||||
"reward_discount": self.reward_discount,
|
||||
"q_value_loss_cls": self.q_value_loss_cls,
|
||||
"soft_update_coef": self.soft_update_coef,
|
||||
}
|
||||
|
||||
|
||||
class DDPGOps(AbsTrainOps):
|
||||
"""DDPG algorithm implementation. Reference: https://spinningup.openai.com/en/latest/algorithms/ddpg.html"""
|
||||
|
@ -56,31 +45,31 @@ class DDPGOps(AbsTrainOps):
|
|||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
policy_creator: Callable[[], RLPolicy],
|
||||
get_q_critic_net_func: Callable[[], QNet],
|
||||
reward_discount: float,
|
||||
policy: RLPolicy,
|
||||
params: DDPGParams,
|
||||
reward_discount: float = 0.9,
|
||||
parallelism: int = 1,
|
||||
q_value_loss_cls: Callable = None,
|
||||
soft_update_coef: float = 1.0,
|
||||
) -> None:
|
||||
super(DDPGOps, self).__init__(
|
||||
name=name,
|
||||
policy_creator=policy_creator,
|
||||
policy=policy,
|
||||
parallelism=parallelism,
|
||||
)
|
||||
|
||||
assert isinstance(self._policy, ContinuousRLPolicy)
|
||||
|
||||
self._target_policy = clone(self._policy)
|
||||
self._target_policy: ContinuousRLPolicy = clone(self._policy)
|
||||
self._target_policy.set_name(f"target_{self._policy.name}")
|
||||
self._target_policy.eval()
|
||||
self._q_critic_net = get_q_critic_net_func()
|
||||
self._q_critic_net = params.get_q_critic_net_func()
|
||||
self._target_q_critic_net: QNet = clone(self._q_critic_net)
|
||||
self._target_q_critic_net.eval()
|
||||
|
||||
self._reward_discount = reward_discount
|
||||
self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss()
|
||||
self._soft_update_coef = soft_update_coef
|
||||
self._q_value_loss_func = (
|
||||
params.q_value_loss_cls() if params.q_value_loss_cls is not None else torch.nn.MSELoss()
|
||||
)
|
||||
self._soft_update_coef = params.soft_update_coef
|
||||
|
||||
def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
|
||||
"""Compute the critic loss of the batch.
|
||||
|
@ -207,7 +196,7 @@ class DDPGOps(AbsTrainOps):
|
|||
self._target_policy.soft_update(self._policy, self._soft_update_coef)
|
||||
self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef)
|
||||
|
||||
def to_device(self, device: str) -> None:
|
||||
def to_device(self, device: str = None) -> None:
|
||||
self._device = get_torch_device(device=device)
|
||||
self._policy.to_device(self._device)
|
||||
self._target_policy.to_device(self._device)
|
||||
|
@ -223,30 +212,49 @@ class DDPGTrainer(SingleAgentTrainer):
|
|||
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, params: DDPGParams) -> None:
|
||||
super(DDPGTrainer, self).__init__(name, params)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
params: DDPGParams,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(DDPGTrainer, self).__init__(
|
||||
name,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
self._params = params
|
||||
self._policy_version = self._target_policy_version = 0
|
||||
self._memory_size = 0
|
||||
|
||||
def build(self) -> None:
|
||||
self._ops = self.get_ops()
|
||||
self._ops = cast(DDPGOps, self.get_ops())
|
||||
self._replay_memory = RandomReplayMemory(
|
||||
capacity=self._params.replay_memory_capacity,
|
||||
capacity=self._replay_memory_capacity,
|
||||
state_dim=self._ops.policy_state_dim,
|
||||
action_dim=self._ops.policy_action_dim,
|
||||
random_overwrite=self._params.random_overwrite,
|
||||
)
|
||||
|
||||
def _register_policy(self, policy: RLPolicy) -> None:
|
||||
assert isinstance(policy, ContinuousRLPolicy)
|
||||
self._policy = policy
|
||||
|
||||
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
|
||||
return transition_batch
|
||||
|
||||
def get_local_ops(self) -> AbsTrainOps:
|
||||
return DDPGOps(
|
||||
name=self._policy_name,
|
||||
policy_creator=self._policy_creator,
|
||||
parallelism=self._params.data_parallelism,
|
||||
**self._params.extract_ops_params(),
|
||||
name=self._policy.name,
|
||||
policy=self._policy,
|
||||
parallelism=self._data_parallelism,
|
||||
reward_discount=self._reward_discount,
|
||||
params=self._params,
|
||||
)
|
||||
|
||||
def _get_batch(self, batch_size: int = None) -> TransitionBatch:
|
||||
|
|
|
@ -2,18 +2,18 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict
|
||||
from typing import Dict, cast
|
||||
|
||||
import torch
|
||||
|
||||
from maro.rl.policy import RLPolicy, ValueBasedPolicy
|
||||
from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote
|
||||
from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote
|
||||
from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor
|
||||
from maro.utils import clone
|
||||
|
||||
|
||||
@dataclass
|
||||
class DQNParams(TrainerParams):
|
||||
class DQNParams(BaseTrainerParams):
|
||||
"""
|
||||
num_epochs (int, default=1): Number of training epochs.
|
||||
update_target_every (int, default=5): Number of gradient steps between target model updates.
|
||||
|
@ -33,42 +33,34 @@ class DQNParams(TrainerParams):
|
|||
double: bool = False
|
||||
random_overwrite: bool = False
|
||||
|
||||
def extract_ops_params(self) -> Dict[str, object]:
|
||||
return {
|
||||
"reward_discount": self.reward_discount,
|
||||
"soft_update_coef": self.soft_update_coef,
|
||||
"double": self.double,
|
||||
}
|
||||
|
||||
|
||||
class DQNOps(AbsTrainOps):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
policy_creator: Callable[[], RLPolicy],
|
||||
parallelism: int = 1,
|
||||
policy: RLPolicy,
|
||||
params: DQNParams,
|
||||
reward_discount: float = 0.9,
|
||||
soft_update_coef: float = 0.1,
|
||||
double: bool = False,
|
||||
parallelism: int = 1,
|
||||
) -> None:
|
||||
super(DQNOps, self).__init__(
|
||||
name=name,
|
||||
policy_creator=policy_creator,
|
||||
policy=policy,
|
||||
parallelism=parallelism,
|
||||
)
|
||||
|
||||
assert isinstance(self._policy, ValueBasedPolicy)
|
||||
|
||||
self._reward_discount = reward_discount
|
||||
self._soft_update_coef = soft_update_coef
|
||||
self._double = double
|
||||
self._soft_update_coef = params.soft_update_coef
|
||||
self._double = params.double
|
||||
self._loss_func = torch.nn.MSELoss()
|
||||
|
||||
self._target_policy: ValueBasedPolicy = clone(self._policy)
|
||||
self._target_policy.set_name(f"target_{self._policy.name}")
|
||||
self._target_policy.eval()
|
||||
|
||||
def _get_batch_loss(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor:
|
||||
"""Compute the loss of the batch.
|
||||
|
||||
Args:
|
||||
|
@ -78,6 +70,8 @@ class DQNOps(AbsTrainOps):
|
|||
loss (torch.Tensor): The loss of the batch.
|
||||
"""
|
||||
assert isinstance(batch, TransitionBatch)
|
||||
assert isinstance(self._policy, ValueBasedPolicy)
|
||||
|
||||
self._policy.train()
|
||||
states = ndarray_to_tensor(batch.states, device=self._device)
|
||||
next_states = ndarray_to_tensor(batch.next_states, device=self._device)
|
||||
|
@ -100,7 +94,7 @@ class DQNOps(AbsTrainOps):
|
|||
return self._loss_func(q_values, target_q_values)
|
||||
|
||||
@remote
|
||||
def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]:
|
||||
"""Compute the network's gradients of a batch.
|
||||
|
||||
Args:
|
||||
|
@ -141,7 +135,7 @@ class DQNOps(AbsTrainOps):
|
|||
"""Soft update the target policy."""
|
||||
self._target_policy.soft_update(self._policy, self._soft_update_coef)
|
||||
|
||||
def to_device(self, device: str) -> None:
|
||||
def to_device(self, device: str = None) -> None:
|
||||
self._device = get_torch_device(device)
|
||||
self._policy.to_device(self._device)
|
||||
self._target_policy.to_device(self._device)
|
||||
|
@ -153,29 +147,48 @@ class DQNTrainer(SingleAgentTrainer):
|
|||
See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, params: DQNParams) -> None:
|
||||
super(DQNTrainer, self).__init__(name, params)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
params: DQNParams,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(DQNTrainer, self).__init__(
|
||||
name,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
self._params = params
|
||||
self._q_net_version = self._target_q_net_version = 0
|
||||
|
||||
def build(self) -> None:
|
||||
self._ops = self.get_ops()
|
||||
self._ops = cast(DQNOps, self.get_ops())
|
||||
self._replay_memory = RandomReplayMemory(
|
||||
capacity=self._params.replay_memory_capacity,
|
||||
capacity=self._replay_memory_capacity,
|
||||
state_dim=self._ops.policy_state_dim,
|
||||
action_dim=self._ops.policy_action_dim,
|
||||
random_overwrite=self._params.random_overwrite,
|
||||
)
|
||||
|
||||
def _register_policy(self, policy: RLPolicy) -> None:
|
||||
assert isinstance(policy, ValueBasedPolicy)
|
||||
self._policy = policy
|
||||
|
||||
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
|
||||
return transition_batch
|
||||
|
||||
def get_local_ops(self) -> AbsTrainOps:
|
||||
return DQNOps(
|
||||
name=self._policy_name,
|
||||
policy_creator=self._policy_creator,
|
||||
parallelism=self._params.data_parallelism,
|
||||
**self._params.extract_ops_params(),
|
||||
name=self._policy.name,
|
||||
policy=self._policy,
|
||||
parallelism=self._data_parallelism,
|
||||
reward_discount=self._reward_discount,
|
||||
params=self._params,
|
||||
)
|
||||
|
||||
def _get_batch(self, batch_size: int = None) -> TransitionBatch:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -12,14 +12,21 @@ import torch
|
|||
from maro.rl.model import MultiQNet
|
||||
from maro.rl.policy import DiscretePolicyGradient, RLPolicy
|
||||
from maro.rl.rollout import ExpElement
|
||||
from maro.rl.training import AbsTrainOps, MultiAgentTrainer, RandomMultiReplayMemory, RemoteOps, TrainerParams, remote
|
||||
from maro.rl.training import (
|
||||
AbsTrainOps,
|
||||
BaseTrainerParams,
|
||||
MultiAgentTrainer,
|
||||
RandomMultiReplayMemory,
|
||||
RemoteOps,
|
||||
remote,
|
||||
)
|
||||
from maro.rl.utils import MultiTransitionBatch, get_torch_device, ndarray_to_tensor
|
||||
from maro.rl.utils.objects import FILE_SUFFIX
|
||||
from maro.utils import clone
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscreteMADDPGParams(TrainerParams):
|
||||
class DiscreteMADDPGParams(BaseTrainerParams):
|
||||
"""
|
||||
get_q_critic_net_func (Callable[[], MultiQNet]): Function to get multi Q critic net.
|
||||
num_epochs (int, default=10): Number of training epochs.
|
||||
|
@ -30,44 +37,28 @@ class DiscreteMADDPGParams(TrainerParams):
|
|||
shared_critic (bool, default=False): Whether different policies use shared critic or individual policies.
|
||||
"""
|
||||
|
||||
get_q_critic_net_func: Callable[[], MultiQNet] = None
|
||||
get_q_critic_net_func: Callable[[], MultiQNet]
|
||||
num_epoch: int = 10
|
||||
update_target_every: int = 5
|
||||
soft_update_coef: float = 0.5
|
||||
q_value_loss_cls: Callable = None
|
||||
q_value_loss_cls: Optional[Callable] = None
|
||||
shared_critic: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.get_q_critic_net_func is not None
|
||||
|
||||
def extract_ops_params(self) -> Dict[str, object]:
|
||||
return {
|
||||
"get_q_critic_net_func": self.get_q_critic_net_func,
|
||||
"shared_critic": self.shared_critic,
|
||||
"reward_discount": self.reward_discount,
|
||||
"soft_update_coef": self.soft_update_coef,
|
||||
"update_target_every": self.update_target_every,
|
||||
"q_value_loss_func": self.q_value_loss_cls() if self.q_value_loss_cls is not None else torch.nn.MSELoss(),
|
||||
}
|
||||
|
||||
|
||||
class DiscreteMADDPGOps(AbsTrainOps):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
policy_creator: Callable[[], RLPolicy],
|
||||
get_q_critic_net_func: Callable[[], MultiQNet],
|
||||
policy: RLPolicy,
|
||||
param: DiscreteMADDPGParams,
|
||||
shared_critic: bool,
|
||||
policy_idx: int,
|
||||
parallelism: int = 1,
|
||||
shared_critic: bool = False,
|
||||
reward_discount: float = 0.9,
|
||||
soft_update_coef: float = 0.5,
|
||||
update_target_every: int = 5,
|
||||
q_value_loss_func: Callable = None,
|
||||
) -> None:
|
||||
super(DiscreteMADDPGOps, self).__init__(
|
||||
name=name,
|
||||
policy_creator=policy_creator,
|
||||
policy=policy,
|
||||
parallelism=parallelism,
|
||||
)
|
||||
|
||||
|
@ -75,23 +66,21 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
self._shared_critic = shared_critic
|
||||
|
||||
# Actor
|
||||
if self._policy_creator:
|
||||
if self._policy:
|
||||
assert isinstance(self._policy, DiscretePolicyGradient)
|
||||
self._target_policy: DiscretePolicyGradient = clone(self._policy)
|
||||
self._target_policy.set_name(f"target_{self._policy.name}")
|
||||
self._target_policy.eval()
|
||||
|
||||
# Critic
|
||||
self._q_critic_net: MultiQNet = get_q_critic_net_func()
|
||||
self._q_critic_net: MultiQNet = param.get_q_critic_net_func()
|
||||
self._target_q_critic_net: MultiQNet = clone(self._q_critic_net)
|
||||
self._target_q_critic_net.eval()
|
||||
|
||||
self._reward_discount = reward_discount
|
||||
self._q_value_loss_func = q_value_loss_func
|
||||
self._update_target_every = update_target_every
|
||||
self._soft_update_coef = soft_update_coef
|
||||
|
||||
self._device = None
|
||||
self._q_value_loss_func = param.q_value_loss_cls() if param.q_value_loss_cls is not None else torch.nn.MSELoss()
|
||||
self._update_target_every = param.update_target_every
|
||||
self._soft_update_coef = param.soft_update_coef
|
||||
|
||||
def get_target_action(self, batch: MultiTransitionBatch) -> torch.Tensor:
|
||||
"""Get the target policies' actions according to the batch.
|
||||
|
@ -248,7 +237,7 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
|
||||
def soft_update_target(self) -> None:
|
||||
"""Soft update the target policies and target critics."""
|
||||
if self._policy_creator:
|
||||
if self._policy:
|
||||
self._target_policy.soft_update(self._policy, self._soft_update_coef)
|
||||
if not self._shared_critic:
|
||||
self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef)
|
||||
|
@ -264,13 +253,13 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
self._target_q_critic_net.set_state(ops_state_dict["target_critic"])
|
||||
|
||||
def get_actor_state(self) -> dict:
|
||||
if self._policy_creator:
|
||||
if self._policy:
|
||||
return {"policy": self._policy.get_state(), "target_policy": self._target_policy.get_state()}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def set_actor_state(self, ops_state_dict: dict) -> None:
|
||||
if self._policy_creator:
|
||||
if self._policy:
|
||||
self._policy.set_state(ops_state_dict["policy"])
|
||||
self._target_policy.set_state(ops_state_dict["target_policy"])
|
||||
|
||||
|
@ -280,9 +269,9 @@ class DiscreteMADDPGOps(AbsTrainOps):
|
|||
def set_non_policy_state(self, state: dict) -> None:
|
||||
self.set_critic_state(state)
|
||||
|
||||
def to_device(self, device: str) -> None:
|
||||
def to_device(self, device: str = None) -> None:
|
||||
self._device = get_torch_device(device)
|
||||
if self._policy_creator:
|
||||
if self._policy:
|
||||
self._policy.to_device(self._device)
|
||||
self._target_policy.to_device(self._device)
|
||||
|
||||
|
@ -296,31 +285,51 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
See https://arxiv.org/abs/1706.02275 for details.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, params: DiscreteMADDPGParams) -> None:
|
||||
super(DiscreteMADDPGTrainer, self).__init__(name, params)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
params: DiscreteMADDPGParams,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(DiscreteMADDPGTrainer, self).__init__(
|
||||
name,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
self._params = params
|
||||
self._ops_params = self._params.extract_ops_params()
|
||||
|
||||
self._state_dim = params.get_q_critic_net_func().state_dim
|
||||
self._policy_version = self._target_policy_version = 0
|
||||
self._shared_critic_ops_name = f"{self._name}.shared_critic"
|
||||
|
||||
self._actor_ops_list = []
|
||||
self._critic_ops = None
|
||||
self._replay_memory = None
|
||||
self._policy2agent = {}
|
||||
self._actor_ops_list: List[DiscreteMADDPGOps] = []
|
||||
self._critic_ops: Optional[DiscreteMADDPGOps] = None
|
||||
self._policy2agent: Dict[str, str] = {}
|
||||
self._ops_dict: Dict[str, DiscreteMADDPGOps] = {}
|
||||
|
||||
def build(self) -> None:
|
||||
for policy_name in self._policy_creator:
|
||||
self._ops_dict[policy_name] = self.get_ops(policy_name)
|
||||
self._placeholder_policy = self._policy_dict[self._policy_names[0]]
|
||||
|
||||
for policy in self._policy_dict.values():
|
||||
self._ops_dict[policy.name] = cast(DiscreteMADDPGOps, self.get_ops(policy.name))
|
||||
|
||||
self._actor_ops_list = list(self._ops_dict.values())
|
||||
|
||||
if self._params.shared_critic:
|
||||
self._ops_dict[self._shared_critic_ops_name] = self.get_ops(self._shared_critic_ops_name)
|
||||
assert self._critic_ops is not None
|
||||
self._ops_dict[self._shared_critic_ops_name] = cast(
|
||||
DiscreteMADDPGOps,
|
||||
self.get_ops(self._shared_critic_ops_name),
|
||||
)
|
||||
self._critic_ops = self._ops_dict[self._shared_critic_ops_name]
|
||||
|
||||
self._replay_memory = RandomMultiReplayMemory(
|
||||
capacity=self._params.replay_memory_capacity,
|
||||
capacity=self._replay_memory_capacity,
|
||||
state_dim=self._state_dim,
|
||||
action_dims=[ops.policy_action_dim for ops in self._actor_ops_list],
|
||||
agent_states_dims=[ops.policy_state_dim for ops in self._actor_ops_list],
|
||||
|
@ -342,7 +351,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
rewards: List[np.ndarray] = []
|
||||
agent_states: List[np.ndarray] = []
|
||||
next_agent_states: List[np.ndarray] = []
|
||||
for policy_name in self._policy_names:
|
||||
for policy_name in self._policy_dict:
|
||||
agent_name = self._policy2agent[policy_name]
|
||||
actions.append(np.vstack([exp_element.action_dict[agent_name] for exp_element in exp_elements]))
|
||||
rewards.append(np.array([exp_element.reward_dict[agent_name] for exp_element in exp_elements]))
|
||||
|
@ -374,23 +383,25 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
|
||||
def get_local_ops(self, name: str) -> AbsTrainOps:
|
||||
if name == self._shared_critic_ops_name:
|
||||
ops_params = dict(self._ops_params)
|
||||
ops_params.update(
|
||||
{
|
||||
"policy_idx": -1,
|
||||
"shared_critic": False,
|
||||
},
|
||||
return DiscreteMADDPGOps(
|
||||
name=name,
|
||||
policy=self._placeholder_policy,
|
||||
param=self._params,
|
||||
shared_critic=False,
|
||||
policy_idx=-1,
|
||||
parallelism=self._data_parallelism,
|
||||
reward_discount=self._reward_discount,
|
||||
)
|
||||
return DiscreteMADDPGOps(name=name, **ops_params)
|
||||
else:
|
||||
ops_params = dict(self._ops_params)
|
||||
ops_params.update(
|
||||
{
|
||||
"policy_creator": self._policy_creator[name],
|
||||
"policy_idx": self._policy_names.index(name),
|
||||
},
|
||||
return DiscreteMADDPGOps(
|
||||
name=name,
|
||||
policy=self._policy_dict[name],
|
||||
param=self._params,
|
||||
shared_critic=self._params.shared_critic,
|
||||
policy_idx=self._policy_names.index(name),
|
||||
parallelism=self._data_parallelism,
|
||||
reward_discount=self._reward_discount,
|
||||
)
|
||||
return DiscreteMADDPGOps(name=name, **ops_params)
|
||||
|
||||
def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch:
|
||||
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)
|
||||
|
@ -405,6 +416,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
|
||||
# Update critic
|
||||
if self._params.shared_critic:
|
||||
assert self._critic_ops is not None
|
||||
self._critic_ops.update_critic(batch, next_actions)
|
||||
critic_state_dict = self._critic_ops.get_critic_state()
|
||||
# Sync latest critic to ops
|
||||
|
@ -431,6 +443,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
|
||||
# Update critic
|
||||
if self._params.shared_critic:
|
||||
assert self._critic_ops is not None
|
||||
critic_grad = await asyncio.gather(*[self._critic_ops.get_critic_grad(batch, next_actions)])
|
||||
assert isinstance(critic_grad, list) and isinstance(critic_grad[0], dict)
|
||||
self._critic_ops.update_critic_with_grad(critic_grad[0])
|
||||
|
@ -460,10 +473,11 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
for ops in self._actor_ops_list:
|
||||
ops.soft_update_target()
|
||||
if self._params.shared_critic:
|
||||
assert self._critic_ops is not None
|
||||
self._critic_ops.soft_update_target()
|
||||
self._target_policy_version = self._policy_version
|
||||
|
||||
def get_policy_state(self) -> Dict[str, object]:
|
||||
def get_policy_state(self) -> Dict[str, dict]:
|
||||
self._assert_ops_exists()
|
||||
ret_policy_state = {}
|
||||
for ops in self._actor_ops_list:
|
||||
|
@ -484,6 +498,7 @@ class DiscreteMADDPGTrainer(MultiAgentTrainer):
|
|||
|
||||
trainer_state = {ops.name: ops.get_state() for ops in self._actor_ops_list}
|
||||
if self._params.shared_critic:
|
||||
assert self._critic_ops is not None
|
||||
trainer_state[self._critic_ops.name] = self._critic_ops.get_state()
|
||||
|
||||
policy_state_dict = {ops_name: state["policy"] for ops_name, state in trainer_state.items()}
|
||||
|
|
|
@ -2,16 +2,16 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from maro.rl.model import VNet
|
||||
from maro.rl.policy import DiscretePolicyGradient, RLPolicy
|
||||
from maro.rl.training.algorithms.base import ACBasedOps, ACBasedParams, ACBasedTrainer
|
||||
from maro.rl.utils import TransitionBatch, discount_cumsum, ndarray_to_tensor
|
||||
from maro.utils import clone
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -23,21 +23,7 @@ class PPOParams(ACBasedParams):
|
|||
If it is None, the actor loss is calculated using the usual policy gradient theorem.
|
||||
"""
|
||||
|
||||
clip_ratio: float = None
|
||||
|
||||
def extract_ops_params(self) -> Dict[str, object]:
|
||||
return {
|
||||
"get_v_critic_net_func": self.get_v_critic_net_func,
|
||||
"reward_discount": self.reward_discount,
|
||||
"critic_loss_cls": self.critic_loss_cls,
|
||||
"clip_ratio": self.clip_ratio,
|
||||
"lam": self.lam,
|
||||
"min_logp": self.min_logp,
|
||||
"is_discrete_action": self.is_discrete_action,
|
||||
}
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.get_v_critic_net_func is not None
|
||||
assert self.clip_ratio is not None
|
||||
|
||||
|
||||
|
@ -45,31 +31,20 @@ class DiscretePPOWithEntropyOps(ACBasedOps):
|
|||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
policy_creator: Callable[[], RLPolicy],
|
||||
get_v_critic_net_func: Callable[[], VNet],
|
||||
policy: RLPolicy,
|
||||
params: ACBasedParams,
|
||||
parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
critic_loss_cls: Callable = None,
|
||||
clip_ratio: float = None,
|
||||
lam: float = 0.9,
|
||||
min_logp: float = None,
|
||||
is_discrete_action: bool = True,
|
||||
) -> None:
|
||||
super(DiscretePPOWithEntropyOps, self).__init__(
|
||||
name=name,
|
||||
policy_creator=policy_creator,
|
||||
get_v_critic_net_func=get_v_critic_net_func,
|
||||
parallelism=parallelism,
|
||||
reward_discount=reward_discount,
|
||||
critic_loss_cls=critic_loss_cls,
|
||||
clip_ratio=clip_ratio,
|
||||
lam=lam,
|
||||
min_logp=min_logp,
|
||||
is_discrete_action=is_discrete_action,
|
||||
name,
|
||||
policy,
|
||||
params,
|
||||
reward_discount,
|
||||
parallelism,
|
||||
)
|
||||
assert is_discrete_action
|
||||
assert isinstance(self._policy, DiscretePolicyGradient)
|
||||
self._policy_old = self._policy_creator()
|
||||
assert self._is_discrete_action
|
||||
self._policy_old: DiscretePolicyGradient = clone(policy)
|
||||
self.update_policy_old()
|
||||
|
||||
def update_policy_old(self) -> None:
|
||||
|
@ -172,8 +147,23 @@ class PPOTrainer(ACBasedTrainer):
|
|||
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ppo.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, params: PPOParams) -> None:
|
||||
super(PPOTrainer, self).__init__(name, params)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
params: PPOParams,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(PPOTrainer, self).__init__(
|
||||
name,
|
||||
params,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
|
||||
|
||||
class DiscretePPOWithEntropyTrainer(ACBasedTrainer):
|
||||
|
@ -182,10 +172,11 @@ class DiscretePPOWithEntropyTrainer(ACBasedTrainer):
|
|||
|
||||
def get_local_ops(self) -> DiscretePPOWithEntropyOps:
|
||||
return DiscretePPOWithEntropyOps(
|
||||
name=self._policy_name,
|
||||
policy_creator=self._policy_creator,
|
||||
parallelism=self._params.data_parallelism,
|
||||
**self._params.extract_ops_params(),
|
||||
name=self._policy.name,
|
||||
policy=self._policy,
|
||||
parallelism=self._data_parallelism,
|
||||
reward_discount=self._reward_discount,
|
||||
params=self._params,
|
||||
)
|
||||
|
||||
def train_step(self) -> None:
|
||||
|
|
|
@ -2,73 +2,59 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
from typing import Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
import torch
|
||||
|
||||
from maro.rl.model import QNet
|
||||
from maro.rl.policy import ContinuousRLPolicy, RLPolicy
|
||||
from maro.rl.training import AbsTrainOps, RandomReplayMemory, RemoteOps, SingleAgentTrainer, TrainerParams, remote
|
||||
from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote
|
||||
from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor
|
||||
from maro.utils import clone
|
||||
|
||||
|
||||
@dataclass
|
||||
class SoftActorCriticParams(TrainerParams):
|
||||
get_q_critic_net_func: Callable[[], QNet] = None
|
||||
class SoftActorCriticParams(BaseTrainerParams):
|
||||
get_q_critic_net_func: Callable[[], QNet]
|
||||
update_target_every: int = 5
|
||||
random_overwrite: bool = False
|
||||
entropy_coef: float = 0.1
|
||||
num_epochs: int = 1
|
||||
n_start_train: int = 0
|
||||
q_value_loss_cls: Callable = None
|
||||
q_value_loss_cls: Optional[Callable] = None
|
||||
soft_update_coef: float = 1.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.get_q_critic_net_func is not None
|
||||
|
||||
def extract_ops_params(self) -> Dict[str, object]:
|
||||
return {
|
||||
"get_q_critic_net_func": self.get_q_critic_net_func,
|
||||
"entropy_coef": self.entropy_coef,
|
||||
"reward_discount": self.reward_discount,
|
||||
"q_value_loss_cls": self.q_value_loss_cls,
|
||||
"soft_update_coef": self.soft_update_coef,
|
||||
}
|
||||
|
||||
|
||||
class SoftActorCriticOps(AbsTrainOps):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
policy_creator: Callable[[], RLPolicy],
|
||||
get_q_critic_net_func: Callable[[], QNet],
|
||||
policy: RLPolicy,
|
||||
params: SoftActorCriticParams,
|
||||
reward_discount: float = 0.9,
|
||||
parallelism: int = 1,
|
||||
*,
|
||||
entropy_coef: float,
|
||||
reward_discount: float,
|
||||
q_value_loss_cls: Callable = None,
|
||||
soft_update_coef: float = 1.0,
|
||||
) -> None:
|
||||
super(SoftActorCriticOps, self).__init__(
|
||||
name=name,
|
||||
policy_creator=policy_creator,
|
||||
policy=policy,
|
||||
parallelism=parallelism,
|
||||
)
|
||||
|
||||
assert isinstance(self._policy, ContinuousRLPolicy)
|
||||
|
||||
self._q_net1 = get_q_critic_net_func()
|
||||
self._q_net2 = get_q_critic_net_func()
|
||||
self._q_net1 = params.get_q_critic_net_func()
|
||||
self._q_net2 = params.get_q_critic_net_func()
|
||||
self._target_q_net1: QNet = clone(self._q_net1)
|
||||
self._target_q_net1.eval()
|
||||
self._target_q_net2: QNet = clone(self._q_net2)
|
||||
self._target_q_net2.eval()
|
||||
|
||||
self._entropy_coef = entropy_coef
|
||||
self._soft_update_coef = soft_update_coef
|
||||
self._entropy_coef = params.entropy_coef
|
||||
self._soft_update_coef = params.soft_update_coef
|
||||
self._reward_discount = reward_discount
|
||||
self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss()
|
||||
self._q_value_loss_func = (
|
||||
params.q_value_loss_cls() if params.q_value_loss_cls is not None else torch.nn.MSELoss()
|
||||
)
|
||||
|
||||
def _get_critic_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._q_net1.train()
|
||||
|
@ -100,11 +86,11 @@ class SoftActorCriticOps(AbsTrainOps):
|
|||
grad_q2 = self._q_net2.get_gradients(loss_q2)
|
||||
return grad_q1, grad_q2
|
||||
|
||||
def update_critic_with_grad(self, grad_dict1: dict, grad_dict2: dict) -> None:
|
||||
def update_critic_with_grad(self, grad_dicts: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]) -> None:
|
||||
self._q_net1.train()
|
||||
self._q_net2.train()
|
||||
self._q_net1.apply_gradients(grad_dict1)
|
||||
self._q_net2.apply_gradients(grad_dict2)
|
||||
self._q_net1.apply_gradients(grad_dicts[0])
|
||||
self._q_net2.apply_gradients(grad_dicts[1])
|
||||
|
||||
def update_critic(self, batch: TransitionBatch) -> None:
|
||||
self._q_net1.train()
|
||||
|
@ -154,7 +140,7 @@ class SoftActorCriticOps(AbsTrainOps):
|
|||
self._target_q_net1.soft_update(self._q_net1, self._soft_update_coef)
|
||||
self._target_q_net2.soft_update(self._q_net2, self._soft_update_coef)
|
||||
|
||||
def to_device(self, device: str) -> None:
|
||||
def to_device(self, device: str = None) -> None:
|
||||
self._device = get_torch_device(device=device)
|
||||
self._q_net1.to(self._device)
|
||||
self._q_net2.to(self._device)
|
||||
|
@ -163,22 +149,38 @@ class SoftActorCriticOps(AbsTrainOps):
|
|||
|
||||
|
||||
class SoftActorCriticTrainer(SingleAgentTrainer):
|
||||
def __init__(self, name: str, params: SoftActorCriticParams) -> None:
|
||||
super(SoftActorCriticTrainer, self).__init__(name, params)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
params: SoftActorCriticParams,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(SoftActorCriticTrainer, self).__init__(
|
||||
name,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
self._params = params
|
||||
self._qnet_version = self._target_qnet_version = 0
|
||||
|
||||
self._replay_memory: Optional[RandomReplayMemory] = None
|
||||
|
||||
def build(self) -> None:
|
||||
self._ops = self.get_ops()
|
||||
self._ops = cast(SoftActorCriticOps, self.get_ops())
|
||||
self._replay_memory = RandomReplayMemory(
|
||||
capacity=self._params.replay_memory_capacity,
|
||||
capacity=self._replay_memory_capacity,
|
||||
state_dim=self._ops.policy_state_dim,
|
||||
action_dim=self._ops.policy_action_dim,
|
||||
random_overwrite=self._params.random_overwrite,
|
||||
)
|
||||
|
||||
def _register_policy(self, policy: RLPolicy) -> None:
|
||||
assert isinstance(policy, ContinuousRLPolicy)
|
||||
self._policy = policy
|
||||
|
||||
def train_step(self) -> None:
|
||||
assert isinstance(self._ops, SoftActorCriticOps)
|
||||
|
||||
|
@ -218,10 +220,11 @@ class SoftActorCriticTrainer(SingleAgentTrainer):
|
|||
|
||||
def get_local_ops(self) -> SoftActorCriticOps:
|
||||
return SoftActorCriticOps(
|
||||
name=self._policy_name,
|
||||
policy_creator=self._policy_creator,
|
||||
parallelism=self._params.data_parallelism,
|
||||
**self._params.extract_ops_params(),
|
||||
name=self._policy.name,
|
||||
policy=self._policy,
|
||||
parallelism=self._data_parallelism,
|
||||
reward_discount=self._reward_discount,
|
||||
params=self._params,
|
||||
)
|
||||
|
||||
def _get_batch(self, batch_size: int = None) -> TransitionBatch:
|
||||
|
|
|
@ -2,8 +2,9 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from typing import Deque
|
||||
|
||||
from maro.rl.distributed import AbsProxy
|
||||
from maro.rl.distributed import DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT, AbsProxy
|
||||
from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes
|
||||
from maro.rl.utils.torch_utils import average_grads
|
||||
from maro.utils import LoggerV2
|
||||
|
@ -20,13 +21,16 @@ class TrainingProxy(AbsProxy):
|
|||
backend_port (int, default=10001): Network port for communicating with back-end workers (task consumers).
|
||||
"""
|
||||
|
||||
def __init__(self, frontend_port: int = 10000, backend_port: int = 10001) -> None:
|
||||
super(TrainingProxy, self).__init__(frontend_port=frontend_port, backend_port=backend_port)
|
||||
self._available_workers = deque()
|
||||
self._worker_ready = False
|
||||
self._connected_ops = set()
|
||||
self._result_cache = defaultdict(list)
|
||||
self._expected_num_results = {}
|
||||
def __init__(self, frontend_port: int = None, backend_port: int = None) -> None:
|
||||
super(TrainingProxy, self).__init__(
|
||||
frontend_port=frontend_port if frontend_port is not None else DEFAULT_TRAINING_FRONTEND_PORT,
|
||||
backend_port=backend_port if backend_port is not None else DEFAULT_TRAINING_BACKEND_PORT,
|
||||
)
|
||||
self._available_workers: Deque = deque()
|
||||
self._worker_ready: bool = False
|
||||
self._connected_ops: set = set()
|
||||
self._result_cache: dict = defaultdict(list)
|
||||
self._expected_num_results: dict = {}
|
||||
self._logger = LoggerV2("TRAIN-PROXY")
|
||||
|
||||
def _route_request_to_compute_node(self, msg: list) -> None:
|
||||
|
@ -48,10 +52,12 @@ class TrainingProxy(AbsProxy):
|
|||
|
||||
self._connected_ops.add(msg[0])
|
||||
req = bytes_to_pyobj(msg[-1])
|
||||
assert isinstance(req, dict)
|
||||
|
||||
desired_parallelism = req["desired_parallelism"]
|
||||
req["args"] = list(req["args"])
|
||||
batch = req["args"][0]
|
||||
workers = []
|
||||
workers: list = []
|
||||
while len(workers) < desired_parallelism and self._available_workers:
|
||||
workers.append(self._available_workers.popleft())
|
||||
|
||||
|
|
|
@ -3,8 +3,9 @@
|
|||
|
||||
import inspect
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Tuple
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from zmq.asyncio import Context, Poller
|
||||
|
||||
|
@ -19,24 +20,21 @@ class AbsTrainOps(object, metaclass=ABCMeta):
|
|||
|
||||
Args:
|
||||
name (str): Name of the ops. This is usually a policy name.
|
||||
policy_creator (Callable[[], RLPolicy]): Function to create a policy instance.
|
||||
policy (RLPolicy): Policy instance.
|
||||
parallelism (int, default=1): Desired degree of data parallelism.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
policy_creator: Callable[[], RLPolicy],
|
||||
policy: RLPolicy,
|
||||
parallelism: int = 1,
|
||||
) -> None:
|
||||
super(AbsTrainOps, self).__init__()
|
||||
self._name = name
|
||||
self._policy_creator = policy_creator
|
||||
# Create the policy.
|
||||
if self._policy_creator:
|
||||
self._policy = self._policy_creator()
|
||||
|
||||
self._policy = policy
|
||||
self._parallelism = parallelism
|
||||
self._device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -44,11 +42,11 @@ class AbsTrainOps(object, metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def policy_state_dim(self) -> int:
|
||||
return self._policy.state_dim if self._policy_creator else None
|
||||
return self._policy.state_dim
|
||||
|
||||
@property
|
||||
def policy_action_dim(self) -> int:
|
||||
return self._policy.action_dim if self._policy_creator else None
|
||||
return self._policy.action_dim
|
||||
|
||||
@property
|
||||
def parallelism(self) -> int:
|
||||
|
@ -75,20 +73,20 @@ class AbsTrainOps(object, metaclass=ABCMeta):
|
|||
self.set_policy_state(ops_state_dict["policy"][1])
|
||||
self.set_non_policy_state(ops_state_dict["non_policy"])
|
||||
|
||||
def get_policy_state(self) -> Tuple[str, object]:
|
||||
def get_policy_state(self) -> Tuple[str, dict]:
|
||||
"""Get the policy's state.
|
||||
|
||||
Returns:
|
||||
policy_name (str)
|
||||
policy_state (object)
|
||||
policy_state (Any)
|
||||
"""
|
||||
return self._policy.name, self._policy.get_state()
|
||||
|
||||
def set_policy_state(self, policy_state: object) -> None:
|
||||
def set_policy_state(self, policy_state: dict) -> None:
|
||||
"""Update the policy's state.
|
||||
|
||||
Args:
|
||||
policy_state (object): The policy state.
|
||||
policy_state (dict): The policy state.
|
||||
"""
|
||||
self._policy.set_state(policy_state)
|
||||
|
||||
|
@ -111,17 +109,17 @@ class AbsTrainOps(object, metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_device(self, device: str):
|
||||
def to_device(self, device: str = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def remote(func) -> Callable:
|
||||
def remote(func: Callable) -> Callable:
|
||||
"""Annotation to indicate that a function / method can be called remotely.
|
||||
|
||||
This annotation takes effect only when an ``AbsTrainOps`` object is wrapped by a ``RemoteOps``.
|
||||
"""
|
||||
|
||||
def remote_annotate(*args, **kwargs) -> object:
|
||||
def remote_annotate(*args: Any, **kwargs: Any) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return remote_annotate
|
||||
|
@ -137,7 +135,7 @@ class AsyncClient(object):
|
|||
"""
|
||||
|
||||
def __init__(self, name: str, address: Tuple[str, int], logger: LoggerV2 = None) -> None:
|
||||
self._logger = DummyLogger() if logger is None else logger
|
||||
self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger()
|
||||
self._name = name
|
||||
host, port = address
|
||||
self._proxy_ip = get_ip_address_by_hostname(host)
|
||||
|
@ -155,7 +153,7 @@ class AsyncClient(object):
|
|||
await self._socket.send(pyobj_to_bytes(req))
|
||||
self._logger.debug(f"{self._name} sent request {req['func']}")
|
||||
|
||||
async def get_response(self) -> object:
|
||||
async def get_response(self) -> Any:
|
||||
"""Waits for a result in asynchronous fashion.
|
||||
|
||||
This is a coroutine and is executed asynchronously with calls to other AsyncClients' ``get_response`` calls.
|
||||
|
@ -209,15 +207,15 @@ class RemoteOps(object):
|
|||
self._client = AsyncClient(self._ops.name, address, logger=logger)
|
||||
self._client.connect()
|
||||
|
||||
def __getattribute__(self, attr_name: str) -> object:
|
||||
def __getattribute__(self, attr_name: str) -> Any:
|
||||
# Ignore methods that belong to the parent class
|
||||
try:
|
||||
return super().__getattribute__(attr_name)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def remote_method(ops_state, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable:
|
||||
async def remote_call(*args, **kwargs) -> object:
|
||||
def remote_method(ops_state: Any, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable:
|
||||
async def remote_call(*args: Any, **kwargs: Any) -> Any:
|
||||
req = {
|
||||
"state": ops_state,
|
||||
"func": func_name,
|
||||
|
|
|
@ -5,7 +5,7 @@ import collections
|
|||
import os
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -21,37 +21,8 @@ from .train_ops import AbsTrainOps, RemoteOps
|
|||
|
||||
|
||||
@dataclass
|
||||
class TrainerParams:
|
||||
"""Common trainer parameters.
|
||||
|
||||
replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory.
|
||||
batch_size (int, default=128): Training batch size.
|
||||
data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when
|
||||
a model is large and computing gradients with respect to a batch becomes expensive. In this case, the
|
||||
batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set
|
||||
of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets
|
||||
updated only after collecting all the gradients from the remote nodes. Note that this value is the desired
|
||||
parallelism and the actual parallelism in a distributed experiment may be smaller depending on the
|
||||
availability of compute resources. For details on distributed deep learning and data parallelism, see
|
||||
https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an abundance
|
||||
of resources available on the internet.
|
||||
reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology.
|
||||
|
||||
"""
|
||||
|
||||
replay_memory_capacity: int = 10000
|
||||
batch_size: int = 128
|
||||
data_parallelism: int = 1
|
||||
reward_discount: float = 0.9
|
||||
|
||||
@abstractmethod
|
||||
def extract_ops_params(self) -> Dict[str, object]:
|
||||
"""Extract parameters that should be passed to the train ops.
|
||||
|
||||
Returns:
|
||||
params (Dict[str, object]): Parameter dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
class BaseTrainerParams:
|
||||
pass
|
||||
|
||||
|
||||
class AbsTrainer(object, metaclass=ABCMeta):
|
||||
|
@ -64,16 +35,36 @@ class AbsTrainer(object, metaclass=ABCMeta):
|
|||
|
||||
Args:
|
||||
name (str): Name of the trainer.
|
||||
params (TrainerParams): Trainer's parameters.
|
||||
replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory.
|
||||
batch_size (int, default=128): Training batch size.
|
||||
data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when
|
||||
a model is large and computing gradients with respect to a batch becomes expensive. In this case, the
|
||||
batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set
|
||||
of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets
|
||||
updated only after collecting all the gradients from the remote nodes. Note that this value is the desired
|
||||
parallelism and the actual parallelism in a distributed experiment may be smaller depending on the
|
||||
availability of compute resources. For details on distributed deep learning and data parallelism, see
|
||||
https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an
|
||||
abundance of resources available on the internet.
|
||||
reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, params: TrainerParams) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._params = params
|
||||
self._batch_size = self._params.batch_size
|
||||
self._replay_memory_capacity = replay_memory_capacity
|
||||
self._batch_size = batch_size
|
||||
self._data_parallelism = data_parallelism
|
||||
self._reward_discount = reward_discount
|
||||
|
||||
self._agent2policy: Dict[Any, str] = {}
|
||||
self._proxy_address: Optional[Tuple[str, int]] = None
|
||||
self._logger = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -83,13 +74,11 @@ class AbsTrainer(object, metaclass=ABCMeta):
|
|||
def agent_num(self) -> int:
|
||||
return len(self._agent2policy)
|
||||
|
||||
def register_logger(self, logger: LoggerV2) -> None:
|
||||
def register_logger(self, logger: LoggerV2 = None) -> None:
|
||||
self._logger = logger
|
||||
|
||||
def register_agent2policy(self, agent2policy: Dict[Any, str], policy_trainer_mapping: Dict[str, str]) -> None:
|
||||
"""Register the agent to policy dict that correspond to the current trainer. A valid policy name should start
|
||||
with the name of its trainer. For example, "DQN.POLICY_NAME". Therefore, we could identify which policies
|
||||
should be registered to the current trainer according to the policy's name.
|
||||
"""Register the agent to policy dict that correspond to the current trainer.
|
||||
|
||||
Args:
|
||||
agent2policy (Dict[Any, str]): Agent name to policy name mapping.
|
||||
|
@ -102,16 +91,11 @@ class AbsTrainer(object, metaclass=ABCMeta):
|
|||
}
|
||||
|
||||
@abstractmethod
|
||||
def register_policy_creator(
|
||||
self,
|
||||
global_policy_creator: Dict[str, Callable[[], AbsPolicy]],
|
||||
policy_trainer_mapping: Dict[str, str],
|
||||
) -> None:
|
||||
"""Register the policy creator. Only keep the creators of the policies that the current trainer need to train.
|
||||
def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None:
|
||||
"""Register the policies. Only keep the creators of the policies that the current trainer need to train.
|
||||
|
||||
Args:
|
||||
global_policy_creator (Dict[str, Callable[[], AbsPolicy]]): Dict that contains the creators for all
|
||||
policies.
|
||||
policies (List[AbsPolicy]): All policies.
|
||||
policy_trainer_mapping (Dict[str, str]): Policy name to trainer name mapping.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -147,7 +131,7 @@ class AbsTrainer(object, metaclass=ABCMeta):
|
|||
self._proxy_address = proxy_address
|
||||
|
||||
@abstractmethod
|
||||
def get_policy_state(self) -> Dict[str, object]:
|
||||
def get_policy_state(self) -> Dict[str, dict]:
|
||||
"""Get policies' states.
|
||||
|
||||
Returns:
|
||||
|
@ -171,30 +155,46 @@ class AbsTrainer(object, metaclass=ABCMeta):
|
|||
class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
||||
"""Policy trainer that trains only one policy."""
|
||||
|
||||
def __init__(self, name: str, params: TrainerParams) -> None:
|
||||
super(SingleAgentTrainer, self).__init__(name, params)
|
||||
self._policy_name: Optional[str] = None
|
||||
self._policy_creator: Optional[Callable[[], RLPolicy]] = None
|
||||
self._ops: Optional[AbsTrainOps] = None
|
||||
self._replay_memory: Optional[ReplayMemory] = None
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(SingleAgentTrainer, self).__init__(
|
||||
name,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
|
||||
@property
|
||||
def ops(self):
|
||||
return self._ops
|
||||
def ops(self) -> Union[AbsTrainOps, RemoteOps]:
|
||||
ops = getattr(self, "_ops", None)
|
||||
assert isinstance(ops, (AbsTrainOps, RemoteOps))
|
||||
return ops
|
||||
|
||||
def register_policy_creator(
|
||||
self,
|
||||
global_policy_creator: Dict[str, Callable[[], AbsPolicy]],
|
||||
policy_trainer_mapping: Dict[str, str],
|
||||
) -> None:
|
||||
policy_names = [
|
||||
policy_name for policy_name in global_policy_creator if policy_trainer_mapping[policy_name] == self.name
|
||||
]
|
||||
if len(policy_names) != 1:
|
||||
@property
|
||||
def replay_memory(self) -> ReplayMemory:
|
||||
replay_memory = getattr(self, "_replay_memory", None)
|
||||
assert isinstance(replay_memory, ReplayMemory), "Replay memory is required."
|
||||
return replay_memory
|
||||
|
||||
def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None:
|
||||
policies = [policy for policy in policies if policy_trainer_mapping[policy.name] == self.name]
|
||||
if len(policies) != 1:
|
||||
raise ValueError(f"Trainer {self._name} should have exactly one policy assigned to it")
|
||||
|
||||
self._policy_name = policy_names.pop()
|
||||
self._policy_creator = global_policy_creator[self._policy_name]
|
||||
policy = policies.pop()
|
||||
assert isinstance(policy, RLPolicy)
|
||||
self._register_policy(policy)
|
||||
|
||||
@abstractmethod
|
||||
def _register_policy(self, policy: RLPolicy) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_local_ops(self) -> AbsTrainOps:
|
||||
|
@ -216,9 +216,9 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
|||
ops = self.get_local_ops()
|
||||
return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops
|
||||
|
||||
def get_policy_state(self) -> Dict[str, object]:
|
||||
def get_policy_state(self) -> Dict[str, dict]:
|
||||
self._assert_ops_exists()
|
||||
policy_name, state = self._ops.get_policy_state()
|
||||
policy_name, state = self.ops.get_policy_state()
|
||||
return {policy_name: state}
|
||||
|
||||
def load(self, path: str) -> None:
|
||||
|
@ -227,7 +227,7 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
|||
policy_state = torch.load(os.path.join(path, f"{self.name}_policy.{FILE_SUFFIX}"))
|
||||
non_policy_state = torch.load(os.path.join(path, f"{self.name}_non_policy.{FILE_SUFFIX}"))
|
||||
|
||||
self._ops.set_state(
|
||||
self.ops.set_state(
|
||||
{
|
||||
"policy": policy_state,
|
||||
"non_policy": non_policy_state,
|
||||
|
@ -237,7 +237,7 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
|||
def save(self, path: str) -> None:
|
||||
self._assert_ops_exists()
|
||||
|
||||
ops_state = self._ops.get_state()
|
||||
ops_state = self.ops.get_state()
|
||||
policy_state = ops_state["policy"]
|
||||
non_policy_state = ops_state["non_policy"]
|
||||
|
||||
|
@ -267,46 +267,57 @@ class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
|||
next_states=np.vstack([exp[4] for exp in exps]),
|
||||
)
|
||||
transition_batch = self._preprocess_batch(transition_batch)
|
||||
self._replay_memory.put(transition_batch)
|
||||
self.replay_memory.put(transition_batch)
|
||||
|
||||
@abstractmethod
|
||||
def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch:
|
||||
raise NotImplementedError
|
||||
|
||||
def _assert_ops_exists(self) -> None:
|
||||
if not self._ops:
|
||||
if not self.ops:
|
||||
raise ValueError("'build' needs to be called to create an ops instance first.")
|
||||
|
||||
async def exit(self) -> None:
|
||||
self._assert_ops_exists()
|
||||
if isinstance(self._ops, RemoteOps):
|
||||
await self._ops.exit()
|
||||
ops = self.ops
|
||||
if isinstance(ops, RemoteOps):
|
||||
await ops.exit()
|
||||
|
||||
|
||||
class MultiAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
||||
"""Policy trainer that trains multiple policies."""
|
||||
|
||||
def __init__(self, name: str, params: TrainerParams) -> None:
|
||||
super(MultiAgentTrainer, self).__init__(name, params)
|
||||
self._policy_creator: Dict[str, Callable[[], RLPolicy]] = {}
|
||||
self._policy_names: List[str] = []
|
||||
self._ops_dict: Dict[str, AbsTrainOps] = {}
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
replay_memory_capacity: int = 10000,
|
||||
batch_size: int = 128,
|
||||
data_parallelism: int = 1,
|
||||
reward_discount: float = 0.9,
|
||||
) -> None:
|
||||
super(MultiAgentTrainer, self).__init__(
|
||||
name,
|
||||
replay_memory_capacity,
|
||||
batch_size,
|
||||
data_parallelism,
|
||||
reward_discount,
|
||||
)
|
||||
|
||||
@property
|
||||
def ops_dict(self):
|
||||
return self._ops_dict
|
||||
def ops_dict(self) -> Dict[str, AbsTrainOps]:
|
||||
ops_dict = getattr(self, "_ops_dict", None)
|
||||
assert isinstance(ops_dict, dict)
|
||||
return ops_dict
|
||||
|
||||
def register_policy_creator(
|
||||
self,
|
||||
global_policy_creator: Dict[str, Callable[[], AbsPolicy]],
|
||||
policy_trainer_mapping: Dict[str, str],
|
||||
) -> None:
|
||||
self._policy_creator: Dict[str, Callable[[], RLPolicy]] = {
|
||||
policy_name: func
|
||||
for policy_name, func in global_policy_creator.items()
|
||||
if policy_trainer_mapping[policy_name] == self.name
|
||||
}
|
||||
self._policy_names = list(self._policy_creator.keys())
|
||||
def register_policies(self, policies: List[AbsPolicy], policy_trainer_mapping: Dict[str, str]) -> None:
|
||||
self._policy_names: List[str] = [
|
||||
policy.name for policy in policies if policy_trainer_mapping[policy.name] == self.name
|
||||
]
|
||||
self._policy_dict: Dict[str, RLPolicy] = {}
|
||||
for policy in policies:
|
||||
if policy_trainer_mapping[policy.name] == self.name:
|
||||
assert isinstance(policy, RLPolicy)
|
||||
self._policy_dict[policy.name] = policy
|
||||
|
||||
@abstractmethod
|
||||
def get_local_ops(self, name: str) -> AbsTrainOps:
|
||||
|
@ -335,7 +346,7 @@ class MultiAgentTrainer(AbsTrainer, metaclass=ABCMeta):
|
|||
return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops
|
||||
|
||||
@abstractmethod
|
||||
def get_policy_state(self) -> Dict[str, object]:
|
||||
def get_policy_state(self) -> Dict[str, dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -7,7 +7,6 @@ import asyncio
|
|||
import collections
|
||||
import os
|
||||
import typing
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
from maro.rl.rollout import ExpElement
|
||||
|
@ -26,8 +25,8 @@ class TrainingManager(object):
|
|||
Training manager. Manage and schedule all trainers to train policies.
|
||||
|
||||
Args:
|
||||
rl_component_bundle (RLComponentBundle): The RL component bundle of the job.
|
||||
explicit_assign_device (bool): Whether to assign policy to its device in the training manager.
|
||||
rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow.
|
||||
explicit_assign_device (bool, default=False): Whether to assign policy to its device in the training manager.
|
||||
proxy_address (Tuple[str, int], default=None): Address of the training proxy. If it is not None,
|
||||
it is registered to all trainers, which in turn create `RemoteOps` for distributed training.
|
||||
logger (LoggerV2, default=None): A logger for logging key events.
|
||||
|
@ -36,36 +35,33 @@ class TrainingManager(object):
|
|||
def __init__(
|
||||
self,
|
||||
rl_component_bundle: RLComponentBundle,
|
||||
explicit_assign_device: bool,
|
||||
explicit_assign_device: bool = False,
|
||||
proxy_address: Tuple[str, int] = None,
|
||||
logger: LoggerV2 = None,
|
||||
) -> None:
|
||||
super(TrainingManager, self).__init__()
|
||||
|
||||
self._trainer_dict: Dict[str, AbsTrainer] = {}
|
||||
self._proxy_address = proxy_address
|
||||
for trainer_name, func in rl_component_bundle.trainer_creator.items():
|
||||
trainer = func()
|
||||
|
||||
self._trainer_dict: Dict[str, AbsTrainer] = {}
|
||||
for trainer in rl_component_bundle.trainers:
|
||||
if self._proxy_address:
|
||||
trainer.set_proxy_address(self._proxy_address)
|
||||
trainer.register_agent2policy(
|
||||
rl_component_bundle.trainable_agent2policy,
|
||||
rl_component_bundle.policy_trainer_mapping,
|
||||
agent2policy=rl_component_bundle.trainable_agent2policy,
|
||||
policy_trainer_mapping=rl_component_bundle.policy_trainer_mapping,
|
||||
)
|
||||
trainer.register_policy_creator(
|
||||
rl_component_bundle.trainable_policy_creator,
|
||||
rl_component_bundle.policy_trainer_mapping,
|
||||
trainer.register_policies(
|
||||
policies=rl_component_bundle.policies,
|
||||
policy_trainer_mapping=rl_component_bundle.policy_trainer_mapping,
|
||||
)
|
||||
trainer.register_logger(logger)
|
||||
trainer.build() # `build()` must be called after `register_policy_creator()`
|
||||
self._trainer_dict[trainer_name] = trainer
|
||||
trainer.build() # `build()` must be called after `register_policies()`
|
||||
self._trainer_dict[trainer.name] = trainer
|
||||
|
||||
# User-defined allocation of compute devices, i.e., GPU's to the trainer ops
|
||||
if explicit_assign_device:
|
||||
for policy_name, device_name in rl_component_bundle.device_mapping.items():
|
||||
if policy_name not in rl_component_bundle.policy_trainer_mapping: # No need to assign device
|
||||
continue
|
||||
|
||||
trainer = self._trainer_dict[rl_component_bundle.policy_trainer_mapping[policy_name]]
|
||||
|
||||
if isinstance(trainer, SingleAgentTrainer):
|
||||
|
@ -95,13 +91,16 @@ class TrainingManager(object):
|
|||
for trainer in self._trainer_dict.values():
|
||||
trainer.train_step()
|
||||
|
||||
def get_policy_state(self) -> Dict[str, Dict[str, object]]:
|
||||
def get_policy_state(self) -> Dict[str, dict]:
|
||||
"""Get policies' states.
|
||||
|
||||
Returns:
|
||||
A double-deck dict with format: {trainer_name: {policy_name: policy_state}}
|
||||
"""
|
||||
return dict(chain(*[trainer.get_policy_state().items() for trainer in self._trainer_dict.values()]))
|
||||
policy_states: Dict[str, dict] = {}
|
||||
for trainer in self._trainer_dict.values():
|
||||
policy_states.update(trainer.get_policy_state())
|
||||
return policy_states
|
||||
|
||||
def record_experiences(self, experiences: List[List[ExpElement]]) -> None:
|
||||
"""Record experiences collected from external modules (for example, EnvSampler).
|
||||
|
|
|
@ -6,7 +6,7 @@ from __future__ import annotations
|
|||
import typing
|
||||
from typing import Dict
|
||||
|
||||
from maro.rl.distributed import AbsWorker
|
||||
from maro.rl.distributed import DEFAULT_TRAINING_BACKEND_PORT, AbsWorker
|
||||
from maro.rl.training import SingleAgentTrainer
|
||||
from maro.rl.utils.common import bytes_to_pyobj, bytes_to_string, pyobj_to_bytes
|
||||
from maro.utils import LoggerV2
|
||||
|
@ -24,7 +24,7 @@ class TrainOpsWorker(AbsWorker):
|
|||
Args:
|
||||
idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}",
|
||||
so that the proxy can keep track of its connection status.
|
||||
rl_component_bundle (RLComponentBundle): The RL component bundle of the job.
|
||||
rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow.
|
||||
producer_host (str): IP address of the proxy host to connect to.
|
||||
producer_port (int, default=10001): Port of the proxy host to connect to.
|
||||
"""
|
||||
|
@ -34,13 +34,13 @@ class TrainOpsWorker(AbsWorker):
|
|||
idx: int,
|
||||
rl_component_bundle: RLComponentBundle,
|
||||
producer_host: str,
|
||||
producer_port: int = 10001,
|
||||
producer_port: int = None,
|
||||
logger: LoggerV2 = None,
|
||||
) -> None:
|
||||
super(TrainOpsWorker, self).__init__(
|
||||
idx=idx,
|
||||
producer_host=producer_host,
|
||||
producer_port=producer_port,
|
||||
producer_port=producer_port if producer_port is not None else DEFAULT_TRAINING_BACKEND_PORT,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
@ -62,13 +62,17 @@ class TrainOpsWorker(AbsWorker):
|
|||
ops_name, req = bytes_to_string(msg[0]), bytes_to_pyobj(msg[-1])
|
||||
assert isinstance(req, dict)
|
||||
|
||||
trainer_dict: Dict[str, AbsTrainer] = {
|
||||
trainer.name: trainer for trainer in self._rl_component_bundle.trainers
|
||||
}
|
||||
|
||||
if ops_name not in self._ops_dict:
|
||||
trainer_name = ops_name.split(".")[0]
|
||||
trainer_name = self._rl_component_bundle.policy_trainer_mapping[ops_name]
|
||||
if trainer_name not in self._trainer_dict:
|
||||
trainer = self._rl_component_bundle.trainer_creator[trainer_name]()
|
||||
trainer.register_policy_creator(
|
||||
self._rl_component_bundle.trainable_policy_creator,
|
||||
self._rl_component_bundle.policy_trainer_mapping,
|
||||
trainer = trainer_dict[trainer_name]
|
||||
trainer.register_policies(
|
||||
policies=self._rl_component_bundle.policies,
|
||||
policy_trainer_mapping=self._rl_component_bundle.policy_trainer_mapping,
|
||||
)
|
||||
self._trainer_dict[trainer_name] = trainer
|
||||
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
import os
|
||||
import pickle
|
||||
import socket
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
||||
def get_env(var_name: str, required: bool = True, default: object = None) -> str:
|
||||
def get_env(var_name: str, required: bool = True, default: str = None) -> Optional[str]:
|
||||
"""Wrapper for os.getenv() that includes a check for mandatory environment variables.
|
||||
|
||||
Args:
|
||||
var_name (str): Variable name.
|
||||
required (bool, default=True): Flag indicating whether the environment variable in questions is required.
|
||||
If this is true and the environment variable is not present in ``os.environ``, a ``KeyError`` is raised.
|
||||
default (object, default=None): Default value for the environment variable if it is missing in ``os.environ``
|
||||
default (str, default=None): Default value for the environment variable if it is missing in ``os.environ``
|
||||
and ``required`` is false. Ignored if ``required`` is True.
|
||||
|
||||
Returns:
|
||||
|
@ -52,11 +52,11 @@ def bytes_to_string(bytes_: bytes) -> str:
|
|||
return bytes_.decode(DEFAULT_MSG_ENCODING)
|
||||
|
||||
|
||||
def pyobj_to_bytes(pyobj) -> bytes:
|
||||
def pyobj_to_bytes(pyobj: Any) -> bytes:
|
||||
return pickle.dumps(pyobj)
|
||||
|
||||
|
||||
def bytes_to_pyobj(bytes_: bytes) -> object:
|
||||
def bytes_to_pyobj(bytes_: bytes) -> Any:
|
||||
return pickle.loads(bytes_)
|
||||
|
||||
|
||||
|
|
|
@ -55,5 +55,5 @@ def average_grads(grad_list: List[dict]) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def get_torch_device(device: str = None):
|
||||
def get_torch_device(device: str = None) -> torch.device:
|
||||
return torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
|
|
@ -207,7 +207,7 @@ class ConfigParser:
|
|||
f"{self._validation_err_pfx}: 'training.checkpointing.interval' must be an int",
|
||||
)
|
||||
|
||||
def _validate_logging_section(self, component, level_dict: dict) -> None:
|
||||
def _validate_logging_section(self, component: str, level_dict: dict) -> None:
|
||||
if any(key not in {"stdout", "file"} for key in level_dict):
|
||||
raise KeyError(
|
||||
f"{self._validation_err_pfx}: fields under section '{component}.logging' must be 'stdout' or 'file'",
|
||||
|
@ -261,7 +261,7 @@ class ConfigParser:
|
|||
num_episodes = self._config["main"]["num_episodes"]
|
||||
main_proc = f"{self._config['job']}.main"
|
||||
min_n_sample = self._config["main"].get("min_n_sample", 1)
|
||||
env = {
|
||||
env: dict = {
|
||||
main_proc: (
|
||||
os.path.join(self._get_workflow_path(containerize=containerize), "main.py"),
|
||||
{
|
||||
|
|
|
@ -6,116 +6,157 @@ import importlib
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Type
|
||||
from typing import List, Union
|
||||
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import BatchEnvSampler, ExpElement
|
||||
from maro.rl.rollout import AbsEnvSampler, BatchEnvSampler, ExpElement
|
||||
from maro.rl.training import TrainingManager
|
||||
from maro.rl.utils import get_torch_device
|
||||
from maro.rl.utils.common import float_or_none, get_env, int_or_none, list_or_none
|
||||
from maro.rl.utils.training import get_latest_ep
|
||||
from maro.rl.workflows.utils import env_str_helper
|
||||
from maro.utils import LoggerV2
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
class WorkflowEnvAttributes:
|
||||
def __init__(self) -> None:
|
||||
# Number of training episodes
|
||||
self.num_episodes = int(env_str_helper(get_env("NUM_EPISODES")))
|
||||
|
||||
# Maximum number of steps in on round of sampling.
|
||||
self.num_steps = int_or_none(get_env("NUM_STEPS", required=False))
|
||||
|
||||
# Minimum number of data samples to start a round of training. If the data samples are insufficient, re-run
|
||||
# data sampling until we have at least `min_n_sample` data entries.
|
||||
self.min_n_sample = int(env_str_helper(get_env("MIN_N_SAMPLE")))
|
||||
|
||||
# Path to store logs.
|
||||
self.log_path = get_env("LOG_PATH")
|
||||
|
||||
# Log levels
|
||||
self.log_level_stdout = get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL")
|
||||
self.log_level_file = get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL")
|
||||
|
||||
# Parallelism of sampling / evaluation. Used in distributed sampling.
|
||||
self.env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False))
|
||||
self.env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False))
|
||||
|
||||
# Training mode, simple or distributed
|
||||
self.train_mode = get_env("TRAIN_MODE")
|
||||
|
||||
# Evaluating schedule.
|
||||
self.eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False))
|
||||
|
||||
# Restore configurations.
|
||||
self.load_path = get_env("LOAD_PATH", required=False)
|
||||
self.load_episode = int_or_none(get_env("LOAD_EPISODE", required=False))
|
||||
|
||||
# Checkpointing configurations.
|
||||
self.checkpoint_path = get_env("CHECKPOINT_PATH", required=False)
|
||||
self.checkpoint_interval = int_or_none(get_env("CHECKPOINT_INTERVAL", required=False))
|
||||
|
||||
# Parallel sampling configurations.
|
||||
self.parallel_rollout = self.env_sampling_parallelism is not None or self.env_eval_parallelism is not None
|
||||
if self.parallel_rollout:
|
||||
self.port = int(env_str_helper(get_env("ROLLOUT_CONTROLLER_PORT")))
|
||||
self.min_env_samples = int_or_none(get_env("MIN_ENV_SAMPLES", required=False))
|
||||
self.grace_factor = float_or_none(get_env("GRACE_FACTOR", required=False))
|
||||
|
||||
self.is_single_thread = self.train_mode == "simple" and not self.parallel_rollout
|
||||
|
||||
# Distributed training configurations.
|
||||
if self.train_mode != "simple":
|
||||
self.proxy_address = (
|
||||
env_str_helper(get_env("TRAIN_PROXY_HOST")),
|
||||
int(env_str_helper(get_env("TRAIN_PROXY_FRONTEND_PORT"))),
|
||||
)
|
||||
|
||||
self.logger = LoggerV2(
|
||||
"MAIN",
|
||||
dump_path=self.log_path,
|
||||
dump_mode="a",
|
||||
stdout_level=self.log_level_stdout,
|
||||
file_level=self.log_level_file,
|
||||
)
|
||||
|
||||
|
||||
def _get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="MARO RL workflow parser")
|
||||
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(rl_component_bundle: RLComponentBundle, args: argparse.Namespace) -> None:
|
||||
if args.evaluate_only:
|
||||
evaluate_only_workflow(rl_component_bundle)
|
||||
else:
|
||||
training_workflow(rl_component_bundle)
|
||||
|
||||
|
||||
def training_workflow(rl_component_bundle: RLComponentBundle) -> None:
|
||||
num_episodes = int(get_env("NUM_EPISODES"))
|
||||
num_steps = int_or_none(get_env("NUM_STEPS", required=False))
|
||||
min_n_sample = int_or_none(get_env("MIN_N_SAMPLE"))
|
||||
|
||||
logger = LoggerV2(
|
||||
"MAIN",
|
||||
dump_path=get_env("LOG_PATH"),
|
||||
dump_mode="a",
|
||||
stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"),
|
||||
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
|
||||
)
|
||||
logger.info("Start training workflow.")
|
||||
|
||||
env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False))
|
||||
env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False))
|
||||
parallel_rollout = env_sampling_parallelism is not None or env_eval_parallelism is not None
|
||||
train_mode = get_env("TRAIN_MODE")
|
||||
|
||||
is_single_thread = train_mode == "simple" and not parallel_rollout
|
||||
if is_single_thread:
|
||||
rl_component_bundle.pre_create_policy_instances()
|
||||
|
||||
if parallel_rollout:
|
||||
env_sampler = BatchEnvSampler(
|
||||
sampling_parallelism=env_sampling_parallelism,
|
||||
port=int(get_env("ROLLOUT_CONTROLLER_PORT")),
|
||||
min_env_samples=int_or_none(get_env("MIN_ENV_SAMPLES", required=False)),
|
||||
grace_factor=float_or_none(get_env("GRACE_FACTOR", required=False)),
|
||||
eval_parallelism=env_eval_parallelism,
|
||||
logger=logger,
|
||||
def _get_env_sampler(
|
||||
rl_component_bundle: RLComponentBundle,
|
||||
env_attr: WorkflowEnvAttributes,
|
||||
) -> Union[AbsEnvSampler, BatchEnvSampler]:
|
||||
if env_attr.parallel_rollout:
|
||||
assert env_attr.env_sampling_parallelism is not None
|
||||
return BatchEnvSampler(
|
||||
sampling_parallelism=env_attr.env_sampling_parallelism,
|
||||
port=env_attr.port,
|
||||
min_env_samples=env_attr.min_env_samples,
|
||||
grace_factor=env_attr.grace_factor,
|
||||
eval_parallelism=env_attr.env_eval_parallelism,
|
||||
logger=env_attr.logger,
|
||||
)
|
||||
else:
|
||||
env_sampler = rl_component_bundle.env_sampler
|
||||
if train_mode != "simple":
|
||||
if rl_component_bundle.device_mapping is not None:
|
||||
for policy_name, device_name in rl_component_bundle.device_mapping.items():
|
||||
env_sampler.assign_policy_to_device(policy_name, get_torch_device(device_name))
|
||||
return env_sampler
|
||||
|
||||
|
||||
def main(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes, args: argparse.Namespace) -> None:
|
||||
if args.evaluate_only:
|
||||
evaluate_only_workflow(rl_component_bundle, env_attr)
|
||||
else:
|
||||
training_workflow(rl_component_bundle, env_attr)
|
||||
|
||||
|
||||
def training_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
|
||||
env_attr.logger.info("Start training workflow.")
|
||||
|
||||
env_sampler = _get_env_sampler(rl_component_bundle, env_attr)
|
||||
|
||||
# evaluation schedule
|
||||
eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False))
|
||||
logger.info(f"Policy will be evaluated at the end of episodes {eval_schedule}")
|
||||
env_attr.logger.info(f"Policy will be evaluated at the end of episodes {env_attr.eval_schedule}")
|
||||
eval_point_index = 0
|
||||
|
||||
training_manager = TrainingManager(
|
||||
rl_component_bundle=rl_component_bundle,
|
||||
explicit_assign_device=(train_mode == "simple"),
|
||||
proxy_address=None
|
||||
if train_mode == "simple"
|
||||
else (
|
||||
get_env("TRAIN_PROXY_HOST"),
|
||||
int(get_env("TRAIN_PROXY_FRONTEND_PORT")),
|
||||
),
|
||||
logger=logger,
|
||||
explicit_assign_device=(env_attr.train_mode == "simple"),
|
||||
proxy_address=None if env_attr.train_mode == "simple" else env_attr.proxy_address,
|
||||
logger=env_attr.logger,
|
||||
)
|
||||
|
||||
load_path = get_env("LOAD_PATH", required=False)
|
||||
load_episode = int_or_none(get_env("LOAD_EPISODE", required=False))
|
||||
if load_path:
|
||||
assert isinstance(load_path, str)
|
||||
if env_attr.load_path:
|
||||
assert isinstance(env_attr.load_path, str)
|
||||
|
||||
ep = load_episode if load_episode is not None else get_latest_ep(load_path)
|
||||
path = os.path.join(load_path, str(ep))
|
||||
ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path)
|
||||
path = os.path.join(env_attr.load_path, str(ep))
|
||||
|
||||
loaded = env_sampler.load_policy_state(path)
|
||||
logger.info(f"Loaded policies {loaded} into env sampler from {path}")
|
||||
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
|
||||
|
||||
loaded = training_manager.load(path)
|
||||
logger.info(f"Loaded trainers {loaded} from {path}")
|
||||
env_attr.logger.info(f"Loaded trainers {loaded} from {path}")
|
||||
start_ep = ep + 1
|
||||
else:
|
||||
start_ep = 1
|
||||
|
||||
checkpoint_path = get_env("CHECKPOINT_PATH", required=False)
|
||||
checkpoint_interval = int_or_none(get_env("CHECKPOINT_INTERVAL", required=False))
|
||||
|
||||
# main loop
|
||||
for ep in range(start_ep, num_episodes + 1):
|
||||
collect_time = training_time = 0
|
||||
for ep in range(start_ep, env_attr.num_episodes + 1):
|
||||
collect_time = training_time = 0.0
|
||||
total_experiences: List[List[ExpElement]] = []
|
||||
total_info_list: List[dict] = []
|
||||
n_sample = 0
|
||||
while n_sample < min_n_sample:
|
||||
while n_sample < env_attr.min_n_sample:
|
||||
tc0 = time.time()
|
||||
result = env_sampler.sample(
|
||||
policy_state=training_manager.get_policy_state() if not is_single_thread else None,
|
||||
num_steps=num_steps,
|
||||
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
|
||||
num_steps=env_attr.num_steps,
|
||||
)
|
||||
experiences: List[List[ExpElement]] = result["experiences"]
|
||||
info_list: List[dict] = result["info"]
|
||||
|
@ -128,23 +169,25 @@ def training_workflow(rl_component_bundle: RLComponentBundle) -> None:
|
|||
|
||||
env_sampler.post_collect(total_info_list, ep)
|
||||
|
||||
logger.info(f"Roll-out completed for episode {ep}. Training started...")
|
||||
env_attr.logger.info(f"Roll-out completed for episode {ep}. Training started...")
|
||||
tu0 = time.time()
|
||||
training_manager.record_experiences(total_experiences)
|
||||
training_manager.train_step()
|
||||
if checkpoint_path and (checkpoint_interval is None or ep % checkpoint_interval == 0):
|
||||
assert isinstance(checkpoint_path, str)
|
||||
pth = os.path.join(checkpoint_path, str(ep))
|
||||
if env_attr.checkpoint_path and (not env_attr.checkpoint_interval or ep % env_attr.checkpoint_interval == 0):
|
||||
assert isinstance(env_attr.checkpoint_path, str)
|
||||
pth = os.path.join(env_attr.checkpoint_path, str(ep))
|
||||
training_manager.save(pth)
|
||||
logger.info(f"All trainer states saved under {pth}")
|
||||
env_attr.logger.info(f"All trainer states saved under {pth}")
|
||||
training_time += time.time() - tu0
|
||||
|
||||
# performance details
|
||||
logger.info(f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds")
|
||||
if eval_schedule and ep == eval_schedule[eval_point_index]:
|
||||
env_attr.logger.info(
|
||||
f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds",
|
||||
)
|
||||
if env_attr.eval_schedule and ep == env_attr.eval_schedule[eval_point_index]:
|
||||
eval_point_index += 1
|
||||
result = env_sampler.eval(
|
||||
policy_state=training_manager.get_policy_state() if not is_single_thread else None,
|
||||
policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None,
|
||||
)
|
||||
env_sampler.post_evaluate(result["info"], ep)
|
||||
|
||||
|
@ -153,42 +196,19 @@ def training_workflow(rl_component_bundle: RLComponentBundle) -> None:
|
|||
training_manager.exit()
|
||||
|
||||
|
||||
def evaluate_only_workflow(rl_component_bundle: RLComponentBundle) -> None:
|
||||
logger = LoggerV2(
|
||||
"MAIN",
|
||||
dump_path=get_env("LOG_PATH"),
|
||||
dump_mode="a",
|
||||
stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"),
|
||||
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
|
||||
)
|
||||
logger.info("Start evaluate only workflow.")
|
||||
def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes) -> None:
|
||||
env_attr.logger.info("Start evaluate only workflow.")
|
||||
|
||||
env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False))
|
||||
env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False))
|
||||
parallel_rollout = env_sampling_parallelism is not None or env_eval_parallelism is not None
|
||||
env_sampler = _get_env_sampler(rl_component_bundle, env_attr)
|
||||
|
||||
if parallel_rollout:
|
||||
env_sampler = BatchEnvSampler(
|
||||
sampling_parallelism=env_sampling_parallelism,
|
||||
port=int(get_env("ROLLOUT_CONTROLLER_PORT")),
|
||||
min_env_samples=int_or_none(get_env("MIN_ENV_SAMPLES", required=False)),
|
||||
grace_factor=float_or_none(get_env("GRACE_FACTOR", required=False)),
|
||||
eval_parallelism=env_eval_parallelism,
|
||||
logger=logger,
|
||||
)
|
||||
else:
|
||||
env_sampler = rl_component_bundle.env_sampler
|
||||
if env_attr.load_path:
|
||||
assert isinstance(env_attr.load_path, str)
|
||||
|
||||
load_path = get_env("LOAD_PATH", required=False)
|
||||
load_episode = int_or_none(get_env("LOAD_EPISODE", required=False))
|
||||
if load_path:
|
||||
assert isinstance(load_path, str)
|
||||
|
||||
ep = load_episode if load_episode is not None else get_latest_ep(load_path)
|
||||
path = os.path.join(load_path, str(ep))
|
||||
ep = env_attr.load_episode if env_attr.load_episode is not None else get_latest_ep(env_attr.load_path)
|
||||
path = os.path.join(env_attr.load_path, str(ep))
|
||||
|
||||
loaded = env_sampler.load_policy_state(path)
|
||||
logger.info(f"Loaded policies {loaded} into env sampler from {path}")
|
||||
env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}")
|
||||
|
||||
result = env_sampler.eval()
|
||||
env_sampler.post_evaluate(result["info"], -1)
|
||||
|
@ -198,11 +218,9 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenario_path = get_env("SCENARIO_PATH")
|
||||
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
|
||||
scenario_path = os.path.normpath(scenario_path)
|
||||
sys.path.insert(0, os.path.dirname(scenario_path))
|
||||
module = importlib.import_module(os.path.basename(scenario_path))
|
||||
|
||||
rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls")
|
||||
rl_component_bundle = rl_component_bundle_cls()
|
||||
main(rl_component_bundle, args=get_args())
|
||||
main(getattr(module, "rl_component_bundle"), WorkflowEnvAttributes(), args=_get_args())
|
||||
|
|
|
@ -4,23 +4,22 @@
|
|||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from typing import Type
|
||||
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import RolloutWorker
|
||||
from maro.rl.utils.common import get_env, int_or_none
|
||||
from maro.rl.workflows.utils import env_str_helper
|
||||
from maro.utils import LoggerV2
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenario_path = get_env("SCENARIO_PATH")
|
||||
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
|
||||
scenario_path = os.path.normpath(scenario_path)
|
||||
sys.path.insert(0, os.path.dirname(scenario_path))
|
||||
module = importlib.import_module(os.path.basename(scenario_path))
|
||||
|
||||
rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls")
|
||||
rl_component_bundle = rl_component_bundle_cls()
|
||||
rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle")
|
||||
|
||||
worker_idx = int_or_none(get_env("ID"))
|
||||
worker_idx = int(env_str_helper(get_env("ID")))
|
||||
logger = LoggerV2(
|
||||
f"ROLLOUT-WORKER.{worker_idx}",
|
||||
dump_path=get_env("LOG_PATH"),
|
||||
|
@ -31,7 +30,7 @@ if __name__ == "__main__":
|
|||
worker = RolloutWorker(
|
||||
idx=worker_idx,
|
||||
rl_component_bundle=rl_component_bundle,
|
||||
producer_host=get_env("ROLLOUT_CONTROLLER_HOST"),
|
||||
producer_host=env_str_helper(get_env("ROLLOUT_CONTROLLER_HOST")),
|
||||
producer_port=int_or_none(get_env("ROLLOUT_CONTROLLER_PORT")),
|
||||
logger=logger,
|
||||
)
|
||||
|
|
|
@ -4,21 +4,20 @@
|
|||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from typing import Type
|
||||
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.training import TrainOpsWorker
|
||||
from maro.rl.utils.common import get_env, int_or_none
|
||||
from maro.rl.workflows.utils import env_str_helper
|
||||
from maro.utils import LoggerV2
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenario_path = get_env("SCENARIO_PATH")
|
||||
scenario_path = env_str_helper(get_env("SCENARIO_PATH"))
|
||||
scenario_path = os.path.normpath(scenario_path)
|
||||
sys.path.insert(0, os.path.dirname(scenario_path))
|
||||
module = importlib.import_module(os.path.basename(scenario_path))
|
||||
|
||||
rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls")
|
||||
rl_component_bundle = rl_component_bundle_cls()
|
||||
rl_component_bundle: RLComponentBundle = getattr(module, "rl_component_bundle")
|
||||
|
||||
worker_idx = int_or_none(get_env("ID"))
|
||||
logger = LoggerV2(
|
||||
|
@ -29,9 +28,9 @@ if __name__ == "__main__":
|
|||
file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"),
|
||||
)
|
||||
worker = TrainOpsWorker(
|
||||
idx=int_or_none(get_env("ID")),
|
||||
idx=int(env_str_helper(get_env("ID"))),
|
||||
rl_component_bundle=rl_component_bundle,
|
||||
producer_host=get_env("TRAIN_PROXY_HOST"),
|
||||
producer_host=env_str_helper(get_env("TRAIN_PROXY_HOST")),
|
||||
producer_port=int_or_none(get_env("TRAIN_PROXY_BACKEND_PORT")),
|
||||
logger=logger,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def env_str_helper(string: Optional[str]) -> str:
|
||||
assert string is not None
|
||||
return string
|
|
@ -72,7 +72,7 @@ class AbsEnv(ABC):
|
|||
return self._business_engine
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
|
||||
def step(self, action) -> Tuple[Optional[dict], Optional[list], bool]:
|
||||
"""Push the environment to next step with action.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from importlib import import_module
|
||||
from inspect import getmembers, isclass
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
||||
|
||||
from maro.backends.frame import FrameBase, SnapshotList
|
||||
from maro.data_lib.dump_csv_converter import DumpConverter
|
||||
|
@ -12,6 +11,7 @@ from maro.event_buffer import ActualEvent, CascadeEvent, EventBuffer, EventState
|
|||
from maro.streamit import streamit
|
||||
from maro.utils.exception.simulator_exception import BusinessEngineNotFoundError
|
||||
|
||||
from ..common import BaseAction, BaseDecisionEvent
|
||||
from .abs_core import AbsEnv, DecisionMode
|
||||
from .scenarios.abs_business_engine import AbsBusinessEngine
|
||||
from .utils.common import tick_to_frame_index
|
||||
|
@ -73,8 +73,8 @@ class Env(AbsEnv):
|
|||
|
||||
self._event_buffer = EventBuffer(disable_finished_events, record_finished_events, record_file_path)
|
||||
|
||||
# decision_events array for dump.
|
||||
self._decision_events = []
|
||||
# decision_payloads array for dump.
|
||||
self._decision_payloads = []
|
||||
|
||||
# The generator used to push the simulator forward.
|
||||
self._simulate_generator = self._simulate()
|
||||
|
@ -89,21 +89,48 @@ class Env(AbsEnv):
|
|||
|
||||
self._streamit_episode = 0
|
||||
|
||||
def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
|
||||
def step(
|
||||
self,
|
||||
action: Union[BaseAction, List[BaseAction], None] = None,
|
||||
) -> Tuple[Optional[dict], Union[BaseDecisionEvent, List[BaseDecisionEvent], None], bool]:
|
||||
"""Push the environment to next step with action.
|
||||
|
||||
Under Sequential mode:
|
||||
- If `action` is None, an empty list will be assigned to the decision event.
|
||||
- Otherwise, the action(s) will be assigned to the decision event.
|
||||
|
||||
Under Joint mode:
|
||||
- If `action` is None, no actions will be assigned to any decision event.
|
||||
- If `action` is a single action, it will be assigned to the first decision event.
|
||||
- If `action` is a list, actions are assigned to each decision event in order. If the number of actions
|
||||
is less than the number of decision events, extra decision events will not be assigned actions. If
|
||||
the number of actions if larger than the number of decision events, extra actions will be ignored.
|
||||
If you want to assign multiple actions to specific event(s), please explicitly pass a list of list. For
|
||||
example:
|
||||
|
||||
```
|
||||
env.step(action=[[a1, a2], a3, [a4, a5]])
|
||||
```
|
||||
|
||||
Will assign `a1` & `a2` to the first decision event, `a3` to the second decision event, and `a4` & `a5`
|
||||
to the third decision event.
|
||||
|
||||
Particularly, if you only want to assign multiple actions to the first decision event, please
|
||||
pass `[[a1, a2, ..., an]]` (a list of one list) instead of `[a1, a2, ..., an]` (an 1D list of n elements),
|
||||
since the latter one will assign the n actions to the first n decision events.
|
||||
|
||||
Args:
|
||||
action (Action): Action(s) from agent.
|
||||
action (Union[BaseAction, List[BaseAction], None]): Action(s) from agent.
|
||||
|
||||
Returns:
|
||||
tuple: a tuple of (metrics, decision event, is_done).
|
||||
"""
|
||||
try:
|
||||
metrics, decision_event, _is_done = self._simulate_generator.send(action)
|
||||
metrics, decision_payloads, _is_done = self._simulate_generator.send(action)
|
||||
except StopIteration:
|
||||
return None, None, True
|
||||
|
||||
return metrics, decision_event, _is_done
|
||||
return metrics, decision_payloads, _is_done
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Dump environment for restore.
|
||||
|
@ -131,10 +158,14 @@ class Env(AbsEnv):
|
|||
|
||||
self._business_engine.frame.dump(dump_folder)
|
||||
self._converter.start_processing(self.configs)
|
||||
self._converter.dump_descsion_events(self._decision_events, self._start_tick, self._snapshot_resolution)
|
||||
self._converter.dump_descsion_events(
|
||||
self._decision_payloads,
|
||||
self._start_tick,
|
||||
self._snapshot_resolution,
|
||||
)
|
||||
self._business_engine.dump(dump_folder)
|
||||
|
||||
self._decision_events.clear()
|
||||
self._decision_payloads.clear()
|
||||
|
||||
self._business_engine.reset(keep_seed)
|
||||
|
||||
|
@ -267,7 +298,29 @@ class Env(AbsEnv):
|
|||
additional_options=self._additional_options,
|
||||
)
|
||||
|
||||
def _simulate(self) -> Generator[Tuple[dict, List[object], bool], object, None]:
|
||||
def _assign_action(
|
||||
self,
|
||||
action: Union[BaseAction, List[BaseAction], None],
|
||||
decision_event: CascadeEvent,
|
||||
) -> None:
|
||||
decision_event.state = EventState.EXECUTING
|
||||
|
||||
if action is None:
|
||||
actions = []
|
||||
elif not isinstance(action, list):
|
||||
actions = [action]
|
||||
else:
|
||||
actions = action
|
||||
|
||||
decision_event.add_immediate_event(self._event_buffer.gen_action_event(self._tick, actions), is_head=True)
|
||||
|
||||
def _simulate(
|
||||
self,
|
||||
) -> Generator[
|
||||
Tuple[dict, Union[BaseDecisionEvent, List[BaseDecisionEvent]], bool],
|
||||
Union[BaseAction, List[BaseAction], None],
|
||||
None,
|
||||
]:
|
||||
"""This is the generator to wrap each episode process."""
|
||||
self._streamit_episode += 1
|
||||
|
||||
|
@ -282,7 +335,7 @@ class Env(AbsEnv):
|
|||
|
||||
while True:
|
||||
# Keep processing events, until no more events in this tick.
|
||||
pending_events = self._event_buffer.execute(self._tick)
|
||||
pending_events = cast(List[CascadeEvent], self._event_buffer.execute(self._tick))
|
||||
|
||||
if len(pending_events) == 0:
|
||||
# We have processed all the event of current tick, lets go for next tick.
|
||||
|
@ -292,50 +345,25 @@ class Env(AbsEnv):
|
|||
self._business_engine.frame.take_snapshot(self.frame_index)
|
||||
|
||||
# Append source event id to decision events, to support sequential action in joint mode.
|
||||
decision_events = [event.payload for event in pending_events]
|
||||
|
||||
decision_events = (
|
||||
decision_events[0] if self._decision_mode == DecisionMode.Sequential else decision_events
|
||||
)
|
||||
|
||||
# Yield current state first, and waiting for action.
|
||||
actions = yield self._business_engine.get_metrics(), decision_events, False
|
||||
# archive decision events.
|
||||
self._decision_events.append(decision_events)
|
||||
|
||||
if actions is None:
|
||||
# Make business engine easy to work.
|
||||
actions = []
|
||||
elif not isinstance(actions, Iterable):
|
||||
actions = [actions]
|
||||
decision_payloads = [event.payload for event in pending_events]
|
||||
|
||||
if self._decision_mode == DecisionMode.Sequential:
|
||||
# Generate a new atom event first.
|
||||
action_event = self._event_buffer.gen_action_event(self._tick, actions)
|
||||
|
||||
# NOTE: decision event always be a CascadeEvent
|
||||
# We just append the action into sub event of first pending cascade event.
|
||||
event = pending_events[0]
|
||||
assert isinstance(event, CascadeEvent)
|
||||
event.state = EventState.EXECUTING
|
||||
event.add_immediate_event(action_event, is_head=True)
|
||||
self._decision_payloads.append(decision_payloads[0])
|
||||
action = yield self._business_engine.get_metrics(), decision_payloads[0], False
|
||||
self._assign_action(action, pending_events[0])
|
||||
else:
|
||||
# For joint mode, we will assign actions from beginning to end.
|
||||
# Then mark others pending events to finished if not sequential action mode.
|
||||
for i, pending_event in enumerate(pending_events):
|
||||
if i >= len(actions):
|
||||
if self._decision_mode == DecisionMode.Joint:
|
||||
# Ignore following pending events that have no action matched.
|
||||
pending_event.state = EventState.FINISHED
|
||||
else:
|
||||
# Set the state as executing, so event buffer will not pop them again.
|
||||
# Then insert the action to it.
|
||||
action = actions[i]
|
||||
pending_event.state = EventState.EXECUTING
|
||||
action_event = self._event_buffer.gen_action_event(self._tick, action)
|
||||
self._decision_payloads += decision_payloads
|
||||
actions = yield self._business_engine.get_metrics(), decision_payloads, False
|
||||
if actions is None:
|
||||
actions = []
|
||||
assert isinstance(actions, list)
|
||||
|
||||
assert isinstance(pending_event, CascadeEvent)
|
||||
pending_event.add_immediate_event(action_event, is_head=True)
|
||||
for action, event in zip(actions, pending_events):
|
||||
self._assign_action(action, event)
|
||||
|
||||
if self._decision_mode == DecisionMode.Joint:
|
||||
for event in pending_events[len(actions) :]:
|
||||
event.state = EventState.FINISHED
|
||||
|
||||
# Check the end tick of the simulation to decide if we should end the simulation.
|
||||
is_end_tick = self._business_engine.post_step(self._tick)
|
||||
|
|
|
@ -714,38 +714,38 @@ class CimBusinessEngine(AbsBusinessEngine):
|
|||
actions = event.payload
|
||||
assert isinstance(actions, list)
|
||||
|
||||
if actions:
|
||||
for action in actions:
|
||||
vessel_idx = action.vessel_idx
|
||||
port_idx = action.port_idx
|
||||
move_num = action.quantity
|
||||
vessel = self._vessels[vessel_idx]
|
||||
port = self._ports[port_idx]
|
||||
port_empty = port.empty
|
||||
vessel_empty = vessel.empty
|
||||
for action in actions:
|
||||
assert isinstance(action, Action)
|
||||
|
||||
assert isinstance(action, Action)
|
||||
action_type = action.action_type
|
||||
vessel_idx = action.vessel_idx
|
||||
port_idx = action.port_idx
|
||||
move_num = action.quantity
|
||||
vessel = self._vessels[vessel_idx]
|
||||
port = self._ports[port_idx]
|
||||
port_empty = port.empty
|
||||
vessel_empty = vessel.empty
|
||||
|
||||
if action_type == ActionType.DISCHARGE:
|
||||
assert move_num <= vessel_empty
|
||||
action_type = action.action_type
|
||||
|
||||
port.empty = port_empty + move_num
|
||||
vessel.empty = vessel_empty - move_num
|
||||
else:
|
||||
assert move_num <= min(port_empty, vessel.remaining_space)
|
||||
if action_type == ActionType.DISCHARGE:
|
||||
assert move_num <= vessel_empty
|
||||
|
||||
port.empty = port_empty - move_num
|
||||
vessel.empty = vessel_empty + move_num
|
||||
port.empty = port_empty + move_num
|
||||
vessel.empty = vessel_empty - move_num
|
||||
else:
|
||||
assert move_num <= min(port_empty, vessel.remaining_space)
|
||||
|
||||
# Align the event type to make the output readable.
|
||||
event.event_type = Events.DISCHARGE_EMPTY if action_type == ActionType.DISCHARGE else Events.LOAD_EMPTY
|
||||
port.empty = port_empty - move_num
|
||||
vessel.empty = vessel_empty + move_num
|
||||
|
||||
# Update transfer cost for port and metrics.
|
||||
self._total_operate_num += move_num
|
||||
port.transfer_cost += move_num
|
||||
# Align the event type to make the output readable.
|
||||
event.event_type = Events.DISCHARGE_EMPTY if action_type == ActionType.DISCHARGE else Events.LOAD_EMPTY
|
||||
|
||||
self._vessel_plans[vessel_idx, port_idx] += self._data_cntr.vessel_period[vessel_idx]
|
||||
# Update transfer cost for port and metrics.
|
||||
self._total_operate_num += move_num
|
||||
port.transfer_cost += move_num
|
||||
|
||||
self._vessel_plans[vessel_idx, port_idx] += self._data_cntr.vessel_period[vessel_idx]
|
||||
|
||||
def _stream_base_info(self):
|
||||
if streamit:
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
from enum import Enum, IntEnum
|
||||
|
||||
from maro.backends.frame import SnapshotList
|
||||
from maro.common import BaseAction, BaseDecisionEvent
|
||||
|
||||
|
||||
class VesselState(IntEnum):
|
||||
|
@ -21,7 +22,7 @@ class ActionType(Enum):
|
|||
DISCHARGE = "discharge"
|
||||
|
||||
|
||||
class Action:
|
||||
class Action(BaseAction):
|
||||
"""Action object that used to pass action from agent to business engine.
|
||||
|
||||
Args:
|
||||
|
@ -68,7 +69,7 @@ class ActionScope:
|
|||
return "%s {load: %r, discharge: %r}" % (self.__class__.__name__, self.load, self.discharge)
|
||||
|
||||
|
||||
class DecisionEvent:
|
||||
class DecisionEvent(BaseDecisionEvent):
|
||||
"""Decision event for agent.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -15,7 +15,7 @@ from maro.backends.frame import FrameBase, SnapshotList
|
|||
from maro.cli.data_pipeline.citi_bike import CitiBikeProcess
|
||||
from maro.cli.data_pipeline.utils import chagne_file_path
|
||||
from maro.data_lib import BinaryReader
|
||||
from maro.event_buffer import AtomEvent, EventBuffer, MaroEvents
|
||||
from maro.event_buffer import AtomEvent, CascadeEvent, EventBuffer, MaroEvents
|
||||
from maro.simulator.scenarios import AbsBusinessEngine
|
||||
from maro.simulator.scenarios.helpers import DocableDict
|
||||
from maro.simulator.scenarios.matrix_accessor import MatrixAttributeAccessor
|
||||
|
@ -23,7 +23,7 @@ from maro.utils.exception.cli_exception import CommandError
|
|||
from maro.utils.logger import CliLogger
|
||||
|
||||
from .adj_loader import load_adj_from_csv
|
||||
from .common import BikeReturnPayload, BikeTransferPayload, DecisionEvent
|
||||
from .common import Action, BikeReturnPayload, BikeTransferPayload, DecisionEvent
|
||||
from .decision_strategy import BikeDecisionStrategy
|
||||
from .events import CitiBikeEvents
|
||||
from .frame_builder import build_frame
|
||||
|
@ -33,7 +33,6 @@ from .weather_table import WeatherTable
|
|||
|
||||
logger = CliLogger(name=__name__)
|
||||
|
||||
|
||||
metrics_desc = """
|
||||
Citi bike metrics used to provide statistics information at current point (may be in the middle of a tick).
|
||||
It contains following keys:
|
||||
|
@ -519,14 +518,15 @@ class CitibikeBusinessEngine(AbsBusinessEngine):
|
|||
|
||||
station.bikes = station_bikes + max_accept_number
|
||||
|
||||
def _on_action_received(self, evt: AtomEvent):
|
||||
def _on_action_received(self, evt: CascadeEvent):
|
||||
"""Callback when we get an action from agent."""
|
||||
action = None
|
||||
actions = evt.payload
|
||||
|
||||
if evt is None or evt.payload is None:
|
||||
return
|
||||
assert isinstance(actions, list)
|
||||
|
||||
for action in actions:
|
||||
assert isinstance(action, Action)
|
||||
|
||||
for action in evt.payload:
|
||||
from_station_idx: int = action.from_station_idx
|
||||
to_station_idx: int = action.to_station_idx
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from maro.common import BaseAction, BaseDecisionEvent
|
||||
|
||||
|
||||
class BikeTransferPayload:
|
||||
"""Payload for bike transfer event.
|
||||
|
@ -63,7 +65,7 @@ class DecisionType(Enum):
|
|||
Demand = "demand"
|
||||
|
||||
|
||||
class DecisionEvent:
|
||||
class DecisionEvent(BaseDecisionEvent):
|
||||
"""Citi bike scenario decision event that contains station information for agent to choose action.
|
||||
|
||||
Args:
|
||||
|
@ -127,7 +129,7 @@ class DecisionEvent:
|
|||
)
|
||||
|
||||
|
||||
class Action:
|
||||
class Action(BaseAction):
|
||||
"""Citi bike scenario action object, that used to pass action from agent to business engine.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -25,7 +25,7 @@ class Station(NodeBase):
|
|||
# avg temp
|
||||
temperature = NodeAttribute("i2")
|
||||
|
||||
# 0: sunny, 1: rainy, 2: snowy, 3: sleet
|
||||
# 0: sunny, 1: rainy, 2: snowy, 3: sleet
|
||||
weather = NodeAttribute("i2")
|
||||
|
||||
# 0: holiday, 1: not holiday
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from .business_engine import VmSchedulingBusinessEngine
|
||||
from .common import AllocateAction, DecisionPayload, Latency, PostponeAction, VmRequestPayload
|
||||
from .common import AllocateAction, DecisionEvent, Latency, PostponeAction, VmRequestPayload
|
||||
from .cpu_reader import CpuReader
|
||||
from .enums import Events, PmState, PostponeType, VmCategory
|
||||
from .physical_machine import PhysicalMachine
|
||||
|
@ -12,7 +12,7 @@ __all__ = [
|
|||
"VmSchedulingBusinessEngine",
|
||||
"AllocateAction",
|
||||
"PostponeAction",
|
||||
"DecisionPayload",
|
||||
"DecisionEvent",
|
||||
"Latency",
|
||||
"VmRequestPayload",
|
||||
"CpuReader",
|
||||
|
|
|
@ -17,7 +17,7 @@ from maro.simulator.scenarios.helpers import DocableDict
|
|||
from maro.utils.logger import CliLogger
|
||||
from maro.utils.utils import convert_dottable
|
||||
|
||||
from .common import AllocateAction, DecisionPayload, Latency, PostponeAction, VmRequestPayload
|
||||
from .common import Action, AllocateAction, DecisionEvent, Latency, PostponeAction, VmRequestPayload
|
||||
from .cpu_reader import CpuReader
|
||||
from .enums import Events, PmState, PostponeType, VmCategory
|
||||
from .frame_builder import build_frame
|
||||
|
@ -528,7 +528,7 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
"""dict: Event payload details of current scenario."""
|
||||
return {
|
||||
Events.REQUEST.name: VmRequestPayload.summary_key,
|
||||
MaroEvents.PENDING_DECISION.name: DecisionPayload.summary_key,
|
||||
MaroEvents.PENDING_DECISION.name: DecisionEvent.summary_key,
|
||||
}
|
||||
|
||||
def get_agent_idx_list(self) -> List[int]:
|
||||
|
@ -820,7 +820,7 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
|
||||
if len(valid_pm_list) > 0:
|
||||
# Generate pending decision.
|
||||
decision_payload = DecisionPayload(
|
||||
decision_payload = DecisionEvent(
|
||||
frame_index=self.frame_index(tick=self._tick),
|
||||
valid_pms=valid_pm_list,
|
||||
vm_id=vm_info.id,
|
||||
|
@ -846,20 +846,24 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
|
||||
def _on_action_received(self, event: CascadeEvent):
|
||||
"""Callback wen we get an action from agent."""
|
||||
action = None
|
||||
if event is None or event.payload is None:
|
||||
actions = event.payload
|
||||
assert isinstance(actions, list)
|
||||
|
||||
if len(actions) == 0:
|
||||
self._pending_vm_request_payload.pop(self._pending_action_vm_id)
|
||||
return
|
||||
|
||||
cur_tick: int = event.tick
|
||||
for action in actions:
|
||||
assert isinstance(action, Action)
|
||||
|
||||
cur_tick: int = event.tick
|
||||
|
||||
for action in event.payload:
|
||||
vm_id: int = action.vm_id
|
||||
|
||||
if vm_id not in self._pending_vm_request_payload:
|
||||
raise Exception(f"The VM id: '{vm_id}' sent by agent is invalid.")
|
||||
|
||||
if type(action) == AllocateAction:
|
||||
if isinstance(action, AllocateAction):
|
||||
pm_id = action.pm_id
|
||||
vm: VirtualMachine = self._pending_vm_request_payload[vm_id].vm_info
|
||||
lifetime = vm.lifetime
|
||||
|
@ -899,7 +903,7 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
)
|
||||
self._successful_allocation += 1
|
||||
|
||||
elif type(action) == PostponeAction:
|
||||
elif isinstance(action, PostponeAction):
|
||||
postpone_step = action.postpone_step
|
||||
remaining_buffer_time = self._pending_vm_request_payload[vm_id].remaining_buffer_time
|
||||
# Either postpone the requirement event or failed.
|
||||
|
|
|
@ -3,10 +3,12 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from maro.common import BaseAction, BaseDecisionEvent
|
||||
|
||||
from .virtual_machine import VirtualMachine
|
||||
|
||||
|
||||
class Action:
|
||||
class Action(BaseAction):
|
||||
"""VM Scheduling scenario action object, which was used to pass action from agent to business engine.
|
||||
|
||||
Args:
|
||||
|
@ -74,7 +76,7 @@ class VmRequestPayload:
|
|||
)
|
||||
|
||||
|
||||
class DecisionPayload:
|
||||
class DecisionEvent(BaseDecisionEvent):
|
||||
"""Decision event in VM Scheduling scenario that contains information for agent to choose action.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -27,6 +27,11 @@ def get_available_envs():
|
|||
return envs
|
||||
|
||||
|
||||
def scenario_not_empty(scenario_path: str) -> bool:
|
||||
_, _, files = next(os.walk(scenario_path))
|
||||
return "business_engine.py" in files
|
||||
|
||||
|
||||
def get_scenarios() -> List[str]:
|
||||
"""Get built-in scenario name list.
|
||||
|
||||
|
@ -35,7 +40,13 @@ def get_scenarios() -> List[str]:
|
|||
"""
|
||||
try:
|
||||
_, scenarios, _ = next(os.walk(scenarios_root_folder))
|
||||
scenarios = sorted([s for s in scenarios if not s.startswith("__")])
|
||||
scenarios = sorted(
|
||||
[
|
||||
s
|
||||
for s in scenarios
|
||||
if not s.startswith("__") and scenario_not_empty(os.path.join(scenarios_root_folder, s))
|
||||
],
|
||||
)
|
||||
|
||||
except StopIteration:
|
||||
return []
|
||||
|
|
|
@ -0,0 +1,452 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start\n",
|
||||
"\n",
|
||||
"This notebook demonstrates the use of MARO's RL toolkit to optimize container inventory management (CIM). The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import necessary packages\n",
|
||||
"from typing import Any, Dict, List, Tuple, Union\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"from torch.optim import Adam, RMSprop\n",
|
||||
"\n",
|
||||
"from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet\n",
|
||||
"from maro.rl.policy import DiscretePolicyGradient\n",
|
||||
"from maro.rl.rl_component.rl_component_bundle import RLComponentBundle\n",
|
||||
"from maro.rl.rollout import AbsEnvSampler, CacheElement, ExpElement\n",
|
||||
"from maro.rl.training import TrainingManager\n",
|
||||
"from maro.rl.training.algorithms import PPOParams, PPOTrainer\n",
|
||||
"from maro.simulator import Env\n",
|
||||
"from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# env and shaping config\n",
|
||||
"reward_shaping_conf = {\n",
|
||||
" \"time_window\": 99,\n",
|
||||
" \"fulfillment_factor\": 1.0,\n",
|
||||
" \"shortage_factor\": 1.0,\n",
|
||||
" \"time_decay\": 0.97,\n",
|
||||
"}\n",
|
||||
"state_shaping_conf = {\n",
|
||||
" \"look_back\": 7,\n",
|
||||
" \"max_ports_downstream\": 2,\n",
|
||||
"}\n",
|
||||
"port_attributes = [\"empty\", \"full\", \"on_shipper\", \"on_consignee\", \"booking\", \"shortage\", \"fulfillment\"]\n",
|
||||
"vessel_attributes = [\"empty\", \"full\", \"remaining_space\"]\n",
|
||||
"action_shaping_conf = {\n",
|
||||
" \"action_space\": [(i - 10) / 10 for i in range(21)],\n",
|
||||
" \"finite_vessel_space\": True,\n",
|
||||
" \"has_early_discharge\": True,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Environment Sampler\n",
|
||||
"\n",
|
||||
"An environment sampler defines state, action and reward shaping logic so that policies can interact with the environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class CIMEnvSampler(AbsEnvSampler):\n",
|
||||
" def _get_global_and_agent_state_impl(\n",
|
||||
" self, event: DecisionEvent, tick: int = None,\n",
|
||||
" ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:\n",
|
||||
" tick = self._env.tick\n",
|
||||
" vessel_snapshots, port_snapshots = self._env.snapshot_list[\"vessels\"], self._env.snapshot_list[\"ports\"]\n",
|
||||
" port_idx, vessel_idx = event.port_idx, event.vessel_idx\n",
|
||||
" ticks = [max(0, tick - rt) for rt in range(state_shaping_conf[\"look_back\"] - 1)]\n",
|
||||
" future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')\n",
|
||||
" state = np.concatenate([\n",
|
||||
" port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes],\n",
|
||||
" vessel_snapshots[tick: vessel_idx: vessel_attributes]\n",
|
||||
" ])\n",
|
||||
" return state, {port_idx: state}\n",
|
||||
"\n",
|
||||
" def _translate_to_env_action(\n",
|
||||
" self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionEvent,\n",
|
||||
" ) -> Dict[Any, object]:\n",
|
||||
" action_space = action_shaping_conf[\"action_space\"]\n",
|
||||
" finite_vsl_space = action_shaping_conf[\"finite_vessel_space\"]\n",
|
||||
" has_early_discharge = action_shaping_conf[\"has_early_discharge\"]\n",
|
||||
"\n",
|
||||
" port_idx, model_action = list(action_dict.items()).pop()\n",
|
||||
"\n",
|
||||
" vsl_idx, action_scope = event.vessel_idx, event.action_scope\n",
|
||||
" vsl_snapshots = self._env.snapshot_list[\"vessels\"]\n",
|
||||
" vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float(\"inf\")\n",
|
||||
"\n",
|
||||
" percent = abs(action_space[model_action[0]])\n",
|
||||
" zero_action_idx = len(action_space) / 2 # index corresponding to value zero.\n",
|
||||
" if model_action < zero_action_idx:\n",
|
||||
" action_type = ActionType.LOAD\n",
|
||||
" actual_action = min(round(percent * action_scope.load), vsl_space)\n",
|
||||
" elif model_action > zero_action_idx:\n",
|
||||
" action_type = ActionType.DISCHARGE\n",
|
||||
" early_discharge = vsl_snapshots[self._env.tick:vsl_idx:\"early_discharge\"][0] if has_early_discharge else 0\n",
|
||||
" plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge\n",
|
||||
" actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)\n",
|
||||
" else:\n",
|
||||
" actual_action, action_type = 0, None\n",
|
||||
"\n",
|
||||
" return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)}\n",
|
||||
"\n",
|
||||
" def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:\n",
|
||||
" start_tick = tick + 1\n",
|
||||
" ticks = list(range(start_tick, start_tick + reward_shaping_conf[\"time_window\"]))\n",
|
||||
"\n",
|
||||
" # Get the ports that took actions at the given tick\n",
|
||||
" ports = [int(port) for port in list(env_action_dict.keys())]\n",
|
||||
" port_snapshots = self._env.snapshot_list[\"ports\"]\n",
|
||||
" future_fulfillment = port_snapshots[ticks:ports:\"fulfillment\"].reshape(len(ticks), -1)\n",
|
||||
" future_shortage = port_snapshots[ticks:ports:\"shortage\"].reshape(len(ticks), -1)\n",
|
||||
"\n",
|
||||
" decay_list = [reward_shaping_conf[\"time_decay\"] ** i for i in range(reward_shaping_conf[\"time_window\"])]\n",
|
||||
" rewards = np.float32(\n",
|
||||
" reward_shaping_conf[\"fulfillment_factor\"] * np.dot(future_fulfillment.T, decay_list)\n",
|
||||
" - reward_shaping_conf[\"shortage_factor\"] * np.dot(future_shortage.T, decay_list)\n",
|
||||
" )\n",
|
||||
" return {agent_id: reward for agent_id, reward in zip(ports, rewards)}\n",
|
||||
"\n",
|
||||
" def _post_step(self, cache_element: CacheElement) -> None:\n",
|
||||
" self._info[\"env_metric\"] = self._env.metrics\n",
|
||||
"\n",
|
||||
" def _post_eval_step(self, cache_element: CacheElement) -> None:\n",
|
||||
" self._post_step(cache_element)\n",
|
||||
"\n",
|
||||
" def post_collect(self, info_list: list, ep: int) -> None:\n",
|
||||
" # print the env metric from each rollout worker\n",
|
||||
" for info in info_list:\n",
|
||||
" print(f\"env summary (episode {ep}): {info['env_metric']}\")\n",
|
||||
"\n",
|
||||
" # print the average env metric\n",
|
||||
" if len(info_list) > 1:\n",
|
||||
" metric_keys, num_envs = info_list[0][\"env_metric\"].keys(), len(info_list)\n",
|
||||
" avg_metric = {key: sum(info[\"env_metric\"][key] for info in info_list) / num_envs for key in metric_keys}\n",
|
||||
" print(f\"average env summary (episode {ep}): {avg_metric}\")\n",
|
||||
"\n",
|
||||
" def post_evaluate(self, info_list: list, ep: int) -> None:\n",
|
||||
" self.post_collect(info_list, ep)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Policies & Trainers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"state_dim = (\n",
|
||||
" (state_shaping_conf[\"look_back\"] + 1) * (state_shaping_conf[\"max_ports_downstream\"] + 1) * len(port_attributes)\n",
|
||||
" + len(vessel_attributes)\n",
|
||||
")\n",
|
||||
"action_num = len(action_shaping_conf[\"action_space\"])\n",
|
||||
"\n",
|
||||
"actor_net_conf = {\n",
|
||||
" \"hidden_dims\": [256, 128, 64],\n",
|
||||
" \"activation\": torch.nn.Tanh,\n",
|
||||
" \"softmax\": True,\n",
|
||||
" \"batch_norm\": False,\n",
|
||||
" \"head\": True,\n",
|
||||
"}\n",
|
||||
"critic_net_conf = {\n",
|
||||
" \"hidden_dims\": [256, 128, 64],\n",
|
||||
" \"output_dim\": 1,\n",
|
||||
" \"activation\": torch.nn.LeakyReLU,\n",
|
||||
" \"softmax\": False,\n",
|
||||
" \"batch_norm\": True,\n",
|
||||
" \"head\": True,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"actor_learning_rate = 0.001\n",
|
||||
"critic_learning_rate = 0.001\n",
|
||||
"\n",
|
||||
"class MyActorNet(DiscreteACBasedNet):\n",
|
||||
" def __init__(self, state_dim: int, action_num: int) -> None:\n",
|
||||
" super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)\n",
|
||||
" self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)\n",
|
||||
" self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)\n",
|
||||
"\n",
|
||||
" def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:\n",
|
||||
" return self._actor(states)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MyCriticNet(VNet):\n",
|
||||
" def __init__(self, state_dim: int) -> None:\n",
|
||||
" super(MyCriticNet, self).__init__(state_dim=state_dim)\n",
|
||||
" self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf)\n",
|
||||
" self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate)\n",
|
||||
"\n",
|
||||
" def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:\n",
|
||||
" return self._critic(states).squeeze(-1)\n",
|
||||
"\n",
|
||||
"def get_ppo_trainer(state_dim: int, name: str) -> PPOTrainer:\n",
|
||||
" return PPOTrainer(\n",
|
||||
" name=name,\n",
|
||||
" reward_discount=.0,\n",
|
||||
" params=PPOParams(\n",
|
||||
" get_v_critic_net_func=lambda: MyCriticNet(state_dim),\n",
|
||||
" grad_iters=10,\n",
|
||||
" critic_loss_cls=torch.nn.SmoothL1Loss,\n",
|
||||
" min_logp=None,\n",
|
||||
" lam=.0,\n",
|
||||
" clip_ratio=0.1,\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"learn_env = Env(scenario=\"cim\", topology=\"toy.4p_ssdd_l0.0\", durations=500)\n",
|
||||
"test_env = learn_env\n",
|
||||
"num_agents = len(learn_env.agent_idx_list)\n",
|
||||
"agent2policy = {agent: f\"ppo_{agent}.policy\"for agent in learn_env.agent_idx_list}\n",
|
||||
"policies = [DiscretePolicyGradient(name=f\"ppo_{i}.policy\", policy_net=MyActorNet(state_dim, action_num)) for i in range(num_agents)]\n",
|
||||
"trainers = [get_ppo_trainer(state_dim, f\"ppo_{i}\") for i in range(num_agents)]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## RL component bundle\n",
|
||||
"\n",
|
||||
"An RL component bundle integrate all necessary resources to launch a learning loop."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rl_component_bundle = RLComponentBundle(\n",
|
||||
" env_sampler=CIMEnvSampler(\n",
|
||||
" learn_env=learn_env,\n",
|
||||
" test_env=test_env,\n",
|
||||
" policies=policies,\n",
|
||||
" agent2policy=agent2policy,\n",
|
||||
" reward_eval_delay=reward_shaping_conf[\"time_window\"],\n",
|
||||
" ),\n",
|
||||
" agent2policy=agent2policy,\n",
|
||||
" policies=policies,\n",
|
||||
" trainers=trainers,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Learning Loop\n",
|
||||
"\n",
|
||||
"This code cell demonstrates a typical single-threaded training workflow."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 1): {'order_requirements': 1000000, 'container_shortage': 688632, 'operation_number': 1940226}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 2): {'order_requirements': 1000000, 'container_shortage': 601337, 'operation_number': 2030600}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 3): {'order_requirements': 1000000, 'container_shortage': 544572, 'operation_number': 1737291}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 4): {'order_requirements': 1000000, 'container_shortage': 545506, 'operation_number': 2008160}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 5): {'order_requirements': 1000000, 'container_shortage': 442000, 'operation_number': 1935439}\n",
|
||||
"\n",
|
||||
"Evaluation result:\n",
|
||||
"env summary (episode 5): {'order_requirements': 1000000, 'container_shortage': 533699, 'operation_number': 891248}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 6): {'order_requirements': 1000000, 'container_shortage': 448461, 'operation_number': 1918664}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 7): {'order_requirements': 1000000, 'container_shortage': 469874, 'operation_number': 1745973}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 8): {'order_requirements': 1000000, 'container_shortage': 364469, 'operation_number': 1974592}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 9): {'order_requirements': 1000000, 'container_shortage': 425449, 'operation_number': 1821885}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 10): {'order_requirements': 1000000, 'container_shortage': 386687, 'operation_number': 1798356}\n",
|
||||
"\n",
|
||||
"Evaluation result:\n",
|
||||
"env summary (episode 10): {'order_requirements': 1000000, 'container_shortage': 950000, 'operation_number': 0}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 11): {'order_requirements': 1000000, 'container_shortage': 403236, 'operation_number': 1742253}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 12): {'order_requirements': 1000000, 'container_shortage': 373426, 'operation_number': 1682848}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 13): {'order_requirements': 1000000, 'container_shortage': 357254, 'operation_number': 1845318}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 14): {'order_requirements': 1000000, 'container_shortage': 215681, 'operation_number': 1969606}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 15): {'order_requirements': 1000000, 'container_shortage': 288347, 'operation_number': 1739670}\n",
|
||||
"\n",
|
||||
"Evaluation result:\n",
|
||||
"env summary (episode 15): {'order_requirements': 1000000, 'container_shortage': 639517, 'operation_number': 680980}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 16): {'order_requirements': 1000000, 'container_shortage': 258659, 'operation_number': 1747509}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 17): {'order_requirements': 1000000, 'container_shortage': 202262, 'operation_number': 1982958}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 18): {'order_requirements': 1000000, 'container_shortage': 209018, 'operation_number': 1765574}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 19): {'order_requirements': 1000000, 'container_shortage': 256471, 'operation_number': 1764379}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 20): {'order_requirements': 1000000, 'container_shortage': 259231, 'operation_number': 1737222}\n",
|
||||
"\n",
|
||||
"Evaluation result:\n",
|
||||
"env summary (episode 20): {'order_requirements': 1000000, 'container_shortage': 9000, 'operation_number': 1974766}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 21): {'order_requirements': 1000000, 'container_shortage': 268553, 'operation_number': 1697234}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 22): {'order_requirements': 1000000, 'container_shortage': 212987, 'operation_number': 1788601}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 23): {'order_requirements': 1000000, 'container_shortage': 234729, 'operation_number': 1803468}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 24): {'order_requirements': 1000000, 'container_shortage': 224261, 'operation_number': 1736261}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 25): {'order_requirements': 1000000, 'container_shortage': 191424, 'operation_number': 1952505}\n",
|
||||
"\n",
|
||||
"Evaluation result:\n",
|
||||
"env summary (episode 25): {'order_requirements': 1000000, 'container_shortage': 606940, 'operation_number': 710472}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 26): {'order_requirements': 1000000, 'container_shortage': 223272, 'operation_number': 1895614}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 27): {'order_requirements': 1000000, 'container_shortage': 427395, 'operation_number': 1351830}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 28): {'order_requirements': 1000000, 'container_shortage': 266455, 'operation_number': 1924877}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 29): {'order_requirements': 1000000, 'container_shortage': 362452, 'operation_number': 1747022}\n",
|
||||
"\n",
|
||||
"Collecting result:\n",
|
||||
"env summary (episode 30): {'order_requirements': 1000000, 'container_shortage': 320532, 'operation_number': 1506639}\n",
|
||||
"\n",
|
||||
"Evaluation result:\n",
|
||||
"env summary (episode 30): {'order_requirements': 1000000, 'container_shortage': 639581, 'operation_number': 681708}\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env_sampler = rl_component_bundle.env_sampler\n",
|
||||
"\n",
|
||||
"num_episodes = 30\n",
|
||||
"eval_schedule = [5, 10, 15, 20, 25, 30]\n",
|
||||
"eval_point_index = 0\n",
|
||||
"\n",
|
||||
"training_manager = TrainingManager(rl_component_bundle=rl_component_bundle)\n",
|
||||
"\n",
|
||||
"# main loop\n",
|
||||
"for ep in range(1, num_episodes + 1):\n",
|
||||
" result = env_sampler.sample()\n",
|
||||
" experiences: List[List[ExpElement]] = result[\"experiences\"]\n",
|
||||
" info_list: List[dict] = result[\"info\"]\n",
|
||||
" \n",
|
||||
" print(\"Collecting result:\")\n",
|
||||
" env_sampler.post_collect(info_list, ep)\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
" training_manager.record_experiences(experiences)\n",
|
||||
" training_manager.train_step()\n",
|
||||
"\n",
|
||||
" if ep == eval_schedule[eval_point_index]:\n",
|
||||
" eval_point_index += 1\n",
|
||||
" result = env_sampler.eval()\n",
|
||||
" \n",
|
||||
" print(\"Evaluation result:\")\n",
|
||||
" env_sampler.post_evaluate(result[\"info\"], ep)\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
"training_manager.exit()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "8f57a09d39b50edfb56e79199ef40583334d721b06ead0e38a39e7e79092073c"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
jupyter==1.0.0
|
||||
ipython-genutils>=0.2.0
|
||||
ipython>=7.16.3
|
||||
jupyter-client
|
||||
jupyter-console
|
||||
jupyter-contrib-core
|
||||
|
@ -7,18 +8,10 @@ jupyter-core
|
|||
jupyter-highlight-selected-word
|
||||
jupyter-latex-envs
|
||||
jupyter-nbextensions-configurator
|
||||
jupyter>=1.0.0
|
||||
jupyterlab
|
||||
jupyterlab-server
|
||||
jupyterthemes
|
||||
isort==4.3.21
|
||||
autopep8==1.4.4
|
||||
isort==4.3.21
|
||||
pandas==0.25.3
|
||||
matplotlib==3.1.2
|
||||
seaborn==0.9.0
|
||||
ipython==7.16.3
|
||||
ipython-genutils==0.2.0
|
||||
shap==0.32.1
|
||||
seaborn==0.9.0
|
||||
numpy<1.20.0
|
||||
numba==0.46.0
|
||||
matplotlib>=3.1.2
|
||||
seaborn>=0.9.0
|
||||
shap>=0.32.1
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
add-trailing-comma
|
||||
altair==4.1.0
|
||||
aria2p==0.9.1
|
||||
astroid==2.3.3
|
||||
altair>=4.1.0
|
||||
aria2p>=0.9.1
|
||||
azure-identity
|
||||
azure-mgmt-authorization
|
||||
azure-mgmt-containerservice
|
||||
|
@ -10,60 +8,51 @@ azure-mgmt-storage
|
|||
azure-storage-file-share
|
||||
black==22.3.0
|
||||
certifi==2022.12.7
|
||||
cryptography==36.0.1
|
||||
cycler==0.10.0
|
||||
cryptography>=36.0.1
|
||||
Cython==0.29.14
|
||||
deepdiff==5.7.0
|
||||
deepdiff>=5.7.0
|
||||
docker
|
||||
editorconfig-checker==2.4.0
|
||||
flake8==4.0.1
|
||||
flask-cors==3.0.10
|
||||
flask==1.1.2
|
||||
flask_cors==3.0.10
|
||||
flask_socketio==5.2.0
|
||||
Flask>=1.1.2
|
||||
Flask_Cors>=3.0.10
|
||||
Flask_SocketIO>=5.2.0
|
||||
flloat==0.3.0
|
||||
geopy==2.0.0
|
||||
guppy3==3.0.9
|
||||
holidays==0.10.3
|
||||
isort==4.3.21
|
||||
jinja2==2.11.3
|
||||
kiwisolver==1.1.0
|
||||
kubernetes==21.7.0
|
||||
lazy-object-proxy==1.4.3
|
||||
geopy>=2.0.0
|
||||
holidays>=0.10.3
|
||||
Jinja2>=2.11.3
|
||||
kubernetes>=21.7.0
|
||||
markupsafe==2.0.1
|
||||
matplotlib==3.5.2
|
||||
mccabe==0.6.1
|
||||
networkx==2.4
|
||||
networkx==2.4
|
||||
numpy<1.20.0
|
||||
numpy>=1.19.5
|
||||
palettable==3.3.0
|
||||
pandas==0.25.3
|
||||
pandas>=0.25.3
|
||||
paramiko>=2.9.2
|
||||
pre-commit==2.19.0
|
||||
prompt_toolkit==2.0.10
|
||||
psutil==5.8.0
|
||||
ptvsd==4.3.2
|
||||
prompt_toolkit>=2.0.10
|
||||
psutil>=5.8.0
|
||||
ptvsd>=4.3.2
|
||||
pulp==2.6.0
|
||||
pyaml==20.4.0
|
||||
PyJWT==2.4.0
|
||||
pyparsing==2.4.5
|
||||
python-dateutil==2.8.1
|
||||
PyYAML==5.4.1
|
||||
pyzmq==19.0.2
|
||||
PyJWT>=2.4.0
|
||||
python_dateutil>=2.8.1
|
||||
PyYAML>=5.4.1
|
||||
pyzmq>=19.0.2
|
||||
recommonmark~=0.6.0
|
||||
redis==3.5.3
|
||||
requests==2.25.1
|
||||
scipy==1.7.0
|
||||
requests>=2.25.1
|
||||
scipy>=1.7.0
|
||||
setuptools==58.0.4
|
||||
six==1.13.0
|
||||
sphinx==1.8.6
|
||||
sphinx_rtd_theme==1.0.0
|
||||
streamlit==1.11.1
|
||||
stringcase==1.2.0
|
||||
tabulate==0.8.5
|
||||
streamlit>=0.69.1
|
||||
stringcase>=1.2.0
|
||||
tabulate>=0.8.5
|
||||
termgraph==0.5.3
|
||||
torch==1.6.0
|
||||
torchsummary==1.5.1
|
||||
tqdm==4.51.0
|
||||
tornado>=6.1
|
||||
tqdm>=4.51.0
|
||||
urllib3==1.26.5
|
||||
wrapt==1.11.2
|
||||
zmq==0.0.0
|
||||
|
|
29
setup.py
29
setup.py
|
@ -135,25 +135,20 @@ setup(
|
|||
],
|
||||
install_requires=[
|
||||
# TODO: use a helper function to collect these
|
||||
"numpy<1.20.0",
|
||||
"scipy<=1.7.0",
|
||||
"torch<1.8.0",
|
||||
"holidays>=0.10.3",
|
||||
"pyaml>=20.4.0",
|
||||
"numpy>=1.19.5",
|
||||
"pandas>=0.25.3",
|
||||
"paramiko>=2.9.2",
|
||||
"ptvsd>=4.3.2",
|
||||
"python_dateutil>=2.8.1",
|
||||
"PyYAML>=5.4.1",
|
||||
"pyzmq>=19.0.2",
|
||||
"redis>=3.5.3",
|
||||
"pyzmq<22.1.0",
|
||||
"requests<=2.26.0",
|
||||
"psutil<5.9.0",
|
||||
"deepdiff>=5.2.2",
|
||||
"azure-storage-blob<12.9.0",
|
||||
"azure-storage-common",
|
||||
"geopy>=2.0.0",
|
||||
"pandas<1.2",
|
||||
"PyYAML<5.5.0",
|
||||
"paramiko>=2.7.2",
|
||||
"kubernetes>=12.0.1",
|
||||
"prompt_toolkit<3.1.0",
|
||||
"stringcase>=1.2.0",
|
||||
"requests>=2.25.1",
|
||||
"scipy>=1.7.0",
|
||||
"tabulate>=0.8.5",
|
||||
"torch>=1.6.0, <1.8.0",
|
||||
"tornado>=6.1",
|
||||
]
|
||||
+ specific_requires,
|
||||
entry_points={
|
||||
|
|
|
@ -1,23 +1,17 @@
|
|||
matplotlib>=3.1.2
|
||||
geopy
|
||||
pandas<1.2
|
||||
numpy<1.20.0
|
||||
holidays>=0.10.3
|
||||
pyaml>=20.4.0
|
||||
redis>=3.5.3
|
||||
pyzmq<22.1.0
|
||||
influxdb
|
||||
requests<=2.26.0
|
||||
psutil<5.9.0
|
||||
deepdiff>=5.2.2
|
||||
altair>=4.1.0
|
||||
azure-storage-blob<12.9.0
|
||||
azure-storage-common
|
||||
torch<1.8.0
|
||||
pytest
|
||||
coverage
|
||||
termgraph
|
||||
paramiko>=2.7.2
|
||||
pytz==2019.3
|
||||
aria2p==0.9.1
|
||||
kubernetes>=12.0.1
|
||||
PyYAML<5.5.0
|
||||
coverage>=6.4.1
|
||||
deepdiff>=5.7.0
|
||||
geopy>=2.0.0
|
||||
holidays>=0.10.3
|
||||
kubernetes>=21.7.0
|
||||
numpy>=1.19.5,<1.24.0
|
||||
pandas>=0.25.3
|
||||
paramiko>=2.9.2
|
||||
pytest>=7.1.2
|
||||
PyYAML>=5.4.1
|
||||
pyzmq>=19.0.2
|
||||
redis>=3.5.3
|
||||
streamlit>=0.69.1
|
||||
termgraph>=0.5.3
|
||||
|
|
|
@ -8,7 +8,7 @@ from maro.simulator.core import Env
|
|||
from maro.simulator.utils import get_available_envs, get_scenarios, get_topologies
|
||||
from maro.simulator.utils.common import frame_index_to_ticks, tick_to_frame_index
|
||||
|
||||
from .dummy.dummy_business_engine import DummyEngine
|
||||
from tests.dummy.dummy_business_engine import DummyEngine
|
||||
from tests.utils import backends_to_test
|
||||
|
||||
|
||||
|
@ -326,7 +326,7 @@ class TestEnv(unittest.TestCase):
|
|||
msg=f"env should stop at tick 6, but {env.tick}",
|
||||
)
|
||||
|
||||
# avaiable snapshot should be 7 (0-6)
|
||||
# available snapshot should be 7 (0-6)
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -376,24 +376,17 @@ class TestEnv(unittest.TestCase):
|
|||
with self.assertRaises(FileNotFoundError) as ctx:
|
||||
Env("cim", "None", 100)
|
||||
|
||||
def test_get_avaiable_envs(self):
|
||||
scenario_names = get_scenarios()
|
||||
def test_get_available_envs(self):
|
||||
scenario_names = sorted(get_scenarios())
|
||||
|
||||
# we have 3 built-in scenarios
|
||||
self.assertEqual(3, len(scenario_names))
|
||||
|
||||
self.assertTrue("cim" in scenario_names)
|
||||
self.assertTrue("citi_bike" in scenario_names)
|
||||
|
||||
cim_topoloies = get_topologies("cim")
|
||||
citi_bike_topologies = get_topologies("citi_bike")
|
||||
vm_topoloties = get_topologies("vm_scheduling")
|
||||
self.assertListEqual(scenario_names, ["cim", "citi_bike", "vm_scheduling"])
|
||||
|
||||
env_list = get_available_envs()
|
||||
|
||||
self.assertEqual(
|
||||
len(env_list),
|
||||
len(cim_topoloies) + len(citi_bike_topologies) + len(vm_topoloties) + len(get_topologies("supply_chain")),
|
||||
sum(len(get_topologies(s)) for s in scenario_names),
|
||||
)
|
||||
|
||||
def test_frame_index_to_ticks(self):
|
||||
|
@ -404,7 +397,7 @@ class TestEnv(unittest.TestCase):
|
|||
self.assertListEqual([0, 1], ticks[0])
|
||||
self.assertListEqual([8, 9], ticks[4])
|
||||
|
||||
def test_get_avalible_frame_index_to_ticks_with_default_resolution(self):
|
||||
def test_get_available_frame_index_to_ticks_with_default_resolution(self):
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
|
@ -425,7 +418,7 @@ class TestEnv(unittest.TestCase):
|
|||
self.assertListEqual([t for t in t2f_mapping.keys()], [t for t in range(max_tick)])
|
||||
self.assertListEqual([f for f in t2f_mapping.values()], [f for f in range(max_tick)])
|
||||
|
||||
def test_get_avalible_frame_index_to_ticks_with_resolution2(self):
|
||||
def test_get_available_frame_index_to_ticks_with_resolution2(self):
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import time
|
|||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
from maro.common import BaseDecisionEvent
|
||||
from maro.event_buffer import ActualEvent, AtomEvent, CascadeEvent, DummyEvent, EventBuffer, EventState, MaroEvents
|
||||
from maro.event_buffer.event_linked_list import EventLinkedList
|
||||
from maro.event_buffer.event_pool import EventPool
|
||||
|
@ -251,9 +252,9 @@ class TestEventBuffer(unittest.TestCase):
|
|||
self.eb.execute(1)
|
||||
|
||||
def test_sub_events_with_decision(self):
|
||||
evt1 = self.eb.gen_decision_event(1, (1, 1, 1))
|
||||
sub1 = self.eb.gen_decision_event(1, (2, 2, 2))
|
||||
sub2 = self.eb.gen_decision_event(1, (3, 3, 3))
|
||||
evt1 = self.eb.gen_decision_event(1, BaseDecisionEvent())
|
||||
sub1 = self.eb.gen_decision_event(1, BaseDecisionEvent())
|
||||
sub2 = self.eb.gen_decision_event(1, BaseDecisionEvent())
|
||||
|
||||
evt1.add_immediate_event(sub1, is_head=True)
|
||||
evt1.add_immediate_event(sub2)
|
||||
|
|
Загрузка…
Ссылка в новой задаче