MARO v0.3: a new design of RL Toolkit, CLI refactorization, and corresponding updates. (#539)
* refined proxy coding style * updated images and refined doc * updated images * updated CIM-AC example * refined proxy retry logic * call policy update only for AbsCorePolicy * add limitation of AbsCorePolicy in Actor.collect() * refined actor to return only experiences for policies that received new experiences * fix MsgKey issue in rollout_manager * fix typo in learner * call exit function for parallel rollout manager * update supply chain example distributed training scripts * 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype * reformat render * fix supply chain business engine action type problem * reset supply chain example render figsize from 4 to 3 * Add render to all modes of supply chain example * fix or policy typos * 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes * refined parallel policy manager * updated rl/__init__/py * fixed lint issues and CIM local learner bugs * deleted unwanted supply_chain test files * revised default config for cim-dqn * removed test_store.py as it is no longer needed * 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms * updated figures * removed unwanted import * refactored CIM-DQN example * added MultiProcessRolloutManager and MultiProcessTrainingManager * updated doc * lint issue fix * lint issue fix * fixed import formatting * [Feature] Prioritized Experience Replay (#355) * added prioritized experience replay * deleted unwanted supply_chain test files * fixed import order * import fix * fixed lint issues * fixed import formatting * added note in docstring that rank-based PER has yet to be implemented Co-authored-by: ysqyang <v-yangqi@microsoft.com> * rm AbsDecisionGenerator * small fixes * bug fix * reorganized training folder structure * fixed lint issues * fixed lint issues * policy manager refined * lint fix * restructured CIM-dqn sync code * added policy version index and used it as a measure of experience staleness * lint issue fix * lint issue fix * switched log_dir and proxy_kwargs order * cim example refinement * eval schedule sorted only when it's a list * eval schedule sorted only when it's a list * update sc env wrapper * added docker scripts for cim-dqn * refactored example folder structure and added workflow templates * fixed lint issues * fixed lint issues * fixed template bugs * removed unused imports * refactoring sc in progress * simplified cim meta * fixed build.sh path bug * template refinement * deleted obsolete svgs * updated learner logs * minor edits * refactored templates for easy merge with async PR * added component names for rollout manager and policy manager * fixed incorrect position to add last episode to eval schedule * added max_lag option in templates * formatting edit in docker_compose_yml script * moved local learner and early stopper outside sync_tools * refactored rl toolkit folder structure * refactored rl toolkit folder structure * moved env_wrapper and agent_wrapper inside rl/learner * refined scripts * fixed typo in script * changes needed for running sc * removed unwanted imports * config change for testing sc scenario * changes for perf testing * Asynchronous Training (#364) * remote inference code draft * changed actor to rollout_worker and updated init files * removed unwanted import * updated inits * more async code * added async scripts * added async training code & scripts for CIM-dqn * changed async to async_tools to avoid conflict with python keyword * reverted unwanted change to dockerfile * added doc for policy server * addressed PR comments and fixed a bug in docker_compose_yml.py * fixed lint issue * resolved PR comment * resolved merge conflicts * added async templates * added proxy.close() for actor and policy_server * fixed incorrect position to add last episode to eval schedule * reverted unwanted changes * added missing async files * rm unwanted echo in kill.sh Co-authored-by: ysqyang <v-yangqi@microsoft.com> * renamed sync to synchronous and async to asynchronous to avoid conflict with keyword * added missing policy version increment in LocalPolicyManager * refined rollout manager recv logic * removed a debugging print * added sleep in distributed launcher to avoid hanging * updated api doc and rl toolkit doc * refined dynamic imports using importlib * 1. moved policy update triggers to policy manager; 2. added version control in policy manager * fixed a few bugs and updated cim RL example * fixed a few more bugs * added agent wrapper instantiation to workflows * added agent wrapper instantiation to workflows * removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet * fixed incorrect get_ac_policy signature for CIM * moved exploration inside core policy * added state to exploration call to support context-dependent exploration * separated non_rl_policy_index and rl_policy_index in workflows * modified sc example code according to workflow changes * modified sc example code according to workflow changes * added replay_agent_ids parameter to get_env_func for RL examples * fixed a few bugs * added maro/simulator/scenarios/supply_chain as bind mount * added post-step, post-collect, post-eval and post-update callbacks * fixed lint issues * fixed lint issues * moved instantiation of policy manager inside simple learner * fixed env_wrapper get_reward signature * minor edits * removed get_eperience kwargs from env_wrapper * 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows * added rollout exp disribution option in RL examples * removed unwanted files * 1. made logger internal in learner; 2 removed logger creation in abs classes * checked out supply chain test files from v0.2_sc * 1. added missing model.eval() to choose_action; 2.added entropy features to AC * fixed a bug in ac entropy * abbreviated coefficient to coeff * removed -dqn from job name in rl example config * added tmp patch to dev.df * renamed image name for running rl examples * added get_loss interface for core policies * added policy manager in rl_toolkit.rst * 1. env_wrapper bug fix; 2. policy manager update logic refinement * refactored policy and algorithms * policy interface redesigned * refined policy interfaces * fixed typo * fixed bugs in refactored policy interface * fixed some bugs * refactoring in progress * policy interface and policy manager redesigned * 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts * fixed bug in distributed policy manager * fixed lint issues * fixed lint issues * added scipy in setup * 1. trimmed rollout manager code; 2. added option to docker scripts * updated api doc for policy manager * 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script * 1. simplified rl example structure; 2. fixed lint issues * further rl toolkit code simplifications * more numpy-based optimization in RL toolkit * moved replay buffer inside policy * bug fixes * numpy optimization and associated refactoring * extracted shaping logic out of env_sampler * fixed bug in CIM shaping and lint issues * preliminary implemetation of parallel batch inference * fixed bug in ddpg transition recording * put get_state, get_env_actions, get_reward back in EnvSampler * simplified exploration and core model interfaces * bug fixes and doc update * added improve() interface for RLPolicy for single-thread support * fixed simple policy manager bug * updated doc, rst, notebook * updated notebook * fixed lint issues * fixed entropy bugs in ac.py * reverted to simple policy manager as default * 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit * fixed lint issues and updated rl toolkit images * removed obsolete images * added back agent2policy for general workflow use * V0.2 rl refinement dist (#377) * Support `slice` operation in ExperienceSet * Support naive distributed policy training by proxy * Dynamically allocate trainers according to number of experience * code check * code check * code check * Fix a bug in distributed trianing with no gradient * Code check * Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy * 1.call allocate_trainer() at first of update(); 2.refine according to code review * Code check * Refine code with new interface * Update docs of PolicyManger and ExperienceSet * Add images for rl_toolkit docs * Update diagram of PolicyManager * Refine with new interface * Extract allocation strategy into `allocation_strategy.py` * add `distributed_learn()` in policies for data-parallel training * Update doc of RL_toolkit * Add gradient workers for data-parallel * Refine code and update docs * Lint check * Refine by comments * Rename `trainer` to `worker` * Rename `distributed_learn` to `learn_with_data_parallel` * Refine allocator and remove redundant code in policy_manager * remove arugments in allocate_by_policy and so on * added checkpointing for simple and multi-process policy managers * 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager * added missing set_state and get_state for CIM policies * removed blank line * updated RL workflow README * Integrate `data_parallel` arguments into `worker_allocator` (#402) * 1. simplified workflow config; 2. added comments to CIM shaping * lint issue fix * 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading * 1. moved post_step callback inside env sampler; 2. updated README for rl workflows * refined READEME for CIM * VM scheduling with RL (#375) * added part of vm scheduling RL code * refined vm env_wrapper code style * added DQN * added get_experiences func for ac in vm scheduling * added post_step callback to env wrapper * moved Aiming's tracking and plotting logic into callbacks * added eval env wrapper * renamed AC config variable name for VM * vm scheduling RL code finished * updated README * fixed various bugs and hard coding for vm_scheduling * uncommented callbacks for VM scheduling * Minor revision for better code style * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * vm example refactoring * fixed bugs in vm_scheduling * removed unwanted files from cim dir * reverted to simple policy manager as default * added part of vm scheduling RL code * refined vm env_wrapper code style * vm scheduling RL code finished * added config.py for vm scheduing * resolved rebase conflicts * fixed bugs in vm_scheduling * added get_state and set_state to vm_scheduling policy models * updated README for vm_scheduling with RL Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * SC refinement (#397) * Refine test scripts & pending_order_daily logic * Refactor code for better code style: complete type hint, correct typos, remove unused items. Refactor code for better code style: complete type hint, correct typos, remove unused items. * Polish test_supply_chain.py * update import format * Modify vehicle steps logic & remove outdated test case * Optimize imports * Optimize imports * Lint error * Lint error * Lint error * Add SupplyChainAction * Lint error Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * refined workflow scripts * fixed bug in ParallelAgentWrapper * 1. fixed lint issues; 2. refined main script in workflows * lint issue fix * restored default config for rl example * Update rollout.py * refined env var processing in policy manager workflow * added hasattr check in agent wrapper * updated docker_compose_yml.py * Minor refinement * Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412) * Prepare to merge master_mirror * Lint error * Minor * Merge latest master into v0.3 (#426) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * Minor * Remove docs/source/examples/multi_agent_dqn_cim.rst * Update .gitignore * Update .gitignore Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> * Change `Env.set_seed()` logic (#456) * Change Env.set_seed() logic * Redesign CIM reset logic; fix lint issues; * Lint * Seed type assertion * Remove all SC related files (#473) * RL Toolkit V3 (#471) * added daemon=True for multi-process rollout, policy manager and inference * removed obsolete files * [REDO][PR#406]V0.2 rl refinement taskq (#408) * Add a usable task_queue * Rename some variables * 1. Add ; 2. Integrate related files; 3. Remove * merge `data_parallel` and `num_grad_workers` into `data_parallelism` * Fix bugs in docker_compose_yml.py and Simple/Multi-process mode. * Move `grad_worker` into marl/rl/workflows * 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue. * Refine code and update docs of `TaskQueue` * Support priority for tasks in `task_queue` * Update diagram of policy manager and task queue. * Add configurable `single_task_limit` and correct docstring about `data_parallelism` * Fix lint errors in `supply chain` * RL policy redesign (V2) (#405) * Drafi v2.0 for V2 * Polish models with more comments * Polish policies with more comments * Lint * Lint * Add developer doc for models. * Add developer doc for policies. * Remove policy manager V2 since it is not used and out-of-date * Lint * Lint * refined messy workflow code * merged 'scenario_dir' and 'scenario' in rl config * 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods * 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2 * merged cim and cim_v2 * lint issue fix * refined logging logic * lint issue fix * reversed unwanted changes * . . . . ReplayMemory & IndexScheduler ReplayMemory & IndexScheduler . MultiReplayMemory get_actions_with_logps EnvSampler on the road EnvSampler Minor * LearnerManager * Use batch to transfer data & add SHAPE_CHECK_FLAG * Rename learner to trainer * Add property for policy._is_exploring * CIM test scenario for V3. Manual test passed. Next step: run it, make it works. * env_sampler.py could run * env_sampler refine on the way * First runnable version done * AC could run, but the result is bad. Need to check the logic * Refine abstract method & shape check error info. * Docs * Very detailed compare. Try again. * AC done * DQN check done * Minor * DDPG, not tested * Minors * A rough draft of MAAC * Cannot use CIM as the multi-agent scenario. * Minor * MAAC refinement on the way * Remove ActionWithAux * Refine batch & memory * MAAC example works * Reproduce-able fix. Policy share between env_sampler and trainer_manager. * Detail refinement * Simplify the user configed workflow * Minor * Refine example codes * Minor polishment * Migrate rollout_manager to V3 * Error on the way * Redesign torch.device management * Rl v3 maddpg (#418) * Add MADDPG trainer * Fit independent critics and shared critic modes. * Add a new property: num_policies * Lint * Fix a bug in `sum(rewards)` * Rename `MADDPG` to `DiscreteMADDPG` and fix type hint. * Rename maddpg in examples. * Preparation for data parallel (#420) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * Rename train worker to train ops; add placeholder for abstract methods; * Lint Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> * [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450) * Preparation for data parallel * Minor refinement & lint fix * Lint * Lint * rename atomic_get_batch_grad to get_batch_grad * Fix a unexpected commit * distributed maddpg * Add critic worker * Minor * Data parallel related minorities * Refine code structure for trainers & add more doc strings * Revert a unwanted change * Use TrainWorker to do the actual calculations. * Some minor redesign of the worker's abstraction * Add set/get_policy_state_dict back * Refine set/get_policy_state_dict * Polish policy trainers move train_batch_size to abs trainer delete _train_step_impl() remove _record_impl remove unused methods a minor bug fix in maddpg * Rl v3 data parallel grad worker (#432) * Fit new `trainer_worker` in `grad_worker` and `task_queue`. * Add batch dispatch * Add `tensor_dict` for task submit interface * Move `_remote_learn` to `AbsTrainWorker`. * Complement docstring for task queue and trainer. * dsitributed training pipeline draft * added temporary test files for review purposes * Several code style refinements (#451) * Polish rl_v3/utils/ * Polish rl_v3/distributed/ * Polish rl_v3/policy_trainer/abs_trainer.py * fixed merge conflicts * unified sync and async interfaces * refactored rl_v3; refinement in progress * Finish the runnable pipeline under new design * Remove outdated files; refine class names; optimize imports; * Lint * Minor maddpg related refinement * Lint Co-authored-by: Default <huo53926@126.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Miner bug fix * Coroutine-related bug fix ("get_policy_state") (#452) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * deleted unwanted folder * removed unwanted changes * resolved PR452 comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Quick fix * Redesign experience recording logic (#453) * Two not important fix * Temp draft. Prepare to WFH * Done * Lint * Lint * Calculating advantages / returns (#454) * V1.0 * Complete DDPG * Rl v3 hanging issue fix (#455) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * Final test & format. Ready to merge. Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * Rl v3 parallel rollout (#457) * fixed rebase conflicts * renamed get_policy_func_dict to policy_creator * unified worker interfaces * recovered some files * dist training + cli code move * fixed bugs * added retry logic to client * 1. refactored CIM with various algos; 2. lint * lint * added type hint * removed some logs * lint * Make main.py more IDE friendly * Make main.py more IDE friendly * Lint * load balancing dispatcher * added parallel rollout * lint * Tracker variable type issue; rename to env_sampler_creator; * Rl v3 parallel rollout follow ups (#458) * AbsWorker & AbsDispatcher * Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds. * Fix policy_creator reuse bug * Format code * Merge AbsTrainerManager & SimpleTrainerManager * AC test passed * Lint * Remove AbsTrainer.build() method. Put all initialization operations into __init__ * Redesign AC preprocess batches logic Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Huoran Li <huoranli@microsoft.com> * MADDPG performance bug fix (#459) * Fix MARL (MADDPG) terminal recording bug; some other minor refinements; * Restore Trainer.build() method * Calculate latest action in the get_actor_grad method in MADDPG. * Share critic bug fix * Rl v3 example update (#461) * updated vm_scheduling example and cim notebook * fixed bugs in vm_scheduling * added local train method * bug fix * modified async client logic to fix hidden issue * reverted to default config * fixed PR comments and some bugs * removed hardcode Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Done (#462) * Rl v3 load save (#463) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL Toolkit data parallelism revamp & config utils (#464) * added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments * 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local * 1. fixed rollout exit issue; 2. refined config * removed config file from example * fixed lint issues * fixed lint issues * added main.py under examples/rl * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL doc string (#465) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * Rl config doc (#467) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * added detailed doc * lint * wording refined * resolved some PR comments * resolved more PR comments * typo fix Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * RL online doc (#469) * Model, policy, trainer * RL workflows and env sampler doc in RST (#468) * First rough draft * Minors * Reformat * Lint * Resolve PR comments * 1. type-sensitive env variable getter; 2. updated READMEs for examples * Rl type specific env getter (#466) * 1. type-sensitive env variable getter; 2. updated READMEs for examples * fixed bugs * fixed bugs * bug fixes * lint Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Example bug fix * Optimize parser.py * Resolve PR comments * added detailed doc * lint * wording refined * resolved some PR comments * rewriting rl toolkit rst * resolved more PR comments * typo fix * updated rst Co-authored-by: Huoran Li <huoranli@microsoft.com> Co-authored-by: Default <huo53926@126.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Finish docs/source/key_components/rl_toolkit.rst * API doc * RL online doc image fix (#470) * resolved some PR comments * fix * fixed PR comments * added numfig=True setting in conf.py for sphinx Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Resolve PR comments * Add example github link Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> * Rl v3 pr comment resolution (#474) * added load/save feature * 1. resolved pr comments; 2. reverted maro/cli/k8s * fixed some bugs Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * RL renaming v2 (#476) * Change all Logger in RL to LoggerV2 * TrainerManager => TrainingManager * Add Trainer suffix to all algorithms * Finish docs * Update interface names * Minor fix * Cherry pick latest RL (#498) * Cherry pick * Remove SC related files * Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509) * Cherry pick RL changes from sc_refinement (2a4869) * Limit time display precision * RL incremental refactor (#501) * Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training. AC & PPO for continuous action policy; refine AC & PPO logic. Cherry pick RL changes from GYM-DDPG Cherry pick RL changes from GYM-SAC Minor error in doc string * Add min_n_sample in template and parser * Resolve PR comments. Fix a minor issue in SAC. * RL component bundle (#513) * CIM passed * Update workers * Refine annotations * VM passed * Code formatting. * Minor import loop issue * Pass batch in PPO again * Remove Scenario * Complete docs * Minor * Remove segment * Optimize logic in RLComponentBundle * Resolve PR comments * Move 'post methods from RLComponenetBundle to EnvSampler * Add method to get mapping of available tick to frame index (#415) * add method to get mapping of available tick to frame index * fix lint issue * fix naming issue * Cherry pick from sc_refinement (#527) * Cherry pick from sc_refinement * Cherry pick from sc_refinement * Refine `terminal` / `next_agent_state` logic (#531) * Optimize RL toolkit * Fix bug in terminal/next_state generation * Rewrite terminal/next_state logic again * Minor renaming * Minor bug fix * Resolve PR comments * Merge master into v0.3 (#536) * update docker hub init (#367) * update docker hub init * replace personal account with maro-team * update hello files for CIM * update docker repository name * update docker file name * fix bugs in notebook, rectify docs * fix doc build issue * remove docs from playground; fix citibike lp example Event issue * update the exampel for vector env * update vector env example * update README due to PR comments * add link to playground above MARO installation in README * fix some typos Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * update package version * update README for package description * update image links for pypi package description * update image links for pypi package description * change the input topology schema for CIM real data mode (#372) * change the input topology schema for CIM real data mode * remove unused importing * update test config file correspondingly * add Exception for env test * add cost factors to cim data dump * update CimDataCollection field name * update field name of data collection related code * update package version * adjust interface to reflect actual signature (#374) Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> * update dataclasses requirement to setup * fix: fixing spelling grammarr * fix: fix typo spelling code commented and data_model.rst * Fix Geo vis IP address & SQL logic bugs. (#383) Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)). * Fix the "Wrong future stop tick predictions" bug (#386) * Propose my new solution Refine to the pre-process version . * Optimize import * Fix reset random seed bug (#387) * update the reset interface of Env and BE * Try to fix reset routes generation seed issue * Refine random related logics. * Minor refinement * Test check * Minor * Remove unused functions so far * Minor Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * update package version * Add _init_vessel_plans in business_engine.reset (#388) * update package version * change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391) Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Refine `event_buffer/` module (#389) * Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check * add higher python version (#398) * add higher python version * update pytorch version * update torchvision version Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * CIM scenario refinement (#400) * Cim scenario refinement (#394) * CIM refinement * Fix lint error * Fix lint error * Cim test coverage (#395) * Enrich tests * Refactor CimDataGenerator * Refactor CIM parsers * Minor refinement * Fix lint error * Fix lint error * Fix lint error * Minor refactor * Type * Add two test file folders. Make a slight change to CIM BE. * Lint error * Lint error * Remove unnecessary public interfaces of CIM BE * Cim disable auto action type detection (#399) * Haven't been tested * Modify document * Add ActionType checking * Minor * Lint error * Action quantity should be a position number * Modify related docs & notebooks * Minor * Change test file name. Prepare to merge into master. * . * Minor test patch * Add `clear()` function to class `SimRandom` (#401) * Add SimRandom.clear() * Minor * Remove commented codes * Lint error * update package version * add branch v0.3 to github workflow * update github test workflow * Update requirements.dev.txt (#444) Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact. * Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460) Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3) --- updated-dependencies: - dependency-name: ipython dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Add & sort requirements.dev.txt Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Remove random_config.py * Remove test_trajectory_utils.py * Pass tests * Update rl docs * Remove python 3.6 in test * Update docs Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Wang.Jinyu <jinywan@microsoft.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: GQ.Chen <675865907@qq.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com> Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com> Co-authored-by: slowy07 <slowy.arfy@gmail.com> Co-authored-by: GQ.Chen <v-guanchen@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: solosilence <abhishekkr23rs@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
|
@ -3,6 +3,7 @@
|
|||
*.pyd
|
||||
*.log
|
||||
*.csv
|
||||
*.parquet
|
||||
*.c
|
||||
*.cpp
|
||||
*.DS_Store
|
||||
|
@ -12,15 +13,18 @@
|
|||
.vs/
|
||||
build/
|
||||
log/
|
||||
logs/
|
||||
checkpoint/
|
||||
checkpoints/
|
||||
streamit/
|
||||
dist/
|
||||
*.egg-info/
|
||||
tools/schedule
|
||||
docs/_build
|
||||
test/
|
||||
data/
|
||||
.eggs/
|
||||
maro_venv/
|
||||
pyvenv.cfg
|
||||
htmlcov/
|
||||
.coverage
|
||||
.coveragerc
|
||||
.coverage
|
||||
.coveragerc
|
||||
.tmp/
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
FROM python:3.7-buster
|
||||
WORKDIR /maro
|
||||
|
||||
# Install Apt packages
|
||||
RUN apt-get update --fix-missing
|
||||
RUN apt-get install -y apt-utils
|
||||
RUN apt-get install -y sudo
|
||||
RUN apt-get install -y gcc
|
||||
RUN apt-get install -y libcurl4 libcurl4-openssl-dev libssl-dev curl
|
||||
RUN apt-get install -y libzmq3-dev
|
||||
RUN apt-get install -y python3-pip
|
||||
RUN apt-get install -y python3-dev libpython3.7-dev python-numpy
|
||||
RUN rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python packages
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install --no-cache-dir Cython==0.29.14
|
||||
RUN pip install --no-cache-dir pyaml==20.4.0
|
||||
RUN pip install --no-cache-dir pyzmq==19.0.2
|
||||
RUN pip install --no-cache-dir numpy==1.19.1
|
||||
RUN pip install --no-cache-dir matplotlib
|
||||
RUN pip install --no-cache-dir torch==1.6.0
|
||||
RUN pip install --no-cache-dir scipy
|
||||
RUN pip install --no-cache-dir matplotlib
|
||||
RUN pip install --no-cache-dir redis
|
||||
RUN pip install --no-cache-dir networkx
|
||||
|
||||
COPY maro /maro/maro
|
||||
COPY scripts /maro/scripts/
|
||||
COPY setup.py /maro/
|
||||
RUN bash /maro/scripts/install_maro.sh
|
||||
RUN pip cache purge
|
||||
|
||||
ENV PYTHONPATH=/maro
|
||||
|
||||
CMD ["/bin/bash"]
|
|
@ -1,198 +1,330 @@
|
|||
Agent
|
||||
Distributed
|
||||
================================================================================
|
||||
|
||||
maro.rl.agent.abs\_agent
|
||||
maro.rl.distributed.abs_proxy
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.agent.abs_agent
|
||||
.. automodule:: maro.rl.distributed.abs_proxy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.agent.dqn
|
||||
maro.rl.distributed.abs_worker
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.agent.dqn
|
||||
.. automodule:: maro.rl.distributed.abs_worker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.agent.ddpg
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.agent.ddpg
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.agent.policy\_optimization
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.agent.policy_optimization
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Agent Manager
|
||||
Exploration
|
||||
================================================================================
|
||||
|
||||
maro.rl.agent.abs\_agent\_manager
|
||||
maro.rl.exploration.scheduling
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.agent.abs_agent_manager
|
||||
.. automodule:: maro.rl.exploration.scheduling
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.exploration.strategies
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.exploration.strategies
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Model
|
||||
================================================================================
|
||||
|
||||
maro.rl.model.learning\_model
|
||||
maro.rl.model.algorithm_nets
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.model.torch.learning_model
|
||||
.. automodule:: maro.rl.model.algorithm_nets
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.model.abs_net
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Explorer
|
||||
.. automodule:: maro.rl.model.abs_net
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.model.fc_block
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.model.fc_block
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.model.multi_q_net
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.model.multi_q_net
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.model.policy_net
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.model.policy_net
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.model.q_net
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.model.q_net
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.model.v_net
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.model.v_net
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Policy
|
||||
================================================================================
|
||||
|
||||
maro.rl.exploration.abs\_explorer
|
||||
maro.rl.policy.abs_policy
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.exploration.abs_explorer
|
||||
.. automodule:: maro.rl.policy.abs_policy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.exploration.epsilon\_greedy\_explorer
|
||||
maro.rl.policy.continuous_rl_policy
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.exploration.epsilon_greedy_explorer
|
||||
.. automodule:: maro.rl.policy.continuous_rl_policy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.exploration.noise\_explorer
|
||||
maro.rl.policy.discrete_rl_policy
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.exploration.noise_explorer
|
||||
.. automodule:: maro.rl.policy.discrete_rl_policy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Scheduler
|
||||
RL Component
|
||||
================================================================================
|
||||
|
||||
maro.rl.scheduling.scheduler
|
||||
maro.rl.rl_component.rl_component_bundle
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.scheduling.scheduler
|
||||
.. automodule:: maro.rl.rl_component.rl_component_bundle
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.scheduling.simple\_parameter\_scheduler
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.scheduling.simple_parameter_scheduler
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Shaping
|
||||
Rollout
|
||||
================================================================================
|
||||
|
||||
maro.rl.shaping.abs\_shaper
|
||||
maro.rl.rollout.batch_env_sampler
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.shaping.abs_shaper
|
||||
.. automodule:: maro.rl.rollout.batch_env_sampler
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.rollout.env_sampler
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Storage
|
||||
.. automodule:: maro.rl.rollout.env_sampler
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.rollout.worker
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.rollout.worker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Training
|
||||
================================================================================
|
||||
|
||||
maro.rl.storage.abs\_store
|
||||
maro.rl.training.algorithms
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.storage.abs_store
|
||||
.. automodule:: maro.rl.training.algorithms
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.storage.simple\_store
|
||||
maro.rl.training.proxy
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.storage.simple_store
|
||||
.. automodule:: maro.rl.training.proxy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.training.replay_memory
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Actor
|
||||
.. automodule:: maro.rl.training.replay_memory
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.training.trainer
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.training.trainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.training.training_manager
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.training.training_manager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.training.train_ops
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.training.train_ops
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.training.utils
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.training.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.training.worker
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.training.worker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Utils
|
||||
================================================================================
|
||||
|
||||
maro.rl.actor.abs\_actor
|
||||
maro.rl.utils.common
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.actor.abs_actor
|
||||
.. automodule:: maro.rl.utils.common
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.actor.simple\_actor
|
||||
maro.rl.utils.message_enums
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.actor.simple_actor
|
||||
.. automodule:: maro.rl.utils.message_enums
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.utils.objects
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Learner
|
||||
.. automodule:: maro.rl.utils.objects
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.utils.torch_utils
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.utils.torch_utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.utils.trajectory_computation
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.utils.trajectory_computation
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.utils.transition_batch
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.utils.transition_batch
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Workflows
|
||||
================================================================================
|
||||
|
||||
maro.rl.learner.abs\_learner
|
||||
maro.rl.workflows.config
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.learner.abs_learner
|
||||
.. automodule:: maro.rl.workflows.config
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.learner.simple\_learner
|
||||
maro.rl.workflows.main
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.learner.simple_learner
|
||||
.. automodule:: maro.rl.workflows.main
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Distributed Topologies
|
||||
================================================================================
|
||||
|
||||
maro.rl.dist\_topologies.common
|
||||
maro.rl.workflows.rollout_worker
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.dist_topologies.common
|
||||
.. automodule:: maro.rl.workflows.rollout_worker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.dist\_topologies.single\_learner\_multi\_actor\_sync\_mode
|
||||
maro.rl.workflows.scenario
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.dist_topologies.single_learner_multi_actor_sync_mode
|
||||
.. automodule:: maro.rl.workflows.scenario
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.workflows.train_proxy
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.workflows.train_proxy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
maro.rl.workflows.train_worker
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
.. automodule:: maro.rl.workflows.train_worker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
|
|
@ -100,3 +100,5 @@ source_parsers = {
|
|||
}
|
||||
|
||||
source_suffix = [".md", ".rst"]
|
||||
|
||||
numfig = True
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
Example Scenario: Bike Repositioning (Citi Bike)
|
||||
================================================
|
||||
|
||||
In this example we demonstrate using a simple greedy policy for `Citi Bike <https://maro.readthedocs.io/en/latest/scenarios/citi_bike.html>`_,
|
||||
a real-world bike repositioning scenario.
|
||||
|
||||
Greedy Policy
|
||||
-------------
|
||||
|
||||
Our greedy policy is simple: if the event type is supply, the policy will make
|
||||
the current station send as many bikes as possible to one of k stations with the most empty docks. If the event type is
|
||||
demand, the policy will make the current station request as many bikes as possible from one of k stations with the most
|
||||
bikes. We use a heap data structure to find the top k supply/demand candidates from the action scope associated with
|
||||
each decision event.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class GreedyPolicy:
|
||||
...
|
||||
def choose_action(self, decision_event: DecisionEvent):
|
||||
if decision_event.type == DecisionType.Supply:
|
||||
"""
|
||||
Find k target stations with the most empty slots, randomly choose one of them and send as many bikes to
|
||||
it as allowed by the action scope
|
||||
"""
|
||||
top_k_demands = []
|
||||
for demand_candidate, available_docks in decision_event.action_scope.items():
|
||||
if demand_candidate == decision_event.station_idx:
|
||||
continue
|
||||
|
||||
heapq.heappush(top_k_demands, (available_docks, demand_candidate))
|
||||
if len(top_k_demands) > self._demand_top_k:
|
||||
heapq.heappop(top_k_demands)
|
||||
|
||||
max_reposition, target_station_idx = random.choice(top_k_demands)
|
||||
action = Action(decision_event.station_idx, target_station_idx, max_reposition)
|
||||
else:
|
||||
"""
|
||||
Find k source stations with the most bikes, randomly choose one of them and request as many bikes from
|
||||
it as allowed by the action scope.
|
||||
"""
|
||||
top_k_supplies = []
|
||||
for supply_candidate, available_bikes in decision_event.action_scope.items():
|
||||
if supply_candidate == decision_event.station_idx:
|
||||
continue
|
||||
|
||||
heapq.heappush(top_k_supplies, (available_bikes, supply_candidate))
|
||||
if len(top_k_supplies) > self._supply_top_k:
|
||||
heapq.heappop(top_k_supplies)
|
||||
|
||||
max_reposition, source_idx = random.choice(top_k_supplies)
|
||||
action = Action(source_idx, decision_event.station_idx, max_reposition)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
Interaction with the Greedy Policy
|
||||
----------------------------------
|
||||
|
||||
This environment is driven by `real trip history data <https://s3.amazonaws.com/tripdata/index.html>`_ from Citi Bike.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
env = Env(scenario=config.env.scenario, topology=config.env.topology, start_tick=config.env.start_tick,
|
||||
durations=config.env.durations, snapshot_resolution=config.env.resolution)
|
||||
|
||||
if config.env.seed is not None:
|
||||
env.set_seed(config.env.seed)
|
||||
|
||||
policy = GreedyPolicy(config.agent.supply_top_k, config.agent.demand_top_k)
|
||||
metrics, decision_event, done = env.step(None)
|
||||
while not done:
|
||||
metrics, decision_event, done = env.step(policy.choose_action(decision_event))
|
||||
|
||||
env.reset()
|
|
@ -1,168 +0,0 @@
|
|||
Multi Agent DQN for CIM
|
||||
================================================
|
||||
|
||||
This example demonstrates how to use MARO's reinforcement learning (RL) toolkit to solve the container
|
||||
inventory management (CIM) problem. It is formalized as a multi-agent reinforcement learning problem,
|
||||
where each port acts as a decision agent. When a vessel arrives at a port, these agents must take actions
|
||||
by transferring a certain amount of containers to / from the vessel. The objective is for the agents to
|
||||
learn policies that minimize the overall container shortage.
|
||||
|
||||
Trajectory
|
||||
----------
|
||||
|
||||
The ``CIMTrajectoryForDQN`` inherits from ``Trajectory`` function and implements methods to be used as callbacks
|
||||
in the roll-out loop. In this example,
|
||||
* ``get_state`` converts environment observations to state vectors that encode temporal and spatial information.
|
||||
The temporal information includes relevant port and vessel information, such as shortage and remaining space,
|
||||
over the past k days (here k = 7). The spatial information includes features of the downstream ports.
|
||||
* ``get_action`` converts agents' output (an integer that maps to a percentage of containers to be loaded
|
||||
to or unloaded from the vessel) to action objects that can be executed by the environment.
|
||||
* ``get_offline_reward`` computes the reward of a given action as a linear combination of fulfillment and
|
||||
shortage within a future time frame.
|
||||
* ``on_finish`` processes a complete trajectory into data that can be used directly by the learning agents.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
class CIMTrajectoryForDQN(Trajectory):
|
||||
def __init__(
|
||||
self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream,
|
||||
reward_time_window, fulfillment_factor, shortage_factor, time_decay,
|
||||
finite_vessel_space=True, has_early_discharge=True
|
||||
):
|
||||
super().__init__(env)
|
||||
self.port_attributes = port_attributes
|
||||
self.vessel_attributes = vessel_attributes
|
||||
self.action_space = action_space
|
||||
self.look_back = look_back
|
||||
self.max_ports_downstream = max_ports_downstream
|
||||
self.reward_time_window = reward_time_window
|
||||
self.fulfillment_factor = fulfillment_factor
|
||||
self.shortage_factor = shortage_factor
|
||||
self.time_decay = time_decay
|
||||
self.finite_vessel_space = finite_vessel_space
|
||||
self.has_early_discharge = has_early_discharge
|
||||
|
||||
def get_state(self, event):
|
||||
vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
|
||||
tick, port_idx, vessel_idx = event.tick, event.port_idx, event.vessel_idx
|
||||
ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)]
|
||||
future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
|
||||
port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes]
|
||||
vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes]
|
||||
return {port_idx: np.concatenate((port_features, vessel_features))}
|
||||
|
||||
def get_action(self, action_by_agent, event):
|
||||
vessel_snapshots = self.env.snapshot_list["vessels"]
|
||||
action_info = list(action_by_agent.values())[0]
|
||||
model_action = action_info[0] if isinstance(action_info, tuple) else action_info
|
||||
scope, tick, port, vessel = event.action_scope, event.tick, event.port_idx, event.vessel_idx
|
||||
zero_action_idx = len(self.action_space) / 2 # index corresponding to value zero.
|
||||
vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float("inf")
|
||||
early_discharge = vessel_snapshots[tick:vessel:"early_discharge"][0] if self.has_early_discharge else 0
|
||||
percent = abs(self.action_space[model_action])
|
||||
|
||||
if model_action < zero_action_idx:
|
||||
action_type = ActionType.LOAD
|
||||
actual_action = min(round(percent * scope.load), vessel_space)
|
||||
elif model_action > zero_action_idx:
|
||||
action_type = ActionType.DISCHARGE
|
||||
plan_action = percent * (scope.discharge + early_discharge) - early_discharge
|
||||
actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
|
||||
else:
|
||||
actual_action, action_type = 0, ActionType.LOAD
|
||||
|
||||
return {port: Action(vessel, port, actual_action, action_type)}
|
||||
|
||||
def get_offline_reward(self, event):
|
||||
port_snapshots = self.env.snapshot_list["ports"]
|
||||
start_tick = event.tick + 1
|
||||
ticks = list(range(start_tick, start_tick + self.reward_time_window))
|
||||
|
||||
future_fulfillment = port_snapshots[ticks::"fulfillment"]
|
||||
future_shortage = port_snapshots[ticks::"shortage"]
|
||||
decay_list = [
|
||||
self.time_decay ** i for i in range(self.reward_time_window)
|
||||
for _ in range(future_fulfillment.shape[0] // self.reward_time_window)
|
||||
]
|
||||
|
||||
tot_fulfillment = np.dot(future_fulfillment, decay_list)
|
||||
tot_shortage = np.dot(future_shortage, decay_list)
|
||||
|
||||
return np.float32(self.fulfillment_factor * tot_fulfillment - self.shortage_factor * tot_shortage)
|
||||
|
||||
def on_env_feedback(self, event, state_by_agent, action_by_agent, reward):
|
||||
self.trajectory["event"].append(event)
|
||||
self.trajectory["state"].append(state_by_agent)
|
||||
self.trajectory["action"].append(action_by_agent)
|
||||
|
||||
def on_finish(self):
|
||||
exp_by_agent = defaultdict(lambda: defaultdict(list))
|
||||
for i in range(len(self.trajectory["state"]) - 1):
|
||||
agent_id = list(self.trajectory["state"][i].keys())[0]
|
||||
exp = exp_by_agent[agent_id]
|
||||
exp["S"].append(self.trajectory["state"][i][agent_id])
|
||||
exp["A"].append(self.trajectory["action"][i][agent_id])
|
||||
exp["R"].append(self.get_offline_reward(self.trajectory["event"][i]))
|
||||
exp["S_"].append(list(self.trajectory["state"][i + 1].values())[0])
|
||||
|
||||
return dict(exp_by_agent)
|
||||
|
||||
|
||||
Agent
|
||||
-----
|
||||
|
||||
The out-of-the-box DQN is used as our agent.
|
||||
|
||||
.. code-block:: python
|
||||
agent_config = {
|
||||
"model": ...,
|
||||
"optimization": ...,
|
||||
"hyper_params": ...
|
||||
}
|
||||
|
||||
def get_dqn_agent():
|
||||
q_model = SimpleMultiHeadModel(
|
||||
FullyConnectedBlock(**agent_config["model"]), optim_option=agent_config["optimization"]
|
||||
)
|
||||
return DQN(q_model, DQNConfig(**agent_config["hyper_params"]))
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
The distributed training consists of one learner process and multiple actor processes. The learner optimizes
|
||||
the policy by collecting roll-out data from the actors to train the underlying agents.
|
||||
|
||||
The actor process must create a roll-out executor for performing the requested roll-outs, which means that the
|
||||
the environment simulator and shapers should be created here. In this example, inference is performed on the
|
||||
actor's side, so a set of DQN agents must be created in order to load the models (and exploration parameters)
|
||||
from the learner.
|
||||
|
||||
.. code-block:: python
|
||||
def cim_dqn_actor():
|
||||
env = Env(**training_config["env"])
|
||||
agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list})
|
||||
actor = Actor(env, agent, CIMTrajectoryForDQN, trajectory_kwargs=common_config)
|
||||
actor.as_worker(training_config["group"])
|
||||
|
||||
The learner's side requires a concrete learner class that inherits from ``AbsLearner`` and implements the ``run``
|
||||
method which contains the main training loop. Here the implementation is similar to the single-threaded version
|
||||
except that the ``collect`` method is used to obtain roll-out data from the actors (since the roll-out executors
|
||||
are located on the actors' side). The agents created here are where training occurs and hence always contains the
|
||||
latest policies.
|
||||
|
||||
.. code-block:: python
|
||||
def cim_dqn_learner():
|
||||
env = Env(**training_config["env"])
|
||||
agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list})
|
||||
scheduler = TwoPhaseLinearParameterScheduler(training_config["max_episode"], **training_config["exploration"])
|
||||
actor = ActorProxy(
|
||||
training_config["group"], training_config["num_actors"],
|
||||
update_trigger=training_config["learner_update_trigger"]
|
||||
)
|
||||
learner = OffPolicyLearner(actor, scheduler, agent, **training_config["training"])
|
||||
learner.run()
|
||||
|
||||
.. note::
|
||||
|
||||
All related code snippets are supported in `maro playground <https://hub.docker.com/r/maro2020/playground>`_.
|
До Ширина: | Высота: | Размер: 186 KiB |
После Ширина: | Высота: | Размер: 27 KiB |
После Ширина: | Высота: | Размер: 28 KiB |
До Ширина: | Высота: | Размер: 180 KiB |
После Ширина: | Высота: | Размер: 24 KiB |
После Ширина: | Высота: | Размер: 28 KiB |
После Ширина: | Высота: | Размер: 64 KiB |
После Ширина: | Высота: | Размер: 29 KiB |
|
@ -89,7 +89,6 @@ Contents
|
|||
:maxdepth: 2
|
||||
:caption: Examples
|
||||
|
||||
examples/multi_agent_dqn_cim.rst
|
||||
examples/greedy_policy_citi_bike.rst
|
||||
|
||||
.. toctree::
|
||||
|
|
|
@ -43,7 +43,7 @@ The main attributes of a message instance include:
|
|||
message = Message(tag="check_in",
|
||||
source="worker_001",
|
||||
destination="master",
|
||||
payload="")
|
||||
body="")
|
||||
|
||||
Session Message
|
||||
^^^^^^^^^^^^^^^
|
||||
|
@ -71,13 +71,13 @@ The stages of each session are maintained internally by the proxy.
|
|||
task_message = SessionMessage(tag="sum",
|
||||
source="master",
|
||||
destination="worker_001",
|
||||
payload=[0, 1, 2, ...],
|
||||
body=[0, 1, 2, ...],
|
||||
session_type=SessionType.TASK)
|
||||
|
||||
notification_message = SessionMessage(tag="check_out",
|
||||
source="worker_001",
|
||||
destination="master",
|
||||
payload="",
|
||||
body="",
|
||||
session_type=SessionType.NOTIFICATION)
|
||||
|
||||
Communication Primitives
|
||||
|
|
|
@ -259,3 +259,799 @@ For better data access, we also provide some advanced features, including:
|
|||
# Also with dynamic implementation, we can get the const attributes which is shared between snapshot list, even without
|
||||
# any snapshot (need to provided one tick for padding).
|
||||
states = test_nodes_snapshots[0: [0, 1]: ["const_attribute", "const_attribute_2"]]
|
||||
|
||||
|
||||
|
||||
States in built-in scenarios' snapshot list
|
||||
-------------------------------------------
|
||||
|
||||
.. TODO: move to environment part?
|
||||
|
||||
Currently there are 3 ways to expose states in built-in scenarios:
|
||||
|
||||
Summary
|
||||
~~~~~~~~~~~
|
||||
|
||||
Summary(env.summary) is used to expose static states to outside, it provide 3 items by default:
|
||||
node_mapping, node_detail and event payload.
|
||||
|
||||
The "node_mapping" item usually contains node name and related index, but the structure may be different
|
||||
for different scenario.
|
||||
|
||||
The "node_detail" usually used to expose node definitions, like node name, attribute name and slot number,
|
||||
this is useful if you want to know what attributes are support for a scenario.
|
||||
|
||||
The "event_payload" used show that payload attributes of event in scenario, like "RETURN_FULL" event in
|
||||
CIM scenario, it contains "src_port_idx", "dest_port_idx" and "quantity".
|
||||
|
||||
Metrics
|
||||
~~~~~~~
|
||||
|
||||
Metrics(env.metrics) is designed that used to expose raw states of reward since we have removed reward
|
||||
support in v0.2 version, and it also can be used to export states that not supported by snapshot list, like dictionary or complex
|
||||
structures. Currently there are 2 ways to get the metrics from environment: env.metrics, or 1st result from env.step.
|
||||
|
||||
This metrics usually is a dictionary with several keys, but this is determined by business engine.
|
||||
|
||||
Snapshot_list
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
Snapshot list is the history of nodes (or data model) for a scenario, it only support numberic data types now.
|
||||
It supported slicing query with a numpy array, so it support batch operations, make it much faster than
|
||||
using raw python objects.
|
||||
|
||||
Nodes and attributes may different for different scenarios, following we will introduce about those in
|
||||
built-in scenarios.
|
||||
|
||||
NOTE:
|
||||
Per tick state means that the attribute value will be reset to 0 after each step.
|
||||
|
||||
CIM
|
||||
---
|
||||
|
||||
Default settings for snapshot list
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Snapshot resolution: 1
|
||||
|
||||
|
||||
Max snapshot number: same as durations
|
||||
|
||||
Nodes and attributes in scenario
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In CIM scenario, there are 3 node types:
|
||||
|
||||
|
||||
port
|
||||
++++
|
||||
|
||||
capacity
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
The capacity of port for stocking containers.
|
||||
|
||||
empty
|
||||
*****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Empty container volume on the port.
|
||||
|
||||
full
|
||||
****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Laden container volume on the port.
|
||||
|
||||
on_shipper
|
||||
**********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Empty containers, which are released to the shipper.
|
||||
|
||||
on_consignee
|
||||
************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Laden containers, which are delivered to the consignee.
|
||||
|
||||
shortage
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick state. Shortage of empty container at current tick.
|
||||
|
||||
acc_storage
|
||||
***********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Accumulated shortage number to the current tick.
|
||||
|
||||
booking
|
||||
*******
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick state. Order booking number of a port at the current tick.
|
||||
|
||||
acc_booking
|
||||
***********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Accumulated order booking number of a port to the current tick.
|
||||
|
||||
fulfillment
|
||||
***********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Fulfilled order number of a port at the current tick.
|
||||
|
||||
acc_fulfillment
|
||||
***************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Accumulated fulfilled order number of a port to the current tick.
|
||||
|
||||
transfer_cost
|
||||
*************
|
||||
|
||||
type: float
|
||||
slots: 1
|
||||
|
||||
Cost of transferring container, which also covers loading and discharging cost.
|
||||
|
||||
vessel
|
||||
++++++
|
||||
|
||||
capacity
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
The capacity of vessel for transferring containers.
|
||||
|
||||
NOTE:
|
||||
This attribute is ignored in current implementation.
|
||||
|
||||
empty
|
||||
*****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Empty container volume on the vessel.
|
||||
|
||||
full
|
||||
****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Laden container volume on the vessel.
|
||||
|
||||
remaining_space
|
||||
***************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Remaining space of the vessel.
|
||||
|
||||
early_discharge
|
||||
***************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Discharged empty container number for loading laden containers.
|
||||
|
||||
route_idx
|
||||
*********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Which route current vessel belongs to.
|
||||
|
||||
last_loc_idx
|
||||
************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Last stop port index in route, it is used to identify where is current vessel.
|
||||
|
||||
next_loc_idx
|
||||
************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Next stop port index in route, it is used to identify where is current vessel.
|
||||
|
||||
past_stop_list
|
||||
**************
|
||||
|
||||
type: int
|
||||
slots: dynamic
|
||||
|
||||
NOTE:
|
||||
This and following attribute are special, that its slot number is determined by configuration,
|
||||
but different with a list attribute, its slot number is fixed at runtime.
|
||||
|
||||
Stop indices that we have stopped in the past.
|
||||
|
||||
past_stop_tick_list
|
||||
*******************
|
||||
|
||||
type: int
|
||||
slots: dynamic
|
||||
|
||||
Ticks that we stopped at the port in the past.
|
||||
|
||||
future_stop_list
|
||||
****************
|
||||
|
||||
type: int
|
||||
slots: dynamic
|
||||
|
||||
Stop indices that we will stop in the future.
|
||||
|
||||
future_stop_tick_list
|
||||
*********************
|
||||
|
||||
type: int
|
||||
slots: dynamic
|
||||
|
||||
Ticks that we will stop in the future.
|
||||
|
||||
matrices
|
||||
++++++++
|
||||
|
||||
Matrices node is used to store big matrix for ports, vessels and containers.
|
||||
|
||||
full_on_ports
|
||||
*************
|
||||
|
||||
type: int
|
||||
slots: port number * port number
|
||||
|
||||
Distribution of full from port to port.
|
||||
|
||||
full_on_vessels
|
||||
***************
|
||||
|
||||
type: int
|
||||
slots: vessel number * port number
|
||||
|
||||
Distribution of full from vessel to port.
|
||||
|
||||
vessel_plans
|
||||
************
|
||||
|
||||
type: int
|
||||
slots: vessel number * port number
|
||||
|
||||
Planed route info for vessels.
|
||||
|
||||
How to
|
||||
~~~~~~
|
||||
|
||||
How to use the matrix(s)
|
||||
++++++++++++++++++++++++
|
||||
|
||||
Matrix is special that it only have one instance (index 0), and the value is saved as a flat 1 dim array, we can reshape it after querying.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# assuming that we want to use full_on_ports attribute.
|
||||
|
||||
tick = 0
|
||||
|
||||
# we can get the instance number of a node by calling the len method
|
||||
port_number = len(env.snapshot_list["port"])
|
||||
|
||||
# this is a 1 dim numpy array
|
||||
full_on_ports = env.snapshot_list["matrices"][tick::"full_on_ports"]
|
||||
|
||||
# reshape it, then this is a 2 dim array that from port to port.
|
||||
full_on_ports = full_on_ports.reshape(port_number, port_number)
|
||||
|
||||
Citi-Bike
|
||||
---------
|
||||
|
||||
Default settings for snapshot list
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Snapshot resolution: 60
|
||||
|
||||
|
||||
Max snapshot number: same as durations
|
||||
|
||||
Nodes and attributes in scenario
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
station
|
||||
+++++++
|
||||
|
||||
bikes
|
||||
*****
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
How many bikes avaiable in current station.
|
||||
|
||||
shortage
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick state. Lack number of bikes in current station.
|
||||
|
||||
trip_requirement
|
||||
****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick states. How many requirements in current station.
|
||||
|
||||
fulfillment
|
||||
***********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
How many requirement is fit in current station.
|
||||
|
||||
capacity
|
||||
********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Max number of bikes this station can take.
|
||||
|
||||
id
|
||||
+++
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Id of current station.
|
||||
|
||||
weekday
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Weekday at current tick.
|
||||
|
||||
temperature
|
||||
***********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Temperature at current tick.
|
||||
|
||||
weather
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Weather at current tick.
|
||||
|
||||
0: sunny, 1: rainy, 2: snowy, 3: sleet.
|
||||
|
||||
holiday
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
If it is holidy at current tick.
|
||||
|
||||
0: holiday, 1: not holiday
|
||||
|
||||
extra_cost
|
||||
**********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Cost after we reach the capacity after executing action, we have to move extra bikes
|
||||
to other stations.
|
||||
|
||||
transfer_cost
|
||||
*************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Cost to execute action to transfer bikes to other station.
|
||||
|
||||
failed_return
|
||||
*************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Per tick state. How many bikes failed to return to current station.
|
||||
|
||||
min_bikes
|
||||
*********
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Min bikes number in a frame.
|
||||
|
||||
matrices
|
||||
++++++++
|
||||
|
||||
trips_adj
|
||||
*********
|
||||
|
||||
type: int
|
||||
slots: station number * station number
|
||||
|
||||
Used to store trip requirement number between 2 stations.
|
||||
|
||||
|
||||
VM-scheduling
|
||||
-------------
|
||||
|
||||
Default settings for snapshot list
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Snapshot resolution: 1
|
||||
|
||||
|
||||
Max snapshot number: same as durations
|
||||
|
||||
Nodes and attributes in scenario
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Cluster
|
||||
+++++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Id of the cluster.
|
||||
|
||||
region_id
|
||||
*********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Region is of current cluster.
|
||||
|
||||
data_center_id
|
||||
**************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Data center id of current cluster.
|
||||
|
||||
total_machine_num
|
||||
******************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machines in the cluster.
|
||||
|
||||
empty_machine_num
|
||||
******************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
The number of empty machines in this cluster. A empty machine means that its allocated CPU cores are 0.
|
||||
|
||||
data_centers
|
||||
++++++++++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Id of current data center.
|
||||
|
||||
region_id
|
||||
*********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Region id of current data center.
|
||||
|
||||
zone_id
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Zone id of current data center.
|
||||
|
||||
total_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machine in current data center.
|
||||
|
||||
empty_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
The number of empty machines in current data center.
|
||||
|
||||
pms
|
||||
+++
|
||||
|
||||
Physical machine node.
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Id of current machine.
|
||||
|
||||
cpu_cores_capacity
|
||||
******************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Max number of cpu core can be used for current machine.
|
||||
|
||||
memory_capacity
|
||||
***************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Max number of memory can be used for current machine.
|
||||
|
||||
pm_type
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Type of current machine.
|
||||
|
||||
cpu_cores_allocated
|
||||
*******************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
How many cpu core is allocated.
|
||||
|
||||
memory_allocated
|
||||
****************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
How many memory is allocated.
|
||||
|
||||
cpu_utilization
|
||||
***************
|
||||
|
||||
type: float
|
||||
slots: 1
|
||||
|
||||
CPU utilization of current machine.
|
||||
|
||||
energy_consumption
|
||||
******************
|
||||
|
||||
type: float
|
||||
slots: 1
|
||||
|
||||
Energy consumption of current machine.
|
||||
|
||||
oversubscribable
|
||||
****************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Physical machine type: non-oversubscribable is -1, empty: 0, oversubscribable is 1.
|
||||
|
||||
region_id
|
||||
*********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Region id of current machine.
|
||||
|
||||
zone_id
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Zone id of current machine.
|
||||
|
||||
data_center_id
|
||||
**************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Data center id of current machine.
|
||||
|
||||
cluster_id
|
||||
**********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Cluster id of current machine.
|
||||
|
||||
rack_id
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Rack id of current machine.
|
||||
|
||||
Rack
|
||||
++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Id of current rack.
|
||||
|
||||
region_id
|
||||
*********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Region id of current rack.
|
||||
|
||||
zone_id
|
||||
*******
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Zone id of current rack.
|
||||
|
||||
data_center_id
|
||||
**************
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Data center id of current rack.
|
||||
|
||||
cluster_id
|
||||
**********
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Cluster id of current rack.
|
||||
|
||||
total_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machines on this rack.
|
||||
|
||||
empty_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Number of machines that not in use on this rack.
|
||||
|
||||
regions
|
||||
+++++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Id of curent region.
|
||||
|
||||
total_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machines in this region.
|
||||
|
||||
empty_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Number of machines that not in use in this region.
|
||||
|
||||
zones
|
||||
+++++
|
||||
|
||||
id
|
||||
***
|
||||
|
||||
type: short
|
||||
slots: 1
|
||||
|
||||
Id of this zone.
|
||||
|
||||
total_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Total number of machines in this zone.
|
||||
|
||||
empty_machine_num
|
||||
*****************
|
||||
|
||||
type: int
|
||||
slots: 1
|
||||
|
||||
Number of machines that not in use in this zone.
|
||||
|
|
|
@ -1,121 +1,198 @@
|
|||
|
||||
RL Toolkit
|
||||
==========
|
||||
|
||||
MARO provides a full-stack abstraction for reinforcement learning (RL), which enables users to
|
||||
apply predefined and customized components to various scenarios. The main abstractions include
|
||||
fundamental components such as `Agent <#agent>`_\ and `Shaper <#shaper>`_\ , and training routine
|
||||
controllers such as `Actor <#actor>` and `Learner <#learner>`.
|
||||
MARO provides a full-stack abstraction for reinforcement learning (RL) which includes various customizable
|
||||
components. In order to provide a gentle introduction for the RL toolkit, we cover the components in a top-down
|
||||
manner, starting from the learning workflow.
|
||||
|
||||
|
||||
Agent
|
||||
-----
|
||||
|
||||
The Agent is the kernel abstraction of the RL formulation for a real-world problem.
|
||||
Our abstraction decouples agent and its underlying model so that an agent can exist
|
||||
as an RL paradigm independent of the inner workings of the models it uses to generate
|
||||
actions or estimate values. For example, the actor-critic algorithm does not need to
|
||||
concern itself with the structures and optimizing schemes of the actor and critic models.
|
||||
This decoupling is achieved by the Core Model abstraction described below.
|
||||
|
||||
|
||||
.. image:: ../images/rl/agent.svg
|
||||
:target: ../images/rl/agent.svg
|
||||
:alt: Agent
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class AbsAgent(ABC):
|
||||
def __init__(self, model: AbsCoreModel, config, experience_pool=None):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = model.to(self.device)
|
||||
self.config = config
|
||||
self._experience_pool = experience_pool
|
||||
|
||||
|
||||
Core Model
|
||||
----------
|
||||
|
||||
MARO provides an abstraction for the underlying models used by agents to form policies and estimate values.
|
||||
The abstraction consists of ``AbsBlock`` and ``AbsCoreModel``, both of which subclass torch's nn.Module.
|
||||
The ``AbsBlock`` represents the smallest structural unit of an NN-based model. For instance, the ``FullyConnectedBlock``
|
||||
provided in the toolkit is a stack of fully connected layers with features like batch normalization,
|
||||
drop-out and skip connection. The ``AbsCoreModel`` is a collection of network components with
|
||||
embedded optimizers and serves as an agent's "brain" by providing a unified interface to it. regardless of how many individual models it requires and how
|
||||
complex the model architecture might be.
|
||||
|
||||
As an example, the initialization of the actor-critic algorithm may look like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
actor_stack = FullyConnectedBlock(...)
|
||||
critic_stack = FullyConnectedBlock(...)
|
||||
model = SimpleMultiHeadModel(
|
||||
{"actor": actor_stack, "critic": critic_stack},
|
||||
optim_option={
|
||||
"actor": OptimizerOption(cls=Adam, params={"lr": 0.001})
|
||||
"critic": OptimizerOption(cls=RMSprop, params={"lr": 0.0001})
|
||||
}
|
||||
)
|
||||
agent = ActorCritic("actor_critic", learning_model, config)
|
||||
|
||||
Choosing an action is simply:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model(state, task_name="actor", training=False)
|
||||
|
||||
And performing one gradient step is simply:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model.learn(critic_loss + actor_loss)
|
||||
|
||||
|
||||
Explorer
|
||||
Workflow
|
||||
--------
|
||||
|
||||
MARO provides an abstraction for exploration in RL. Some RL algorithms such as DQN and DDPG require
|
||||
explicit exploration governed by a set of parameters. The ``AbsExplorer`` class is designed to cater
|
||||
to these needs. Simple exploration schemes, such as ``EpsilonGreedyExplorer`` for discrete action space
|
||||
and ``UniformNoiseExplorer`` and ``GaussianNoiseExplorer`` for continuous action space, are provided in
|
||||
the toolkit.
|
||||
The nice thing about MARO's RL workflows is that it is abstracted neatly from business logic, policies and learning algorithms,
|
||||
making it applicable to practically any scenario that utilizes standard reinforcement learning paradigms. The workflow is
|
||||
controlled by a main process that executes 2-phase learning cycles: roll-out and training (:numref:`1`). The roll-out phase
|
||||
collects data from one or more environment simulators for training. There can be a single environment simulator located in the same thread as the main
|
||||
loop, or multiple environment simulators running in parallel on a set of remote workers (:numref:`2`) if you need to collect large amounts of data
|
||||
fast. The training phase uses the data collected during the roll-out phase to train models involved in RL policies and algorithms.
|
||||
In the case of multiple large models, this phase can be made faster by having the computationally intensive gradient-related tasks
|
||||
sent to a set of remote workers for parallel processing (:numref:`3`).
|
||||
|
||||
As an example, the exploration for DQN may be carried out with the aid of an ``EpsilonGreedyExplorer``:
|
||||
.. _1:
|
||||
.. figure:: ../images/rl/learning_workflow.svg
|
||||
:alt: Overview
|
||||
:align: center
|
||||
|
||||
Learning Workflow
|
||||
|
||||
|
||||
.. _2:
|
||||
.. figure:: ../images/rl/parallel_rollout.svg
|
||||
:alt: Overview
|
||||
:align: center
|
||||
|
||||
Parallel Roll-out
|
||||
|
||||
|
||||
.. _3:
|
||||
.. figure:: ../images/rl/distributed_training.svg
|
||||
:alt: Overview
|
||||
:align: center
|
||||
|
||||
Distributed Training
|
||||
|
||||
|
||||
Environment Sampler
|
||||
-------------------
|
||||
|
||||
An environment sampler is an entity that contains an environment simulator and a set of policies used by agents to
|
||||
interact with the environment (:numref:`4`). When creating an RL formulation for a scenario, it is necessary to define an environment
|
||||
sampler class that includes these key elements:
|
||||
|
||||
- how observations / snapshots of the environment are encoded into state vectors as input to the policy models. This
|
||||
is sometimes referred to as state shaping in applied reinforcement learning;
|
||||
- how model outputs are converted to action objects defined by the environment simulator;
|
||||
- how rewards / penalties are evaluated. This is sometimes referred to as reward shaping.
|
||||
|
||||
In parallel roll-out, each roll-out worker should have its own environment sampler instance.
|
||||
|
||||
|
||||
.. _4:
|
||||
.. figure:: ../images/rl/env_sampler.svg
|
||||
:alt: Overview
|
||||
:align: center
|
||||
|
||||
Environment Sampler
|
||||
|
||||
|
||||
Policy
|
||||
------
|
||||
|
||||
``Policy`` is the most important concept in reinforcement learning. In MARO, the highest level abstraction of a policy
|
||||
object is ``AbsPolicy``. It defines the interface ``get_actions()`` which takes a batch of states as input and returns
|
||||
corresponding actions.
|
||||
The action is defined by the policy itself. It could be a scalar or a vector or any other types.
|
||||
Env sampler should take responsibility for parsing the action to the acceptable format before passing it to the
|
||||
environment.
|
||||
|
||||
The simplest type of policy is ``RuleBasedPolicy`` which generates actions by pre-defined rules. ``RuleBasedPolicy``
|
||||
is mostly used in naive scenarios. However, in most cases where we need to train the policy by interacting with the
|
||||
environment, we need to use ``RLPolicy``. In MARO's design, a policy cannot train itself. Instead,
|
||||
polices could only be trained by :ref:`trainer` (we will introduce trainer later in this page). Therefore, in addition
|
||||
to ``get_actions()``, ``RLPolicy`` also has a set of training-related interfaces, such as ``step()``, ``get_gradients()``
|
||||
and ``set_gradients()``. These interfaces will be called by trainers for training. As you may have noticed, currently
|
||||
we assume policies are built upon deep learning models, so the training-related interfaces are specifically
|
||||
designed for gradient descent.
|
||||
|
||||
|
||||
``RLPolicy`` is further divided into three types:
|
||||
- ``ValueBasedPolicy``: For valued-based policies.
|
||||
- ``DiscretePolicyGradient``: For gradient-based policies that generate discrete actions.
|
||||
- ``ContinuousPolicyGradient``: For gradient-based policies that generate continuous actions.
|
||||
|
||||
The above classes are all concrete classes. Users do not need to implement any new classes, but can directly
|
||||
create a policy object by configuring parameters. Here is a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
explorer = EpsilonGreedyExplorer(num_actions=10)
|
||||
greedy_action = learning_model(state, training=False).argmax(dim=1).data
|
||||
exploration_action = explorer(greedy_action)
|
||||
ValueBasedPolicy(
|
||||
name="policy",
|
||||
q_net=MyQNet(state_dim=128, action_num=64),
|
||||
)
|
||||
|
||||
|
||||
Tools for Training
|
||||
------------------------------
|
||||
For now, you may have no idea about the ``q_net`` parameter, but don't worry, we will introduce it in the next section.
|
||||
|
||||
.. image:: ../images/rl/learner_actor.svg
|
||||
:target: ../images/rl/learner_actor.svg
|
||||
:alt: RL Overview
|
||||
Model
|
||||
-----
|
||||
|
||||
The RL toolkit provides tools that make local and distributed training easy:
|
||||
* Learner, the central controller of the learning process, which consists of collecting simulation data from
|
||||
remote actors and training the agents with them. The training data collection can be done in local or
|
||||
distributed fashion by loading an ``Actor`` or ``ActorProxy`` instance, respectively.
|
||||
* Actor, which implements the ``roll_out`` method where the agent interacts with the environment for one
|
||||
episode. It consists of an environment instance and an agent (a single agent or multiple agents wrapped by
|
||||
``MultiAgentWrapper``). The class provides the as_worker() method which turns it to an event loop where roll-outs
|
||||
are performed on the learner's demand. In distributed RL, there are typically many actor processes running
|
||||
simultaneously to parallelize training data collection.
|
||||
* Actor proxy, which also implements the ``roll_out`` method with the same signature, but manages a set of remote
|
||||
actors for parallel data collection.
|
||||
* Trajectory, which is primarily responsible for translating between scenario-specific information and model
|
||||
input / output. It implements the following methods which are used as callbacks in the actor's roll-out loop:
|
||||
* ``get_state``, which converts observations of an environment into model input. For example, the observation
|
||||
may be represented by a multi-level data structure, which gets encoded by a state shaper to a one-dimensional
|
||||
vector as input to a neural network. The state shaper usually goes hand in hand with the underlying policy
|
||||
or value models.
|
||||
* ``get_action``, which provides model output with necessary context so that it can be executed by the
|
||||
environment simulator.
|
||||
* ``get_reward``, which computes a reward for a given action.
|
||||
* ``on_env_feedback``, which defines things to do upon getting feedback from the environment.
|
||||
* ``on_finish``, which defines things to do upon completion of a roll-out episode.
|
||||
The above code snippet creates a ``ValueBasedPolicy`` object. Let's pay attention to the parameter ``q_net``.
|
||||
``q_net`` accepts a ``DiscreteQNet`` object, and it serves as the core part of a ``ValueBasedPolicy`` object. In
|
||||
other words, ``q_net`` defines the model structure of the Q-network in the value-based policy, and further determines
|
||||
the policy's behavior. ``DiscreteQNet`` is an abstract class, and ``MyQNet`` is a user-defined implementation
|
||||
of ``DiscreteQNet``. It can be a simple MLP, a multi-head transformer, or any other structure that the user wants.
|
||||
|
||||
MARO provides a set of abstractions of basic & commonly used PyTorch models like ``DiscereteQNet``, which enables
|
||||
users to implement their own deep learning models in a handy way. They are:
|
||||
|
||||
- ``DiscreteQNet``: For ``ValueBasedPolicy``.
|
||||
- ``DiscretePolicyNet``: For ``DiscretePolicyGradient``.
|
||||
- ``ContinuousPolicyNet``: For ``ContinuousPolicyGradient``.
|
||||
|
||||
Users should choose the proper types of models according to the type of policies, and then implement their own
|
||||
models by inheriting the abstract ones (just like ``MyQNet``).
|
||||
|
||||
There are also some other models for training purposes. For example:
|
||||
|
||||
- ``VNet``: Used in the critic part in the actor-critic algorithm.
|
||||
- ``MultiQNet``: Used in the critic part in the MADDPG algorithm.
|
||||
- ...
|
||||
|
||||
The way to use these models is exactly the same as the way to use the policy models.
|
||||
|
||||
.. _trainer:
|
||||
|
||||
Algorithm (Trainer)
|
||||
-------
|
||||
|
||||
When introducing policies, we mentioned that policies cannot train themselves. Instead, they have to be trained
|
||||
by external algorithms, which are also called trainers.
|
||||
In MARO, a trainer represents an RL algorithm, such as DQN, actor-critic,
|
||||
and so on. These two concepts are equivalent in the MARO context.
|
||||
Trainers take interaction experiences and store them in the internal memory, and then use the experiences
|
||||
in the memory to train the policies. Like ``RLPolicy``, trainers are also concrete classes, which means they could
|
||||
be used by configuring parameters. Currently, we have 4 trainers (algorithms) in MARO:
|
||||
|
||||
- ``DiscreteActorCriticTrainer``: Actor-critic algorithm for policies that generate discrete actions.
|
||||
- ``DiscretePPOTrainer``: PPO algorithm for policies that generate discrete actions.
|
||||
- ``DDPGTrainer``: DDPG algorithm for policies that generate continuous actions.
|
||||
- ``DQNTrainer``: DQN algorithm for policies that generate discrete actions.
|
||||
- ``DiscreteMADDPGTrainer``: MADDPG algorithm for policies that generate discrete actions.
|
||||
|
||||
Each trainer has a corresponding ``Param`` class to manage all related parameters. For example,
|
||||
``DiscreteActorCriticParams`` contains all parameters used in ``DiscreteActorCriticTrainer``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dataclass
|
||||
class DiscreteActorCriticParams(TrainerParams):
|
||||
get_v_critic_net_func: Callable[[], VNet] = None
|
||||
reward_discount: float = 0.9
|
||||
grad_iters: int = 1
|
||||
critic_loss_cls: Callable = None
|
||||
clip_ratio: float = None
|
||||
lam: float = 0.9
|
||||
min_logp: Optional[float] = None
|
||||
|
||||
An example of creating an actor-critic trainer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
DiscreteActorCriticTrainer(
|
||||
name='ac',
|
||||
params=DiscreteActorCriticParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim=128),
|
||||
reward_discount=.0,
|
||||
grad_iters=10,
|
||||
critic_loss_cls=torch.nn.SmoothL1Loss,
|
||||
min_logp=None,
|
||||
lam=.0
|
||||
)
|
||||
)
|
||||
|
||||
In order to indicate which trainer each policy is trained by, in MARO, we require that the name of the policy
|
||||
start with the name of the trainer responsible for training it. For example, policy ``ac_1.policy_1`` is trained
|
||||
by the trainer named ``ac_1``. Violating this provision will make MARO unable to correctly establish the
|
||||
corresponding relationship between policy and trainer.
|
||||
|
||||
More details and examples can be found in the code base (`link`_).
|
||||
|
||||
.. _link: https://github.com/microsoft/maro/blob/master/examples/rl/cim/policy_trainer.py
|
||||
|
||||
As a summary, the relationship among policy, model, and trainer is demonstrated in :numref:`5`:
|
||||
|
||||
.. _5:
|
||||
.. figure:: ../images/rl/policy_model_trainer.svg
|
||||
:alt: Overview
|
||||
:align: center
|
||||
|
||||
Summary of policy, model, and trainer
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
# Container Inventory Management
|
||||
|
||||
Container inventory management (CIM) is a scenario where reinforcement learning (RL) can potentially prove useful. Three algorithms are used to learn the multi-agent policy in given environments. Each algorithm has a ``config`` folder which contains ``agent_config.py`` and ``training_config.py``. The former contains parameters for the underlying models and algorithm specific hyper-parameters. The latter contains parameters for the environment and the main training loop. The file ``common.py`` contains parameters and utility functions shared by some or all of these algorithms.
|
||||
|
||||
In the ``ac`` folder, , the policy is trained using the Actor-Critc algorithm in single-threaded fashion. The example can be run by simply executing ``python3 main.py``. Logs will be saved in a file named ``cim-ac.CURRENT_TIME_STAMP.log`` under the ``ac/logs`` folder, where ``CURRENT_TIME_STAMP`` is the time of executing the script.
|
||||
|
||||
In the ``dqn`` folder, the policy is trained using the DQN algorithm in multi-process / distributed mode. This example can be run in three ways.
|
||||
* ``python3 main.py`` or ``python3 main.py -w 0`` runs the example in multi-process mode, in which a main process spawns one learner process and a number of actor processes as specified in ``config/training_config.py``.
|
||||
* ``python3 main.py -w 1`` launches the learner process only. This is for distributed training and expects a number of actor processes (as specified in ``config/training_config.py``) running on some other node(s).
|
||||
* ``python3 main.py -w 2`` launches the actor process only. This is for distributed training and expects a learner process running on some other node.
|
||||
Logs will be saved in a file named ``GROUP_NAME.log`` under the ``{ac_gnn, dqn}/logs`` folder, where ``GROUP_NAME`` is specified in the "group" field in ``config/training_config.py``.
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .agent_config import agent_config
|
||||
from .training_config import training_config
|
||||
|
||||
__all__ = ["agent_config", "training_config"]
|
|
@ -1,52 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from maro.rl import OptimOption
|
||||
|
||||
from examples.cim.common import common_config
|
||||
|
||||
input_dim = (
|
||||
(common_config["look_back"] + 1) *
|
||||
(common_config["max_ports_downstream"] + 1) *
|
||||
len(common_config["port_attributes"]) +
|
||||
len(common_config["vessel_attributes"])
|
||||
)
|
||||
|
||||
agent_config = {
|
||||
"model": {
|
||||
"actor": {
|
||||
"input_dim": input_dim,
|
||||
"output_dim": len(common_config["action_space"]),
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": nn.Tanh,
|
||||
"softmax": True,
|
||||
"batch_norm": False,
|
||||
"head": True
|
||||
},
|
||||
"critic": {
|
||||
"input_dim": input_dim,
|
||||
"output_dim": 1,
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"head": True
|
||||
}
|
||||
},
|
||||
"optimization": {
|
||||
"actor": OptimOption(optim_cls=Adam, optim_params={"lr": 0.001}),
|
||||
"critic": OptimOption(optim_cls=RMSprop, optim_params={"lr": 0.001})
|
||||
},
|
||||
"hyper_params": {
|
||||
"reward_discount": .0,
|
||||
"critic_loss_func": nn.SmoothL1Loss(),
|
||||
"train_iters": 10,
|
||||
"actor_loss_coefficient": 0.1,
|
||||
"k": 1,
|
||||
"lam": 0.0
|
||||
# "clip_ratio": 0.8
|
||||
}
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
training_config = {
|
||||
"env": {
|
||||
"scenario": "cim",
|
||||
"topology": "toy.4p_ssdd_l0.0",
|
||||
"durations": 1120,
|
||||
},
|
||||
"max_episode": 50
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.rl import (
|
||||
Actor, ActorCritic, ActorCriticConfig, FullyConnectedBlock, MultiAgentWrapper, SimpleMultiHeadModel,
|
||||
Scheduler, OnPolicyLearner
|
||||
)
|
||||
from maro.simulator import Env
|
||||
from maro.utils import set_seeds
|
||||
|
||||
from examples.cim.ac.config import agent_config, training_config
|
||||
from examples.cim.common import CIMTrajectory, common_config
|
||||
|
||||
|
||||
def get_ac_agent():
|
||||
actor_net = FullyConnectedBlock(**agent_config["model"]["actor"])
|
||||
critic_net = FullyConnectedBlock(**agent_config["model"]["critic"])
|
||||
ac_model = SimpleMultiHeadModel(
|
||||
{"actor": actor_net, "critic": critic_net}, optim_option=agent_config["optimization"],
|
||||
)
|
||||
return ActorCritic(ac_model, ActorCriticConfig(**agent_config["hyper_params"]))
|
||||
|
||||
|
||||
class CIMTrajectoryForAC(CIMTrajectory):
|
||||
def on_finish(self):
|
||||
training_data = {}
|
||||
for event, state, action in zip(self.trajectory["event"], self.trajectory["state"], self.trajectory["action"]):
|
||||
agent_id = list(state.keys())[0]
|
||||
data = training_data.setdefault(agent_id, {"args": [[] for _ in range(4)]})
|
||||
data["args"][0].append(state[agent_id]) # state
|
||||
data["args"][1].append(action[agent_id][0]) # action
|
||||
data["args"][2].append(action[agent_id][1]) # log_p
|
||||
data["args"][3].append(self.get_offline_reward(event)) # reward
|
||||
|
||||
for agent_id in training_data:
|
||||
training_data[agent_id]["args"] = [
|
||||
np.asarray(vals, dtype=np.float32 if i == 3 else None)
|
||||
for i, vals in enumerate(training_data[agent_id]["args"])
|
||||
]
|
||||
|
||||
return training_data
|
||||
|
||||
|
||||
# Single-threaded launcher
|
||||
if __name__ == "__main__":
|
||||
set_seeds(1024) # for reproducibility
|
||||
env = Env(**training_config["env"])
|
||||
agent = MultiAgentWrapper({name: get_ac_agent() for name in env.agent_idx_list})
|
||||
actor = Actor(env, agent, CIMTrajectoryForAC, trajectory_kwargs=common_config) # local actor
|
||||
learner = OnPolicyLearner(actor, training_config["max_episode"])
|
||||
learner.run()
|
|
@ -1,99 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.rl import Trajectory
|
||||
from maro.simulator.scenarios.cim.common import Action, ActionType
|
||||
|
||||
common_config = {
|
||||
"port_attributes": ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"],
|
||||
"vessel_attributes": ["empty", "full", "remaining_space"],
|
||||
"action_space": list(np.linspace(-1.0, 1.0, 21)),
|
||||
# Parameters for computing states
|
||||
"look_back": 7,
|
||||
"max_ports_downstream": 2,
|
||||
# Parameters for computing actions
|
||||
"finite_vessel_space": True,
|
||||
"has_early_discharge": True,
|
||||
# Parameters for computing rewards
|
||||
"reward_time_window": 99,
|
||||
"fulfillment_factor": 1.0,
|
||||
"shortage_factor": 1.0,
|
||||
"time_decay": 0.97
|
||||
}
|
||||
|
||||
|
||||
class CIMTrajectory(Trajectory):
|
||||
def __init__(
|
||||
self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream,
|
||||
reward_time_window, fulfillment_factor, shortage_factor, time_decay,
|
||||
finite_vessel_space=True, has_early_discharge=True
|
||||
):
|
||||
super().__init__(env)
|
||||
self.port_attributes = port_attributes
|
||||
self.vessel_attributes = vessel_attributes
|
||||
self.action_space = action_space
|
||||
self.look_back = look_back
|
||||
self.max_ports_downstream = max_ports_downstream
|
||||
self.reward_time_window = reward_time_window
|
||||
self.fulfillment_factor = fulfillment_factor
|
||||
self.shortage_factor = shortage_factor
|
||||
self.time_decay = time_decay
|
||||
self.finite_vessel_space = finite_vessel_space
|
||||
self.has_early_discharge = has_early_discharge
|
||||
|
||||
def get_state(self, event):
|
||||
vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
|
||||
tick, port_idx, vessel_idx = event.tick, event.port_idx, event.vessel_idx
|
||||
ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)]
|
||||
future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
|
||||
port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes]
|
||||
vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes]
|
||||
return {port_idx: np.concatenate((port_features, vessel_features))}
|
||||
|
||||
def get_action(self, action_by_agent, event):
|
||||
vessel_snapshots = self.env.snapshot_list["vessels"]
|
||||
action_info = list(action_by_agent.values())[0]
|
||||
model_action = action_info[0] if isinstance(action_info, tuple) else action_info
|
||||
scope, tick, port, vessel = event.action_scope, event.tick, event.port_idx, event.vessel_idx
|
||||
zero_action_idx = len(self.action_space) / 2 # index corresponding to value zero.
|
||||
vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float("inf")
|
||||
early_discharge = vessel_snapshots[tick:vessel:"early_discharge"][0] if self.has_early_discharge else 0
|
||||
percent = abs(self.action_space[model_action])
|
||||
|
||||
if model_action < zero_action_idx:
|
||||
action_type = ActionType.LOAD
|
||||
actual_action = min(round(percent * scope.load), vessel_space)
|
||||
elif model_action > zero_action_idx:
|
||||
action_type = ActionType.DISCHARGE
|
||||
plan_action = percent * (scope.discharge + early_discharge) - early_discharge
|
||||
actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
|
||||
else:
|
||||
actual_action, action_type = 0, ActionType.LOAD
|
||||
|
||||
return {port: Action(vessel, port, actual_action, action_type)}
|
||||
|
||||
def get_offline_reward(self, event):
|
||||
port_snapshots = self.env.snapshot_list["ports"]
|
||||
start_tick = event.tick + 1
|
||||
ticks = list(range(start_tick, start_tick + self.reward_time_window))
|
||||
|
||||
future_fulfillment = port_snapshots[ticks::"fulfillment"]
|
||||
future_shortage = port_snapshots[ticks::"shortage"]
|
||||
decay_list = [
|
||||
self.time_decay ** i for i in range(self.reward_time_window)
|
||||
for _ in range(future_fulfillment.shape[0] // self.reward_time_window)
|
||||
]
|
||||
|
||||
tot_fulfillment = np.dot(future_fulfillment, decay_list)
|
||||
tot_shortage = np.dot(future_shortage, decay_list)
|
||||
|
||||
return np.float32(self.fulfillment_factor * tot_fulfillment - self.shortage_factor * tot_shortage)
|
||||
|
||||
def on_env_feedback(self, event, state_by_agent, action_by_agent, reward):
|
||||
self.trajectory["event"].append(event)
|
||||
self.trajectory["state"].append(state_by_agent)
|
||||
self.trajectory["action"].append(action_by_agent)
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .agent_config import agent_config
|
||||
from .training_config import training_config
|
||||
|
||||
__all__ = ["agent_config", "training_config"]
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import RMSprop
|
||||
|
||||
from maro.rl import DQN, DQNConfig, FullyConnectedBlock, OptimOption, PolicyGradient, SimpleMultiHeadModel
|
||||
|
||||
from examples.cim.common import common_config
|
||||
|
||||
input_dim = (
|
||||
(common_config["look_back"] + 1) *
|
||||
(common_config["max_ports_downstream"] + 1) *
|
||||
len(common_config["port_attributes"]) +
|
||||
len(common_config["vessel_attributes"])
|
||||
)
|
||||
|
||||
agent_config = {
|
||||
"model": {
|
||||
"input_dim": input_dim,
|
||||
"output_dim": len(common_config["action_space"]), # number of possible actions
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"skip_connection": False,
|
||||
"head": True,
|
||||
"dropout_p": 0.0
|
||||
},
|
||||
"optimization": OptimOption(optim_cls=RMSprop, optim_params={"lr": 0.05}),
|
||||
"hyper_params": {
|
||||
"reward_discount": .0,
|
||||
"loss_cls": nn.SmoothL1Loss,
|
||||
"target_update_freq": 5,
|
||||
"tau": 0.1,
|
||||
"double": False
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
training_config = {
|
||||
"env": {
|
||||
"scenario": "cim",
|
||||
"topology": "toy.4p_ssdd_l0.0",
|
||||
"durations": 1120,
|
||||
},
|
||||
"max_episode": 100,
|
||||
"exploration": {
|
||||
"parameter_names": ["epsilon"],
|
||||
"split": 0.5,
|
||||
"start": 0.4,
|
||||
"mid": 0.32,
|
||||
"end": 0.0
|
||||
},
|
||||
"training": {
|
||||
"min_experiences_to_train": 1024,
|
||||
"train_iter": 10,
|
||||
"batch_size": 128,
|
||||
"prioritized_sampling_by_loss": True
|
||||
},
|
||||
"group": "cim-dqn",
|
||||
"learner_update_trigger": 2,
|
||||
"num_actors": 2
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from collections import defaultdict
|
||||
from multiprocessing import Process
|
||||
from os import makedirs
|
||||
from os.path import dirname, join, realpath
|
||||
|
||||
from maro.rl import (
|
||||
Actor, ActorProxy, DQN, DQNConfig, FullyConnectedBlock, MultiAgentWrapper, OffPolicyLearner,
|
||||
SimpleMultiHeadModel, TwoPhaseLinearParameterScheduler
|
||||
)
|
||||
from maro.simulator import Env
|
||||
from maro.utils import set_seeds
|
||||
|
||||
cim_dqn_path = dirname(realpath(__file__))
|
||||
cim_example_path = dirname(cim_dqn_path)
|
||||
sys.path.insert(0, cim_example_path)
|
||||
|
||||
from common import CIMTrajectory, common_config
|
||||
from dqn.config import agent_config, training_config
|
||||
|
||||
log_dir = join(cim_dqn_path, "log")
|
||||
makedirs(log_dir, exist_ok=True)
|
||||
|
||||
|
||||
def get_dqn_agent():
|
||||
q_model = SimpleMultiHeadModel(
|
||||
FullyConnectedBlock(**agent_config["model"]), optim_option=agent_config["optimization"]
|
||||
)
|
||||
return DQN(q_model, DQNConfig(**agent_config["hyper_params"]))
|
||||
|
||||
|
||||
class CIMTrajectoryForDQN(CIMTrajectory):
|
||||
def on_finish(self):
|
||||
exp_by_agent = defaultdict(lambda: defaultdict(list))
|
||||
for i in range(len(self.trajectory["state"]) - 1):
|
||||
agent_id = list(self.trajectory["state"][i].keys())[0]
|
||||
exp = exp_by_agent[agent_id]
|
||||
exp["S"].append(self.trajectory["state"][i][agent_id])
|
||||
exp["A"].append(self.trajectory["action"][i][agent_id])
|
||||
exp["R"].append(self.get_offline_reward(self.trajectory["event"][i]))
|
||||
exp["S_"].append(list(self.trajectory["state"][i + 1].values())[0])
|
||||
|
||||
return dict(exp_by_agent)
|
||||
|
||||
|
||||
def cim_dqn_learner():
|
||||
env = Env(**training_config["env"])
|
||||
agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list})
|
||||
scheduler = TwoPhaseLinearParameterScheduler(training_config["max_episode"], **training_config["exploration"])
|
||||
actor = ActorProxy(
|
||||
training_config["group"], training_config["num_actors"],
|
||||
update_trigger=training_config["learner_update_trigger"],
|
||||
log_dir=log_dir
|
||||
)
|
||||
learner = OffPolicyLearner(actor, scheduler, agent, **training_config["training"], log_dir=log_dir)
|
||||
learner.run()
|
||||
|
||||
|
||||
def cim_dqn_actor():
|
||||
env = Env(**training_config["env"])
|
||||
agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list})
|
||||
actor = Actor(env, agent, CIMTrajectoryForDQN, trajectory_kwargs=common_config)
|
||||
actor.as_worker(training_config["group"], log_dir=log_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-w", "--whoami", type=int, choices=[0, 1, 2], default=0,
|
||||
help="Identity of this process: 0 - multi-process mode, 1 - learner, 2 - actor"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.whoami == 0:
|
||||
actor_processes = [Process(target=cim_dqn_actor) for _ in range(training_config["num_actors"])]
|
||||
learner_process = Process(target=cim_dqn_learner)
|
||||
|
||||
for i, actor_process in enumerate(actor_processes):
|
||||
set_seeds(i) # this is to ensure that the actors explore differently.
|
||||
actor_process.start()
|
||||
|
||||
learner_process.start()
|
||||
|
||||
for actor_process in actor_processes:
|
||||
actor_process.join()
|
||||
|
||||
learner_process.join()
|
||||
elif args.whoami == 1:
|
||||
cim_dqn_learner()
|
||||
elif args.whoami == 2:
|
||||
cim_dqn_actor()
|
|
@ -0,0 +1,9 @@
|
|||
# Container Inventory Management
|
||||
|
||||
This example demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time. In this folder you can find:
|
||||
* ``__init__.py``, the entrance of this example. You must expose a `rl_component_bundle_cls` interface in `__init__.py` (see the example file for details);
|
||||
* ``config.py``, which contains general configurations for the scenario;
|
||||
* ``algorithms/``, which contains configurations for the PPO, Actor-Critic, DQN and discrete-MADDPG algorithms, including network configurations;
|
||||
* ``rl_componenet_bundle.py``, which defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read `CIMBundle` to learn its basic usage.
|
||||
|
||||
We recommend that you follow this example to write your own scenarios.
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .rl_component_bundle import CIMBundle as rl_component_bundle_cls
|
||||
|
||||
__all__ = [
|
||||
"rl_component_bundle_cls",
|
||||
]
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
|
||||
from maro.rl.policy import DiscretePolicyGradient
|
||||
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": torch.nn.Tanh,
|
||||
"softmax": True,
|
||||
"batch_norm": False,
|
||||
"head": True,
|
||||
}
|
||||
critic_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"output_dim": 1,
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"head": True,
|
||||
}
|
||||
actor_learning_rate = 0.001
|
||||
critic_learning_rate = 0.001
|
||||
|
||||
|
||||
class MyActorNet(DiscreteACBasedNet):
|
||||
def __init__(self, state_dim: int, action_num: int) -> None:
|
||||
super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
|
||||
self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
|
||||
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)
|
||||
|
||||
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._actor(states)
|
||||
|
||||
|
||||
class MyCriticNet(VNet):
|
||||
def __init__(self, state_dim: int) -> None:
|
||||
super(MyCriticNet, self).__init__(state_dim=state_dim)
|
||||
self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf)
|
||||
self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate)
|
||||
|
||||
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._critic(states).squeeze(-1)
|
||||
|
||||
|
||||
def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
|
||||
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))
|
||||
|
||||
|
||||
def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
|
||||
return ActorCriticTrainer(
|
||||
name=name,
|
||||
params=ActorCriticParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
|
||||
reward_discount=.0,
|
||||
grad_iters=10,
|
||||
critic_loss_cls=torch.nn.SmoothL1Loss,
|
||||
min_logp=None,
|
||||
lam=.0,
|
||||
),
|
||||
)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.optim import RMSprop
|
||||
|
||||
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
|
||||
from maro.rl.model import DiscreteQNet, FullyConnected
|
||||
from maro.rl.policy import ValueBasedPolicy
|
||||
from maro.rl.training.algorithms import DQNTrainer, DQNParams
|
||||
|
||||
q_net_conf = {
|
||||
"hidden_dims": [256, 128, 64, 32],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"skip_connection": False,
|
||||
"head": True,
|
||||
"dropout_p": 0.0,
|
||||
}
|
||||
learning_rate = 0.05
|
||||
|
||||
|
||||
class MyQNet(DiscreteQNet):
|
||||
def __init__(self, state_dim: int, action_num: int) -> None:
|
||||
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
|
||||
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
|
||||
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)
|
||||
|
||||
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._fc(states)
|
||||
|
||||
|
||||
def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
|
||||
return ValueBasedPolicy(
|
||||
name=name,
|
||||
q_net=MyQNet(state_dim, action_num),
|
||||
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
|
||||
exploration_scheduling_options=[(
|
||||
"epsilon", MultiLinearExplorationScheduler, {
|
||||
"splits": [(2, 0.32)],
|
||||
"initial_value": 0.4,
|
||||
"last_ep": 5,
|
||||
"final_value": 0.0,
|
||||
}
|
||||
)],
|
||||
warmup=100,
|
||||
)
|
||||
|
||||
|
||||
def get_dqn(name: str) -> DQNTrainer:
|
||||
return DQNTrainer(
|
||||
name=name,
|
||||
params=DQNParams(
|
||||
reward_discount=.0,
|
||||
update_target_every=5,
|
||||
num_epochs=10,
|
||||
soft_update_coef=0.1,
|
||||
double=False,
|
||||
replay_memory_capacity=10000,
|
||||
random_overwrite=False,
|
||||
batch_size=32,
|
||||
),
|
||||
)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from maro.rl.model import DiscreteACBasedNet, FullyConnected, MultiQNet
|
||||
from maro.rl.policy import DiscretePolicyGradient
|
||||
from maro.rl.training.algorithms import DiscreteMADDPGTrainer, DiscreteMADDPGParams
|
||||
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": torch.nn.Tanh,
|
||||
"softmax": True,
|
||||
"batch_norm": False,
|
||||
"head": True
|
||||
}
|
||||
critic_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"output_dim": 1,
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"head": True
|
||||
}
|
||||
actor_learning_rate = 0.001
|
||||
critic_learning_rate = 0.001
|
||||
|
||||
|
||||
# #####################################################################################################################
|
||||
class MyActorNet(DiscreteACBasedNet):
|
||||
def __init__(self, state_dim: int, action_num: int) -> None:
|
||||
super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
|
||||
self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
|
||||
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)
|
||||
|
||||
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
return self._actor(states)
|
||||
|
||||
|
||||
class MyMultiCriticNet(MultiQNet):
|
||||
def __init__(self, state_dim: int, action_dims: List[int]) -> None:
|
||||
super(MyMultiCriticNet, self).__init__(state_dim=state_dim, action_dims=action_dims)
|
||||
self._critic = FullyConnected(input_dim=state_dim + sum(action_dims), **critic_net_conf)
|
||||
self._optim = RMSprop(self._critic.parameters(), critic_learning_rate)
|
||||
|
||||
def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor:
|
||||
return self._critic(torch.cat([states] + actions, dim=1)).squeeze(-1)
|
||||
|
||||
|
||||
def get_multi_critic_net(state_dim: int, action_dims: List[int]) -> MyMultiCriticNet:
|
||||
return MyMultiCriticNet(state_dim, action_dims)
|
||||
|
||||
|
||||
def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
|
||||
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))
|
||||
|
||||
|
||||
def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer:
|
||||
return DiscreteMADDPGTrainer(
|
||||
name=name,
|
||||
params=DiscreteMADDPGParams(
|
||||
reward_discount=.0,
|
||||
num_epoch=10,
|
||||
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
|
||||
shared_critic=False
|
||||
)
|
||||
)
|
|
@ -0,0 +1,25 @@
|
|||
import torch
|
||||
|
||||
from maro.rl.policy import DiscretePolicyGradient
|
||||
from maro.rl.training.algorithms import PPOParams, PPOTrainer
|
||||
|
||||
from .ac import MyActorNet, MyCriticNet
|
||||
|
||||
|
||||
def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
|
||||
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))
|
||||
|
||||
|
||||
def get_ppo(state_dim: int, name: str) -> PPOTrainer:
|
||||
return PPOTrainer(
|
||||
name=name,
|
||||
params=PPOParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
|
||||
reward_discount=.0,
|
||||
grad_iters=10,
|
||||
critic_loss_cls=torch.nn.SmoothL1Loss,
|
||||
min_logp=None,
|
||||
lam=.0,
|
||||
clip_ratio=0.1,
|
||||
),
|
||||
)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
env_conf = {
|
||||
"scenario": "cim",
|
||||
"topology": "toy.4p_ssdd_l0.0",
|
||||
"durations": 560
|
||||
}
|
||||
|
||||
if env_conf["topology"].startswith("toy"):
|
||||
num_agents = int(env_conf["topology"].split(".")[1][0])
|
||||
else:
|
||||
num_agents = int(env_conf["topology"].split(".")[1][:2])
|
||||
|
||||
port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
|
||||
vessel_attributes = ["empty", "full", "remaining_space"]
|
||||
|
||||
state_shaping_conf = {
|
||||
"look_back": 7,
|
||||
"max_ports_downstream": 2
|
||||
}
|
||||
|
||||
action_shaping_conf = {
|
||||
"action_space": [(i - 10) / 10 for i in range(21)],
|
||||
"finite_vessel_space": True,
|
||||
"has_early_discharge": True
|
||||
}
|
||||
|
||||
reward_shaping_conf = {
|
||||
"time_window": 99,
|
||||
"fulfillment_factor": 1.0,
|
||||
"shortage_factor": 1.0,
|
||||
"time_decay": 0.97
|
||||
}
|
||||
|
||||
# obtain state dimension from a temporary env_wrapper instance
|
||||
state_dim = (
|
||||
(state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
|
||||
+ len(vessel_attributes)
|
||||
)
|
||||
|
||||
action_num = len(action_shaping_conf["action_space"])
|
||||
|
||||
algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.rl.rollout import AbsEnvSampler, CacheElement
|
||||
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent
|
||||
|
||||
from .config import (
|
||||
action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf,
|
||||
vessel_attributes,
|
||||
)
|
||||
|
||||
|
||||
class CIMEnvSampler(AbsEnvSampler):
|
||||
def _get_global_and_agent_state_impl(
|
||||
self, event: DecisionEvent, tick: int = None,
|
||||
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
|
||||
tick = self._env.tick
|
||||
vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
|
||||
port_idx, vessel_idx = event.port_idx, event.vessel_idx
|
||||
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
|
||||
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
|
||||
state = np.concatenate([
|
||||
port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes],
|
||||
vessel_snapshots[tick: vessel_idx: vessel_attributes]
|
||||
])
|
||||
return state, {port_idx: state}
|
||||
|
||||
def _translate_to_env_action(
|
||||
self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionEvent,
|
||||
) -> Dict[Any, object]:
|
||||
action_space = action_shaping_conf["action_space"]
|
||||
finite_vsl_space = action_shaping_conf["finite_vessel_space"]
|
||||
has_early_discharge = action_shaping_conf["has_early_discharge"]
|
||||
|
||||
port_idx, model_action = list(action_dict.items()).pop()
|
||||
|
||||
vsl_idx, action_scope = event.vessel_idx, event.action_scope
|
||||
vsl_snapshots = self._env.snapshot_list["vessels"]
|
||||
vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")
|
||||
|
||||
percent = abs(action_space[model_action[0]])
|
||||
zero_action_idx = len(action_space) / 2 # index corresponding to value zero.
|
||||
if model_action < zero_action_idx:
|
||||
action_type = ActionType.LOAD
|
||||
actual_action = min(round(percent * action_scope.load), vsl_space)
|
||||
elif model_action > zero_action_idx:
|
||||
action_type = ActionType.DISCHARGE
|
||||
early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
|
||||
plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
|
||||
actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
|
||||
else:
|
||||
actual_action, action_type = 0, None
|
||||
|
||||
return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)}
|
||||
|
||||
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:
|
||||
start_tick = tick + 1
|
||||
ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))
|
||||
|
||||
# Get the ports that took actions at the given tick
|
||||
ports = [int(port) for port in list(env_action_dict.keys())]
|
||||
port_snapshots = self._env.snapshot_list["ports"]
|
||||
future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
|
||||
future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)
|
||||
|
||||
decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
|
||||
rewards = np.float32(
|
||||
reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
|
||||
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
|
||||
)
|
||||
return {agent_id: reward for agent_id, reward in zip(ports, rewards)}
|
||||
|
||||
def _post_step(self, cache_element: CacheElement) -> None:
|
||||
self._info["env_metric"] = self._env.metrics
|
||||
|
||||
def _post_eval_step(self, cache_element: CacheElement) -> None:
|
||||
self._post_step(cache_element)
|
||||
|
||||
def post_collect(self, info_list: list, ep: int) -> None:
|
||||
# print the env metric from each rollout worker
|
||||
for info in info_list:
|
||||
print(f"env summary (episode {ep}): {info['env_metric']}")
|
||||
|
||||
# print the average env metric
|
||||
if len(info_list) > 1:
|
||||
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
|
||||
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
|
||||
print(f"average env summary (episode {ep}): {avg_metric}")
|
||||
|
||||
def post_evaluate(self, info_list: list, ep: int) -> None:
|
||||
self.post_collect(info_list, ep)
|
|
@ -0,0 +1,84 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
|
||||
from examples.cim.rl.env_sampler import CIMEnvSampler
|
||||
from maro.rl.policy import AbsPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import AbsEnvSampler
|
||||
from maro.rl.training import AbsTrainer
|
||||
|
||||
from .algorithms.ac import get_ac_policy
|
||||
from .algorithms.dqn import get_dqn_policy
|
||||
from .algorithms.maddpg import get_maddpg_policy
|
||||
from .algorithms.ppo import get_ppo_policy
|
||||
from .algorithms.ac import get_ac
|
||||
from .algorithms.ppo import get_ppo
|
||||
from .algorithms.dqn import get_dqn
|
||||
from .algorithms.maddpg import get_maddpg
|
||||
|
||||
|
||||
class CIMBundle(RLComponentBundle):
|
||||
def get_env_config(self) -> dict:
|
||||
return env_conf
|
||||
|
||||
def get_test_env_config(self) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
def get_env_sampler(self) -> AbsEnvSampler:
|
||||
return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"])
|
||||
|
||||
def get_agent2policy(self) -> Dict[Any, str]:
|
||||
return {agent: f"{algorithm}_{agent}.policy"for agent in self.env.agent_idx_list}
|
||||
|
||||
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
|
||||
if algorithm == "ac":
|
||||
policy_creator = {
|
||||
f"{algorithm}_{i}.policy": partial(get_ac_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "ppo":
|
||||
policy_creator = {
|
||||
f"{algorithm}_{i}.policy": partial(get_ppo_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "dqn":
|
||||
policy_creator = {
|
||||
f"{algorithm}_{i}.policy": partial(get_dqn_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "discrete_maddpg":
|
||||
policy_creator = {
|
||||
f"{algorithm}_{i}.policy": partial(get_maddpg_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return policy_creator
|
||||
|
||||
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
|
||||
if algorithm == "ac":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "ppo":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "dqn":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "discrete_maddpg":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return trainer_creator
|
|
@ -99,7 +99,7 @@ demand is 34 (at a specific station, during a time interval of 20 minutes), the
|
|||
corresponding demand distribution shows that demand exceeding 10 bikes per time
|
||||
interval (20 minutes) is only 2%.
|
||||
|
||||
![Demand Distribution Between Tick 2400 ~ Tick 2519](./LogDemand.ny201910.2400.png)
|
||||
![Demand Distribution Between Tick 2400 ~ Tick 2519](LogDemand.ny201910.2400.png)
|
||||
|
||||
Besides, we can also find that the percentage of forecasting results that differ
|
||||
to the data extracted from trip log is not low. To dive deeper in the practical
|
||||
|
@ -110,9 +110,9 @@ show the distribution of the forecasting difference to the trip log. One for the
|
|||
interval with the *Max Diff* (16:00-18:00), one for the interval with the highest
|
||||
percentage of *Diff > 5* (10:00-12:00).
|
||||
|
||||
![Demand Distribution Between Tick 2400 ~ Tick 2519](./DemandDiff.ny201910.2400.png)
|
||||
![Demand Distribution Between Tick 2400 ~ Tick 2519](DemandDiff.ny201910.2400.png)
|
||||
|
||||
![Demand Distribution Between Tick 2040 ~ Tick 2159](./DemandDiff.ny201910.2040.png)
|
||||
![Demand Distribution Between Tick 2040 ~ Tick 2159](DemandDiff.ny201910.2040.png)
|
||||
|
||||
Maybe due to the *sparse* and *small* trip demand, and the *small* difference
|
||||
between the forecasting results and data extracted from the trip log data, the
|
||||
|
|
|
@ -75,10 +75,10 @@ class MaIlpAgent():
|
|||
event_type = finished_events[self._next_event_idx].event_type
|
||||
if event_type == CitiBikeEvents.RequireBike:
|
||||
# TODO: Replace it with a pre-defined PayLoad.
|
||||
payload = finished_events[self._next_event_idx].payload
|
||||
payload = finished_events[self._next_event_idx].body
|
||||
demand_history[interval_idx, payload.src_station] += 1
|
||||
elif event_type == CitiBikeEvents.ReturnBike:
|
||||
payload: BikeReturnPayload = finished_events[self._next_event_idx].payload
|
||||
payload: BikeReturnPayload = finished_events[self._next_event_idx].body
|
||||
supply_history[interval_idx, payload.to_station_idx] += payload.number
|
||||
|
||||
# Update the index to the finished event that has not been processed.
|
||||
|
@ -129,7 +129,7 @@ class MaIlpAgent():
|
|||
# Process to get the future supply from Pending Events.
|
||||
for pending_event in ENV.get_pending_events(tick=tick):
|
||||
if pending_event.event_type == CitiBikeEvents.ReturnBike:
|
||||
payload: BikeReturnPayload = pending_event.payload
|
||||
payload: BikeReturnPayload = pending_event.body
|
||||
supply[interval_idx, payload.to_station_idx] += payload.number
|
||||
|
||||
return demand, supply
|
||||
|
|
|
@ -21,13 +21,13 @@ def worker(group_name):
|
|||
print(f"{proxy.name}'s counter is {counter}.")
|
||||
|
||||
# Nonrecurring receive the message from the proxy.
|
||||
for msg in proxy.receive(is_continuous=False):
|
||||
print(f"{proxy.name} receive message from {msg.source}.")
|
||||
msg = proxy.receive_once()
|
||||
print(f"{proxy.name} received message from {msg.source}.")
|
||||
|
||||
if msg.tag == "INC":
|
||||
counter += 1
|
||||
print(f"{proxy.name} receive INC request, {proxy.name}'s count is {counter}.")
|
||||
proxy.reply(message=msg, tag="done")
|
||||
if msg.tag == "INC":
|
||||
counter += 1
|
||||
print(f"{proxy.name} receive INC request, {proxy.name}'s count is {counter}.")
|
||||
proxy.reply(message=msg, tag="done")
|
||||
|
||||
|
||||
def master(group_name: str, worker_num: int, is_immediate: bool = False):
|
||||
|
|
|
@ -21,12 +21,12 @@ def summation_worker(group_name):
|
|||
expected_peers={"master": 1})
|
||||
|
||||
# Nonrecurring receive the message from the proxy.
|
||||
for msg in proxy.receive(is_continuous=False):
|
||||
print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.")
|
||||
msg = proxy.receive_once()
|
||||
print(f"{proxy.name} received message from {msg.source}. the payload is {msg.body}.")
|
||||
|
||||
if msg.tag == "job":
|
||||
replied_payload = sum(msg.payload)
|
||||
proxy.reply(message=msg, tag="sum", payload=replied_payload)
|
||||
if msg.tag == "job":
|
||||
replied_payload = sum(msg.body)
|
||||
proxy.reply(message=msg, tag="sum", body=replied_payload)
|
||||
|
||||
|
||||
def multiplication_worker(group_name):
|
||||
|
@ -41,12 +41,12 @@ def multiplication_worker(group_name):
|
|||
expected_peers={"master": 1})
|
||||
|
||||
# Nonrecurring receive the message from the proxy.
|
||||
for msg in proxy.receive(is_continuous=False):
|
||||
print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.")
|
||||
msg = proxy.receive_once()
|
||||
print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.body}.")
|
||||
|
||||
if msg.tag == "job":
|
||||
replied_payload = np.prod(msg.payload)
|
||||
proxy.reply(message=msg, tag="multiply", payload=replied_payload)
|
||||
if msg.tag == "job":
|
||||
replied_payload = np.prod(msg.body)
|
||||
proxy.reply(message=msg, tag="multiply", body=replied_payload)
|
||||
|
||||
|
||||
def master(group_name: str, sum_worker_number: int, multiply_worker_number: int, is_immediate: bool = False):
|
||||
|
@ -73,13 +73,13 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
|
|||
|
||||
# Assign sum tasks for summation workers.
|
||||
destination_payload_list = []
|
||||
for idx, peer in enumerate(proxy.peers_name["sum_worker"]):
|
||||
data_length_per_peer = int(len(sum_list) / len(proxy.peers_name["sum_worker"]))
|
||||
for idx, peer in enumerate(proxy.peers["sum_worker"]):
|
||||
data_length_per_peer = int(len(sum_list) / len(proxy.peers["sum_worker"]))
|
||||
destination_payload_list.append((peer, sum_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer]))
|
||||
|
||||
# Assign multiply tasks for multiplication workers.
|
||||
for idx, peer in enumerate(proxy.peers_name["multiply_worker"]):
|
||||
data_length_per_peer = int(len(multiple_list) / len(proxy.peers_name["multiply_worker"]))
|
||||
for idx, peer in enumerate(proxy.peers["multiply_worker"]):
|
||||
data_length_per_peer = int(len(multiple_list) / len(proxy.peers["multiply_worker"]))
|
||||
destination_payload_list.append(
|
||||
(peer, multiple_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer]))
|
||||
|
||||
|
@ -98,11 +98,11 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
|
|||
sum_result, multiply_result = 0, 1
|
||||
for msg in replied_msgs:
|
||||
if msg.tag == "sum":
|
||||
print(f"{proxy.name} receive message from {msg.source} with the sum result {msg.payload}.")
|
||||
sum_result += msg.payload
|
||||
print(f"{proxy.name} receive message from {msg.source} with the sum result {msg.body}.")
|
||||
sum_result += msg.body
|
||||
elif msg.tag == "multiply":
|
||||
print(f"{proxy.name} receive message from {msg.source} with the multiply result {msg.payload}.")
|
||||
multiply_result *= msg.payload
|
||||
print(f"{proxy.name} receive message from {msg.source} with the multiply result {msg.body}.")
|
||||
multiply_result *= msg.body
|
||||
|
||||
# Check task result correction.
|
||||
assert(sum(sum_list) == sum_result)
|
||||
|
|
|
@ -21,12 +21,12 @@ def worker(group_name):
|
|||
expected_peers={"master": 1})
|
||||
|
||||
# Nonrecurring receive the message from the proxy.
|
||||
for msg in proxy.receive(is_continuous=False):
|
||||
print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.")
|
||||
msg = proxy.receive_once()
|
||||
print(f"{proxy.name} received message from {msg.source}. the payload is {msg.body}.")
|
||||
|
||||
if msg.tag == "sum":
|
||||
replied_payload = sum(msg.payload)
|
||||
proxy.reply(message=msg, tag="sum", payload=replied_payload)
|
||||
if msg.tag == "sum":
|
||||
replied_payload = sum(msg.body)
|
||||
proxy.reply(message=msg, tag="sum", body=replied_payload)
|
||||
|
||||
|
||||
def master(group_name: str, is_immediate: bool = False):
|
||||
|
@ -47,11 +47,11 @@ def master(group_name: str, is_immediate: bool = False):
|
|||
random_integer_list = np.random.randint(0, 100, 5)
|
||||
print(f"generate random integer list: {random_integer_list}.")
|
||||
|
||||
for peer in proxy.peers_name["worker"]:
|
||||
for peer in proxy.peers["worker"]:
|
||||
message = SessionMessage(tag="sum",
|
||||
source=proxy.name,
|
||||
destination=peer,
|
||||
payload=random_integer_list,
|
||||
body=random_integer_list,
|
||||
session_type=SessionType.TASK)
|
||||
if is_immediate:
|
||||
session_id = proxy.isend(message)
|
||||
|
@ -61,7 +61,7 @@ def master(group_name: str, is_immediate: bool = False):
|
|||
replied_msgs = proxy.send(message, timeout=-1)
|
||||
|
||||
for msg in replied_msgs:
|
||||
print(f"{proxy.name} receive {msg.source}, replied payload is {msg.payload}.")
|
||||
print(f"{proxy.name} receive {msg.source}, replied payload is {msg.body}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Reinforcement Learning (RL) Examples
|
||||
|
||||
This folder contains scenarios that employ reinforcement learning. MARO's RL toolkit provides scenario-agnostic workflows to run a variety of scenarios in single-thread, multi-process or distributed modes.
|
||||
|
||||
## How to Run
|
||||
|
||||
The entrance of a RL workflow is a YAML config file. For readers' convenience, we call this config file `config.yml` in the rest part of this doc. `config.yml` specifies the path of all necessary resources, definitions, and configurations to run the job. MARO provides a comprehensive template of the config file with detailed explanations (`maro/maro/rl/workflows/config/template.yml`). Meanwhile, MARO also provides several simple examples of `config.yml` under the current folder.
|
||||
|
||||
There are two ways to start the RL job:
|
||||
- If you only need to have a quick look and try to start an out-of-box workflow, just run `python .\examples\rl\run_rl_example.py PATH_TO_CONFIG_YAML`. For example, `python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml` will run the complete example RL training workflow of CIM scenario. If you only want to run the evaluation workflow, you could start the job with `--evaluate_only`.
|
||||
- (**Require install MARO from source**) You could also start the job through MARO CLI. Use the command `maro local run [-c] path/to/your/config` to run in containerized (with `-c`) or non-containerized (without `-c`) environments. Similar, you could add `--evaluate_only` if you only need to run the evaluation workflow.
|
||||
|
||||
## Create Your Own Scenarios
|
||||
|
||||
You can create your own scenarios by supplying the necessary ingredients without worrying about putting them together in a workflow. It is necessary to create an ``__init__.py`` under your scenario folder (so that it can be treated as a package) and expose a `rl_component_bundle_cls` interface. The MARO's RL workflow will use this interface to create a `RLComponentBundle` instance and start the RL workflow based on it. a `RLComponentBundle` instance defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read one of the examples to learn its basic usage.
|
||||
|
||||
## Example
|
||||
|
||||
For a complete example, please check `examples/cim/rl`.
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for CIM scenario.
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
# Run this workflow by executing one of the following commands:
|
||||
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
|
||||
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml
|
||||
|
||||
job: cim_rl_workflow
|
||||
scenario_path: "examples/cim/rl"
|
||||
log_path: "log/rl_job/cim.txt"
|
||||
main:
|
||||
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
|
||||
num_steps: null
|
||||
eval_schedule: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
rollout:
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
training:
|
||||
mode: simple
|
||||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: "checkpoint/rl_job/cim"
|
||||
interval: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
|
@ -0,0 +1,15 @@
|
|||
import argparse
|
||||
|
||||
from maro.cli.local.commands import run
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("conf_path", help='Path of the job deployment')
|
||||
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
run(conf_path=args.conf_path, containerize=False, evaluate_only=args.evaluate_only)
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Example RL config file for VM scheduling scenario.
|
||||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
|
||||
|
||||
# Run this workflow by executing one of the following commands:
|
||||
# - python .\examples\rl\run_rl_example.py .\examples\rl\vm_scheduling.yml
|
||||
# - (Requires installing MARO from source) maro local run .\examples\rl\vm_scheduling.yml
|
||||
|
||||
job: vm_scheduling_rl_workflow
|
||||
scenario_path: "examples/vm_scheduling/rl"
|
||||
log_path: "log/rl_job/vm_scheduling.txt"
|
||||
main:
|
||||
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
|
||||
num_steps: null
|
||||
eval_schedule: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
rollout:
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
||||
training:
|
||||
mode: simple
|
||||
load_path: null
|
||||
load_episode: null
|
||||
checkpointing:
|
||||
path: "checkpoint/rl_job/vm_scheduling"
|
||||
interval: 5
|
||||
logging:
|
||||
stdout: INFO
|
||||
file: DEBUG
|
|
@ -22,14 +22,11 @@ with io.open(CONFIG_PATH, "r") as in_file:
|
|||
config = convert_dottable(raw_config)
|
||||
|
||||
LOG_PATH = os.path.join(FILE_PATH, "log", config.experiment_name)
|
||||
if not os.path.exists(LOG_PATH):
|
||||
os.makedirs(LOG_PATH)
|
||||
simulation_logger = Logger(tag="simulation", format_=LogFormat.none, dump_folder=LOG_PATH, dump_mode="w", auto_timestamp=False)
|
||||
ilp_logger = Logger(tag="ilp", format_=LogFormat.none, dump_folder=LOG_PATH, dump_mode="w", auto_timestamp=False)
|
||||
simulation_logger = Logger(tag="simulation", format_=LogFormat.none, dump_path=LOG_PATH, dump_mode="w")
|
||||
ilp_logger = Logger(tag="ilp", format_=LogFormat.none, dump_path=LOG_PATH, dump_mode="w")
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_time = timeit.default_timer()
|
||||
|
||||
env = Env(
|
||||
scenario=config.env.scenario,
|
||||
topology=config.env.topology,
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Virtual Machine Scheduling
|
||||
|
||||
A virtual machine (VM) scheduler is a cloud computing service component responsible for providing compute resources to satisfy user demands. A good resource allocation policy should aim to optimize several metrics at the same time, such as user wait time, profit, energy consumption and physical machine (PM) overload. Many commercial cloud providers use rule-based policies. Alternatively, the policy can also be optimized using reinforcement learning (RL) techniques, which involves simulating with historical data. This example demonstrates how DQN and Actor-Critic algorithms can be applied to this scenario. In this folder, you can find:
|
||||
|
||||
* ``__init__.py``, the entrance of this example. You must expose a `rl_component_bundle_cls` interface in `__init__.py` (see the example file for details);
|
||||
* ``config.py``, which contains general configurations for the scenario;
|
||||
* ``algorithms/``, which contains configurations for the algorithms, including network configurations;
|
||||
* ``rl_componenet_bundle.py``, which defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read `VMBundle` to learn its basic usage.
|
||||
|
||||
We recommend that you follow this example to write your own scenarios.
|
||||
|
||||
|
||||
# Some Comments About the Results
|
||||
|
||||
This example is meant to serve as a demonstration of using MARO's RL toolkit in a real-life scenario. In fact, we have yet to find a configuration that makes the policy learned by either DQN or Actor-Critic perform reasonably well in our experimental settings.
|
||||
|
||||
For reference, the best results have been achieved by the ``Best Fit`` algorithm (see ``examples/vm_scheduling/rule_based_algorithm/best_fit.py`` for details). The over-subscription rate is 115% in the over-subscription settings.
|
||||
|
||||
|Topology | PM Setting | Time Spent(s) | Total VM Requests |Successful Allocation| Energy Consumption| Total Oversubscriptions | Total Overload PMs
|
||||
|:----:|-----|:--------:|:---:|:-------:|:----:|:---:|:---:|
|
||||
|10k| 100 PMs, 32 Cores, 128 GB | 104.98|10,000| 10,000| 2,399,610 | 0 | 0|
|
||||
|10k.oversubscription| 100 PMs, 32 Cores, 128 GB| 101.00 |10,000 |10,000| 2,386,371| 279,331 | 0|
|
||||
|336k| 880 PMs, 16 Cores, 112 GB | 7,896.37 |335,985| 109,249 |26,425,878 | 0 | 0 |
|
||||
|336k.oversubscription| 880 PMs, 16 Cores, 112 GB | 7,903.33| 335,985| 115,008 | 27,440,946 | 3,868,475 | 0
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .rl_component_bundle import VMBundle as rl_component_bundle_cls
|
||||
|
||||
__all__ = [
|
||||
"rl_component_bundle_cls",
|
||||
]
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, SGD
|
||||
|
||||
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
|
||||
from maro.rl.policy import DiscretePolicyGradient
|
||||
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
|
||||
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [64, 32, 32],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"softmax": True,
|
||||
"batch_norm": False,
|
||||
"head": True,
|
||||
}
|
||||
|
||||
critic_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": False,
|
||||
"head": True,
|
||||
}
|
||||
|
||||
actor_learning_rate = 0.0001
|
||||
critic_learning_rate = 0.001
|
||||
|
||||
|
||||
class MyActorNet(DiscreteACBasedNet):
|
||||
def __init__(self, state_dim: int, action_num: int, num_features: int) -> None:
|
||||
super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
|
||||
self._num_features = num_features
|
||||
self._actor = FullyConnected(input_dim=num_features, output_dim=action_num, **actor_net_conf)
|
||||
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)
|
||||
|
||||
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
features, masks = states[:, :self._num_features], states[:, self._num_features:]
|
||||
masks += 1e-8 # this is to prevent zero probability and infinite logP.
|
||||
return self._actor(features) * masks
|
||||
|
||||
|
||||
class MyCriticNet(VNet):
|
||||
def __init__(self, state_dim: int, num_features: int) -> None:
|
||||
super(MyCriticNet, self).__init__(state_dim=state_dim)
|
||||
self._num_features = num_features
|
||||
self._critic = FullyConnected(input_dim=num_features, output_dim=1, **critic_net_conf)
|
||||
self._optim = SGD(self._critic.parameters(), lr=critic_learning_rate)
|
||||
|
||||
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
|
||||
features, masks = states[:, :self._num_features], states[:, self._num_features:]
|
||||
masks += 1e-8 # this is to prevent zero probability and infinite logP.
|
||||
return self._critic(features).squeeze(-1)
|
||||
|
||||
|
||||
def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str) -> DiscretePolicyGradient:
|
||||
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num, num_features))
|
||||
|
||||
|
||||
def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer:
|
||||
return ActorCriticTrainer(
|
||||
name=name,
|
||||
params=ActorCriticParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features),
|
||||
reward_discount=0.9,
|
||||
grad_iters=100,
|
||||
critic_loss_cls=torch.nn.MSELoss,
|
||||
min_logp=-20,
|
||||
lam=.0,
|
||||
),
|
||||
)
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
|
||||
from maro.rl.exploration import MultiLinearExplorationScheduler
|
||||
from maro.rl.model import DiscreteQNet, FullyConnected
|
||||
from maro.rl.policy import ValueBasedPolicy
|
||||
from maro.rl.training.algorithms import DQNParams, DQNTrainer
|
||||
|
||||
q_net_conf = {
|
||||
"hidden_dims": [64, 128, 256],
|
||||
"activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": False,
|
||||
"skip_connection": False,
|
||||
"head": True,
|
||||
"dropout_p": 0.0,
|
||||
}
|
||||
q_net_learning_rate = 0.0005
|
||||
q_net_lr_scheduler_params = {"T_0": 500, "T_mult": 2}
|
||||
|
||||
|
||||
class MyQNet(DiscreteQNet):
|
||||
def __init__(self, state_dim: int, action_num: int, num_features: int) -> None:
|
||||
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
|
||||
self._num_features = num_features
|
||||
self._fc = FullyConnected(input_dim=num_features, output_dim=action_num, **q_net_conf)
|
||||
self._optim = SGD(self._fc.parameters(), lr=q_net_learning_rate)
|
||||
self._lr_scheduler = CosineAnnealingWarmRestarts(self._optim, **q_net_lr_scheduler_params)
|
||||
|
||||
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
|
||||
masks = states[:, self._num_features:]
|
||||
q_for_all_actions = self._fc(states[:, :self._num_features])
|
||||
return q_for_all_actions + (masks - 1) * 1e8
|
||||
|
||||
|
||||
class MaskedEpsGreedy:
|
||||
def __init__(self, state_dim: int, num_features: int) -> None:
|
||||
self._state_dim = state_dim
|
||||
self._num_features = num_features
|
||||
|
||||
def __call__(self, states, actions, num_actions, *, epsilon):
|
||||
masks = states[:, self._num_features:]
|
||||
return np.array([
|
||||
action if np.random.random() > epsilon else np.random.choice(np.where(mask == 1)[0])
|
||||
for action, mask in zip(actions, masks)
|
||||
])
|
||||
|
||||
|
||||
def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str) -> ValueBasedPolicy:
|
||||
return ValueBasedPolicy(
|
||||
name=name,
|
||||
q_net=MyQNet(state_dim, action_num, num_features),
|
||||
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
|
||||
exploration_scheduling_options=[(
|
||||
"epsilon", MultiLinearExplorationScheduler, {
|
||||
"splits": [(100, 0.32)],
|
||||
"initial_value": 0.4,
|
||||
"last_ep": 400,
|
||||
"final_value": 0.0,
|
||||
}
|
||||
)],
|
||||
warmup=100,
|
||||
)
|
||||
|
||||
|
||||
def get_dqn(name: str) -> DQNTrainer:
|
||||
return DQNTrainer(
|
||||
name=name,
|
||||
params=DQNParams(
|
||||
reward_discount=0.9,
|
||||
update_target_every=5,
|
||||
num_epochs=100,
|
||||
soft_update_coef=0.1,
|
||||
double=False,
|
||||
replay_memory_capacity=10000,
|
||||
random_overwrite=False,
|
||||
batch_size=32,
|
||||
data_parallelism=2,
|
||||
),
|
||||
)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from maro.simulator import Env
|
||||
|
||||
|
||||
env_conf = {
|
||||
"scenario": "vm_scheduling",
|
||||
"topology": "azure.2019.10k",
|
||||
"start_tick": 0,
|
||||
"durations": 300, # 8638
|
||||
"snapshot_resolution": 1,
|
||||
}
|
||||
|
||||
num_pms = Env(**env_conf).business_engine.pm_amount
|
||||
pm_window_size = 1
|
||||
num_features = 2 * num_pms * pm_window_size + 4
|
||||
state_dim = num_features + num_pms + 1
|
||||
|
||||
pm_attributes = ["cpu_cores_capacity", "memory_capacity", "cpu_cores_allocated", "memory_allocated"]
|
||||
# vm_attributes = ["cpu_cores_requirement", "memory_requirement", "lifetime", "remain_time", "total_income"]
|
||||
|
||||
|
||||
reward_shaping_conf = {
|
||||
"alpha": 0.0,
|
||||
"beta": 1.0,
|
||||
}
|
||||
seed = 666
|
||||
|
||||
test_env_conf = {
|
||||
"scenario": "vm_scheduling",
|
||||
"topology": "azure.2019.10k.oversubscription",
|
||||
"start_tick": 0,
|
||||
"durations": 300,
|
||||
"snapshot_resolution": 1,
|
||||
}
|
||||
test_reward_shaping_conf = {
|
||||
"alpha": 0.0,
|
||||
"beta": 1.0,
|
||||
}
|
||||
|
||||
test_seed = 1024
|
||||
|
||||
algorithm = "ac" # "dqn" or "ac"
|
|
@ -0,0 +1,200 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from os import makedirs
|
||||
from os.path import dirname, join, realpath
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from maro.rl.rollout import AbsEnvSampler, CacheElement
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
|
||||
|
||||
from .config import (
|
||||
num_features, pm_attributes, pm_window_size, reward_shaping_conf, seed, test_reward_shaping_conf, test_seed,
|
||||
)
|
||||
|
||||
timestamp = str(time.time())
|
||||
plt_path = join(dirname(realpath(__file__)), "plots", timestamp)
|
||||
makedirs(plt_path, exist_ok=True)
|
||||
|
||||
|
||||
class VMEnvSampler(AbsEnvSampler):
|
||||
def __init__(self, learn_env: Env, test_env: Env) -> None:
|
||||
super(VMEnvSampler, self).__init__(learn_env, test_env)
|
||||
|
||||
self._learn_env.set_seed(seed)
|
||||
self._test_env.set_seed(test_seed)
|
||||
|
||||
# adjust the ratio of the success allocation and the total income when computing the reward
|
||||
self.num_pms = self._learn_env.business_engine._pm_amount # the number of pms
|
||||
self._durations = self._learn_env.business_engine._max_tick
|
||||
self._pm_state_history = np.zeros((pm_window_size - 1, self.num_pms, 2))
|
||||
self._legal_pm_mask = None
|
||||
|
||||
def _get_global_and_agent_state_impl(
|
||||
self, event: DecisionPayload, tick: int = None,
|
||||
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
|
||||
pm_state, vm_state = self._get_pm_state(), self._get_vm_state(event)
|
||||
# get the legal number of PM.
|
||||
legal_pm_mask = np.zeros(self.num_pms + 1)
|
||||
if len(event.valid_pms) <= 0:
|
||||
# no pm available
|
||||
legal_pm_mask[self.num_pms] = 1
|
||||
else:
|
||||
legal_pm_mask[self.num_pms] = 1
|
||||
remain_cpu_dict = dict()
|
||||
for pm in event.valid_pms:
|
||||
# If two pms have the same remaining cpu, choose the one with the smaller id
|
||||
if pm_state[-1, pm, 0] not in remain_cpu_dict:
|
||||
remain_cpu_dict[pm_state[-1, pm, 0]] = 1
|
||||
legal_pm_mask[pm] = 1
|
||||
else:
|
||||
legal_pm_mask[pm] = 0
|
||||
|
||||
self._legal_pm_mask = legal_pm_mask
|
||||
state = np.concatenate((pm_state.flatten(), vm_state.flatten(), legal_pm_mask)).astype(np.float32)
|
||||
return None, {"AGENT": state}
|
||||
|
||||
def _translate_to_env_action(
|
||||
self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionPayload,
|
||||
) -> Dict[Any, object]:
|
||||
if action_dict["AGENT"] == self.num_pms:
|
||||
return {"AGENT": PostponeAction(vm_id=event.vm_id, postpone_step=1)}
|
||||
else:
|
||||
return {"AGENT": AllocateAction(vm_id=event.vm_id, pm_id=action_dict["AGENT"][0])}
|
||||
|
||||
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionPayload, tick: int) -> Dict[Any, float]:
|
||||
action = env_action_dict["AGENT"]
|
||||
conf = reward_shaping_conf if self._env == self._learn_env else test_reward_shaping_conf
|
||||
if isinstance(action, PostponeAction): # postponement
|
||||
if np.sum(self._legal_pm_mask) != 1:
|
||||
reward = -0.1 * conf["alpha"] + 0.0 * conf["beta"]
|
||||
else:
|
||||
reward = 0.0 * conf["alpha"] + 0.0 * conf["beta"]
|
||||
else:
|
||||
reward = self._get_allocation_reward(event, conf["alpha"], conf["beta"]) if event else .0
|
||||
return {"AGENT": np.float32(reward)}
|
||||
|
||||
def _get_pm_state(self):
|
||||
total_pm_info = self._env.snapshot_list["pms"][self._env.frame_index::pm_attributes]
|
||||
total_pm_info = total_pm_info.reshape(self.num_pms, len(pm_attributes))
|
||||
|
||||
# normalize the attributes of pms' cpu and memory
|
||||
self._max_cpu_capacity = np.max(total_pm_info[:, 0])
|
||||
self._max_memory_capacity = np.max(total_pm_info[:, 1])
|
||||
total_pm_info[:, 2] /= self._max_cpu_capacity
|
||||
total_pm_info[:, 3] /= self._max_memory_capacity
|
||||
|
||||
# get the remaining cpu and memory of the pms
|
||||
remain_cpu = (1 - total_pm_info[:, 2]).reshape(1, self.num_pms, 1)
|
||||
remain_memory = (1 - total_pm_info[:, 3]).reshape(1, self.num_pms, 1)
|
||||
|
||||
# get the pms' information
|
||||
total_pm_info = np.concatenate((remain_cpu, remain_memory), axis=2) # (1, num_pms, 2)
|
||||
|
||||
# get the sequence pms' information
|
||||
self._pm_state_history = np.concatenate((self._pm_state_history, total_pm_info), axis=0)
|
||||
return self._pm_state_history[-pm_window_size:, :, :] # (win_size, num_pms, 2)
|
||||
|
||||
def _get_vm_state(self, event):
|
||||
return np.array([
|
||||
event.vm_cpu_cores_requirement / self._max_cpu_capacity,
|
||||
event.vm_memory_requirement / self._max_memory_capacity,
|
||||
(self._durations - self._env.tick) * 1.0 / 200, # TODO: CHANGE 200 TO SOMETHING CONFIGURABLE
|
||||
self._env.business_engine._get_unit_price(event.vm_cpu_cores_requirement, event.vm_memory_requirement)
|
||||
])
|
||||
|
||||
def _get_allocation_reward(self, event: DecisionPayload, alpha: float, beta: float):
|
||||
vm_unit_price = self._env.business_engine._get_unit_price(
|
||||
event.vm_cpu_cores_requirement, event.vm_memory_requirement
|
||||
)
|
||||
return (alpha + beta * vm_unit_price * min(self._durations - event.frame_index, event.remaining_buffer_time))
|
||||
|
||||
def _post_step(self, cache_element: CacheElement) -> None:
|
||||
self._info["env_metric"] = {k: v for k, v in self._env.metrics.items() if k != "total_latency"}
|
||||
self._info["env_metric"]["latency_due_to_agent"] = self._env.metrics["total_latency"].due_to_agent
|
||||
self._info["env_metric"]["latency_due_to_resource"] = self._env.metrics["total_latency"].due_to_resource
|
||||
if "actions_by_core_requirement" not in self._info:
|
||||
self._info["actions_by_core_requirement"] = defaultdict(list)
|
||||
if "action_sequence" not in self._info:
|
||||
self._info["action_sequence"] = []
|
||||
|
||||
action = cache_element.action_dict["AGENT"]
|
||||
if cache_element.state:
|
||||
mask = cache_element.state[num_features:]
|
||||
self._info["actions_by_core_requirement"][cache_element.event.vm_cpu_cores_requirement].append([action, mask])
|
||||
self._info["action_sequence"].append(action)
|
||||
|
||||
def _post_eval_step(self, cache_element: CacheElement) -> None:
|
||||
self._post_step(cache_element)
|
||||
|
||||
def post_collect(self, info_list: list, ep: int) -> None:
|
||||
# print the env metric from each rollout worker
|
||||
for info in info_list:
|
||||
print(f"env summary (episode {ep}): {info['env_metric']}")
|
||||
|
||||
# print the average env metric
|
||||
if len(info_list) > 1:
|
||||
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
|
||||
avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys}
|
||||
print(f"average env metric (episode {ep}): {avg_metric}")
|
||||
|
||||
def post_evaluate(self, info_list: list, ep: int) -> None:
|
||||
# print the env metric from each rollout worker
|
||||
for info in info_list:
|
||||
print(f"env summary (evaluation episode {ep}): {info['env_metric']}")
|
||||
|
||||
# print the average env metric
|
||||
if len(info_list) > 1:
|
||||
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
|
||||
avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys}
|
||||
print(f"average env metric (evaluation episode {ep}): {avg_metric}")
|
||||
|
||||
for info in info_list:
|
||||
core_requirement = info["actions_by_core_requirement"]
|
||||
action_sequence = info["action_sequence"]
|
||||
# plot action sequence
|
||||
fig = plt.figure(figsize=(40, 32))
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
ax.plot(action_sequence)
|
||||
fig.savefig(f"{plt_path}/action_sequence_{ep}")
|
||||
plt.cla()
|
||||
plt.close("all")
|
||||
|
||||
# plot with legal action mask
|
||||
fig = plt.figure(figsize=(40, 32))
|
||||
for idx, key in enumerate(core_requirement.keys()):
|
||||
ax = fig.add_subplot(len(core_requirement.keys()), 1, idx + 1)
|
||||
for i in range(len(core_requirement[key])):
|
||||
if i == 0:
|
||||
ax.plot(core_requirement[key][i][0] * core_requirement[key][i][1], label=str(key))
|
||||
ax.legend()
|
||||
else:
|
||||
ax.plot(core_requirement[key][i][0] * core_requirement[key][i][1])
|
||||
|
||||
fig.savefig(f"{plt_path}/values_with_legal_action_{ep}")
|
||||
|
||||
plt.cla()
|
||||
plt.close("all")
|
||||
|
||||
# plot without legal actin mask
|
||||
fig = plt.figure(figsize=(40, 32))
|
||||
|
||||
for idx, key in enumerate(core_requirement.keys()):
|
||||
ax = fig.add_subplot(len(core_requirement.keys()), 1, idx + 1)
|
||||
for i in range(len(core_requirement[key])):
|
||||
if i == 0:
|
||||
ax.plot(core_requirement[key][i][0], label=str(key))
|
||||
ax.legend()
|
||||
else:
|
||||
ax.plot(core_requirement[key][i][0])
|
||||
|
||||
fig.savefig(f"{plt_path}/values_without_legal_action_{ep}")
|
||||
|
||||
plt.cla()
|
||||
plt.close("all")
|
|
@ -0,0 +1,57 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from examples.vm_scheduling.rl.algorithms.ac import get_ac_policy
|
||||
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn_policy
|
||||
from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf
|
||||
from examples.vm_scheduling.rl.env_sampler import VMEnvSampler
|
||||
from maro.rl.policy import AbsPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import AbsEnvSampler
|
||||
from maro.rl.training import AbsTrainer
|
||||
|
||||
|
||||
class VMBundle(RLComponentBundle):
|
||||
def get_env_config(self) -> dict:
|
||||
return env_conf
|
||||
|
||||
def get_test_env_config(self) -> Optional[dict]:
|
||||
return test_env_conf
|
||||
|
||||
def get_env_sampler(self) -> AbsEnvSampler:
|
||||
return VMEnvSampler(self.env, self.test_env)
|
||||
|
||||
def get_agent2policy(self) -> Dict[Any, str]:
|
||||
return {"AGENT": f"{algorithm}.policy"}
|
||||
|
||||
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
|
||||
action_num = num_pms + 1 # action could be any PM or postponement, hence the plus 1
|
||||
|
||||
if algorithm == "ac":
|
||||
policy_creator = {
|
||||
f"{algorithm}.policy": partial(
|
||||
get_ac_policy, state_dim, action_num, num_features, f"{algorithm}.policy",
|
||||
)
|
||||
}
|
||||
elif algorithm == "dqn":
|
||||
policy_creator = {
|
||||
f"{algorithm}.policy": partial(
|
||||
get_dqn_policy, state_dim, action_num, num_features, f"{algorithm}.policy",
|
||||
)
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return policy_creator
|
||||
|
||||
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
|
||||
if algorithm == "ac":
|
||||
from .algorithms.ac import get_ac, get_ac_policy
|
||||
trainer_creator = {algorithm: partial(get_ac, state_dim, num_features, algorithm)}
|
||||
elif algorithm == "dqn":
|
||||
from .algorithms.dqn import get_dqn, get_dqn_policy
|
||||
trainer_creator = {algorithm: partial(get_dqn, algorithm)}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return trainer_creator
|
139
maro/README.rst
|
@ -41,13 +41,13 @@
|
|||
|
||||
|
||||
.. image:: https://github.com/microsoft/maro/workflows/test/badge.svg
|
||||
:target: https://github.com/microsoft/maro/actions?query=workflow%3Atest
|
||||
:alt: test
|
||||
:target: https://github.com/microsoft/maro/actions?query=workflow%3Atest
|
||||
:alt: test
|
||||
|
||||
|
||||
.. image:: https://github.com/microsoft/maro/workflows/build/badge.svg
|
||||
:target: https://github.com/microsoft/maro/actions?query=workflow%3Abuild
|
||||
:alt: build
|
||||
:target: https://github.com/microsoft/maro/actions?query=workflow%3Abuild
|
||||
:alt: build
|
||||
|
||||
|
||||
.. image:: https://github.com/microsoft/maro/workflows/docker/badge.svg
|
||||
|
@ -56,8 +56,8 @@
|
|||
|
||||
|
||||
.. image:: https://readthedocs.org/projects/maro/badge/?version=latest
|
||||
:target: https://maro.readthedocs.io/
|
||||
:alt: docs
|
||||
:target: https://maro.readthedocs.io/
|
||||
:alt: docs
|
||||
|
||||
|
||||
.. image:: https://img.shields.io/pypi/v/pymaro
|
||||
|
@ -142,6 +142,69 @@
|
|||
|
||||
================================================================================================================
|
||||
|
||||
|
||||
.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/vm_scheduling.svg
|
||||
:target: https://maro.readthedocs.io/en/latest/scenarios/vm_scheduling.html
|
||||
:alt: VM Scheduling
|
||||
|
||||
|
||||
.. image:: https://img.shields.io/gitter/room/microsoft/maro
|
||||
:target: https://gitter.im/Microsoft/MARO#
|
||||
:alt: Gitter
|
||||
|
||||
|
||||
.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/stack_overflow.svg
|
||||
:target: https://stackoverflow.com/questions/ask?tags=maro
|
||||
:alt: Stack Overflow
|
||||
|
||||
|
||||
.. image:: https://img.shields.io/github/release-date-pre/microsoft/maro
|
||||
:target: https://github.com/microsoft/maro/releases
|
||||
:alt: Releases
|
||||
|
||||
|
||||
.. image:: https://img.shields.io/github/commits-since/microsoft/maro/latest/master
|
||||
:target: https://github.com/microsoft/maro/commits/master
|
||||
:alt: Commits
|
||||
|
||||
|
||||
.. image:: https://github.com/microsoft/maro/workflows/vulnerability%20scan/badge.svg
|
||||
:target: https://github.com/microsoft/maro/actions?query=workflow%3A%22vulnerability+scan%22
|
||||
:alt: Vulnerability Scan
|
||||
|
||||
|
||||
.. image:: https://github.com/microsoft/maro/workflows/lint/badge.svg
|
||||
:target: https://github.com/microsoft/maro/actions?query=workflow%3Alint
|
||||
:alt: Lint
|
||||
|
||||
|
||||
.. image:: https://img.shields.io/codecov/c/github/microsoft/maro
|
||||
:target: https://codecov.io/gh/microsoft/maro
|
||||
:alt: Coverage
|
||||
|
||||
|
||||
.. image:: https://img.shields.io/pypi/dm/pymaro
|
||||
:target: https://pypi.org/project/pymaro/#files
|
||||
:alt: Downloads
|
||||
|
||||
|
||||
.. image:: https://img.shields.io/docker/pulls/maro2020/maro
|
||||
:target: https://hub.docker.com/repository/docker/maro2020/maro
|
||||
:alt: Docker Pulls
|
||||
|
||||
|
||||
.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/play_with_maro.svg
|
||||
:target: https://hub.docker.com/r/maro2020/maro
|
||||
:alt: Play with MARO
|
||||
|
||||
|
||||
|
||||
.. image:: https://github.com/microsoft/maro/blob/master/docs/source/images/logo.svg
|
||||
:target: https://maro.readthedocs.io/en/latest/
|
||||
:alt: MARO LOGO
|
||||
|
||||
================================================================================================================
|
||||
|
||||
Multi-Agent Resource Optimization (MARO) platform is an instance of Reinforcement
|
||||
learning as a Service (RaaS) for real-world resource optimization. It can be
|
||||
applied to many important industrial domains, such as `container inventory
|
||||
|
@ -172,18 +235,18 @@ Contents
|
|||
--------
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
:header-rows: 1
|
||||
|
||||
* - File/folder
|
||||
- Description
|
||||
* - ``maro``
|
||||
- MARO source code.
|
||||
* - ``docs``
|
||||
- MARO docs, it is host on `readthedocs <https://maro.readthedocs.io/en/latest/>`_.
|
||||
* - ``examples``
|
||||
- Showcase of MARO.
|
||||
* - ``notebooks``
|
||||
- MARO quick-start notebooks.
|
||||
* - File/folder
|
||||
- Description
|
||||
* - ``maro``
|
||||
- MARO source code.
|
||||
* - ``docs``
|
||||
- MARO docs, it is host on `readthedocs <https://maro.readthedocs.io/en/latest/>`_.
|
||||
* - ``examples``
|
||||
- Showcase of MARO.
|
||||
* - ``notebooks``
|
||||
- MARO quick-start notebooks.
|
||||
|
||||
|
||||
*Try `MARO playground <#run-playground>`_ to have a quick experience.*
|
||||
|
@ -199,17 +262,17 @@ Install MARO from `PyPI <https://pypi.org/project/pymaro/#files>`_
|
|||
|
||||
.. code-block:: sh
|
||||
|
||||
pip install pymaro
|
||||
pip install pymaro
|
||||
|
||||
*
|
||||
Windows
|
||||
|
||||
.. code-block:: powershell
|
||||
|
||||
# Install torch first, if you don't have one.
|
||||
pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
# Install torch first, if you don't have one.
|
||||
pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
pip install pymaro
|
||||
pip install pymaro
|
||||
|
||||
Install MARO from Source
|
||||
------------------------
|
||||
|
@ -235,9 +298,9 @@ Install MARO from Source
|
|||
|
||||
.. code-block:: sh
|
||||
|
||||
# If your environment is not clean, create a virtual environment firstly.
|
||||
python -m venv maro_venv
|
||||
source ./maro_venv/bin/activate
|
||||
# If your environment is not clean, create a virtual environment firstly.
|
||||
python -m venv maro_venv
|
||||
source ./maro_venv/bin/activate
|
||||
|
||||
*
|
||||
Windows
|
||||
|
@ -267,16 +330,16 @@ Install MARO from Source
|
|||
|
||||
.. code-block:: sh
|
||||
|
||||
# Install MARO from source.
|
||||
bash scripts/install_maro.sh
|
||||
# Install MARO from source.
|
||||
bash scripts/install_maro.sh
|
||||
|
||||
*
|
||||
Windows
|
||||
|
||||
.. code-block:: powershell
|
||||
|
||||
# Install MARO from source.
|
||||
.\scripts\install_maro.bat
|
||||
# Install MARO from source.
|
||||
.\scripts\install_maro.bat
|
||||
|
||||
*
|
||||
*Notes: If your package is not found, remember to set your PYTHONPATH*
|
||||
|
@ -300,16 +363,16 @@ Quick Example
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator import Env
|
||||
|
||||
env = Env(scenario="cim", topology="toy.5p_ssddd_l0.0", start_tick=0, durations=100)
|
||||
env = Env(scenario="cim", topology="toy.5p_ssddd_l0.0", start_tick=0, durations=100)
|
||||
|
||||
metrics, decision_event, is_done = env.step(None)
|
||||
metrics, decision_event, is_done = env.step(None)
|
||||
|
||||
while not is_done:
|
||||
metrics, decision_event, is_done = env.step(None)
|
||||
while not is_done:
|
||||
metrics, decision_event, is_done = env.step(None)
|
||||
|
||||
print(f"environment metrics: {env.metrics}")
|
||||
print(f"environment metrics: {env.metrics}")
|
||||
|
||||
`Environment Visualization <https://maro.readthedocs.io/en/latest/>`_
|
||||
-------------------------------------------------------------------------
|
||||
|
@ -382,8 +445,8 @@ Run Playground
|
|||
|
||||
.. code-block:: sh
|
||||
|
||||
# Build playground image.
|
||||
bash ./scripts/build_playground.sh
|
||||
# Build playground image.
|
||||
bash ./scripts/build_playground.sh
|
||||
|
||||
# Run playground container.
|
||||
# Redis commander (GUI for redis) -> http://127.0.0.1:40009
|
||||
|
@ -395,8 +458,8 @@ Run Playground
|
|||
|
||||
.. code-block:: powershell
|
||||
|
||||
# Build playground image.
|
||||
.\scripts\build_playground.bat
|
||||
# Build playground image.
|
||||
.\scripts\build_playground.bat
|
||||
|
||||
# Run playground container.
|
||||
# Redis commander (GUI for redis) -> http://127.0.0.1:40009
|
||||
|
|
|
@ -74,6 +74,15 @@ def node(name: str):
|
|||
return node_dec
|
||||
|
||||
|
||||
def try_get_attribute(target, name, default=None):
|
||||
try:
|
||||
attr = object.__getattribute__(target, name)
|
||||
|
||||
return attr
|
||||
except:
|
||||
return default
|
||||
|
||||
|
||||
cdef class NodeAttribute:
|
||||
def __cinit__(self, object dtype = None, SLOT_INDEX slot_num = 1, is_const = False, is_list = False):
|
||||
# Check the type of dtype, used to compact with old version
|
||||
|
@ -532,6 +541,8 @@ cdef class FrameBase:
|
|||
else:
|
||||
node._is_deleted = False
|
||||
|
||||
# Also
|
||||
|
||||
cpdef void take_snapshot(self, INT tick) except *:
|
||||
"""Take snapshot for specified point (tick) for current frame.
|
||||
|
||||
|
|
|
@ -0,0 +1,317 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from os.path import abspath, dirname, expanduser, join
|
||||
|
||||
import yaml
|
||||
|
||||
from maro.cli.utils import docker as docker_utils
|
||||
from maro.cli.utils.azure import storage as azure_storage_utils
|
||||
from maro.cli.utils.azure.aks import attach_acr
|
||||
from maro.cli.utils.azure.deployment import create_deployment
|
||||
from maro.cli.utils.azure.general import connect_to_aks, get_acr_push_permissions, set_env_credentials
|
||||
from maro.cli.utils.azure.resource_group import create_resource_group, delete_resource_group
|
||||
from maro.cli.utils.common import show_log
|
||||
from maro.rl.workflows.config import ConfigParser
|
||||
from maro.utils.logger import CliLogger
|
||||
from maro.utils.utils import LOCAL_MARO_ROOT
|
||||
|
||||
from ..utils import k8s_manifest_generator, k8s_ops
|
||||
|
||||
# metadata
|
||||
CLI_AKS_PATH = dirname(abspath(__file__))
|
||||
TEMPLATE_PATH = join(CLI_AKS_PATH, "template.json")
|
||||
NVIDIA_PLUGIN_PATH = join(CLI_AKS_PATH, "create_nvidia_plugin", "nvidia-device-plugin.yml")
|
||||
LOCAL_ROOT = expanduser("~/.maro/aks")
|
||||
DEPLOYMENT_CONF_PATH = os.path.join(LOCAL_ROOT, "conf.json")
|
||||
DOCKER_FILE_PATH = join(LOCAL_MARO_ROOT, "docker_files", "dev.df")
|
||||
DOCKER_IMAGE_NAME = "maro-aks"
|
||||
REDIS_HOST = "maro-redis"
|
||||
REDIS_PORT = 6379
|
||||
ADDRESS_REGISTRY_NAME = "address-registry"
|
||||
ADDRESS_REGISTRY_PORT = 6379
|
||||
K8S_SECRET_NAME = "azure-secret"
|
||||
|
||||
# display
|
||||
NO_DEPLOYMENT_MSG = "No Kubernetes deployment on Azure found. Use 'maro aks init' to create a deployment first"
|
||||
NO_JOB_MSG = "No job named {} has been scheduled. Use 'maro aks job add' to add the job first."
|
||||
JOB_EXISTS_MSG = "A job named {} has already been scheduled."
|
||||
|
||||
logger = CliLogger(name=__name__)
|
||||
|
||||
|
||||
# helper functions
|
||||
def get_resource_group_name(deployment_name: str):
|
||||
return f"rg-{deployment_name}"
|
||||
|
||||
|
||||
def get_acr_name(deployment_name: str):
|
||||
return f"crmaro{deployment_name}"
|
||||
|
||||
|
||||
def get_acr_server_name(acr_name: str):
|
||||
return f"{acr_name}.azurecr.io"
|
||||
|
||||
|
||||
def get_docker_image_name_in_acr(acr_name: str, docker_image_name: str):
|
||||
return f"{get_acr_server_name(acr_name)}/{docker_image_name}"
|
||||
|
||||
|
||||
def get_aks_name(deployment_name: str):
|
||||
return f"aks-maro-{deployment_name}"
|
||||
|
||||
|
||||
def get_agentpool_name(deployment_name: str):
|
||||
return f"ap{deployment_name}"
|
||||
|
||||
|
||||
def get_fileshare_name(deployment_name: str):
|
||||
return f"fs-{deployment_name}"
|
||||
|
||||
|
||||
def get_storage_account_name(deployment_name: str):
|
||||
return f"stscenario{deployment_name}"
|
||||
|
||||
|
||||
def get_virtual_network_name(location: str, deployment_name: str):
|
||||
return f"vnet-prod-{location}-{deployment_name}"
|
||||
|
||||
|
||||
def get_local_job_path(job_name: str):
|
||||
return os.path.join(LOCAL_ROOT, job_name)
|
||||
|
||||
|
||||
def get_storage_account_secret(resource_group_name: str, storage_account_name: str, namespace: str):
|
||||
storage_account_keys = azure_storage_utils.get_storage_account_keys(resource_group_name, storage_account_name)
|
||||
storage_key = storage_account_keys[0]["value"]
|
||||
secret_data = {
|
||||
"azurestorageaccountname": base64.b64encode(storage_account_name.encode()).decode(),
|
||||
"azurestorageaccountkey": base64.b64encode(bytes(storage_key.encode())).decode()
|
||||
}
|
||||
k8s_ops.create_secret(K8S_SECRET_NAME, secret_data, namespace)
|
||||
|
||||
|
||||
def get_resource_params(deployment_conf: dict) -> dict:
|
||||
"""Create ARM parameters for Azure resource deployment ().
|
||||
|
||||
See https://docs.microsoft.com/en-us/azure/azure-resource-manager/templates/overview for details.
|
||||
|
||||
Args:
|
||||
deployment_conf (dict): Configuration dict for deployment on Azure.
|
||||
|
||||
Returns:
|
||||
dict: parameter dict, should be exported to json.
|
||||
"""
|
||||
name = deployment_conf["name"]
|
||||
return {
|
||||
"acrName": get_acr_name(name),
|
||||
"acrSku": deployment_conf["container_registry_service_tier"],
|
||||
"systemPoolVMCount": deployment_conf["resources"]["k8s"]["vm_count"],
|
||||
"systemPoolVMSize": deployment_conf["resources"]["k8s"]["vm_size"],
|
||||
"userPoolName": get_agentpool_name(name),
|
||||
"userPoolVMCount": deployment_conf["resources"]["app"]["vm_count"],
|
||||
"userPoolVMSize": deployment_conf["resources"]["app"]["vm_size"],
|
||||
"aksName": get_aks_name(name),
|
||||
"location": deployment_conf["location"],
|
||||
"storageAccountName": get_storage_account_name(name),
|
||||
"fileShareName": get_fileshare_name(name)
|
||||
# "virtualNetworkName": get_virtual_network_name(deployment_conf["location"], name)
|
||||
}
|
||||
|
||||
|
||||
def prepare_docker_image_and_push_to_acr(image_name: str, context: str, docker_file_path: str, acr_name: str):
|
||||
# build and tag docker image locally and push to the Azure Container Registry
|
||||
if not docker_utils.image_exists(image_name):
|
||||
docker_utils.build_image(context, docker_file_path, image_name)
|
||||
|
||||
get_acr_push_permissions(os.environ["AZURE_CLIENT_ID"], acr_name)
|
||||
docker_utils.push(image_name, get_acr_server_name(acr_name))
|
||||
|
||||
|
||||
def start_redis_service_in_aks(host: str, port: int, namespace: str):
|
||||
k8s_ops.load_config()
|
||||
k8s_ops.create_namespace(namespace)
|
||||
k8s_ops.create_deployment(k8s_manifest_generator.get_redis_deployment_manifest(host, port), namespace)
|
||||
k8s_ops.create_service(k8s_manifest_generator.get_redis_service_manifest(host, port), namespace)
|
||||
|
||||
|
||||
# CLI command functions
|
||||
def init(deployment_conf_path: str, **kwargs):
|
||||
"""Prepare Azure resources needed for an AKS cluster using a YAML configuration file.
|
||||
|
||||
The configuration file template can be found in cli/k8s/aks/conf.yml. Use the Azure CLI to log into
|
||||
your Azure account (az login ...) and the the Azure Container Registry (az acr login ...) first.
|
||||
|
||||
Args:
|
||||
deployment_conf_path (str): Path to the deployment configuration file.
|
||||
"""
|
||||
with open(deployment_conf_path, "r") as fp:
|
||||
deployment_conf = yaml.safe_load(fp)
|
||||
|
||||
subscription = deployment_conf["azure_subscription"]
|
||||
name = deployment_conf["name"]
|
||||
if os.path.isfile(DEPLOYMENT_CONF_PATH):
|
||||
logger.warning(f"Deployment {name} has already been created")
|
||||
return
|
||||
|
||||
os.makedirs(LOCAL_ROOT, exist_ok=True)
|
||||
resource_group_name = get_resource_group_name(name)
|
||||
try:
|
||||
# Set credentials as environment variables
|
||||
set_env_credentials(LOCAL_ROOT, f"sp-{name}")
|
||||
|
||||
# create resource group
|
||||
resource_group = create_resource_group(subscription, resource_group_name, deployment_conf["location"])
|
||||
logger.info_green(f"Provisioned resource group {resource_group.name} in {resource_group.location}")
|
||||
|
||||
# Create ARM parameters and start deployment
|
||||
logger.info("Creating Azure resources...")
|
||||
resource_params = get_resource_params(deployment_conf)
|
||||
with open(TEMPLATE_PATH, 'r') as fp:
|
||||
template = json.load(fp)
|
||||
|
||||
create_deployment(subscription, resource_group_name, name, template, resource_params)
|
||||
|
||||
# Attach ACR to AKS
|
||||
aks_name, acr_name = resource_params["aksName"], resource_params["acrName"]
|
||||
attach_acr(resource_group_name, aks_name, acr_name)
|
||||
connect_to_aks(resource_group_name, aks_name)
|
||||
|
||||
# build and tag docker image locally and push to the Azure Container Registry
|
||||
logger.info("Preparing docker image...")
|
||||
prepare_docker_image_and_push_to_acr(DOCKER_IMAGE_NAME, LOCAL_MARO_ROOT, DOCKER_FILE_PATH, acr_name)
|
||||
|
||||
# start the Redis service in the k8s cluster in the deployment namespace and expose it
|
||||
logger.info("Starting Redis service in the k8s cluster...")
|
||||
start_redis_service_in_aks(REDIS_HOST, REDIS_PORT, name)
|
||||
|
||||
# Dump the deployment configuration
|
||||
with open(DEPLOYMENT_CONF_PATH, "w") as fp:
|
||||
json.dump({
|
||||
"name": name,
|
||||
"subscription": subscription,
|
||||
"resource_group": resource_group_name,
|
||||
"resources": resource_params
|
||||
}, fp)
|
||||
logger.info_green(f"Cluster '{name}' is created")
|
||||
except Exception as e:
|
||||
# If failed, remove details folder, then raise
|
||||
shutil.rmtree(LOCAL_ROOT)
|
||||
logger.error_red(f"Deployment {name} failed due to {e}, rolling back...")
|
||||
delete_resource_group(subscription, resource_group_name)
|
||||
except KeyboardInterrupt:
|
||||
shutil.rmtree(LOCAL_ROOT)
|
||||
logger.error_red(f"Deployment {name} aborted, rolling back...")
|
||||
delete_resource_group(subscription, resource_group_name)
|
||||
|
||||
|
||||
def add_job(conf_path: dict, **kwargs):
|
||||
if not os.path.isfile(DEPLOYMENT_CONF_PATH):
|
||||
logger.error_red(NO_DEPLOYMENT_MSG)
|
||||
return
|
||||
|
||||
parser = ConfigParser(conf_path)
|
||||
job_name = parser.config["job"]
|
||||
local_job_path = get_local_job_path(job_name)
|
||||
if os.path.isdir(local_job_path):
|
||||
logger.error_red(JOB_EXISTS_MSG.format(job_name))
|
||||
return
|
||||
|
||||
os.makedirs(local_job_path)
|
||||
with open(DEPLOYMENT_CONF_PATH, "r") as fp:
|
||||
deployment_conf = json.load(fp)
|
||||
|
||||
resource_group_name, resource_name = deployment_conf["resource_group"], deployment_conf["resources"]
|
||||
fileshare = azure_storage_utils.get_fileshare(resource_name["storageAccountName"], resource_name["fileShareName"])
|
||||
job_dir = azure_storage_utils.get_directory(fileshare, job_name)
|
||||
scenario_path = parser.config["scenario_path"]
|
||||
logger.info(f"Uploading local directory {scenario_path}...")
|
||||
azure_storage_utils.upload_to_fileshare(job_dir, scenario_path, name="scenario")
|
||||
azure_storage_utils.get_directory(job_dir, "checkpoints")
|
||||
azure_storage_utils.get_directory(job_dir, "logs")
|
||||
|
||||
# Define mount volumes, i.e., scenario code, checkpoints, logs and load point
|
||||
job_path_in_share = f"{resource_name['fileShareName']}/{job_name}"
|
||||
volumes = [
|
||||
k8s_manifest_generator.get_azurefile_volume_spec(name, f"{job_path_in_share}/{name}", K8S_SECRET_NAME)
|
||||
for name in ["scenario", "logs", "checkpoints"]
|
||||
]
|
||||
|
||||
if "load_path" in parser.config["training"]:
|
||||
load_path = parser.config["training"]["load_path"]
|
||||
logger.info(f"Uploading local model directory {load_path}...")
|
||||
azure_storage_utils.upload_to_fileshare(job_dir, load_path, name="loadpoint")
|
||||
volumes.append(
|
||||
k8s_manifest_generator.get_azurefile_volume_spec(
|
||||
"loadpoint", f"{job_path_in_share}/loadpoint", K8S_SECRET_NAME)
|
||||
)
|
||||
|
||||
# Start k8s jobs
|
||||
k8s_ops.load_config()
|
||||
k8s_ops.create_namespace(job_name)
|
||||
get_storage_account_secret(resource_group_name, resource_name["storageAccountName"], job_name)
|
||||
k8s_ops.create_service(
|
||||
k8s_manifest_generator.get_cross_namespace_service_access_manifest(
|
||||
ADDRESS_REGISTRY_NAME, REDIS_HOST, deployment_conf["name"], ADDRESS_REGISTRY_PORT
|
||||
), job_name
|
||||
)
|
||||
for component_name, (script, env) in parser.get_job_spec(containerize=True).items():
|
||||
container_spec = k8s_manifest_generator.get_container_spec(
|
||||
get_docker_image_name_in_acr(resource_name["acrName"], DOCKER_IMAGE_NAME),
|
||||
component_name,
|
||||
script,
|
||||
env,
|
||||
volumes
|
||||
)
|
||||
manifest = k8s_manifest_generator.get_job_manifest(
|
||||
resource_name["userPoolName"],
|
||||
component_name,
|
||||
container_spec,
|
||||
volumes
|
||||
)
|
||||
k8s_ops.create_job(manifest, job_name)
|
||||
|
||||
|
||||
def remove_jobs(job_names: str, **kwargs):
|
||||
if not os.path.isfile(DEPLOYMENT_CONF_PATH):
|
||||
logger.error_red(NO_DEPLOYMENT_MSG)
|
||||
return
|
||||
|
||||
k8s_ops.load_config()
|
||||
for job_name in job_names:
|
||||
local_job_path = get_local_job_path(job_name)
|
||||
if not os.path.isdir(local_job_path):
|
||||
logger.error_red(NO_JOB_MSG.format(job_name))
|
||||
return
|
||||
|
||||
k8s_ops.delete_job(job_name)
|
||||
|
||||
|
||||
def get_job_logs(job_name: str, tail: int = -1, **kwargs):
|
||||
with open(DEPLOYMENT_CONF_PATH, "r") as fp:
|
||||
deployment_conf = json.load(fp)
|
||||
|
||||
local_log_path = os.path.join(get_local_job_path(job_name), "log")
|
||||
resource_name = deployment_conf["resources"]
|
||||
fileshare = azure_storage_utils.get_fileshare(resource_name["storageAccountName"], resource_name["fileShareName"])
|
||||
job_dir = azure_storage_utils.get_directory(fileshare, job_name)
|
||||
log_dir = azure_storage_utils.get_directory(job_dir, "logs")
|
||||
azure_storage_utils.download_from_fileshare(log_dir, f"{job_name}.log", local_log_path)
|
||||
show_log(local_log_path, tail=tail)
|
||||
|
||||
|
||||
def exit(**kwargs):
|
||||
try:
|
||||
with open(DEPLOYMENT_CONF_PATH, "r") as fp:
|
||||
deployment_conf = json.load(fp)
|
||||
except FileNotFoundError:
|
||||
logger.error(NO_DEPLOYMENT_MSG)
|
||||
return
|
||||
|
||||
name = deployment_conf["name"]
|
||||
set_env_credentials(LOCAL_ROOT, f"sp-{name}")
|
||||
delete_resource_group(deployment_conf["subscription"], deployment_conf["resource_group"])
|
|
@ -0,0 +1,12 @@
|
|||
mode: ""
|
||||
azure_subscription: your_azure_subscription_id
|
||||
name: your_deployment_name
|
||||
location: your_azure_service_location
|
||||
container_registry_service_tier: Standard # "Basic", "Standard", "Premium", see https://docs.microsoft.com/en-us/azure/container-registry/container-registry-skus for details
|
||||
resources:
|
||||
k8s:
|
||||
vm_size: Standard_DS2_v2 # https://docs.microsoft.com/en-us/azure/virtual-machines/sizes, https://docs.microsoft.com/en-us/azure/aks/quotas-skus-regions
|
||||
vm_count: 1 # must be at least 2 for k8s to function properly.
|
||||
app:
|
||||
vm_size: Standard_DS2_v2 # https://docs.microsoft.com/en-us/azure/virtual-machines/sizes, https://docs.microsoft.com/en-us/azure/aks/quotas-skus-regions
|
||||
vm_count: 1
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"$schema": "https://schema.management.azure.com/schemas/2015-01-01/deploymentParameters.json#",
|
||||
"contentVersion": "1.1.0.0",
|
||||
"parameters": {
|
||||
"acrName": {
|
||||
"value": "myacr"
|
||||
},
|
||||
"acrSku": {
|
||||
"value": "Basic"
|
||||
},
|
||||
"agentCount": {
|
||||
"value": 1
|
||||
},
|
||||
"agentVMSize": {
|
||||
"value": "standard_a2_v2"
|
||||
},
|
||||
"clusterName": {
|
||||
"value": "myaks"
|
||||
},
|
||||
"fileShareName": {
|
||||
"value": "myfileshare"
|
||||
},
|
||||
"location": {
|
||||
"value": "East US"
|
||||
},
|
||||
"storageAccountName": {
|
||||
"value": "mystorage"
|
||||
},
|
||||
"virtualNetworkName": {
|
||||
"value": "myvnet"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,157 @@
|
|||
{
|
||||
"$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
|
||||
"contentVersion": "1.1.0.0",
|
||||
"parameters": {
|
||||
"acrName": {
|
||||
"type": "string",
|
||||
"minLength": 5,
|
||||
"maxLength": 50,
|
||||
"metadata": {
|
||||
"description": "Name of your Azure Container Registry"
|
||||
}
|
||||
},
|
||||
"acrSku": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Tier of your Azure Container Registry."
|
||||
},
|
||||
"defaultValue": "Standard",
|
||||
"allowedValues": [
|
||||
"Basic",
|
||||
"Standard",
|
||||
"Premium"
|
||||
]
|
||||
},
|
||||
"systemPoolVMCount": {
|
||||
"type": "int",
|
||||
"metadata": {
|
||||
"description": "The number of VMs allocated for running the k8s system components."
|
||||
},
|
||||
"minValue": 1,
|
||||
"maxValue": 50
|
||||
},
|
||||
"systemPoolVMSize": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Virtual Machine size for running the k8s system components."
|
||||
}
|
||||
},
|
||||
"userPoolName": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Name of the user node pool."
|
||||
}
|
||||
},
|
||||
"userPoolVMCount": {
|
||||
"type": "int",
|
||||
"metadata": {
|
||||
"description": "The number of VMs allocated for running the user appplication."
|
||||
},
|
||||
"minValue": 1,
|
||||
"maxValue": 50
|
||||
},
|
||||
"userPoolVMSize": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Virtual Machine size for running the user application."
|
||||
}
|
||||
},
|
||||
"aksName": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Name of the Managed Cluster resource."
|
||||
}
|
||||
},
|
||||
"location": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Location of the Managed Cluster resource."
|
||||
}
|
||||
},
|
||||
"storageAccountName": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Azure storage account name."
|
||||
}
|
||||
},
|
||||
"fileShareName": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Azure file share name."
|
||||
}
|
||||
}
|
||||
},
|
||||
"resources": [
|
||||
{
|
||||
"name": "[parameters('acrName')]",
|
||||
"type": "Microsoft.ContainerRegistry/registries",
|
||||
"apiVersion": "2021-09-01",
|
||||
"location": "[parameters('location')]",
|
||||
"sku": {
|
||||
"name": "[parameters('acrSku')]"
|
||||
},
|
||||
"properties": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "[parameters('aksName')]",
|
||||
"type": "Microsoft.ContainerService/managedClusters",
|
||||
"apiVersion": "2021-10-01",
|
||||
"location": "[parameters('location')]",
|
||||
"properties": {
|
||||
"dnsPrefix": "maro",
|
||||
"agentPoolProfiles": [
|
||||
{
|
||||
"name": "system",
|
||||
"osDiskSizeGB": 0,
|
||||
"count": "[parameters('systemPoolVMCount')]",
|
||||
"vmSize": "[parameters('systemPoolVMSize')]",
|
||||
"osType": "Linux",
|
||||
"storageProfile": "ManagedDisks",
|
||||
"mode": "System",
|
||||
"type": "VirtualMachineScaleSets"
|
||||
},
|
||||
{
|
||||
"name": "[parameters('userPoolName')]",
|
||||
"osDiskSizeGB": 0,
|
||||
"count": "[parameters('userPoolVMCount')]",
|
||||
"vmSize": "[parameters('userPoolVMSize')]",
|
||||
"osType": "Linux",
|
||||
"storageProfile": "ManagedDisks",
|
||||
"mode": "User",
|
||||
"type": "VirtualMachineScaleSets"
|
||||
}
|
||||
],
|
||||
"networkProfile": {
|
||||
"networkPlugin": "azure",
|
||||
"loadBalancerSku": "standard"
|
||||
}
|
||||
},
|
||||
"identity": {
|
||||
"type": "SystemAssigned"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Microsoft.Storage/storageAccounts",
|
||||
"apiVersion": "2021-08-01",
|
||||
"name": "[parameters('storageAccountName')]",
|
||||
"location": "[parameters('location')]",
|
||||
"kind": "StorageV2",
|
||||
"sku": {
|
||||
"name": "Standard_LRS",
|
||||
"tier": "Standard"
|
||||
},
|
||||
"properties": {
|
||||
"accessTier": "Hot"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Microsoft.Storage/storageAccounts/fileServices/shares",
|
||||
"apiVersion": "2021-04-01",
|
||||
"name": "[concat(parameters('storageAccountName'), '/default/', parameters('fileShareName'))]",
|
||||
"dependsOn": [
|
||||
"[resourceId('Microsoft.Storage/storageAccounts', parameters('storageAccountName'))]"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -22,18 +22,6 @@
|
|||
"Premium"
|
||||
]
|
||||
},
|
||||
"adminPublicKey": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "Configure all linux machines with the SSH RSA public key string. Your key should include three parts, for example 'ssh-rsa AAAAB...snip...UcyupgH azureuser@linuxvm'"
|
||||
}
|
||||
},
|
||||
"adminUsername": {
|
||||
"type": "string",
|
||||
"metadata": {
|
||||
"description": "User name for the Linux Virtual Machines."
|
||||
}
|
||||
},
|
||||
"agentCount": {
|
||||
"type": "int",
|
||||
"metadata": {
|
||||
|
@ -87,7 +75,7 @@
|
|||
"resources": [
|
||||
{
|
||||
"type": "Microsoft.Storage/storageAccounts/fileServices/shares",
|
||||
"apiVersion": "2020-08-01-preview",
|
||||
"apiVersion": "2021-04-01",
|
||||
"name": "[concat(parameters('storageAccountName'), '/default/', parameters('fileShareName'))]",
|
||||
"dependsOn": [
|
||||
"[variables('stvmId')]"
|
||||
|
@ -96,7 +84,7 @@
|
|||
{
|
||||
"name": "[parameters('acrName')]",
|
||||
"type": "Microsoft.ContainerRegistry/registries",
|
||||
"apiVersion": "2020-11-01-preview",
|
||||
"apiVersion": "2021-09-01",
|
||||
"location": "[parameters('location')]",
|
||||
"sku": {
|
||||
"name": "[parameters('acrSku')]"
|
||||
|
@ -107,7 +95,7 @@
|
|||
{
|
||||
"name": "[parameters('clusterName')]",
|
||||
"type": "Microsoft.ContainerService/managedClusters",
|
||||
"apiVersion": "2020-03-01",
|
||||
"apiVersion": "2021-10-01",
|
||||
"location": "[parameters('location')]",
|
||||
"dependsOn": [
|
||||
"[variables('vnetId')]"
|
||||
|
@ -127,16 +115,6 @@
|
|||
"type": "VirtualMachineScaleSets"
|
||||
}
|
||||
],
|
||||
"linuxProfile": {
|
||||
"adminUsername": "[parameters('adminUsername')]",
|
||||
"ssh": {
|
||||
"publicKeys": [
|
||||
{
|
||||
"keyData": "[parameters('adminPublicKey')]"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"networkProfile": {
|
||||
"networkPlugin": "azure",
|
||||
"loadBalancerSku": "standard"
|
||||
|
@ -148,7 +126,7 @@
|
|||
},
|
||||
{
|
||||
"type": "Microsoft.Storage/storageAccounts",
|
||||
"apiVersion": "2020-08-01-preview",
|
||||
"apiVersion": "2021-08-01",
|
||||
"name": "[parameters('storageAccountName')]",
|
||||
"location": "[parameters('location')]",
|
||||
"kind": "StorageV2",
|
||||
|
@ -163,7 +141,7 @@
|
|||
{
|
||||
"name": "[parameters('virtualNetworkName')]",
|
||||
"type": "Microsoft.Network/virtualNetworks",
|
||||
"apiVersion": "2020-04-01",
|
||||
"apiVersion": "2020-11-01",
|
||||
"location": "[parameters('location')]",
|
||||
"properties": {
|
||||
"addressSpace": {
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import List
|
||||
|
||||
from maro.cli.utils.common import format_env_vars
|
||||
|
||||
|
||||
def get_job_manifest(agent_pool_name: str, component_name: str, container_spec: dict, volumes: List[dict]):
|
||||
return {
|
||||
"metadata": {"name": component_name},
|
||||
"spec": {
|
||||
"template": {
|
||||
"spec": {
|
||||
"nodeSelector": {"agentpool": agent_pool_name},
|
||||
"restartPolicy": "Never",
|
||||
"volumes": volumes,
|
||||
"containers": [container_spec]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_azurefile_volume_spec(name: str, share_name: str, secret_name: str):
|
||||
return {
|
||||
"name": name,
|
||||
"azureFile": {
|
||||
"secretName": secret_name,
|
||||
"shareName": share_name,
|
||||
"readOnly": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_container_spec(image_name: str, component_name: str, script: str, env: dict, volumes):
|
||||
common_container_spec = {
|
||||
"image": image_name,
|
||||
"imagePullPolicy": "Always",
|
||||
"volumeMounts": [{"name": vol["name"], "mountPath": f"/{vol['name']}"} for vol in volumes]
|
||||
}
|
||||
return {
|
||||
**common_container_spec,
|
||||
**{
|
||||
"name": component_name,
|
||||
"command": ["python3", script],
|
||||
"env": format_env_vars(env, mode="k8s")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_redis_deployment_manifest(host: str, port: int):
|
||||
return {
|
||||
"metadata": {
|
||||
"name": host,
|
||||
"labels": {"app": "redis"}
|
||||
},
|
||||
"spec": {
|
||||
"selector": {
|
||||
"matchLabels": {"app": "redis"}
|
||||
},
|
||||
"replicas": 1,
|
||||
"template": {
|
||||
"metadata": {
|
||||
"labels": {"app": "redis"}
|
||||
},
|
||||
"spec": {
|
||||
"containers": [
|
||||
{
|
||||
"name": "master",
|
||||
"image": "redis:6",
|
||||
"ports": [{"containerPort": port}]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_redis_service_manifest(host: str, port: int):
|
||||
return {
|
||||
"metadata": {
|
||||
"name": host,
|
||||
"labels": {"app": "redis"}
|
||||
},
|
||||
"spec": {
|
||||
"ports": [{"port": port, "targetPort": port}],
|
||||
"selector": {"app": "redis"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_cross_namespace_service_access_manifest(
|
||||
service_name: str, target_service_name: str, target_service_namespace: str, target_service_port: int
|
||||
):
|
||||
return {
|
||||
"metadata": {
|
||||
"name": service_name,
|
||||
},
|
||||
"spec": {
|
||||
"type": "ExternalName",
|
||||
"externalName": f"{target_service_name}.{target_service_namespace}.svc.cluster.local",
|
||||
"ports": [{"port": target_service_port}]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import kubernetes
|
||||
from kubernetes import client, config
|
||||
|
||||
|
||||
def load_config():
|
||||
config.load_kube_config()
|
||||
|
||||
|
||||
def create_namespace(namespace: str):
|
||||
body = client.V1Namespace()
|
||||
body.metadata = client.V1ObjectMeta(name=namespace)
|
||||
try:
|
||||
client.CoreV1Api().create_namespace(body)
|
||||
except kubernetes.client.exceptions.ApiException:
|
||||
pass
|
||||
|
||||
|
||||
def create_deployment(conf: dict, namespace: str):
|
||||
client.AppsV1Api().create_namespaced_deployment(namespace, conf)
|
||||
|
||||
|
||||
def create_service(conf: dict, namespace: str):
|
||||
client.CoreV1Api().create_namespaced_service(namespace, conf)
|
||||
|
||||
|
||||
def create_job(conf: dict, namespace: str):
|
||||
client.BatchV1Api().create_namespaced_job(namespace, conf)
|
||||
|
||||
|
||||
def create_secret(name: str, data: dict, namespace: str):
|
||||
client.CoreV1Api().create_namespaced_secret(
|
||||
body=client.V1Secret(metadata=client.V1ObjectMeta(name=name), data=data),
|
||||
namespace=namespace
|
||||
)
|
||||
|
||||
|
||||
def delete_job(namespace: str):
|
||||
client.BatchV1Api().delete_collection_namespaced_job(namespace)
|
||||
client.CoreV1Api().delete_namespace(namespace)
|
||||
|
||||
|
||||
def describe_job(namespace: str):
|
||||
client.CoreV1Api().read_namespace(namespace)
|
|
@ -0,0 +1,253 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from os import makedirs
|
||||
from os.path import abspath, dirname, exists, expanduser, join
|
||||
|
||||
import redis
|
||||
import yaml
|
||||
|
||||
from maro.cli.utils.common import close_by_pid, show_log
|
||||
from maro.rl.workflows.config import ConfigParser
|
||||
from maro.utils.logger import CliLogger
|
||||
from maro.utils.utils import LOCAL_MARO_ROOT
|
||||
|
||||
from .utils import (
|
||||
JobStatus, RedisHashKey, start_redis, start_rl_job, start_rl_job_with_docker_compose, stop_redis,
|
||||
stop_rl_job_with_docker_compose
|
||||
)
|
||||
|
||||
# metadata
|
||||
LOCAL_ROOT = expanduser("~/.maro/local")
|
||||
SESSION_STATE_PATH = join(LOCAL_ROOT, "session.json")
|
||||
DOCKERFILE_PATH = join(LOCAL_MARO_ROOT, "docker_files", "dev.df")
|
||||
DOCKER_IMAGE_NAME = "maro-local"
|
||||
DOCKER_NETWORK = "MAROLOCAL"
|
||||
|
||||
# display
|
||||
NO_JOB_MANAGER_MSG = """No job manager found. Run "maro local init" to start the job manager first."""
|
||||
NO_JOB_MSG = """No job named {} found. Run "maro local job ls" to view existing jobs."""
|
||||
JOB_LS_TEMPLATE = "{JOB:12}{STATUS:15}{STARTED:20}"
|
||||
|
||||
logger = CliLogger(name="MARO-LOCAL")
|
||||
|
||||
|
||||
# helper functions
|
||||
def get_redis_conn(port=None):
|
||||
if port is None:
|
||||
try:
|
||||
with open(SESSION_STATE_PATH, "r") as fp:
|
||||
port = json.load(fp)["port"]
|
||||
except FileNotFoundError:
|
||||
logger.error(NO_JOB_MANAGER_MSG)
|
||||
return
|
||||
|
||||
try:
|
||||
redis_conn = redis.Redis(host="localhost", port=port)
|
||||
redis_conn.ping()
|
||||
return redis_conn
|
||||
except redis.exceptions.ConnectionError:
|
||||
logger.error(NO_JOB_MANAGER_MSG)
|
||||
|
||||
|
||||
# Functions executed on CLI commands
|
||||
def run(conf_path: str, containerize: bool = False, evaluate_only: bool = False, **kwargs):
|
||||
# Load job configuration file
|
||||
parser = ConfigParser(conf_path)
|
||||
if containerize:
|
||||
try:
|
||||
start_rl_job_with_docker_compose(
|
||||
parser, LOCAL_MARO_ROOT, DOCKERFILE_PATH, DOCKER_IMAGE_NAME, evaluate_only=evaluate_only,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
stop_rl_job_with_docker_compose(parser.config["job"], LOCAL_MARO_ROOT)
|
||||
else:
|
||||
try:
|
||||
start_rl_job(parser, LOCAL_MARO_ROOT, evaluate_only=evaluate_only)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def init(
|
||||
port: int = 19999,
|
||||
max_running: int = 3,
|
||||
query_every: int = 5,
|
||||
timeout: int = 3,
|
||||
containerize: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
if exists(SESSION_STATE_PATH):
|
||||
with open(SESSION_STATE_PATH, "r") as fp:
|
||||
session_state = json.load(fp)
|
||||
logger.warning(
|
||||
f"Local job manager is already running at port {session_state['port']}. "
|
||||
f"Run 'maro local job add/rm' to add / remove jobs."
|
||||
)
|
||||
return
|
||||
|
||||
start_redis(port)
|
||||
|
||||
# Start job manager
|
||||
command = ["python", join(dirname(abspath(__file__)), 'job_manager.py')]
|
||||
job_manager = subprocess.Popen(
|
||||
command,
|
||||
env={
|
||||
"PYTHONPATH": LOCAL_MARO_ROOT,
|
||||
"MAX_RUNNING": str(max_running),
|
||||
"QUERY_EVERY": str(query_every),
|
||||
"SIGTERM_TIMEOUT": str(timeout),
|
||||
"CONTAINERIZE": str(containerize),
|
||||
"REDIS_PORT": str(port),
|
||||
"LOCAL_MARO_ROOT": LOCAL_MARO_ROOT,
|
||||
"DOCKER_IMAGE_NAME": DOCKER_IMAGE_NAME,
|
||||
"DOCKERFILE_PATH": DOCKERFILE_PATH
|
||||
}
|
||||
)
|
||||
|
||||
# Dump environment setting
|
||||
makedirs(LOCAL_ROOT, exist_ok=True)
|
||||
with open(SESSION_STATE_PATH, "w") as fp:
|
||||
json.dump({"port": port, "job_manager_pid": job_manager.pid, "containerized": containerize}, fp)
|
||||
|
||||
# Create log folder
|
||||
logger.info("Local job manager started")
|
||||
|
||||
|
||||
def exit(**kwargs):
|
||||
try:
|
||||
with open(SESSION_STATE_PATH, "r") as fp:
|
||||
session_state = json.load(fp)
|
||||
except FileNotFoundError:
|
||||
logger.error(NO_JOB_MANAGER_MSG)
|
||||
return
|
||||
|
||||
redis_conn = get_redis_conn()
|
||||
|
||||
# Mark all jobs as REMOVED and let the job manager terminate them properly.
|
||||
job_details = redis_conn.hgetall(RedisHashKey.JOB_DETAILS)
|
||||
if job_details:
|
||||
for job_name, details in job_details.items():
|
||||
details = json.loads(details)
|
||||
details["status"] = JobStatus.REMOVED
|
||||
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
|
||||
logger.info(f"Gracefully terminating job {job_name.decode()}")
|
||||
|
||||
# Stop job manager
|
||||
close_by_pid(int(session_state["job_manager_pid"]))
|
||||
|
||||
# Stop Redis
|
||||
stop_redis(session_state["port"])
|
||||
|
||||
# Remove dump folder.
|
||||
shutil.rmtree(LOCAL_ROOT, True)
|
||||
|
||||
logger.info("Local job manager terminated.")
|
||||
|
||||
|
||||
def add_job(conf_path: str, **kwargs):
|
||||
redis_conn = get_redis_conn()
|
||||
if not redis_conn:
|
||||
return
|
||||
|
||||
# Load job configuration file
|
||||
with open(conf_path, "r") as fr:
|
||||
conf = yaml.safe_load(fr)
|
||||
|
||||
job_name = conf["job"]
|
||||
if redis_conn.hexists(RedisHashKey.JOB_DETAILS, job_name):
|
||||
logger.error(f"A job named '{job_name}' has already been added.")
|
||||
return
|
||||
|
||||
# Push job config to redis
|
||||
redis_conn.hset(RedisHashKey.JOB_CONF, job_name, json.dumps(conf))
|
||||
details = {
|
||||
"status": JobStatus.PENDING,
|
||||
"added": time.time()
|
||||
}
|
||||
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
|
||||
|
||||
|
||||
def remove_jobs(job_names, **kwargs):
|
||||
redis_conn = get_redis_conn()
|
||||
if not redis_conn:
|
||||
return
|
||||
|
||||
for job_name in job_names:
|
||||
details = redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name)
|
||||
if not details:
|
||||
logger.error(f"No job named '{job_name}' has been scheduled or started.")
|
||||
else:
|
||||
details = json.loads(details)
|
||||
details["status"] = JobStatus.REMOVED
|
||||
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
|
||||
logger.info(f"Removed job {job_name}")
|
||||
|
||||
|
||||
def describe_job(job_name, **kwargs):
|
||||
redis_conn = get_redis_conn()
|
||||
if not redis_conn:
|
||||
return
|
||||
|
||||
details = redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name)
|
||||
if not details:
|
||||
logger.error(NO_JOB_MSG.format(job_name))
|
||||
return
|
||||
|
||||
details = json.loads(details)
|
||||
err = "error_message" in details
|
||||
if err:
|
||||
err_msg = details["error_message"].split('\n')
|
||||
del details["error_message"]
|
||||
|
||||
logger.info(details)
|
||||
if err:
|
||||
for line in err_msg:
|
||||
logger.info(line)
|
||||
|
||||
|
||||
def get_job_logs(job_name: str, tail: int = -1, **kwargs):
|
||||
redis_conn = get_redis_conn()
|
||||
if not redis_conn.hexists(RedisHashKey.JOB_CONF, job_name):
|
||||
logger.error(NO_JOB_MSG.format(job_name))
|
||||
return
|
||||
|
||||
conf = json.loads(redis_conn.hget(RedisHashKey.JOB_CONF, job_name))
|
||||
show_log(conf["log_path"], tail=tail)
|
||||
|
||||
|
||||
def list_jobs(**kwargs):
|
||||
redis_conn = get_redis_conn()
|
||||
if not redis_conn:
|
||||
return
|
||||
|
||||
def get_time_diff_string(time_diff):
|
||||
time_diff = int(time_diff)
|
||||
days = time_diff // (3600 * 24)
|
||||
if days:
|
||||
return f"{days} days"
|
||||
|
||||
hours = time_diff // 3600
|
||||
if hours:
|
||||
return f"{hours} hours"
|
||||
|
||||
minutes = time_diff // 60
|
||||
if minutes:
|
||||
return f"{minutes} minutes"
|
||||
|
||||
return f"{time_diff} seconds"
|
||||
|
||||
# Header
|
||||
logger.info(JOB_LS_TEMPLATE.format(JOB="JOB", STATUS="STATUS", STARTED="STARTED"))
|
||||
for job_name, details in redis_conn.hgetall(RedisHashKey.JOB_DETAILS).items():
|
||||
job_name = job_name.decode()
|
||||
details = json.loads(details)
|
||||
if "start_time" in details:
|
||||
time_diff = f"{get_time_diff_string(time.time() - details['start_time'])} ago"
|
||||
logger.info(JOB_LS_TEMPLATE.format(JOB=job_name, STATUS=details["status"], STARTED=time_diff))
|
||||
else:
|
||||
logger.info(JOB_LS_TEMPLATE.format(JOB=job_name, STATUS=details["status"], STARTED=JobStatus.PENDING))
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
from maro.cli.local.utils import JobStatus, RedisHashKey, poll, start_rl_job, start_rl_job_in_containers, term
|
||||
from maro.cli.utils.docker import build_image, image_exists
|
||||
from maro.rl.workflows.config import ConfigParser
|
||||
|
||||
if __name__ == "__main__":
|
||||
redis_port = int(os.getenv("REDIS_PORT", default=19999))
|
||||
redis_conn = redis.Redis(host="localhost", port=redis_port)
|
||||
started, max_running = {}, int(os.getenv("MAX_RUNNING", default=1))
|
||||
query_every = int(os.getenv("QUERY_EVERY", default=5))
|
||||
sigterm_timeout = int(os.getenv("SIGTERM_TIMEOUT", default=3))
|
||||
containerize = os.getenv("CONTAINERIZE", default="False") == "True"
|
||||
local_maro_root = os.getenv("LOCAL_MARO_ROOT")
|
||||
docker_file_path = os.getenv("DOCKERFILE_PATH")
|
||||
docker_image_name = os.getenv("DOCKER_IMAGE_NAME")
|
||||
|
||||
# thread to monitor a job
|
||||
def monitor(job_name):
|
||||
removed, error, err_out, running = False, False, None, started[job_name]
|
||||
while running:
|
||||
error, err_out, running = poll(running)
|
||||
# check if the job has been marked as REMOVED before termination
|
||||
details = json.loads(redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name))
|
||||
if details["status"] == JobStatus.REMOVED:
|
||||
removed = True
|
||||
break
|
||||
|
||||
if error:
|
||||
break
|
||||
|
||||
if removed:
|
||||
term(started[job_name], job_name, timeout=sigterm_timeout)
|
||||
redis_conn.hdel(RedisHashKey.JOB_DETAILS, job_name)
|
||||
redis_conn.hdel(RedisHashKey.JOB_CONF, job_name)
|
||||
return
|
||||
|
||||
if error:
|
||||
term(started[job_name], job_name, timeout=sigterm_timeout)
|
||||
details["status"] = JobStatus.ERROR
|
||||
details["error_message"] = err_out
|
||||
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
|
||||
else: # all job processes terminated normally
|
||||
details["status"] = JobStatus.FINISHED
|
||||
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
|
||||
|
||||
# Continue to monitor if the job is marked as REMOVED
|
||||
while json.loads(redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name))["status"] != JobStatus.REMOVED:
|
||||
time.sleep(query_every)
|
||||
|
||||
term(started[job_name], job_name, timeout=sigterm_timeout)
|
||||
redis_conn.hdel(RedisHashKey.JOB_DETAILS, job_name)
|
||||
redis_conn.hdel(RedisHashKey.JOB_CONF, job_name)
|
||||
|
||||
while True:
|
||||
# check for pending jobs
|
||||
job_details = redis_conn.hgetall(RedisHashKey.JOB_DETAILS)
|
||||
if job_details:
|
||||
num_running, pending = 0, []
|
||||
for job_name, details in job_details.items():
|
||||
job_name, details = job_name.decode(), json.loads(details)
|
||||
if details["status"] == JobStatus.RUNNING:
|
||||
num_running += 1
|
||||
elif details["status"] == JobStatus.PENDING:
|
||||
pending.append((job_name, json.loads(redis_conn.hget(RedisHashKey.JOB_CONF, job_name))))
|
||||
|
||||
for job_name, conf in pending[:max(0, max_running - num_running)]:
|
||||
if containerize and not image_exists(docker_image_name):
|
||||
redis_conn.hset(
|
||||
RedisHashKey.JOB_DETAILS, job_name, json.dumps({"status": JobStatus.IMAGE_BUILDING})
|
||||
)
|
||||
build_image(local_maro_root, docker_file_path, docker_image_name)
|
||||
|
||||
parser = ConfigParser(conf)
|
||||
if containerize:
|
||||
path_mapping = parser.get_path_mapping(containerize=True)
|
||||
started[job_name] = start_rl_job_in_containers(parser, docker_image_name)
|
||||
details["containers"] = started[job_name]
|
||||
else:
|
||||
started[job_name] = start_rl_job(parser, local_maro_root, background=True)
|
||||
details["pids"] = [proc.pid for proc in started[job_name]]
|
||||
details = {"status": JobStatus.RUNNING, "start_time": time.time()}
|
||||
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
|
||||
threading.Thread(target=monitor, args=(job_name,)).start() # start job monitoring thread
|
||||
|
||||
time.sleep(query_every)
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
|
||||
import docker
|
||||
import yaml
|
||||
|
||||
from maro.cli.utils.common import format_env_vars
|
||||
from maro.rl.workflows.config.parser import ConfigParser
|
||||
|
||||
|
||||
class RedisHashKey:
|
||||
"""Record Redis elements name, and only for maro process"""
|
||||
JOB_CONF = "job_conf"
|
||||
JOB_DETAILS = "job_details"
|
||||
|
||||
|
||||
class JobStatus:
|
||||
PENDING = "pending"
|
||||
IMAGE_BUILDING = "image_building"
|
||||
RUNNING = "running"
|
||||
ERROR = "error"
|
||||
REMOVED = "removed"
|
||||
FINISHED = "finished"
|
||||
|
||||
|
||||
def start_redis(port: int):
|
||||
subprocess.Popen(["redis-server", "--port", str(port)], stdout=subprocess.DEVNULL)
|
||||
|
||||
|
||||
def stop_redis(port: int):
|
||||
subprocess.Popen(["redis-cli", "-p", str(port), "shutdown"], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
|
||||
|
||||
def extract_error_msg_from_docker_log(container: docker.models.containers.Container):
|
||||
logs = container.logs().decode().splitlines()
|
||||
for i, log in enumerate(logs):
|
||||
if "Traceback (most recent call last):" in log:
|
||||
return "\n".join(logs[i:])
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def check_proc_status(proc):
|
||||
if isinstance(proc, subprocess.Popen):
|
||||
if proc.poll() is None:
|
||||
return True, 0, None
|
||||
_, err_out = proc.communicate()
|
||||
return False, proc.returncode, err_out
|
||||
else:
|
||||
client = docker.from_env()
|
||||
container_state = client.api.inspect_container(proc.id)["State"]
|
||||
return container_state["Running"], container_state["ExitCode"], extract_error_msg_from_docker_log(proc)
|
||||
|
||||
|
||||
def poll(procs):
|
||||
error, running = False, []
|
||||
for proc in procs:
|
||||
is_running, exit_code, err_out = check_proc_status(proc)
|
||||
if is_running:
|
||||
running.append(proc)
|
||||
elif exit_code:
|
||||
error = True
|
||||
break
|
||||
|
||||
return error, err_out, running
|
||||
|
||||
|
||||
def term(procs, job_name: str, timeout: int = 3):
|
||||
if isinstance(procs[0], subprocess.Popen):
|
||||
for proc in procs:
|
||||
if proc.poll() is None:
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
else:
|
||||
for proc in procs:
|
||||
try:
|
||||
proc.stop(timeout=timeout)
|
||||
proc.remove()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
client = docker.from_env()
|
||||
try:
|
||||
job_network = client.networks.get(job_name)
|
||||
job_network.remove()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def exec(cmd: str, env: dict, debug: bool = False) -> subprocess.Popen:
|
||||
stream = None if debug else subprocess.PIPE
|
||||
return subprocess.Popen(
|
||||
cmd.split(), env={**os.environ.copy(), **env}, stdout=stream, stderr=stream, encoding="utf8"
|
||||
)
|
||||
|
||||
|
||||
def start_rl_job(
|
||||
parser: ConfigParser, maro_root: str, evaluate_only: bool, background: bool = False,
|
||||
) -> List[subprocess.Popen]:
|
||||
procs = [
|
||||
exec(
|
||||
f"python {script}" + ("" if not evaluate_only else " --evaluate_only"),
|
||||
format_env_vars({**env, "PYTHONPATH": maro_root}, mode="proc"),
|
||||
debug=not background
|
||||
)
|
||||
for script, env in parser.get_job_spec().values()
|
||||
]
|
||||
if not background:
|
||||
for proc in procs:
|
||||
proc.communicate()
|
||||
|
||||
return procs
|
||||
|
||||
|
||||
def start_rl_job_in_containers(parser: ConfigParser, image_name: str) -> list:
|
||||
job_name = parser.config["job"]
|
||||
client, containers = docker.from_env(), []
|
||||
training_mode = parser.config["training"]["mode"]
|
||||
if "parallelism" in parser.config["rollout"]:
|
||||
rollout_parallelism = max(
|
||||
parser.config["rollout"]["parallelism"]["sampling"],
|
||||
parser.config["rollout"]["parallelism"].get("eval", 1)
|
||||
)
|
||||
else:
|
||||
rollout_parallelism = 1
|
||||
if training_mode != "simple" or rollout_parallelism > 1:
|
||||
# create the exclusive network for the job
|
||||
client.networks.create(job_name, driver="bridge")
|
||||
|
||||
for component, (script, env) in parser.get_job_spec(containerize=True).items():
|
||||
# volume mounts for scenario folder, policy loading, checkpointing and logging
|
||||
container = client.containers.run(
|
||||
image_name,
|
||||
command=f"python3 {script}",
|
||||
detach=True,
|
||||
name=component,
|
||||
environment=env,
|
||||
volumes=[f"{src}:{dst}" for src, dst in parser.get_path_mapping(containerize=True).items()],
|
||||
network=job_name
|
||||
)
|
||||
|
||||
containers.append(container)
|
||||
|
||||
return containers
|
||||
|
||||
|
||||
def get_docker_compose_yml_path(maro_root: str) -> str:
|
||||
return os.path.join(maro_root, ".tmp", "docker-compose.yml")
|
||||
|
||||
|
||||
def start_rl_job_with_docker_compose(
|
||||
parser: ConfigParser, context: str, dockerfile_path: str, image_name: str, evaluate_only: bool,
|
||||
) -> None:
|
||||
common_spec = {
|
||||
"build": {"context": context, "dockerfile": dockerfile_path},
|
||||
"image": image_name,
|
||||
"volumes": [f"./{src}:{dst}" for src, dst in parser.get_path_mapping(containerize=True).items()]
|
||||
}
|
||||
|
||||
job_name = parser.config["job"]
|
||||
manifest = {
|
||||
"version": "3.9",
|
||||
"services": {
|
||||
component: {
|
||||
**deepcopy(common_spec),
|
||||
**{
|
||||
"container_name": component,
|
||||
"command": f"python3 {script}" + ("" if not evaluate_only else " --evaluate_only"),
|
||||
"environment": format_env_vars(env, mode="docker-compose")
|
||||
}
|
||||
}
|
||||
for component, (script, env) in parser.get_job_spec(containerize=True).items()
|
||||
},
|
||||
}
|
||||
|
||||
docker_compose_file_path = get_docker_compose_yml_path(maro_root=context)
|
||||
with open(docker_compose_file_path, "w") as fp:
|
||||
yaml.safe_dump(manifest, fp)
|
||||
|
||||
subprocess.run(
|
||||
["docker-compose", "--project-name", job_name, "-f", docker_compose_file_path, "up", "--remove-orphans"]
|
||||
)
|
||||
|
||||
|
||||
def stop_rl_job_with_docker_compose(job_name: str, context: str):
|
||||
subprocess.run(["docker-compose", "--project-name", job_name, "down"])
|
||||
os.remove(get_docker_compose_yml_path(maro_root=context))
|
296
maro/cli/maro.py
|
@ -90,6 +90,15 @@ def main():
|
|||
parser_k8s.set_defaults(func=_help_func(parser=parser_k8s))
|
||||
load_parser_k8s(prev_parser=parser_k8s, global_parser=global_parser)
|
||||
|
||||
# maro aks
|
||||
parser_aks = subparsers.add_parser(
|
||||
"aks",
|
||||
help="Manage distributed cluster with Kubernetes.",
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_aks.set_defaults(func=_help_func(parser=parser_aks))
|
||||
load_parser_aks(prev_parser=parser_aks, global_parser=global_parser)
|
||||
|
||||
# maro inspector
|
||||
parser_inspector = subparsers.add_parser(
|
||||
'inspector',
|
||||
|
@ -99,13 +108,13 @@ def main():
|
|||
parser_inspector.set_defaults(func=_help_func(parser=parser_inspector))
|
||||
load_parser_inspector(parser_inspector, global_parser)
|
||||
|
||||
# maro process
|
||||
parser_process = subparsers.add_parser(
|
||||
"process",
|
||||
help="Run application by mulit-process to simulate distributed mode."
|
||||
# maro local
|
||||
parser_local = subparsers.add_parser(
|
||||
"local",
|
||||
help="Run jobs locally."
|
||||
)
|
||||
parser_process.set_defaults(func=_help_func(parser=parser_process))
|
||||
load_parser_process(prev_parser=parser_process, global_parser=global_parser)
|
||||
parser_local.set_defaults(func=_help_func(parser=parser_local))
|
||||
load_parser_local(prev_parser=parser_local, global_parser=global_parser)
|
||||
|
||||
# maro project
|
||||
parser_project = subparsers.add_parser(
|
||||
|
@ -151,152 +160,128 @@ def main():
|
|||
logger.error_red(f"{e.__class__.__name__}: {e.get_message()}")
|
||||
|
||||
|
||||
def load_parser_process(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None:
|
||||
def load_parser_local(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None:
|
||||
subparsers = prev_parser.add_subparsers()
|
||||
|
||||
# maro process create
|
||||
from maro.cli.process.create import create
|
||||
parser_setup = subparsers.add_parser(
|
||||
"create",
|
||||
help="Create local process environment.",
|
||||
# maro local run
|
||||
from maro.cli.local.commands import run
|
||||
parser = subparsers.add_parser(
|
||||
"run",
|
||||
help="Run a job in debug mode.",
|
||||
examples=CliExamples.MARO_PROCESS_SETUP,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_setup.add_argument(
|
||||
'deployment_path',
|
||||
help='Path of the local process setting deployment.',
|
||||
nargs='?',
|
||||
default=None)
|
||||
parser_setup.set_defaults(func=create)
|
||||
parser.add_argument("conf_path", help='Path of the job deployment')
|
||||
parser.add_argument("-c", "--containerize", action="store_true", help="Whether to run jobs in containers")
|
||||
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
|
||||
parser.add_argument("-p", "--port", type=int, default=20000, help="")
|
||||
parser.set_defaults(func=run)
|
||||
|
||||
# maro process delete
|
||||
from maro.cli.process.delete import delete
|
||||
parser_setup = subparsers.add_parser(
|
||||
"delete",
|
||||
help="Delete the local process environment. Including closing agents and maro Redis.",
|
||||
# maro local init
|
||||
from maro.cli.local.commands import init
|
||||
parser = subparsers.add_parser(
|
||||
"init",
|
||||
help="Initialize local job manager.",
|
||||
examples=CliExamples.MARO_PROCESS_SETUP,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_setup.set_defaults(func=delete)
|
||||
parser.add_argument(
|
||||
"-p", "--port", type=int, default=19999,
|
||||
help="Port on local machine to launch the Redis server at. Defaults to 19999."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--max-running", type=int, default=3,
|
||||
help="Maximum number of jobs to allow running at the same time. Defaults to 3."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q", "--query-every", type=int, default=5,
|
||||
help="Number of seconds to wait between queries to the Redis server for pending or removed jobs. Defaults to 5."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--timeout", type=int, default=3,
|
||||
help="""
|
||||
Number of seconds to wait after sending SIGTERM to a process. If the process does not terminate
|
||||
during this time, the process will be force-killed through SIGKILL. Defaults to 3.
|
||||
"""
|
||||
)
|
||||
parser.add_argument("-c", "--containerize", action="store_true", help="Whether to run jobs in containers")
|
||||
parser.set_defaults(func=init)
|
||||
|
||||
# maro process job
|
||||
parser_job = subparsers.add_parser(
|
||||
# maro local exit
|
||||
from maro.cli.local.commands import exit
|
||||
parser = subparsers.add_parser(
|
||||
"exit",
|
||||
help="Terminate the local job manager",
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser.set_defaults(func=exit)
|
||||
|
||||
# maro local job
|
||||
parser = subparsers.add_parser(
|
||||
"job",
|
||||
help="Manage jobs",
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job.set_defaults(func=_help_func(parser=parser_job))
|
||||
parser_job_subparsers = parser_job.add_subparsers()
|
||||
parser.set_defaults(func=_help_func(parser=parser))
|
||||
job_subparsers = parser.add_subparsers()
|
||||
|
||||
# maro process job start
|
||||
from maro.cli.process.job import start_job
|
||||
parser_job_start = parser_job_subparsers.add_parser(
|
||||
'start',
|
||||
help='Start a training job',
|
||||
# maro local job add
|
||||
from maro.cli.local.commands import add_job
|
||||
job_add_parser = job_subparsers.add_parser(
|
||||
"add",
|
||||
help="Start an RL job",
|
||||
examples=CliExamples.MARO_PROCESS_JOB_START,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job_start.add_argument(
|
||||
'deployment_path', help='Path of the job deployment')
|
||||
parser_job_start.set_defaults(func=start_job)
|
||||
job_add_parser.add_argument("conf_path", help='Path of the job deployment')
|
||||
job_add_parser.set_defaults(func=add_job)
|
||||
|
||||
# maro process job stop
|
||||
from maro.cli.process.job import stop_job
|
||||
parser_job_stop = parser_job_subparsers.add_parser(
|
||||
'stop',
|
||||
help='Stop a training job',
|
||||
# maro local job rm
|
||||
from maro.cli.local.commands import remove_jobs
|
||||
job_stop_parser = job_subparsers.add_parser(
|
||||
"rm",
|
||||
help='Stop an RL job',
|
||||
examples=CliExamples.MARO_PROCESS_JOB_STOP,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job_stop.add_argument(
|
||||
'job_name', help='Name of the job')
|
||||
parser_job_stop.set_defaults(func=stop_job)
|
||||
job_stop_parser.add_argument('job_names', help="Job names", nargs="*")
|
||||
job_stop_parser.set_defaults(func=remove_jobs)
|
||||
|
||||
# maro process job delete
|
||||
from maro.cli.process.job import delete_job
|
||||
parser_job_delete = parser_job_subparsers.add_parser(
|
||||
'delete',
|
||||
help='delete a stopped job',
|
||||
examples=CliExamples.MARO_PROCESS_JOB_DELETE,
|
||||
# maro local job describe
|
||||
from maro.cli.local.commands import describe_job
|
||||
job_stop_parser = job_subparsers.add_parser(
|
||||
"describe",
|
||||
help="Get the status of an RL job and the error information if the job fails due to some error",
|
||||
examples=CliExamples.MARO_PROCESS_JOB_STOP,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job_delete.add_argument(
|
||||
'job_name', help='Name of the job or the schedule')
|
||||
parser_job_delete.set_defaults(func=delete_job)
|
||||
job_stop_parser.add_argument('job_name', help='Job name')
|
||||
job_stop_parser.set_defaults(func=describe_job)
|
||||
|
||||
# maro process job list
|
||||
from maro.cli.process.job import list_jobs
|
||||
parser_job_list = parser_job_subparsers.add_parser(
|
||||
'list',
|
||||
# maro local job ls
|
||||
from maro.cli.local.commands import list_jobs
|
||||
job_list_parser = job_subparsers.add_parser(
|
||||
"ls",
|
||||
help='List all jobs',
|
||||
examples=CliExamples.MARO_PROCESS_JOB_LIST,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job_list.set_defaults(func=list_jobs)
|
||||
job_list_parser.set_defaults(func=list_jobs)
|
||||
|
||||
# maro process job logs
|
||||
from maro.cli.process.job import get_job_logs
|
||||
parser_job_logs = parser_job_subparsers.add_parser(
|
||||
'logs',
|
||||
help='Get logs of the job',
|
||||
# maro local job logs
|
||||
from maro.cli.local.commands import get_job_logs
|
||||
job_logs_parser = job_subparsers.add_parser(
|
||||
"logs",
|
||||
help="Get job logs",
|
||||
examples=CliExamples.MARO_PROCESS_JOB_LOGS,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job_logs.add_argument(
|
||||
'job_name', help='Name of the job')
|
||||
parser_job_logs.set_defaults(func=get_job_logs)
|
||||
|
||||
# maro process schedule
|
||||
parser_schedule = subparsers.add_parser(
|
||||
'schedule',
|
||||
help='Manage schedules',
|
||||
parents=[global_parser]
|
||||
job_logs_parser.add_argument("job_name", help="job name")
|
||||
job_logs_parser.add_argument(
|
||||
"-n", "--tail", type=int, default=-1,
|
||||
help="Number of lines to show from the end of the given job's logs"
|
||||
)
|
||||
parser_schedule.set_defaults(func=_help_func(parser=parser_schedule))
|
||||
parser_schedule_subparsers = parser_schedule.add_subparsers()
|
||||
|
||||
# maro process schedule start
|
||||
from maro.cli.process.schedule import start_schedule
|
||||
parser_schedule_start = parser_schedule_subparsers.add_parser(
|
||||
'start',
|
||||
help='Start a schedule',
|
||||
examples=CliExamples.MARO_PROCESS_SCHEDULE_START,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_schedule_start.add_argument(
|
||||
'deployment_path', help='Path of the schedule deployment')
|
||||
parser_schedule_start.set_defaults(func=start_schedule)
|
||||
|
||||
# maro process schedule stop
|
||||
from maro.cli.process.schedule import stop_schedule
|
||||
parser_schedule_stop = parser_schedule_subparsers.add_parser(
|
||||
'stop',
|
||||
help='Stop a schedule',
|
||||
examples=CliExamples.MARO_PROCESS_SCHEDULE_STOP,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_schedule_stop.add_argument(
|
||||
'schedule_name', help='Name of the schedule')
|
||||
parser_schedule_stop.set_defaults(func=stop_schedule)
|
||||
|
||||
# maro process template
|
||||
from maro.cli.process.template import template
|
||||
parser_template = subparsers.add_parser(
|
||||
"template",
|
||||
help="Get deployment templates",
|
||||
examples=CliExamples.MARO_PROCESS_TEMPLATE,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_template.add_argument(
|
||||
"--setting_deploy",
|
||||
action="store_true",
|
||||
help="Get environment setting templates"
|
||||
)
|
||||
parser_template.add_argument(
|
||||
"export_path",
|
||||
default="./",
|
||||
nargs='?',
|
||||
help="Path of the export directory")
|
||||
parser_template.set_defaults(func=template)
|
||||
job_logs_parser.set_defaults(func=get_job_logs)
|
||||
|
||||
|
||||
def load_parser_grass(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None:
|
||||
|
@ -922,6 +907,81 @@ def load_parser_k8s(prev_parser: ArgumentParser, global_parser: ArgumentParser)
|
|||
parser_template.set_defaults(func=template)
|
||||
|
||||
|
||||
def load_parser_aks(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None:
|
||||
subparsers = prev_parser.add_subparsers()
|
||||
|
||||
# maro aks create
|
||||
from maro.cli.k8s.aks.aks_commands import init
|
||||
parser_create = subparsers.add_parser(
|
||||
"init",
|
||||
help="""
|
||||
Deploy resources and start required services on Azure. The configuration file template can be found
|
||||
in cli/k8s/aks/conf.yml. Use the Azure CLI to log into your Azure account (az login ...) and the the
|
||||
Azure Container Registry (az acr login ...) first.
|
||||
""",
|
||||
examples=CliExamples.MARO_K8S_CREATE,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_create.add_argument("deployment_conf_path", help="Path of the deployment configuration file")
|
||||
parser_create.set_defaults(func=init)
|
||||
|
||||
# maro aks exit
|
||||
from maro.cli.k8s.aks.aks_commands import exit
|
||||
parser_create = subparsers.add_parser(
|
||||
"exit",
|
||||
help="Delete deployed resources",
|
||||
examples=CliExamples.MARO_K8S_DELETE,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_create.set_defaults(func=exit)
|
||||
|
||||
# maro aks job
|
||||
parser_job = subparsers.add_parser(
|
||||
"job",
|
||||
help="AKS job-related commands",
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job.set_defaults(func=_help_func(parser=parser_job))
|
||||
job_subparsers = parser_job.add_subparsers()
|
||||
|
||||
# maro aks job add
|
||||
from maro.cli.k8s.aks.aks_commands import add_job
|
||||
parser_job_start = job_subparsers.add_parser(
|
||||
"add",
|
||||
help="Add an RL job to the AKS cluster",
|
||||
examples=CliExamples.MARO_K8S_JOB_START,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job_start.add_argument("conf_path", help="Path to the job configuration file")
|
||||
parser_job_start.set_defaults(func=add_job)
|
||||
|
||||
# maro aks job rm
|
||||
from maro.cli.k8s.aks.aks_commands import remove_jobs
|
||||
parser_job_start = job_subparsers.add_parser(
|
||||
"rm",
|
||||
help="Remove previously scheduled RL jobs from the AKS cluster",
|
||||
examples=CliExamples.MARO_K8S_JOB_START,
|
||||
parents=[global_parser]
|
||||
)
|
||||
parser_job_start.add_argument("job_names", help="Name of job to be removed", nargs="*")
|
||||
parser_job_start.set_defaults(func=remove_jobs)
|
||||
|
||||
# maro aks job logs
|
||||
from maro.cli.k8s.aks.aks_commands import get_job_logs
|
||||
job_logs_parser = job_subparsers.add_parser(
|
||||
"logs",
|
||||
help="Get job logs",
|
||||
examples=CliExamples.MARO_PROCESS_JOB_LOGS,
|
||||
parents=[global_parser]
|
||||
)
|
||||
job_logs_parser.add_argument("job_name", help="job name")
|
||||
job_logs_parser.add_argument(
|
||||
"-n", "--tail", type=int, default=-1,
|
||||
help="Number of lines to show from the end of the given job's logs"
|
||||
)
|
||||
job_logs_parser.set_defaults(func=get_job_logs)
|
||||
|
||||
|
||||
def load_parser_data(prev_parser: ArgumentParser, global_parser: ArgumentParser):
|
||||
data_cmd_sub_parsers = prev_parser.add_subparsers()
|
||||
|
||||
|
|
|
@ -1,206 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import redis
|
||||
|
||||
from maro.cli.grass.lib.services.utils.params import JobStatus
|
||||
from maro.cli.process.utils.details import close_by_pid, get_child_pid
|
||||
from maro.cli.utils.details_reader import DetailsReader
|
||||
from maro.cli.utils.params import LocalPaths, ProcessRedisName
|
||||
|
||||
|
||||
class PendingJobAgent(mp.Process):
|
||||
def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60):
|
||||
super().__init__()
|
||||
self.cluster_detail = cluster_detail
|
||||
self.redis_connection = redis_connection
|
||||
self.check_interval = check_interval
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
self._check_pending_ticket()
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def _check_pending_ticket(self):
|
||||
# Check pending job ticket
|
||||
pending_jobs = self.redis_connection.lrange(ProcessRedisName.PENDING_JOB_TICKETS, 0, -1)
|
||||
running_jobs_length = len(JobTrackingAgent.get_running_jobs(
|
||||
self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)
|
||||
))
|
||||
parallel_level = self.cluster_detail["parallel_level"]
|
||||
|
||||
for job_name in pending_jobs:
|
||||
job_detail = json.loads(self.redis_connection.hget(ProcessRedisName.JOB_DETAILS, job_name))
|
||||
# Start pending job only if current running job's number less than parallel level.
|
||||
if int(parallel_level) > running_jobs_length:
|
||||
self._start_job(job_detail)
|
||||
self.redis_connection.lrem(ProcessRedisName.PENDING_JOB_TICKETS, 0, job_name)
|
||||
running_jobs_length += 1
|
||||
|
||||
def _start_job(self, job_details: dict):
|
||||
command_pid_list = []
|
||||
for component_type, command_info in job_details["components"].items():
|
||||
component_number = command_info["num"]
|
||||
component_command = f"JOB_NAME={job_details['name']} " + command_info["command"]
|
||||
for number in range(component_number):
|
||||
job_local_path = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_details['name']}")
|
||||
if not os.path.exists(job_local_path):
|
||||
os.makedirs(job_local_path)
|
||||
|
||||
with open(f"{job_local_path}/{component_type}_{number}.log", "w") as log_file:
|
||||
proc = subprocess.Popen(component_command, shell=True, stdout=log_file)
|
||||
command_pid = get_child_pid(proc.pid)
|
||||
if not command_pid:
|
||||
command_pid_list.append(proc.pid)
|
||||
else:
|
||||
command_pid_list.append(command_pid)
|
||||
|
||||
job_details["status"] = JobStatus.RUNNING
|
||||
job_details["pid_list"] = command_pid_list
|
||||
self.redis_connection.hset(ProcessRedisName.JOB_DETAILS, job_details["name"], json.dumps(job_details))
|
||||
|
||||
|
||||
class JobTrackingAgent(mp.Process):
|
||||
def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60):
|
||||
super().__init__()
|
||||
self.cluster_detail = cluster_detail
|
||||
self.redis_connection = redis_connection
|
||||
self.check_interval = check_interval
|
||||
self._shutdown_count = 0
|
||||
self._countdown = cluster_detail["agent_countdown"]
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
self._check_job_status()
|
||||
time.sleep(self.check_interval)
|
||||
keep_alive = self.cluster_detail["keep_agent_alive"]
|
||||
if not keep_alive:
|
||||
self._close_agents()
|
||||
|
||||
def _check_job_status(self):
|
||||
running_jobs = self.get_running_jobs(self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS))
|
||||
|
||||
for running_job_name, running_job_detail in running_jobs.items():
|
||||
# Check pid status
|
||||
still_alive = False
|
||||
for pid in running_job_detail["pid_list"]:
|
||||
if psutil.pid_exists(pid):
|
||||
still_alive = True
|
||||
|
||||
# Update if no pid exists
|
||||
if not still_alive:
|
||||
running_job_detail["status"] = JobStatus.FINISH
|
||||
del running_job_detail["pid_list"]
|
||||
self.redis_connection.hset(
|
||||
ProcessRedisName.JOB_DETAILS,
|
||||
running_job_name,
|
||||
json.dumps(running_job_detail)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_running_jobs(job_details: dict):
|
||||
running_jobs = {}
|
||||
|
||||
for job_name, job_detail in job_details.items():
|
||||
job_detail = json.loads(job_detail)
|
||||
if job_detail["status"] == JobStatus.RUNNING:
|
||||
running_jobs[job_name.decode()] = job_detail
|
||||
|
||||
return running_jobs
|
||||
|
||||
def _close_agents(self):
|
||||
if (
|
||||
not len(
|
||||
JobTrackingAgent.get_running_jobs(self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS))
|
||||
) and
|
||||
not self.redis_connection.llen(ProcessRedisName.PENDING_JOB_TICKETS)
|
||||
):
|
||||
self._shutdown_count += 1
|
||||
else:
|
||||
self._shutdown_count = 0
|
||||
|
||||
if self._shutdown_count >= self._countdown:
|
||||
agent_pid = int(self.redis_connection.hget(ProcessRedisName.SETTING, "agent_pid"))
|
||||
|
||||
# close agent
|
||||
close_by_pid(pid=agent_pid, recursive=True)
|
||||
|
||||
# Set agent status to 0
|
||||
self.redis_connection.hset(ProcessRedisName.SETTING, "agent_status", 0)
|
||||
|
||||
|
||||
class KilledJobAgent(mp.Process):
|
||||
def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60):
|
||||
super().__init__()
|
||||
self.cluster_detail = cluster_detail
|
||||
self.redis_connection = redis_connection
|
||||
self.check_interval = check_interval
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
self._check_killed_tickets()
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def _check_killed_tickets(self):
|
||||
# Check pending job ticket
|
||||
killed_job_names = self.redis_connection.lrange(ProcessRedisName.KILLED_JOB_TICKETS, 0, -1)
|
||||
|
||||
for job_name in killed_job_names:
|
||||
job_detail = json.loads(self.redis_connection.hget(ProcessRedisName.JOB_DETAILS, job_name))
|
||||
if job_detail["status"] == JobStatus.RUNNING:
|
||||
close_by_pid(pid=job_detail["pid_list"], recursive=False)
|
||||
del job_detail["pid_list"]
|
||||
elif job_detail["status"] == JobStatus.PENDING:
|
||||
self.redis_connection.lrem(ProcessRedisName.PENDING_JOB_TICKETS, 0, job_name)
|
||||
elif job_detail["status"] == JobStatus.FINISH:
|
||||
continue
|
||||
|
||||
job_detail["status"] = JobStatus.KILLED
|
||||
self.redis_connection.hset(ProcessRedisName.JOB_DETAILS, job_name, json.dumps(job_detail))
|
||||
self.redis_connection.lrem(ProcessRedisName.KILLED_JOB_TICKETS, 0, job_name)
|
||||
|
||||
|
||||
class MasterAgent:
|
||||
def __init__(self):
|
||||
self.cluster_detail = DetailsReader.load_cluster_details("process")
|
||||
self.check_interval = self.cluster_detail["check_interval"]
|
||||
self.redis_connection = redis.Redis(
|
||||
host=self.cluster_detail["redis_info"]["host"],
|
||||
port=self.cluster_detail["redis_info"]["port"]
|
||||
)
|
||||
self.redis_connection.hset(ProcessRedisName.SETTING, "agent_pid", os.getpid())
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start agents."""
|
||||
pending_job_agent = PendingJobAgent(
|
||||
cluster_detail=self.cluster_detail,
|
||||
redis_connection=self.redis_connection,
|
||||
check_interval=self.check_interval
|
||||
)
|
||||
pending_job_agent.start()
|
||||
|
||||
killed_job_agent = KilledJobAgent(
|
||||
cluster_detail=self.cluster_detail,
|
||||
redis_connection=self.redis_connection,
|
||||
check_interval=self.check_interval
|
||||
)
|
||||
killed_job_agent.start()
|
||||
|
||||
job_tracking_agent = JobTrackingAgent(
|
||||
cluster_detail=self.cluster_detail,
|
||||
redis_connection=self.redis_connection,
|
||||
check_interval=self.check_interval
|
||||
)
|
||||
job_tracking_agent.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
master_agent = MasterAgent()
|
||||
master_agent.start()
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
from maro.cli.utils.params import LocalParams
|
||||
from maro.cli.utils.resource_executor import ResourceInfo
|
||||
from maro.utils.exception.cli_exception import BadRequestError
|
||||
|
||||
|
||||
class ResourceTrackingAgent(mp.Process):
|
||||
def __init__(
|
||||
self,
|
||||
check_interval: int = 30
|
||||
):
|
||||
super().__init__()
|
||||
self._redis_connection = redis.Redis(host="localhost", port=LocalParams.RESOURCE_REDIS_PORT)
|
||||
try:
|
||||
if self._redis_connection.hexists(LocalParams.RESOURCE_INFO, "check_interval"):
|
||||
self._check_interval = int(self._redis_connection.hget(LocalParams.RESOURCE_INFO, "check_interval"))
|
||||
else:
|
||||
self._check_interval = check_interval
|
||||
except Exception:
|
||||
raise BadRequestError(
|
||||
"Failure to connect to Resource Redis."
|
||||
"Please make sure at least one cluster running."
|
||||
)
|
||||
|
||||
self._set_resource_info()
|
||||
|
||||
def _set_resource_info(self):
|
||||
# Set resource agent pid.
|
||||
self._redis_connection.hset(
|
||||
LocalParams.RESOURCE_INFO,
|
||||
"agent_pid",
|
||||
os.getpid()
|
||||
)
|
||||
|
||||
# Set resource agent check interval.
|
||||
self._redis_connection.hset(
|
||||
LocalParams.RESOURCE_INFO,
|
||||
"check_interval",
|
||||
json.dumps(self._check_interval)
|
||||
)
|
||||
|
||||
# Push static resource information into Redis.
|
||||
resource = ResourceInfo.get_static_info()
|
||||
self._redis_connection.hset(
|
||||
LocalParams.RESOURCE_INFO,
|
||||
"resource",
|
||||
json.dumps(resource)
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Start tracking node status and updating details.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
while True:
|
||||
start_time = time.time()
|
||||
self.push_local_resource_usage()
|
||||
time.sleep(max(self._check_interval - (time.time() - start_time), 0))
|
||||
|
||||
self._check_interval = int(self._redis_connection.hget(LocalParams.RESOURCE_INFO, "check_interval"))
|
||||
|
||||
def push_local_resource_usage(self):
|
||||
resource_usage = ResourceInfo.get_dynamic_info(self._check_interval)
|
||||
|
||||
self._redis_connection.rpush(
|
||||
LocalParams.CPU_USAGE,
|
||||
json.dumps(resource_usage["cpu_usage_per_core"])
|
||||
)
|
||||
|
||||
self._redis_connection.rpush(
|
||||
LocalParams.MEMORY_USAGE,
|
||||
json.dumps(resource_usage["memory_usage"])
|
||||
)
|
||||
|
||||
self._redis_connection.rpush(
|
||||
LocalParams.GPU_USAGE,
|
||||
json.dumps(resource_usage["gpu_memory_usage"])
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
resource_agent = ResourceTrackingAgent()
|
||||
resource_agent.start()
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import yaml
|
||||
|
||||
from maro.cli.process.executor import ProcessExecutor
|
||||
from maro.cli.process.utils.default_param import process_setting
|
||||
|
||||
|
||||
def create(deployment_path: str, **kwargs):
|
||||
if deployment_path is not None:
|
||||
with open(deployment_path, "r") as fr:
|
||||
create_deployment = yaml.safe_load(fr)
|
||||
else:
|
||||
create_deployment = process_setting
|
||||
|
||||
executor = ProcessExecutor(create_deployment)
|
||||
executor.create()
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from maro.cli.process.executor import ProcessExecutor
|
||||
|
||||
|
||||
def delete(**kwargs):
|
||||
executor = ProcessExecutor()
|
||||
executor.delete()
|
|
@ -1,10 +0,0 @@
|
|||
mode: process
|
||||
name: MyJobName # str: name of the training job
|
||||
|
||||
components: # component config
|
||||
actor:
|
||||
num: 5 # int: number of this component
|
||||
command: "python /target/path/run_actor.py" # str: command to be executed
|
||||
learner:
|
||||
num: 1
|
||||
command: "python /target/path/run_learner.py"
|
|
@ -1,16 +0,0 @@
|
|||
mode: process
|
||||
name: MyScheduleName # str: name of the training schedule
|
||||
|
||||
job_names: # list: names of the training job
|
||||
- MyJobName2
|
||||
- MyJobName3
|
||||
- MyJobName4
|
||||
- MyJobName5
|
||||
|
||||
components: # component config
|
||||
actor:
|
||||
num: 5 # int: number of this component
|
||||
command: "python /target/path/run_actor.py" # str: command to be executed
|
||||
learner:
|
||||
num: 1
|
||||
command: "python /target/path/run_learner.py"
|
|
@ -1,8 +0,0 @@
|
|||
redis_info:
|
||||
host: "localhost"
|
||||
port: 19999
|
||||
redis_mode: MARO # one of MARO, customized. customized Redis won't be exited after maro process clear.
|
||||
parallel_level: 1 # Represented the maximum number of running jobs in the same times.
|
||||
keep_agent_alive: True # If True represented the agents won't exit until the environment delete; otherwise, False.
|
||||
agent_countdown: 5 # After agent_countdown times checks, still no jobs will close agents. Available only if keep_agent_alive is 0.
|
||||
check_interval: 60 # The time interval (seconds) of agents check with Redis
|
|
@ -1,248 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
import redis
|
||||
import yaml
|
||||
|
||||
from maro.cli.grass.lib.services.utils.params import JobStatus
|
||||
from maro.cli.process.utils.details import close_by_pid, get_redis_pid_by_port
|
||||
from maro.cli.utils.abs_visible_executor import AbsVisibleExecutor
|
||||
from maro.cli.utils.details_reader import DetailsReader
|
||||
from maro.cli.utils.details_writer import DetailsWriter
|
||||
from maro.cli.utils.params import GlobalPaths, LocalPaths, ProcessRedisName
|
||||
from maro.cli.utils.resource_executor import LocalResourceExecutor
|
||||
from maro.utils.logger import CliLogger
|
||||
|
||||
logger = CliLogger(name=__name__)
|
||||
|
||||
|
||||
class ProcessExecutor(AbsVisibleExecutor):
|
||||
def __init__(self, details: dict = None):
|
||||
self.details = details if details else \
|
||||
DetailsReader.load_cluster_details("process")
|
||||
|
||||
# Connection with Redis
|
||||
redis_port = self.details["redis_info"]["port"]
|
||||
self._redis_connection = redis.Redis(host="localhost", port=redis_port)
|
||||
try:
|
||||
self._redis_connection.ping()
|
||||
except Exception:
|
||||
redis_process = subprocess.Popen(
|
||||
["redis-server", "--port", str(redis_port), "--daemonize yes"]
|
||||
)
|
||||
redis_process.wait(timeout=2)
|
||||
|
||||
# Connection with Resource Redis
|
||||
self._resource_redis = LocalResourceExecutor()
|
||||
|
||||
def create(self):
|
||||
logger.info("Starting MARO Multi-Process Mode.")
|
||||
if os.path.isdir(f"{GlobalPaths.ABS_MARO_CLUSTERS}/process"):
|
||||
logger.warning("Process mode has been created.")
|
||||
|
||||
# Get environment setting
|
||||
DetailsWriter.save_cluster_details(
|
||||
cluster_name="process",
|
||||
cluster_details=self.details
|
||||
)
|
||||
|
||||
# Start agents
|
||||
command = f"python {LocalPaths.MARO_PROCESS_AGENT}"
|
||||
_ = subprocess.Popen(command, shell=True)
|
||||
self._redis_connection.hset(ProcessRedisName.SETTING, "agent_status", 1)
|
||||
|
||||
# Add connection to resource Redis.
|
||||
self._resource_redis.add_cluster()
|
||||
|
||||
logger.info(f"MARO process mode setting: {self.details}")
|
||||
|
||||
def delete(self):
|
||||
process_setting = self._redis_connection.hgetall(ProcessRedisName.SETTING)
|
||||
process_setting = {
|
||||
key.decode(): json.loads(value) for key, value in process_setting.items()
|
||||
}
|
||||
|
||||
# Stop running jobs
|
||||
jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)
|
||||
if jobs:
|
||||
for job_name, job_detail in jobs.items():
|
||||
job_detail = json.loads(job_detail)
|
||||
if job_detail["status"] == JobStatus.RUNNING:
|
||||
close_by_pid(pid=job_detail["pid_list"], recursive=False)
|
||||
logger.info(f"Stop running job {job_name.decode()}.")
|
||||
|
||||
# Stop agents
|
||||
agent_status = int(process_setting["agent_status"])
|
||||
if agent_status:
|
||||
agent_pid = int(process_setting["agent_pid"])
|
||||
close_by_pid(pid=agent_pid, recursive=True)
|
||||
logger.info("Close agents.")
|
||||
else:
|
||||
logger.info("Agents is already closed.")
|
||||
|
||||
# Stop Redis or clear Redis
|
||||
redis_mode = self.details["redis_mode"]
|
||||
if redis_mode == "MARO":
|
||||
redis_pid = get_redis_pid_by_port(self.details["redis_info"]["port"])
|
||||
close_by_pid(pid=redis_pid, recursive=False)
|
||||
else:
|
||||
self._redis_clear()
|
||||
|
||||
# Rm connection from resource redis.
|
||||
self._resource_redis.sub_cluster()
|
||||
|
||||
logger.info("Redis cleared.")
|
||||
|
||||
# Remove local process file.
|
||||
shutil.rmtree(f"{GlobalPaths.ABS_MARO_CLUSTERS}/process", True)
|
||||
logger.info("Process mode has been deleted.")
|
||||
|
||||
def _redis_clear(self):
|
||||
redis_keys = self._redis_connection.keys("process:*")
|
||||
for key in redis_keys:
|
||||
self._redis_connection.delete(key)
|
||||
|
||||
def start_job(self, deployment_path: str):
|
||||
# Load start_job_deployment
|
||||
with open(deployment_path, "r") as fr:
|
||||
start_job_deployment = yaml.safe_load(fr)
|
||||
|
||||
job_name = start_job_deployment["name"]
|
||||
start_job_deployment["status"] = JobStatus.PENDING
|
||||
# Push job details to redis
|
||||
self._redis_connection.hset(
|
||||
ProcessRedisName.JOB_DETAILS,
|
||||
job_name,
|
||||
json.dumps(start_job_deployment)
|
||||
)
|
||||
|
||||
self._push_pending_job(job_name)
|
||||
|
||||
def _push_pending_job(self, job_name: str):
|
||||
# Push job name to pending_job_tickets
|
||||
self._redis_connection.lpush(
|
||||
ProcessRedisName.PENDING_JOB_TICKETS,
|
||||
job_name
|
||||
)
|
||||
logger.info(f"Sending {job_name} into pending job tickets.")
|
||||
|
||||
def stop_job(self, job_name: str):
|
||||
if not self._redis_connection.hexists(ProcessRedisName.JOB_DETAILS, job_name):
|
||||
logger.error(f"No such job '{job_name}' in Redis.")
|
||||
return
|
||||
|
||||
# push job_name into kill_job_tickets
|
||||
self._redis_connection.lpush(
|
||||
ProcessRedisName.KILLED_JOB_TICKETS,
|
||||
job_name
|
||||
)
|
||||
logger.info(f"Sending {job_name} into killed job tickets.")
|
||||
|
||||
def delete_job(self, job_name: str):
|
||||
# Stop job for running and pending job.
|
||||
self.stop_job(job_name)
|
||||
|
||||
# Rm job details in Redis
|
||||
self._redis_connection.hdel(ProcessRedisName.JOB_DETAILS, job_name)
|
||||
|
||||
# Rm job's log folder
|
||||
job_folder = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_name}")
|
||||
shutil.rmtree(job_folder, True)
|
||||
logger.info(f"Remove local temporary log folder {job_folder}.")
|
||||
|
||||
def get_job_logs(self, job_name):
|
||||
source_path = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_name}")
|
||||
if not os.path.exists(source_path):
|
||||
logger.error(f"Cannot find the logs of {job_name}.")
|
||||
|
||||
destination = os.path.join(os.getcwd(), job_name)
|
||||
if os.path.exists(destination):
|
||||
shutil.rmtree(destination)
|
||||
shutil.copytree(source_path, destination)
|
||||
logger.info(f"Dump logs in path: {destination}.")
|
||||
|
||||
def list_job(self):
|
||||
# Get all jobs
|
||||
jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)
|
||||
for job_name, job_detail in jobs.items():
|
||||
job_name = job_name.decode()
|
||||
job_detail = json.loads(job_detail)
|
||||
|
||||
logger.info(job_detail)
|
||||
|
||||
def start_schedule(self, deployment_path: str):
|
||||
with open(deployment_path, "r") as fr:
|
||||
schedule_detail = yaml.safe_load(fr)
|
||||
|
||||
# push schedule details to Redis
|
||||
self._redis_connection.hset(
|
||||
ProcessRedisName.JOB_DETAILS,
|
||||
schedule_detail["name"],
|
||||
json.dumps(schedule_detail)
|
||||
)
|
||||
|
||||
job_list = schedule_detail["job_names"]
|
||||
# switch schedule details into job details
|
||||
job_detail = copy.deepcopy(schedule_detail)
|
||||
del job_detail["job_names"]
|
||||
|
||||
for job_name in job_list:
|
||||
job_detail["name"] = job_name
|
||||
|
||||
# Push job details to redis
|
||||
self._redis_connection.hset(
|
||||
ProcessRedisName.JOB_DETAILS,
|
||||
job_name,
|
||||
json.dumps(job_detail)
|
||||
)
|
||||
|
||||
self._push_pending_job(job_name)
|
||||
|
||||
def stop_schedule(self, schedule_name: str):
|
||||
if self._redis_connection.hexists(ProcessRedisName.JOB_DETAILS, schedule_name):
|
||||
schedule_details = json.loads(self._redis_connection.hget(ProcessRedisName.JOB_DETAILS, schedule_name))
|
||||
else:
|
||||
logger.error(f"Cannot find {schedule_name} in Redis. Please check schedule name.")
|
||||
return
|
||||
|
||||
if "job_names" not in schedule_details.keys():
|
||||
logger.error(f"'{schedule_name}' is not a schedule.")
|
||||
return
|
||||
|
||||
job_list = schedule_details["job_names"]
|
||||
|
||||
for job_name in job_list:
|
||||
self.stop_job(job_name)
|
||||
|
||||
def get_job_details(self):
|
||||
jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)
|
||||
for job_name, job_details_str in jobs.items():
|
||||
jobs[job_name] = json.loads(job_details_str)
|
||||
|
||||
return list(jobs.values())
|
||||
|
||||
def get_job_queue(self):
|
||||
pending_job_queue = self._redis_connection.lrange(
|
||||
ProcessRedisName.PENDING_JOB_TICKETS,
|
||||
0, -1
|
||||
)
|
||||
killed_job_queue = self._redis_connection.lrange(
|
||||
ProcessRedisName.KILLED_JOB_TICKETS,
|
||||
0, -1
|
||||
)
|
||||
return {
|
||||
"pending_jobs": pending_job_queue,
|
||||
"killed_jobs": killed_job_queue
|
||||
}
|
||||
|
||||
def get_resource(self):
|
||||
return self._resource_redis.get_local_resource()
|
||||
|
||||
def get_resource_usage(self, previous_length: int):
|
||||
return self._resource_redis.get_local_resource_usage(previous_length)
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
from maro.cli.process.executor import ProcessExecutor
|
||||
|
||||
|
||||
def start_job(deployment_path: str, **kwargs):
|
||||
executor = ProcessExecutor()
|
||||
executor.start_job(deployment_path=deployment_path)
|
||||
|
||||
|
||||
def stop_job(job_name: str, **kwargs):
|
||||
executor = ProcessExecutor()
|
||||
executor.stop_job(job_name=job_name)
|
||||
|
||||
|
||||
def delete_job(job_name: str, **kwargs):
|
||||
executor = ProcessExecutor()
|
||||
executor.delete_job(job_name=job_name)
|
||||
|
||||
|
||||
def list_jobs(**kwargs):
|
||||
executor = ProcessExecutor()
|
||||
executor.list_job()
|
||||
|
||||
|
||||
def get_job_logs(job_name: str, **kwargs):
|
||||
executor = ProcessExecutor()
|
||||
executor.get_job_logs(job_name=job_name)
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
from maro.cli.process.executor import ProcessExecutor
|
||||
|
||||
|
||||
def start_schedule(deployment_path: str, **kwargs):
|
||||
executor = ProcessExecutor()
|
||||
executor.start_schedule(deployment_path=deployment_path)
|
||||
|
||||
|
||||
def stop_schedule(schedule_name: str, **kwargs):
|
||||
executor = ProcessExecutor()
|
||||
executor.stop_schedule(schedule_name=schedule_name)
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from maro.cli.utils.params import LocalPaths
|
||||
|
||||
|
||||
def template(setting_deploy, export_path, **kwargs):
|
||||
deploy_files = os.listdir(LocalPaths.MARO_PROCESS_DEPLOYMENT)
|
||||
if not setting_deploy:
|
||||
deploy_files.remove("process_setting_deployment.yml")
|
||||
export_path = os.path.abspath(export_path)
|
||||
for file_name in deploy_files:
|
||||
if os.path.isfile(f"{LocalPaths.MARO_PROCESS_DEPLOYMENT}/{file_name}"):
|
||||
shutil.copy(f"{LocalPaths.MARO_PROCESS_DEPLOYMENT}/{file_name}", export_path)
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
process_setting = {
|
||||
"redis_info": {
|
||||
"host": "localhost",
|
||||
"port": 19999
|
||||
},
|
||||
"redis_mode": "MARO", # one of MARO, customized. customized Redis won't exit after maro process clear.
|
||||
"parallel_level": 1,
|
||||
"keep_agent_alive": 1, # If 0 (False), agents will exit after 5 minutes of no pending jobs and running jobs.
|
||||
"check_interval": 60, # seconds
|
||||
"agent_countdown": 5 # how many times to shutdown agents about finding no job in Redis.
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
from typing import Union
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def close_by_pid(pid: Union[int, list], recursive: bool = False):
|
||||
if isinstance(pid, int):
|
||||
if not psutil.pid_exists(pid):
|
||||
return
|
||||
|
||||
if recursive:
|
||||
current_process = psutil.Process(pid)
|
||||
children_process = current_process.children(recursive=False)
|
||||
# May launch by JobTrackingAgent which is child process, so need close parent process first.
|
||||
current_process.kill()
|
||||
for child_process in children_process:
|
||||
child_process.kill()
|
||||
else:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
else:
|
||||
for p in pid:
|
||||
if psutil.pid_exists(p):
|
||||
os.kill(p, signal.SIGKILL)
|
||||
|
||||
|
||||
def get_child_pid(parent_pid):
|
||||
command = f"ps -o pid --ppid {parent_pid} --noheaders"
|
||||
get_children_pid_process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)
|
||||
children_pids = get_children_pid_process.stdout.read()
|
||||
get_children_pid_process.wait(timeout=2)
|
||||
|
||||
# Convert into list or int
|
||||
try:
|
||||
children_pids = int(children_pids)
|
||||
except ValueError:
|
||||
children_pids = children_pids.decode().split("\n")
|
||||
children_pids = [int(pid) for pid in children_pids[:-1]]
|
||||
|
||||
return children_pids
|
||||
|
||||
|
||||
def get_redis_pid_by_port(port: int):
|
||||
get_redis_pid_command = f"pidof 'redis-server *:{port}'"
|
||||
get_redis_pid_process = subprocess.Popen(get_redis_pid_command, shell=True, stdout=subprocess.PIPE)
|
||||
redis_pid = int(get_redis_pid_process.stdout.read())
|
||||
get_redis_pid_process.wait()
|
||||
|
||||
return redis_pid
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
|
||||
from maro.cli.utils.subprocess import Subprocess
|
||||
|
||||
|
||||
def login_acr(acr_name: str) -> None:
|
||||
command = f"az acr login --name {acr_name}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def list_acr_repositories(acr_name: str) -> list:
|
||||
command = f"az acr repository list -n {acr_name}"
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import subprocess
|
||||
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.mgmt.authorization import AuthorizationManagementClient
|
||||
from azure.mgmt.containerservice import ContainerServiceClient
|
||||
|
||||
from maro.cli.utils.subprocess import Subprocess
|
||||
|
||||
|
||||
def get_container_service_client(subscription: str):
|
||||
return ContainerServiceClient(DefaultAzureCredential(), subscription)
|
||||
|
||||
|
||||
def get_authorization_client(subscription: str):
|
||||
return AuthorizationManagementClient()
|
||||
|
||||
|
||||
def load_aks_context(resource_group: str, aks_name: str) -> None:
|
||||
command = f"az aks get-credentials -g {resource_group} --name {aks_name}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def get_aks(subscription: str, resource_group: str, aks_name: str) -> dict:
|
||||
container_service_client = get_container_service_client(subscription)
|
||||
return container_service_client.managed_clusters.get(resource_group, aks_name)
|
||||
|
||||
|
||||
def attach_acr(resource_group: str, aks_name: str, acr_name: str) -> None:
|
||||
subprocess.run(f"az aks update -g {resource_group} -n {aks_name} --attach-acr {acr_name}".split())
|
||||
|
||||
|
||||
def add_nodepool(resource_group: str, aks_name: str, nodepool_name: str, node_count: int, node_size: str) -> None:
|
||||
command = (
|
||||
f"az aks nodepool add "
|
||||
f"-g {resource_group} "
|
||||
f"--cluster-name {aks_name} "
|
||||
f"--name {nodepool_name} "
|
||||
f"--node-count {node_count} "
|
||||
f"--node-vm-size {node_size}"
|
||||
)
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def scale_nodepool(resource_group: str, aks_name: str, nodepool_name: str, node_count: int) -> None:
|
||||
command = (
|
||||
f"az aks nodepool scale "
|
||||
f"-g {resource_group} "
|
||||
f"--cluster-name {aks_name} "
|
||||
f"--name {nodepool_name} "
|
||||
f"--node-count {node_count}"
|
||||
)
|
||||
_ = Subprocess.run(command=command)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .general import get_resource_client
|
||||
|
||||
|
||||
def create_deployment(
|
||||
subscription: str,
|
||||
resource_group: str,
|
||||
deployment_name: str,
|
||||
template: dict,
|
||||
params: dict,
|
||||
sync: bool = True
|
||||
) -> None:
|
||||
params = {k: {"value": v} for k, v in params.items()}
|
||||
resource_client = get_resource_client(subscription)
|
||||
deployment_params = {"mode": "Incremental", "template": template, "parameters": params}
|
||||
result = resource_client.deployments.begin_create_or_update(
|
||||
resource_group, deployment_name, {"properties": deployment_params}
|
||||
)
|
||||
if sync:
|
||||
result.result()
|
||||
|
||||
|
||||
def delete_deployment(subscription: str, resource_group: str, deployment_name: str) -> None:
|
||||
resource_client = get_resource_client(subscription)
|
||||
resource_client.deployments.begin_delete(resource_group, deployment_name)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.mgmt.resource import ResourceManagementClient
|
||||
|
||||
from maro.cli.utils.subprocess import Subprocess
|
||||
|
||||
|
||||
def set_subscription(subscription: str) -> None:
|
||||
command = f"az account set --subscription {subscription}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def get_version() -> dict:
|
||||
command = "az version"
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)
|
||||
|
||||
|
||||
def get_resource_client(subscription: str):
|
||||
return ResourceManagementClient(DefaultAzureCredential(), subscription)
|
||||
|
||||
|
||||
def set_env_credentials(dump_path: str, service_principal_name: str):
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
service_principal_file_path = os.path.join(dump_path, f"{service_principal_name}.json")
|
||||
# If the service principal file does not exist, create one using the az CLI command.
|
||||
# For details on service principals, refer to
|
||||
# https://docs.microsoft.com/en-us/azure/active-directory/develop/app-objects-and-service-principals
|
||||
if not os.path.exists(service_principal_file_path):
|
||||
with open(service_principal_file_path, 'w') as fp:
|
||||
subprocess.run(
|
||||
f"az ad sp create-for-rbac --name {service_principal_name} --sdk-auth --role contributor".split(),
|
||||
stdout=fp
|
||||
)
|
||||
|
||||
with open(service_principal_file_path, 'r') as fp:
|
||||
service_principal = json.load(fp)
|
||||
|
||||
os.environ["AZURE_TENANT_ID"] = service_principal["tenantId"]
|
||||
os.environ["AZURE_CLIENT_ID"] = service_principal["clientId"]
|
||||
os.environ["AZURE_CLIENT_SECRET"] = service_principal["clientSecret"]
|
||||
os.environ["AZURE_SUBSCRIPTION_ID"] = service_principal["subscriptionId"]
|
||||
|
||||
|
||||
def connect_to_aks(resource_group: str, aks: str):
|
||||
subprocess.run(f"az aks get-credentials --resource-group {resource_group} --name {aks}".split())
|
||||
|
||||
|
||||
def get_acr_push_permissions(service_principal_id: str, acr: str):
|
||||
acr_id = json.loads(
|
||||
subprocess.run(f"az acr show --name {acr} --query id".split(), stdout=subprocess.PIPE).stdout
|
||||
)
|
||||
subprocess.run(
|
||||
f"az role assignment create --assignee {service_principal_id} --scope {acr_id} --role acrpush".split()
|
||||
)
|
||||
subprocess.run(f"az acr login --name {acr}".split())
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
|
||||
from maro.cli.utils.subprocess import Subprocess
|
||||
from maro.utils.exception.cli_exception import CommandExecutionError
|
||||
|
||||
from .general import get_resource_client
|
||||
|
||||
|
||||
def get_resource_group(resource_group: str) -> dict:
|
||||
command = f"az group show --name {resource_group}"
|
||||
try:
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)
|
||||
except CommandExecutionError:
|
||||
return {}
|
||||
|
||||
|
||||
def delete_resource_group(resource_group: str) -> None:
|
||||
command = f"az group delete --yes --name {resource_group}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
# Chained Azure resource group operations
|
||||
def create_resource_group(subscription: str, resource_group: str, location: str):
|
||||
"""Create the resource group if it does not exist.
|
||||
|
||||
Args:
|
||||
subscription (str): Azure subscription name.
|
||||
resource group (str): Resource group name.
|
||||
location (str): Reousrce group location.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
resource_client = get_resource_client(subscription)
|
||||
return resource_client.resource_groups.create_or_update(resource_group, {"location": location})
|
||||
|
||||
|
||||
def delete_resource_group_under_subscription(subscription: str, resource_group: str):
|
||||
resource_client = get_resource_client(subscription)
|
||||
return resource_client.resource_groups.begin_delete(resource_group)
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
|
||||
from maro.cli.utils.subprocess import Subprocess
|
||||
|
||||
|
||||
def list_resources(resource_group: str) -> list:
|
||||
command = f"az resource list -g {resource_group}"
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)
|
||||
|
||||
|
||||
def delete_resources(resource_ids: list) -> None:
|
||||
command = f"az resource delete --ids {' '.join(resource_ids)}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def cleanup(cluster_name: str, resource_group: str) -> None:
|
||||
# Get resource list
|
||||
resource_list = list_resources(resource_group)
|
||||
|
||||
# Filter resources
|
||||
deletable_ids = []
|
||||
for resource in resource_list:
|
||||
if resource["name"].startswith(cluster_name):
|
||||
deletable_ids.append(resource["id"])
|
||||
|
||||
# Delete resources
|
||||
if deletable_ids:
|
||||
delete_resources(resource_ids=deletable_ids)
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from azure.core.exceptions import ResourceExistsError
|
||||
from azure.storage.fileshare import ShareClient, ShareDirectoryClient
|
||||
|
||||
from maro.cli.utils.subprocess import Subprocess
|
||||
|
||||
|
||||
def get_storage_account_keys(resource_group: str, storage_account_name: str) -> dict:
|
||||
command = f"az storage account keys list -g {resource_group} --account-name {storage_account_name}"
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)
|
||||
|
||||
|
||||
def get_storage_account_sas(
|
||||
account_name: str,
|
||||
services: str = "bqtf",
|
||||
resource_types: str = "sco",
|
||||
permissions: str = "rwdlacup",
|
||||
expiry: str = (datetime.datetime.utcnow() + datetime.timedelta(days=365)).strftime("%Y-%m-%dT%H:%M:%S") + "Z"
|
||||
) -> str:
|
||||
command = (
|
||||
f"az storage account generate-sas --account-name {account_name} --services {services} "
|
||||
f"--resource-types {resource_types} --permissions {permissions} --expiry {expiry}"
|
||||
)
|
||||
sas_str = Subprocess.run(command=command).strip("\n").replace('"', "")
|
||||
# logger.debug(sas_str)
|
||||
return sas_str
|
||||
|
||||
|
||||
def get_connection_string(storage_account_name: str) -> str:
|
||||
"""Get the connection string for a storage account.
|
||||
|
||||
Args:
|
||||
storage_account_name: The storage account name.
|
||||
|
||||
Returns:
|
||||
str: Connection string.
|
||||
"""
|
||||
command = f"az storage account show-connection-string --name {storage_account_name}"
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)["connectionString"]
|
||||
|
||||
|
||||
def get_fileshare(storage_account_name: str, fileshare_name: str):
|
||||
connection_string = get_connection_string(storage_account_name)
|
||||
share = ShareClient.from_connection_string(connection_string, fileshare_name)
|
||||
try:
|
||||
share.create_share()
|
||||
except ResourceExistsError:
|
||||
pass
|
||||
|
||||
return share
|
||||
|
||||
|
||||
def get_directory(share: Union[ShareClient, ShareDirectoryClient], name: str):
|
||||
if isinstance(share, ShareClient):
|
||||
directory = share.get_directory_client(directory_path=name)
|
||||
try:
|
||||
directory.create_directory()
|
||||
except ResourceExistsError:
|
||||
pass
|
||||
|
||||
return directory
|
||||
elif isinstance(share, ShareDirectoryClient):
|
||||
try:
|
||||
return share.create_subdirectory(name)
|
||||
except ResourceExistsError:
|
||||
return share.get_subdirectory_client(name)
|
||||
|
||||
|
||||
def upload_to_fileshare(share: Union[ShareClient, ShareDirectoryClient], source_path: str, name: str = None):
|
||||
if os.path.isdir(source_path):
|
||||
if not name:
|
||||
name = os.path.basename(source_path)
|
||||
directory = get_directory(share, name)
|
||||
for file in os.listdir(source_path):
|
||||
upload_to_fileshare(directory, os.path.join(source_path, file))
|
||||
else:
|
||||
with open(source_path, "rb") as fp:
|
||||
share.upload_file(file_name=os.path.basename(source_path), data=fp)
|
||||
|
||||
|
||||
def download_from_fileshare(share: ShareDirectoryClient, file_name: str, local_path: str):
|
||||
file = share.get_file_client(file_name=file_name)
|
||||
with open(local_path, "wb") as fp:
|
||||
fp.write(file.download_file().readall())
|
||||
|
||||
|
||||
def delete_directory(share: Union[ShareClient, ShareDirectoryClient], name: str, recursive: bool = True):
|
||||
share.delete_directory(directory_name=name)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
|
||||
from maro.cli.utils.subprocess import Subprocess
|
||||
|
||||
|
||||
def list_ip_addresses(resource_group: str, vm_name: str) -> list:
|
||||
command = f"az vm list-ip-addresses -g {resource_group} --name {vm_name}"
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)
|
||||
|
||||
|
||||
def start_vm(resource_group: str, vm_name: str) -> None:
|
||||
command = f"az vm start -g {resource_group} --name {vm_name}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def stop_vm(resource_group: str, vm_name: str) -> None:
|
||||
command = f"az vm stop -g {resource_group} --name {vm_name}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def list_vm_sizes(location: str) -> list:
|
||||
command = f"az vm list-sizes -l {location}"
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)
|
||||
|
||||
|
||||
def deallocate_vm(resource_group: str, vm_name: str) -> None:
|
||||
command = f"az vm deallocate --resource-group {resource_group} --name {vm_name}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def generalize_vm(resource_group: str, vm_name: str) -> None:
|
||||
command = f"az vm generalize --resource-group {resource_group} --name {vm_name}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def create_image_from_vm(resource_group: str, image_name: str, vm_name: str) -> None:
|
||||
command = f"az image create --resource-group {resource_group} --name {image_name} --source {vm_name}"
|
||||
_ = Subprocess.run(command=command)
|
||||
|
||||
|
||||
def get_image_resource_id(resource_group: str, image_name: str) -> str:
|
||||
command = f"az image show --resource-group {resource_group} --name {image_name}"
|
||||
return_str = Subprocess.run(command=command)
|
||||
return json.loads(return_str)["id"]
|
|
@ -1,7 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import deque
|
||||
|
||||
import psutil
|
||||
|
||||
from maro.utils import Logger
|
||||
|
||||
|
||||
def close_by_pid(pid: int, recursive: bool = True):
|
||||
if not psutil.pid_exists(pid):
|
||||
return
|
||||
|
||||
proc = psutil.Process(pid)
|
||||
if recursive:
|
||||
for child in proc.children(recursive=recursive):
|
||||
child.kill()
|
||||
|
||||
proc.kill()
|
||||
|
||||
|
||||
def get_child_pids(parent_pid):
|
||||
# command = f"ps -o pid --ppid {parent_pid} --noheaders"
|
||||
# get_children_pid_process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)
|
||||
# children_pids = get_children_pid_process.stdout.read()
|
||||
# get_children_pid_process.wait(timeout=2)
|
||||
|
||||
# # Convert into list or int
|
||||
# try:
|
||||
# children_pids = int(children_pids)
|
||||
# except ValueError:
|
||||
# children_pids = children_pids.decode().split("\n")
|
||||
# children_pids = [int(pid) for pid in children_pids[:-1]]
|
||||
|
||||
# return children_pids
|
||||
try:
|
||||
return [child.pid for child in psutil.Process(parent_pid).children(recursive=True)]
|
||||
except psutil.NoSuchProcess:
|
||||
print(f"No process with PID {parent_pid} found")
|
||||
return
|
||||
|
||||
|
||||
def get_redis_pid_by_port(port: int):
|
||||
get_redis_pid_command = f"pidof 'redis-server *:{port}'"
|
||||
get_redis_pid_process = subprocess.Popen(get_redis_pid_command, shell=True, stdout=subprocess.PIPE)
|
||||
redis_pid = int(get_redis_pid_process.stdout.read())
|
||||
get_redis_pid_process.wait()
|
||||
return redis_pid
|
||||
|
||||
|
||||
def exit(state: int = 0, msg: str = None):
|
||||
|
@ -10,3 +58,75 @@ def exit(state: int = 0, msg: str = None):
|
|||
sys.stderr.write(msg)
|
||||
|
||||
sys.exit(state)
|
||||
|
||||
|
||||
def get_last_k_lines(file_name: str, k: int):
|
||||
"""
|
||||
Helper function to retrieve the last K lines from a file in a memory-efficient way.
|
||||
|
||||
Code slightly adapted from https://thispointer.com/python-get-last-n-lines-of-a-text-file-like-tail-command/
|
||||
"""
|
||||
# Create an empty list to keep the track of last k lines
|
||||
lines = deque()
|
||||
# Open file for reading in binary mode
|
||||
with open(file_name, 'rb') as fp:
|
||||
# Move the cursor to the end of the file
|
||||
fp.seek(0, os.SEEK_END)
|
||||
# Create a buffer to keep the last read line
|
||||
buffer = bytearray()
|
||||
# Get the current position of pointer i.e eof
|
||||
ptr = fp.tell()
|
||||
# Loop till pointer reaches the top of the file
|
||||
while ptr >= 0:
|
||||
# Move the file pointer to the location pointed by ptr
|
||||
fp.seek(ptr)
|
||||
# Shift pointer location by -1
|
||||
ptr -= 1
|
||||
# read that byte / character
|
||||
new_byte = fp.read(1)
|
||||
# If the read byte is new line character then it means one line is read
|
||||
if new_byte != b'\n':
|
||||
# If last read character is not eol then add it in buffer
|
||||
buffer.extend(new_byte)
|
||||
elif buffer:
|
||||
lines.appendleft(buffer.decode()[::-1])
|
||||
if len(lines) == k:
|
||||
return lines
|
||||
# Reinitialize the byte array to save next line
|
||||
buffer.clear()
|
||||
|
||||
# As file is read completely, if there is still data in buffer, then it's the first of the last K lines.
|
||||
if buffer:
|
||||
lines.appendleft(buffer.decode()[::-1])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def show_log(log_path: str, tail: int = -1, logger: Logger = None):
|
||||
print_fn = logger.info if logger else print
|
||||
if tail == -1:
|
||||
with open(log_path, "r") as fp:
|
||||
for line in fp:
|
||||
print_fn(line.rstrip('\n'))
|
||||
else:
|
||||
for line in get_last_k_lines(log_path, tail):
|
||||
print_fn(line)
|
||||
|
||||
|
||||
def format_env_vars(env: dict, mode: str = "proc"):
|
||||
if mode == "proc":
|
||||
return env
|
||||
|
||||
if mode == "docker":
|
||||
env_opt_list = []
|
||||
for key, val in env.items():
|
||||
env_opt_list.extend(["--env", f"{key}={val}"])
|
||||
return env_opt_list
|
||||
|
||||
if mode == "docker-compose":
|
||||
return [f"{key}={val}" for key, val in env.items()]
|
||||
|
||||
if mode == "k8s":
|
||||
return [{"name": key, "value": val} for key, val in env.items()]
|
||||
|
||||
raise ValueError(f"'mode' should be one of 'proc', 'docker', 'docker-compose', 'k8s', got {mode}")
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import docker
|
||||
|
||||
|
||||
def image_exists(image_name: str):
|
||||
try:
|
||||
client = docker.from_env()
|
||||
client.images.get(image_name)
|
||||
return True
|
||||
except docker.errors.ImageNotFound:
|
||||
return False
|
||||
|
||||
|
||||
def build_image(context: str, docker_file_path: str, image_name: str):
|
||||
client = docker.from_env()
|
||||
with open(docker_file_path, "r"):
|
||||
client.images.build(
|
||||
path=context,
|
||||
tag=image_name,
|
||||
quiet=False,
|
||||
rm=True,
|
||||
custom_context=False,
|
||||
dockerfile=docker_file_path
|
||||
)
|
||||
|
||||
|
||||
def push(local_image_name: str, repository: str):
|
||||
client = docker.from_env()
|
||||
image = client.images.get(local_image_name)
|
||||
acr_tag = f"{repository}/{local_image_name}"
|
||||
image.tag(acr_tag)
|
||||
# subprocess.run(f"docker push {acr_tag}".split())
|
||||
client.images.push(acr_tag)
|
||||
print(f"Pushed image to {acr_tag}")
|
|
@ -38,23 +38,3 @@ class LocalParams:
|
|||
CPU_USAGE = "local_resource:cpu_usage_per_core"
|
||||
MEMORY_USAGE = "local_resource:memory_usage"
|
||||
GPU_USAGE = "local_resource:gpu_memory_usage"
|
||||
|
||||
|
||||
class LocalPaths:
|
||||
"""Only use by maro process cli"""
|
||||
MARO_PROCESS = "~/.maro/clusters/process"
|
||||
MARO_PROCESS_AGENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/agent/job_agent.py")
|
||||
MARO_RESOURCE_AGENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/agent/resource_agent.py")
|
||||
MARO_PROCESS_DEPLOYMENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/deployment")
|
||||
MARO_GRASS_LOCAL_AGENT = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../grass/lib/services/master_agent/local_agent.py"
|
||||
)
|
||||
|
||||
|
||||
class ProcessRedisName:
|
||||
"""Record Redis elements name, and only for maro process"""
|
||||
PENDING_JOB_TICKETS = "process:pending_job_tickets"
|
||||
KILLED_JOB_TICKETS = "process:killed_job_tickets"
|
||||
JOB_DETAILS = "process:job_details"
|
||||
SETTING = "process:setting"
|
||||
|
|
|
@ -26,7 +26,7 @@ def dist(proxy: Proxy, handler_dict: {object: Callable}):
|
|||
self.local_instance = cls(*args, **kwargs)
|
||||
self.proxy = proxy
|
||||
self._handler_function = {}
|
||||
self._registry_table = RegisterTable(self.proxy.peers_name)
|
||||
self._registry_table = RegisterTable(self.proxy.peers)
|
||||
# Use functools.partial to freeze handling function's local_instance and proxy
|
||||
# arguments to self.local_instance and self.proxy.
|
||||
for constraint, handler_fun in handler_dict.items():
|
||||
|
|
|
@ -69,7 +69,7 @@ class ZmqDriver(AbsDriver):
|
|||
"""
|
||||
self._unicast_receiver = self._zmq_context.socket(zmq.PULL)
|
||||
unicast_receiver_port = self._unicast_receiver.bind_to_random_port(f"{self._protocol}://*")
|
||||
self._logger.info(f"Receive message via unicasting at {self._ip_address}:{unicast_receiver_port}.")
|
||||
self._logger.debug(f"Receive message via unicasting at {self._ip_address}:{unicast_receiver_port}.")
|
||||
|
||||
# Dict about zmq.PUSH sockets, fulfills in self.connect.
|
||||
self._unicast_sender_dict = {}
|
||||
|
@ -80,7 +80,7 @@ class ZmqDriver(AbsDriver):
|
|||
self._broadcast_receiver = self._zmq_context.socket(zmq.SUB)
|
||||
self._broadcast_receiver.setsockopt(zmq.SUBSCRIBE, self._component_type.encode())
|
||||
broadcast_receiver_port = self._broadcast_receiver.bind_to_random_port(f"{self._protocol}://*")
|
||||
self._logger.info(f"Subscriber message at {self._ip_address}:{broadcast_receiver_port}.")
|
||||
self._logger.debug(f"Subscriber message at {self._ip_address}:{broadcast_receiver_port}.")
|
||||
|
||||
# Record own sockets' address.
|
||||
self._address = {
|
||||
|
@ -122,10 +122,10 @@ class ZmqDriver(AbsDriver):
|
|||
self._unicast_sender_dict[peer_name] = self._zmq_context.socket(zmq.PUSH)
|
||||
self._unicast_sender_dict[peer_name].setsockopt(zmq.SNDTIMEO, self._send_timeout)
|
||||
self._unicast_sender_dict[peer_name].connect(address)
|
||||
self._logger.info(f"Connects to {peer_name} via unicasting.")
|
||||
self._logger.debug(f"Connects to {peer_name} via unicasting.")
|
||||
elif int(socket_type) == zmq.SUB:
|
||||
self._broadcast_sender.connect(address)
|
||||
self._logger.info(f"Connects to {peer_name} via broadcasting.")
|
||||
self._logger.debug(f"Connects to {peer_name} via broadcasting.")
|
||||
else:
|
||||
raise SocketTypeError(f"Unrecognized socket type {socket_type}.")
|
||||
except Exception as e:
|
||||
|
@ -158,13 +158,13 @@ class ZmqDriver(AbsDriver):
|
|||
raise PeersDisconnectionError(f"Driver cannot disconnect to {peer_name}! Due to {str(e)}")
|
||||
|
||||
self._disconnected_peer_name_list.append(peer_name)
|
||||
self._logger.info(f"Disconnected with {peer_name}.")
|
||||
self._logger.debug(f"Disconnected with {peer_name}.")
|
||||
|
||||
def receive(self, is_continuous: bool = True, timeout: int = None):
|
||||
def receive(self, timeout: int = None):
|
||||
"""Receive message from ``zmq.POLLER``.
|
||||
|
||||
Args:
|
||||
is_continuous (bool): Continuously receive message or not. Defaults to True.
|
||||
timeout (int): Timeout for polling. If the first poll times out, the function returns None.
|
||||
|
||||
Yields:
|
||||
recv_message (Message): The received message from the poller.
|
||||
|
@ -184,13 +184,38 @@ class ZmqDriver(AbsDriver):
|
|||
recv_message = pickle.loads(recv_message)
|
||||
self._logger.debug(f"Receive a message from {recv_message.source} through broadcast receiver.")
|
||||
else:
|
||||
self._logger.debug(f"Cannot receive any message within {receive_timeout}.")
|
||||
self._logger.debug(f"No message received within {receive_timeout}.")
|
||||
return
|
||||
|
||||
yield recv_message
|
||||
|
||||
if not is_continuous:
|
||||
break
|
||||
def receive_once(self, timeout: int = None):
|
||||
"""Receive a single message from ``zmq.POLLER``.
|
||||
|
||||
Args:
|
||||
timeout (int): Time-out for ZMQ polling. If the first poll times out, the function returns None.
|
||||
|
||||
Returns:
|
||||
recv_message (Message): The received message from the poller or None if the poller times out.
|
||||
"""
|
||||
receive_timeout = timeout if timeout else self._receive_timeout
|
||||
try:
|
||||
sockets = dict(self._poller.poll(receive_timeout))
|
||||
except Exception as e:
|
||||
raise DriverReceiveError(f"Driver cannot receive message as {e}")
|
||||
|
||||
if self._unicast_receiver in sockets:
|
||||
recv_message = self._unicast_receiver.recv_pyobj()
|
||||
self._logger.debug(f"Receive a message from {recv_message.source} through unicast receiver.")
|
||||
elif self._broadcast_receiver in sockets:
|
||||
_, recv_message = self._broadcast_receiver.recv_multipart()
|
||||
recv_message = pickle.loads(recv_message)
|
||||
self._logger.debug(f"Receive a message from {recv_message.source} through broadcast receiver.")
|
||||
else:
|
||||
self._logger.debug(f"No message received within {receive_timeout}.")
|
||||
return
|
||||
|
||||
return recv_message
|
||||
|
||||
def send(self, message: Message):
|
||||
"""Send message.
|
||||
|
|