MARO v0.3: a new design of RL Toolkit, CLI refactorization, and corresponding updates. (#539)

* refined proxy coding style

* updated images and refined doc

* updated images

* updated CIM-AC example

* refined proxy retry logic

* call policy update only for AbsCorePolicy

* add limitation of AbsCorePolicy in Actor.collect()

* refined actor to return only experiences for policies that received new experiences

* fix MsgKey issue in rollout_manager

* fix typo in learner

* call exit function for parallel rollout manager

* update supply chain example distributed training scripts

* 1. moved exploration scheduling to rollout manager; 2. fixed bug in lr schedule registration in core model; 3. added parallel policy manager prorotype

* reformat render

* fix supply chain business engine action type problem

* reset supply chain example render figsize from 4 to 3

* Add render to all modes of supply chain example

* fix or policy typos

* 1. added parallel policy manager prototype; 2. used training ep for evaluation episodes

* refined parallel policy manager

* updated rl/__init__/py

* fixed lint issues and CIM local learner bugs

* deleted unwanted supply_chain test files

* revised default config for cim-dqn

* removed test_store.py as it is no longer needed

* 1. changed Actor class to rollout_worker function; 2. renamed algorithm to algorithms

* updated figures

* removed unwanted import

* refactored CIM-DQN example

* added MultiProcessRolloutManager and MultiProcessTrainingManager

* updated doc

* lint issue fix

* lint issue fix

* fixed import formatting

* [Feature] Prioritized Experience Replay (#355)

* added prioritized experience replay

* deleted unwanted supply_chain test files

* fixed import order

* import fix

* fixed lint issues

* fixed import formatting

* added note in docstring that rank-based PER has yet to be implemented

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* rm AbsDecisionGenerator

* small fixes

* bug fix

* reorganized training folder structure

* fixed lint issues

* fixed lint issues

* policy manager refined

* lint fix

* restructured CIM-dqn sync code

* added policy version index and used it as a measure of experience staleness

* lint issue fix

* lint issue fix

* switched log_dir and proxy_kwargs order

* cim example refinement

* eval schedule sorted only when it's a list

* eval schedule sorted only when it's a list

* update sc env wrapper

* added docker scripts for cim-dqn

* refactored example folder structure and added workflow templates

* fixed lint issues

* fixed lint issues

* fixed template bugs

* removed unused imports

* refactoring sc in progress

* simplified cim meta

* fixed build.sh path bug

* template refinement

* deleted obsolete svgs

* updated learner logs

* minor edits

* refactored templates for easy merge with async PR

* added component names for rollout manager and policy manager

* fixed incorrect position to add last episode to eval schedule

* added max_lag option in templates

* formatting edit in docker_compose_yml script

* moved local learner and early stopper outside sync_tools

* refactored rl toolkit folder structure

* refactored rl toolkit folder structure

* moved env_wrapper and agent_wrapper inside rl/learner

* refined scripts

* fixed typo in script

* changes needed for running sc

* removed unwanted imports

* config change for testing sc scenario

* changes for perf testing

* Asynchronous Training (#364)

* remote inference code draft

* changed actor to rollout_worker and updated init files

* removed unwanted import

* updated inits

* more async code

* added async scripts

* added async training code & scripts for CIM-dqn

* changed async to async_tools to avoid conflict with python keyword

* reverted unwanted change to dockerfile

* added doc for policy server

* addressed PR comments and fixed a bug in docker_compose_yml.py

* fixed lint issue

* resolved PR comment

* resolved merge conflicts

* added async templates

* added proxy.close() for actor and policy_server

* fixed incorrect position to add last episode to eval schedule

* reverted unwanted changes

* added missing async files

* rm unwanted echo in kill.sh

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* renamed sync to synchronous and async to asynchronous to avoid conflict with keyword

* added missing policy version increment in LocalPolicyManager

* refined rollout manager recv logic

* removed a debugging print

* added sleep in distributed launcher to avoid hanging

* updated api doc and rl toolkit doc

* refined dynamic imports using importlib

* 1. moved policy update triggers to policy manager; 2. added version control in policy manager

* fixed a few bugs and updated cim RL example

* fixed a few more bugs

* added agent wrapper instantiation to workflows

* added agent wrapper instantiation to workflows

* removed abs_block and added max_prob option for DiscretePolicyNet and DiscreteACNet

* fixed incorrect get_ac_policy signature for CIM

* moved exploration inside core policy

* added state to exploration call to support context-dependent exploration

* separated non_rl_policy_index and rl_policy_index in workflows

* modified sc example code according to workflow changes

* modified sc example code according to workflow changes

* added replay_agent_ids parameter to get_env_func for RL examples

* fixed a few bugs

* added maro/simulator/scenarios/supply_chain as bind mount

* added post-step, post-collect, post-eval and post-update callbacks

* fixed lint issues

* fixed lint issues

* moved instantiation of policy manager inside simple learner

* fixed env_wrapper get_reward signature

* minor edits

* removed get_eperience kwargs from env_wrapper

* 1. renamed step_callback to post_step in env_wrapper; 2. added get_eval_env_func to RL workflows

* added rollout exp disribution option in RL examples

* removed unwanted files

* 1. made logger internal in learner; 2 removed logger creation in abs classes

* checked out supply chain test files from v0.2_sc

* 1. added missing model.eval() to choose_action; 2.added entropy features to AC

* fixed a bug in ac entropy

* abbreviated coefficient to coeff

* removed -dqn from job name in rl example config

* added tmp patch to dev.df

* renamed image name for running rl examples

* added get_loss interface for core policies

* added policy manager in rl_toolkit.rst

* 1. env_wrapper bug fix; 2. policy manager update logic refinement

* refactored policy and algorithms

* policy interface redesigned

* refined policy interfaces

* fixed typo

* fixed bugs in refactored policy interface

* fixed some bugs

* refactoring in progress

* policy interface and policy manager redesigned

* 1. fixed bugs in ac and pg; 2. fixed bugs rl workflow scripts

* fixed bug in distributed policy manager

* fixed lint issues

* fixed lint issues

* added scipy in setup

* 1. trimmed rollout manager code; 2. added option to docker scripts

* updated api doc for policy manager

* 1. simplified rl/learning code structure; 2. fixed bugs in rl example docker script

* 1. simplified rl example structure; 2. fixed lint issues

* further rl toolkit code simplifications

* more numpy-based optimization in RL toolkit

* moved replay buffer inside policy

* bug fixes

* numpy optimization and associated refactoring

* extracted shaping logic out of env_sampler

* fixed bug in CIM shaping and lint issues

* preliminary implemetation of parallel batch inference

* fixed bug in ddpg transition recording

* put get_state, get_env_actions, get_reward back in EnvSampler

* simplified exploration and core model interfaces

* bug fixes and doc update

* added improve() interface for RLPolicy for single-thread support

* fixed simple policy manager bug

* updated doc, rst, notebook

* updated notebook

* fixed lint issues

* fixed entropy bugs in ac.py

* reverted to simple policy manager as default

* 1. unified single-thread and distributed mode in learning_loop.py; 2. updated api doc for algorithms and rst for rl toolkit

* fixed lint issues and updated rl toolkit images

* removed obsolete images

* added back agent2policy for general workflow use

* V0.2 rl refinement dist (#377)

* Support `slice` operation in ExperienceSet

* Support naive distributed policy training by proxy

* Dynamically allocate trainers according to number of experience

* code check

* code check

* code check

* Fix a bug in distributed trianing with no gradient

* Code check

* Move Back-Propagation from trainer to policy_manager and extract trainer-allocation strategy

* 1.call allocate_trainer() at first of update(); 2.refine according to code review

* Code check

* Refine code with new interface

* Update docs of PolicyManger and ExperienceSet

* Add images for rl_toolkit docs

* Update diagram of PolicyManager

* Refine with new interface

* Extract allocation strategy into `allocation_strategy.py`

* add `distributed_learn()` in policies for data-parallel training

* Update doc of RL_toolkit

* Add gradient workers for data-parallel

* Refine code and update docs

* Lint check

* Refine by comments

* Rename `trainer` to `worker`

* Rename `distributed_learn` to `learn_with_data_parallel`

* Refine allocator and remove redundant code in policy_manager

* remove arugments in allocate_by_policy and so on

* added checkpointing for simple and multi-process policy managers

* 1. bug fixes in checkpointing; 2. removed version and max_lag in rollout manager

* added missing set_state and get_state for CIM policies

* removed blank line

* updated RL workflow README

* Integrate `data_parallel` arguments into `worker_allocator` (#402)

* 1. simplified workflow config; 2. added comments to CIM shaping

* lint issue fix

* 1. added algorithm type setting in CIM config; 2. added try-except clause for initial policy state loading

* 1. moved post_step callback inside env sampler; 2. updated README for rl workflows

* refined READEME for CIM

* VM scheduling with RL (#375)

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* added DQN

* added get_experiences func for ac in vm scheduling

* added post_step callback to env wrapper

* moved Aiming's tracking and plotting logic into callbacks

* added eval env wrapper

* renamed AC config variable name for VM

* vm scheduling RL code finished

* updated README

* fixed various bugs and hard coding for vm_scheduling

* uncommented callbacks for VM scheduling

* Minor revision for better code style

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* vm example refactoring

* fixed bugs in vm_scheduling

* removed unwanted files from cim dir

* reverted to simple policy manager as default

* added part of vm scheduling RL code

* refined vm env_wrapper code style

* vm scheduling RL code finished

* added config.py for vm scheduing

* resolved rebase conflicts

* fixed bugs in vm_scheduling

* added get_state and set_state to vm_scheduling policy models

* updated README for vm_scheduling with RL

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* SC refinement (#397)

* Refine test scripts & pending_order_daily logic

* Refactor code for better code style: complete type hint, correct typos, remove unused items.

Refactor code for better code style: complete type hint, correct typos, remove unused items.

* Polish test_supply_chain.py

* update import format

* Modify vehicle steps logic & remove outdated test case

* Optimize imports

* Optimize imports

* Lint error

* Lint error

* Lint error

* Add SupplyChainAction

* Lint error

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* refined workflow scripts

* fixed bug in ParallelAgentWrapper

* 1. fixed lint issues; 2. refined main script in workflows

* lint issue fix

* restored default config for rl example

* Update rollout.py

* refined env var processing in policy manager workflow

* added hasattr check in agent wrapper

* updated docker_compose_yml.py

* Minor refinement

* Minor PR. Prepare to merge latest master branch into v0.3 branch. (#412)

* Prepare to merge master_mirror

* Lint error

* Minor

* Merge latest master into v0.3 (#426)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* Minor

* Remove docs/source/examples/multi_agent_dqn_cim.rst

* Update .gitignore

* Update .gitignore

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>

* Change `Env.set_seed()` logic (#456)

* Change Env.set_seed() logic

* Redesign CIM reset logic; fix lint issues;

* Lint

* Seed type assertion

* Remove all SC related files (#473)

* RL Toolkit V3 (#471)

* added daemon=True for multi-process rollout, policy manager and inference

* removed obsolete files

* [REDO][PR#406]V0.2 rl refinement taskq (#408)

* Add a usable task_queue

* Rename some variables

* 1. Add ; 2. Integrate  related files; 3. Remove

* merge `data_parallel` and `num_grad_workers` into `data_parallelism`

* Fix bugs in docker_compose_yml.py and Simple/Multi-process mode.

* Move `grad_worker` into marl/rl/workflows

* 1.Merge data_parallel and num_workers into data_parallelism in config; 2.Assign recently used workers as possible in task_queue.

* Refine code and update docs of `TaskQueue`

* Support priority for tasks in `task_queue`

* Update diagram of policy manager and task queue.

* Add configurable `single_task_limit` and correct docstring about `data_parallelism`

* Fix lint errors in `supply chain`

* RL policy redesign (V2) (#405)

* Drafi v2.0 for V2

* Polish models with more comments

* Polish policies with more comments

* Lint

* Lint

* Add developer doc for models.

* Add developer doc for policies.

* Remove policy manager V2 since it is not used and out-of-date

* Lint

* Lint

* refined messy workflow code

* merged 'scenario_dir' and 'scenario' in rl config

* 1. refined env_sampler and agent_wrapper code; 2. added docstrings for env_sampler methods

* 1. temporarily renamed RLPolicy from polivy_v2 to RLPolicyV2; 2. merged env_sampler and env_sampler_v2

* merged cim and cim_v2

* lint issue fix

* refined logging logic

* lint issue fix

* reversed unwanted changes

* .

.

.

.

ReplayMemory & IndexScheduler

ReplayMemory & IndexScheduler

.

MultiReplayMemory

get_actions_with_logps

EnvSampler on the road

EnvSampler

Minor

* LearnerManager

* Use batch to transfer data & add SHAPE_CHECK_FLAG

* Rename learner to trainer

* Add property for policy._is_exploring

* CIM test scenario for V3. Manual test passed. Next step: run it, make it works.

* env_sampler.py could run

* env_sampler refine on the way

* First runnable version done

* AC could run, but the result is bad. Need to check the logic

* Refine abstract method & shape check error info.

* Docs

* Very detailed compare. Try again.

* AC done

* DQN check done

* Minor

* DDPG, not tested

* Minors

* A rough draft of MAAC

* Cannot use CIM as the multi-agent scenario.

* Minor

* MAAC refinement on the way

* Remove ActionWithAux

* Refine batch & memory

* MAAC example works

* Reproduce-able fix. Policy share between env_sampler and trainer_manager.

* Detail refinement

* Simplify the user configed workflow

* Minor

* Refine example codes

* Minor polishment

* Migrate rollout_manager to V3

* Error on the way

* Redesign torch.device management

* Rl v3 maddpg (#418)

* Add MADDPG trainer

* Fit independent critics and shared critic modes.

* Add a new property: num_policies

* Lint

* Fix a bug in `sum(rewards)`

* Rename `MADDPG` to `DiscreteMADDPG` and fix type hint.

* Rename maddpg in examples.

* Preparation for data parallel (#420)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* Rename train worker to train ops; add placeholder for abstract methods;

* Lint

Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>

* [DRAFT] distributed training pipeline based on RL Toolkit V3 (#450)

* Preparation for data parallel

* Minor refinement & lint fix

* Lint

* Lint

* rename atomic_get_batch_grad to get_batch_grad

* Fix a unexpected commit

* distributed maddpg

* Add critic worker

* Minor

* Data parallel related minorities

* Refine code structure for trainers & add more doc strings

* Revert a unwanted change

* Use TrainWorker to do the actual calculations.

* Some minor redesign of the worker's abstraction

* Add set/get_policy_state_dict back

* Refine set/get_policy_state_dict

* Polish policy trainers

move train_batch_size to abs trainer
delete _train_step_impl()
remove _record_impl
remove unused methods
a minor bug fix in maddpg

* Rl v3 data parallel grad worker (#432)

* Fit new `trainer_worker` in `grad_worker` and `task_queue`.

* Add batch dispatch

* Add `tensor_dict` for task submit interface

* Move `_remote_learn` to `AbsTrainWorker`.

* Complement docstring for task queue and trainer.

* dsitributed training pipeline draft

* added temporary test files for review purposes

* Several code style refinements (#451)

* Polish rl_v3/utils/

* Polish rl_v3/distributed/

* Polish rl_v3/policy_trainer/abs_trainer.py

* fixed merge conflicts

* unified sync and async interfaces

* refactored rl_v3; refinement in progress

* Finish the runnable pipeline under new design

* Remove outdated files; refine class names; optimize imports;

* Lint

* Minor maddpg related refinement

* Lint

Co-authored-by: Default <huo53926@126.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Miner bug fix

* Coroutine-related bug fix ("get_policy_state") (#452)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* deleted unwanted folder

* removed unwanted changes

* resolved PR452 comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Quick fix

* Redesign experience recording logic (#453)

* Two not important fix

* Temp draft. Prepare to WFH

* Done

* Lint

* Lint

* Calculating advantages / returns (#454)

* V1.0

* Complete DDPG

* Rl v3 hanging issue fix (#455)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* Final test & format. Ready to merge.

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* Rl v3 parallel rollout (#457)

* fixed rebase conflicts

* renamed get_policy_func_dict to policy_creator

* unified worker interfaces

* recovered some files

* dist training + cli code move

* fixed bugs

* added retry logic to client

* 1. refactored CIM with various algos; 2. lint

* lint

* added type hint

* removed some logs

* lint

* Make main.py more IDE friendly

* Make main.py more IDE friendly

* Lint

* load balancing dispatcher

* added parallel rollout

* lint

* Tracker variable type issue; rename to env_sampler_creator;

* Rl v3 parallel rollout follow ups (#458)

* AbsWorker & AbsDispatcher

* Pass env idx to AbsTrainer.record() method, and let the trainer to decide how to record experiences sampled from different worlds.

* Fix policy_creator reuse bug

* Format code

* Merge AbsTrainerManager & SimpleTrainerManager

* AC test passed

* Lint

* Remove AbsTrainer.build() method. Put all initialization operations into __init__

* Redesign AC preprocess batches logic

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Huoran Li <huoranli@microsoft.com>

* MADDPG performance bug fix (#459)

* Fix MARL (MADDPG) terminal recording bug; some other minor refinements;

* Restore Trainer.build() method

* Calculate latest action in the get_actor_grad method in MADDPG.

* Share critic bug fix

* Rl v3 example update (#461)

* updated vm_scheduling example and cim notebook

* fixed bugs in vm_scheduling

* added local train method

* bug fix

* modified async client logic to fix hidden issue

* reverted to default config

* fixed PR comments and some bugs

* removed hardcode

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Done (#462)

* Rl v3 load save (#463)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL Toolkit data parallelism revamp & config utils (#464)

* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

* 1. fixed data parallelism issue; 2. added config validator; 3. refactored cli local

* 1. fixed rollout exit issue; 2. refined config

* removed config file from example

* fixed lint issues

* fixed lint issues

* added main.py under examples/rl

* fixed lint issues

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL doc string (#465)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* Rl config doc (#467)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* resolved more PR comments

* typo fix

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* RL online doc (#469)

* Model, policy, trainer

* RL workflows and env sampler doc in RST (#468)

* First rough draft

* Minors

* Reformat

* Lint

* Resolve PR comments

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* Rl type specific env getter (#466)

* 1. type-sensitive env variable getter; 2. updated READMEs for examples

* fixed bugs

* fixed bugs

* bug fixes

* lint

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Example bug fix

* Optimize parser.py

* Resolve PR comments

* added detailed doc

* lint

* wording refined

* resolved some PR comments

* rewriting rl toolkit rst

* resolved more PR comments

* typo fix

* updated rst

Co-authored-by: Huoran Li <huoranli@microsoft.com>
Co-authored-by: Default <huo53926@126.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Finish docs/source/key_components/rl_toolkit.rst

* API doc

* RL online doc image fix (#470)

* resolved some PR comments

* fix

* fixed PR comments

* added numfig=True setting in conf.py for sphinx

Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* Resolve PR comments

* Add example github link

Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

* Rl v3 pr comment resolution (#474)

* added load/save feature

* 1. resolved pr comments; 2. reverted maro/cli/k8s

* fixed some bugs

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>

Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: ysqyang <v-yangqi@microsoft.com>

* RL renaming v2 (#476)

* Change all Logger in RL to LoggerV2

* TrainerManager => TrainingManager

* Add Trainer suffix to all algorithms

* Finish docs

* Update interface names

* Minor fix

* Cherry pick latest RL (#498)

* Cherry pick

* Remove SC related files

* Cherry pick RL changes from `sc_refinement` (latest commit: `2a4869`) (#509)

* Cherry pick RL changes from sc_refinement (2a4869)

* Limit time display precision

* RL incremental refactor (#501)

* Refactor rollout logic. Allow multiple sampling in one epoch, so that we can generate more data for training.

AC & PPO for continuous action policy; refine AC & PPO logic.

Cherry pick RL changes from GYM-DDPG

Cherry pick RL changes from GYM-SAC

Minor error in doc string

* Add min_n_sample in template and parser

* Resolve PR comments. Fix a minor issue in SAC.

* RL component bundle (#513)

* CIM passed

* Update workers

* Refine annotations

* VM passed

* Code formatting.

* Minor import loop issue

* Pass batch in PPO again

* Remove Scenario

* Complete docs

* Minor

* Remove segment

* Optimize logic in RLComponentBundle

* Resolve PR comments

* Move 'post methods from RLComponenetBundle to EnvSampler

* Add method to get mapping of available tick to frame index (#415)

* add method to get mapping of available tick to frame index

* fix lint issue

* fix naming issue

* Cherry pick from sc_refinement (#527)

* Cherry pick from sc_refinement

* Cherry pick from sc_refinement

* Refine `terminal` / `next_agent_state` logic (#531)

* Optimize RL toolkit

* Fix bug in terminal/next_state generation

* Rewrite terminal/next_state logic again

* Minor renaming

* Minor bug fix

* Resolve PR comments

* Merge master into v0.3 (#536)

* update docker hub init (#367)

* update docker hub init

* replace personal account with maro-team

* update hello files for CIM

* update docker repository name

* update docker file name

* fix bugs in notebook, rectify docs

* fix doc build issue

* remove docs from playground; fix citibike lp example Event issue

* update the exampel for vector env

* update vector env example

* update README due to PR comments

* add link to playground above MARO installation in README

* fix some typos

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* update package version

* update README for package description

* update image links for pypi package description

* update image links for pypi package description

* change the input topology schema for CIM real data mode (#372)

* change the input topology schema for CIM real data mode

* remove unused importing

* update test config file correspondingly

* add Exception for env test

* add cost factors to cim data dump

* update CimDataCollection field name

* update field name of data collection related code

* update package version

* adjust interface to reflect actual signature (#374)

Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>

* update dataclasses requirement to setup

* fix: fixing spelling grammarr

* fix: fix typo spelling code commented and data_model.rst

* Fix Geo vis IP address & SQL logic bugs. (#383)

Fix Geo vis IP address & SQL logic bugs (issue [352](https://github.com/microsoft/maro/issues/352) and [314](https://github.com/microsoft/maro/issues/314)).

* Fix the "Wrong future stop tick predictions" bug (#386)

* Propose my new solution

Refine to the pre-process version

.

* Optimize import

* Fix reset random seed bug (#387)

* update the reset interface of Env and BE

* Try to fix reset routes generation seed issue

* Refine random related logics.

* Minor refinement

* Test check

* Minor

* Remove unused functions so far

* Minor

Co-authored-by: Jinyu Wang <jinywan@microsoft.com>

* update package version

* Add _init_vessel_plans in business_engine.reset (#388)

* update package version

* change the default solver used in Citibike OnlineLP example, from GLPK to CBC (#391)

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* Refine `event_buffer/` module (#389)

* Core & Business Engine code refinement (#392)

* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check

* add higher python version (#398)

* add higher python version

* update pytorch version

* update torchvision version

Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>

* CIM scenario refinement (#400)

* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch

* Add `clear()` function to class `SimRandom` (#401)

* Add SimRandom.clear()

* Minor

* Remove commented codes

* Lint error

* update package version

* add branch v0.3 to github workflow

* update github test workflow

* Update requirements.dev.txt (#444)

Added the versions of dependencies and resolve some conflicts occurs when installing. By adding these version number it will tell you the exact.

* Bump ipython from 7.10.1 to 7.16.3 in /notebooks (#460)

Bumps [ipython](https://github.com/ipython/ipython) from 7.10.1 to 7.16.3.
- [Release notes](https://github.com/ipython/ipython/releases)
- [Commits](https://github.com/ipython/ipython/compare/7.10.1...7.16.3)

---
updated-dependencies:
- dependency-name: ipython
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Add & sort requirements.dev.txt

Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: Jinyu Wang <jinywan@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Remove random_config.py

* Remove test_trajectory_utils.py

* Pass tests

* Update rl docs

* Remove python 3.6 in test

* Update docs

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: Wang.Jinyu <jinywan@microsoft.com>
Co-authored-by: ysqyang <ysqyang@gmail.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com>
Co-authored-by: GQ.Chen <675865907@qq.com>
Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com>
Co-authored-by: Jeremy Reynolds <jeremy.reynolds@microsoft.com>
Co-authored-by: Jeremy Reynolds <jeremr@microsoft.com>
Co-authored-by: slowy07 <slowy.arfy@gmail.com>
Co-authored-by: GQ.Chen <v-guanchen@microsoft.com>
Co-authored-by: Chaos Yu <chaos.you@gmail.com>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
Huoran Li 2022-06-01 15:05:22 +08:00 коммит произвёл GitHub
Родитель d25ebf4cb2
Коммит 0cc9a89cca
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
216 изменённых файлов: 12289 добавлений и 4677 удалений

12
.gitignore поставляемый
Просмотреть файл

@ -3,6 +3,7 @@
*.pyd
*.log
*.csv
*.parquet
*.c
*.cpp
*.DS_Store
@ -12,15 +13,18 @@
.vs/
build/
log/
logs/
checkpoint/
checkpoints/
streamit/
dist/
*.egg-info/
tools/schedule
docs/_build
test/
data/
.eggs/
maro_venv/
pyvenv.cfg
htmlcov/
.coverage
.coveragerc
.coverage
.coveragerc
.tmp/

36
docker_files/dev.df Normal file
Просмотреть файл

@ -0,0 +1,36 @@
FROM python:3.7-buster
WORKDIR /maro
# Install Apt packages
RUN apt-get update --fix-missing
RUN apt-get install -y apt-utils
RUN apt-get install -y sudo
RUN apt-get install -y gcc
RUN apt-get install -y libcurl4 libcurl4-openssl-dev libssl-dev curl
RUN apt-get install -y libzmq3-dev
RUN apt-get install -y python3-pip
RUN apt-get install -y python3-dev libpython3.7-dev python-numpy
RUN rm -rf /var/lib/apt/lists/*
# Install Python packages
RUN pip install --upgrade pip
RUN pip install --no-cache-dir Cython==0.29.14
RUN pip install --no-cache-dir pyaml==20.4.0
RUN pip install --no-cache-dir pyzmq==19.0.2
RUN pip install --no-cache-dir numpy==1.19.1
RUN pip install --no-cache-dir matplotlib
RUN pip install --no-cache-dir torch==1.6.0
RUN pip install --no-cache-dir scipy
RUN pip install --no-cache-dir matplotlib
RUN pip install --no-cache-dir redis
RUN pip install --no-cache-dir networkx
COPY maro /maro/maro
COPY scripts /maro/scripts/
COPY setup.py /maro/
RUN bash /maro/scripts/install_maro.sh
RUN pip cache purge
ENV PYTHONPATH=/maro
CMD ["/bin/bash"]

Просмотреть файл

@ -1,198 +1,330 @@
Agent
Distributed
================================================================================
maro.rl.agent.abs\_agent
maro.rl.distributed.abs_proxy
--------------------------------------------------------------------------------
.. automodule:: maro.rl.agent.abs_agent
.. automodule:: maro.rl.distributed.abs_proxy
:members:
:undoc-members:
:show-inheritance:
maro.rl.agent.dqn
maro.rl.distributed.abs_worker
--------------------------------------------------------------------------------
.. automodule:: maro.rl.agent.dqn
.. automodule:: maro.rl.distributed.abs_worker
:members:
:undoc-members:
:show-inheritance:
maro.rl.agent.ddpg
--------------------------------------------------------------------------------
.. automodule:: maro.rl.agent.ddpg
:members:
:undoc-members:
:show-inheritance:
maro.rl.agent.policy\_optimization
--------------------------------------------------------------------------------
.. automodule:: maro.rl.agent.policy_optimization
:members:
:undoc-members:
:show-inheritance:
Agent Manager
Exploration
================================================================================
maro.rl.agent.abs\_agent\_manager
maro.rl.exploration.scheduling
--------------------------------------------------------------------------------
.. automodule:: maro.rl.agent.abs_agent_manager
.. automodule:: maro.rl.exploration.scheduling
:members:
:undoc-members:
:show-inheritance:
maro.rl.exploration.strategies
--------------------------------------------------------------------------------
.. automodule:: maro.rl.exploration.strategies
:members:
:undoc-members:
:show-inheritance:
Model
================================================================================
maro.rl.model.learning\_model
maro.rl.model.algorithm_nets
--------------------------------------------------------------------------------
.. automodule:: maro.rl.model.torch.learning_model
.. automodule:: maro.rl.model.algorithm_nets
:members:
:undoc-members:
:show-inheritance:
maro.rl.model.abs_net
--------------------------------------------------------------------------------
Explorer
.. automodule:: maro.rl.model.abs_net
:members:
:undoc-members:
:show-inheritance:
maro.rl.model.fc_block
--------------------------------------------------------------------------------
.. automodule:: maro.rl.model.fc_block
:members:
:undoc-members:
:show-inheritance:
maro.rl.model.multi_q_net
--------------------------------------------------------------------------------
.. automodule:: maro.rl.model.multi_q_net
:members:
:undoc-members:
:show-inheritance:
maro.rl.model.policy_net
--------------------------------------------------------------------------------
.. automodule:: maro.rl.model.policy_net
:members:
:undoc-members:
:show-inheritance:
maro.rl.model.q_net
--------------------------------------------------------------------------------
.. automodule:: maro.rl.model.q_net
:members:
:undoc-members:
:show-inheritance:
maro.rl.model.v_net
--------------------------------------------------------------------------------
.. automodule:: maro.rl.model.v_net
:members:
:undoc-members:
:show-inheritance:
Policy
================================================================================
maro.rl.exploration.abs\_explorer
maro.rl.policy.abs_policy
--------------------------------------------------------------------------------
.. automodule:: maro.rl.exploration.abs_explorer
.. automodule:: maro.rl.policy.abs_policy
:members:
:undoc-members:
:show-inheritance:
maro.rl.exploration.epsilon\_greedy\_explorer
maro.rl.policy.continuous_rl_policy
--------------------------------------------------------------------------------
.. automodule:: maro.rl.exploration.epsilon_greedy_explorer
.. automodule:: maro.rl.policy.continuous_rl_policy
:members:
:undoc-members:
:show-inheritance:
maro.rl.exploration.noise\_explorer
maro.rl.policy.discrete_rl_policy
--------------------------------------------------------------------------------
.. automodule:: maro.rl.exploration.noise_explorer
.. automodule:: maro.rl.policy.discrete_rl_policy
:members:
:undoc-members:
:show-inheritance:
Scheduler
RL Component
================================================================================
maro.rl.scheduling.scheduler
maro.rl.rl_component.rl_component_bundle
--------------------------------------------------------------------------------
.. automodule:: maro.rl.scheduling.scheduler
.. automodule:: maro.rl.rl_component.rl_component_bundle
:members:
:undoc-members:
:show-inheritance:
maro.rl.scheduling.simple\_parameter\_scheduler
--------------------------------------------------------------------------------
.. automodule:: maro.rl.scheduling.simple_parameter_scheduler
:members:
:undoc-members:
:show-inheritance:
Shaping
Rollout
================================================================================
maro.rl.shaping.abs\_shaper
maro.rl.rollout.batch_env_sampler
--------------------------------------------------------------------------------
.. automodule:: maro.rl.shaping.abs_shaper
.. automodule:: maro.rl.rollout.batch_env_sampler
:members:
:undoc-members:
:show-inheritance:
maro.rl.rollout.env_sampler
--------------------------------------------------------------------------------
Storage
.. automodule:: maro.rl.rollout.env_sampler
:members:
:undoc-members:
:show-inheritance:
maro.rl.rollout.worker
--------------------------------------------------------------------------------
.. automodule:: maro.rl.rollout.worker
:members:
:undoc-members:
:show-inheritance:
Training
================================================================================
maro.rl.storage.abs\_store
maro.rl.training.algorithms
--------------------------------------------------------------------------------
.. automodule:: maro.rl.storage.abs_store
.. automodule:: maro.rl.training.algorithms
:members:
:undoc-members:
:show-inheritance:
maro.rl.storage.simple\_store
maro.rl.training.proxy
--------------------------------------------------------------------------------
.. automodule:: maro.rl.storage.simple_store
.. automodule:: maro.rl.training.proxy
:members:
:undoc-members:
:show-inheritance:
maro.rl.training.replay_memory
--------------------------------------------------------------------------------
Actor
.. automodule:: maro.rl.training.replay_memory
:members:
:undoc-members:
:show-inheritance:
maro.rl.training.trainer
--------------------------------------------------------------------------------
.. automodule:: maro.rl.training.trainer
:members:
:undoc-members:
:show-inheritance:
maro.rl.training.training_manager
--------------------------------------------------------------------------------
.. automodule:: maro.rl.training.training_manager
:members:
:undoc-members:
:show-inheritance:
maro.rl.training.train_ops
--------------------------------------------------------------------------------
.. automodule:: maro.rl.training.train_ops
:members:
:undoc-members:
:show-inheritance:
maro.rl.training.utils
--------------------------------------------------------------------------------
.. automodule:: maro.rl.training.utils
:members:
:undoc-members:
:show-inheritance:
maro.rl.training.worker
--------------------------------------------------------------------------------
.. automodule:: maro.rl.training.worker
:members:
:undoc-members:
:show-inheritance:
Utils
================================================================================
maro.rl.actor.abs\_actor
maro.rl.utils.common
--------------------------------------------------------------------------------
.. automodule:: maro.rl.actor.abs_actor
.. automodule:: maro.rl.utils.common
:members:
:undoc-members:
:show-inheritance:
maro.rl.actor.simple\_actor
maro.rl.utils.message_enums
--------------------------------------------------------------------------------
.. automodule:: maro.rl.actor.simple_actor
.. automodule:: maro.rl.utils.message_enums
:members:
:undoc-members:
:show-inheritance:
maro.rl.utils.objects
--------------------------------------------------------------------------------
Learner
.. automodule:: maro.rl.utils.objects
:members:
:undoc-members:
:show-inheritance:
maro.rl.utils.torch_utils
--------------------------------------------------------------------------------
.. automodule:: maro.rl.utils.torch_utils
:members:
:undoc-members:
:show-inheritance:
maro.rl.utils.trajectory_computation
--------------------------------------------------------------------------------
.. automodule:: maro.rl.utils.trajectory_computation
:members:
:undoc-members:
:show-inheritance:
maro.rl.utils.transition_batch
--------------------------------------------------------------------------------
.. automodule:: maro.rl.utils.transition_batch
:members:
:undoc-members:
:show-inheritance:
Workflows
================================================================================
maro.rl.learner.abs\_learner
maro.rl.workflows.config
--------------------------------------------------------------------------------
.. automodule:: maro.rl.learner.abs_learner
.. automodule:: maro.rl.workflows.config
:members:
:undoc-members:
:show-inheritance:
maro.rl.learner.simple\_learner
maro.rl.workflows.main
--------------------------------------------------------------------------------
.. automodule:: maro.rl.learner.simple_learner
.. automodule:: maro.rl.workflows.main
:members:
:undoc-members:
:show-inheritance:
Distributed Topologies
================================================================================
maro.rl.dist\_topologies.common
maro.rl.workflows.rollout_worker
--------------------------------------------------------------------------------
.. automodule:: maro.rl.dist_topologies.common
.. automodule:: maro.rl.workflows.rollout_worker
:members:
:undoc-members:
:show-inheritance:
maro.rl.dist\_topologies.single\_learner\_multi\_actor\_sync\_mode
maro.rl.workflows.scenario
--------------------------------------------------------------------------------
.. automodule:: maro.rl.dist_topologies.single_learner_multi_actor_sync_mode
.. automodule:: maro.rl.workflows.scenario
:members:
:undoc-members:
:show-inheritance:
maro.rl.workflows.train_proxy
--------------------------------------------------------------------------------
.. automodule:: maro.rl.workflows.train_proxy
:members:
:undoc-members:
:show-inheritance:
maro.rl.workflows.train_worker
--------------------------------------------------------------------------------
.. automodule:: maro.rl.workflows.train_worker
:members:
:undoc-members:
:show-inheritance:

Просмотреть файл

@ -100,3 +100,5 @@ source_parsers = {
}
source_suffix = [".md", ".rst"]
numfig = True

Просмотреть файл

@ -1,75 +0,0 @@
Example Scenario: Bike Repositioning (Citi Bike)
================================================
In this example we demonstrate using a simple greedy policy for `Citi Bike <https://maro.readthedocs.io/en/latest/scenarios/citi_bike.html>`_,
a real-world bike repositioning scenario.
Greedy Policy
-------------
Our greedy policy is simple: if the event type is supply, the policy will make
the current station send as many bikes as possible to one of k stations with the most empty docks. If the event type is
demand, the policy will make the current station request as many bikes as possible from one of k stations with the most
bikes. We use a heap data structure to find the top k supply/demand candidates from the action scope associated with
each decision event.
.. code-block:: python
class GreedyPolicy:
...
def choose_action(self, decision_event: DecisionEvent):
if decision_event.type == DecisionType.Supply:
"""
Find k target stations with the most empty slots, randomly choose one of them and send as many bikes to
it as allowed by the action scope
"""
top_k_demands = []
for demand_candidate, available_docks in decision_event.action_scope.items():
if demand_candidate == decision_event.station_idx:
continue
heapq.heappush(top_k_demands, (available_docks, demand_candidate))
if len(top_k_demands) > self._demand_top_k:
heapq.heappop(top_k_demands)
max_reposition, target_station_idx = random.choice(top_k_demands)
action = Action(decision_event.station_idx, target_station_idx, max_reposition)
else:
"""
Find k source stations with the most bikes, randomly choose one of them and request as many bikes from
it as allowed by the action scope.
"""
top_k_supplies = []
for supply_candidate, available_bikes in decision_event.action_scope.items():
if supply_candidate == decision_event.station_idx:
continue
heapq.heappush(top_k_supplies, (available_bikes, supply_candidate))
if len(top_k_supplies) > self._supply_top_k:
heapq.heappop(top_k_supplies)
max_reposition, source_idx = random.choice(top_k_supplies)
action = Action(source_idx, decision_event.station_idx, max_reposition)
return action
Interaction with the Greedy Policy
----------------------------------
This environment is driven by `real trip history data <https://s3.amazonaws.com/tripdata/index.html>`_ from Citi Bike.
.. code-block:: python
env = Env(scenario=config.env.scenario, topology=config.env.topology, start_tick=config.env.start_tick,
durations=config.env.durations, snapshot_resolution=config.env.resolution)
if config.env.seed is not None:
env.set_seed(config.env.seed)
policy = GreedyPolicy(config.agent.supply_top_k, config.agent.demand_top_k)
metrics, decision_event, done = env.step(None)
while not done:
metrics, decision_event, done = env.step(policy.choose_action(decision_event))
env.reset()

Просмотреть файл

@ -1,168 +0,0 @@
Multi Agent DQN for CIM
================================================
This example demonstrates how to use MARO's reinforcement learning (RL) toolkit to solve the container
inventory management (CIM) problem. It is formalized as a multi-agent reinforcement learning problem,
where each port acts as a decision agent. When a vessel arrives at a port, these agents must take actions
by transferring a certain amount of containers to / from the vessel. The objective is for the agents to
learn policies that minimize the overall container shortage.
Trajectory
----------
The ``CIMTrajectoryForDQN`` inherits from ``Trajectory`` function and implements methods to be used as callbacks
in the roll-out loop. In this example,
* ``get_state`` converts environment observations to state vectors that encode temporal and spatial information.
The temporal information includes relevant port and vessel information, such as shortage and remaining space,
over the past k days (here k = 7). The spatial information includes features of the downstream ports.
* ``get_action`` converts agents' output (an integer that maps to a percentage of containers to be loaded
to or unloaded from the vessel) to action objects that can be executed by the environment.
* ``get_offline_reward`` computes the reward of a given action as a linear combination of fulfillment and
shortage within a future time frame.
* ``on_finish`` processes a complete trajectory into data that can be used directly by the learning agents.
.. code-block:: python
class CIMTrajectoryForDQN(Trajectory):
def __init__(
self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream,
reward_time_window, fulfillment_factor, shortage_factor, time_decay,
finite_vessel_space=True, has_early_discharge=True
):
super().__init__(env)
self.port_attributes = port_attributes
self.vessel_attributes = vessel_attributes
self.action_space = action_space
self.look_back = look_back
self.max_ports_downstream = max_ports_downstream
self.reward_time_window = reward_time_window
self.fulfillment_factor = fulfillment_factor
self.shortage_factor = shortage_factor
self.time_decay = time_decay
self.finite_vessel_space = finite_vessel_space
self.has_early_discharge = has_early_discharge
def get_state(self, event):
vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
tick, port_idx, vessel_idx = event.tick, event.port_idx, event.vessel_idx
ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)]
future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes]
vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes]
return {port_idx: np.concatenate((port_features, vessel_features))}
def get_action(self, action_by_agent, event):
vessel_snapshots = self.env.snapshot_list["vessels"]
action_info = list(action_by_agent.values())[0]
model_action = action_info[0] if isinstance(action_info, tuple) else action_info
scope, tick, port, vessel = event.action_scope, event.tick, event.port_idx, event.vessel_idx
zero_action_idx = len(self.action_space) / 2 # index corresponding to value zero.
vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float("inf")
early_discharge = vessel_snapshots[tick:vessel:"early_discharge"][0] if self.has_early_discharge else 0
percent = abs(self.action_space[model_action])
if model_action < zero_action_idx:
action_type = ActionType.LOAD
actual_action = min(round(percent * scope.load), vessel_space)
elif model_action > zero_action_idx:
action_type = ActionType.DISCHARGE
plan_action = percent * (scope.discharge + early_discharge) - early_discharge
actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
else:
actual_action, action_type = 0, ActionType.LOAD
return {port: Action(vessel, port, actual_action, action_type)}
def get_offline_reward(self, event):
port_snapshots = self.env.snapshot_list["ports"]
start_tick = event.tick + 1
ticks = list(range(start_tick, start_tick + self.reward_time_window))
future_fulfillment = port_snapshots[ticks::"fulfillment"]
future_shortage = port_snapshots[ticks::"shortage"]
decay_list = [
self.time_decay ** i for i in range(self.reward_time_window)
for _ in range(future_fulfillment.shape[0] // self.reward_time_window)
]
tot_fulfillment = np.dot(future_fulfillment, decay_list)
tot_shortage = np.dot(future_shortage, decay_list)
return np.float32(self.fulfillment_factor * tot_fulfillment - self.shortage_factor * tot_shortage)
def on_env_feedback(self, event, state_by_agent, action_by_agent, reward):
self.trajectory["event"].append(event)
self.trajectory["state"].append(state_by_agent)
self.trajectory["action"].append(action_by_agent)
def on_finish(self):
exp_by_agent = defaultdict(lambda: defaultdict(list))
for i in range(len(self.trajectory["state"]) - 1):
agent_id = list(self.trajectory["state"][i].keys())[0]
exp = exp_by_agent[agent_id]
exp["S"].append(self.trajectory["state"][i][agent_id])
exp["A"].append(self.trajectory["action"][i][agent_id])
exp["R"].append(self.get_offline_reward(self.trajectory["event"][i]))
exp["S_"].append(list(self.trajectory["state"][i + 1].values())[0])
return dict(exp_by_agent)
Agent
-----
The out-of-the-box DQN is used as our agent.
.. code-block:: python
agent_config = {
"model": ...,
"optimization": ...,
"hyper_params": ...
}
def get_dqn_agent():
q_model = SimpleMultiHeadModel(
FullyConnectedBlock(**agent_config["model"]), optim_option=agent_config["optimization"]
)
return DQN(q_model, DQNConfig(**agent_config["hyper_params"]))
Training
--------
The distributed training consists of one learner process and multiple actor processes. The learner optimizes
the policy by collecting roll-out data from the actors to train the underlying agents.
The actor process must create a roll-out executor for performing the requested roll-outs, which means that the
the environment simulator and shapers should be created here. In this example, inference is performed on the
actor's side, so a set of DQN agents must be created in order to load the models (and exploration parameters)
from the learner.
.. code-block:: python
def cim_dqn_actor():
env = Env(**training_config["env"])
agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list})
actor = Actor(env, agent, CIMTrajectoryForDQN, trajectory_kwargs=common_config)
actor.as_worker(training_config["group"])
The learner's side requires a concrete learner class that inherits from ``AbsLearner`` and implements the ``run``
method which contains the main training loop. Here the implementation is similar to the single-threaded version
except that the ``collect`` method is used to obtain roll-out data from the actors (since the roll-out executors
are located on the actors' side). The agents created here are where training occurs and hence always contains the
latest policies.
.. code-block:: python
def cim_dqn_learner():
env = Env(**training_config["env"])
agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list})
scheduler = TwoPhaseLinearParameterScheduler(training_config["max_episode"], **training_config["exploration"])
actor = ActorProxy(
training_config["group"], training_config["num_actors"],
update_trigger=training_config["learner_update_trigger"]
)
learner = OffPolicyLearner(actor, scheduler, agent, **training_config["training"])
learner.run()
.. note::
All related code snippets are supported in `maro playground <https://hub.docker.com/r/maro2020/playground>`_.

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

До

Ширина:  |  Высота:  |  Размер: 186 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

После

Ширина:  |  Высота:  |  Размер: 27 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

После

Ширина:  |  Высота:  |  Размер: 28 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

До

Ширина:  |  Высота:  |  Размер: 180 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

После

Ширина:  |  Высота:  |  Размер: 24 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

После

Ширина:  |  Высота:  |  Размер: 28 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

После

Ширина:  |  Высота:  |  Размер: 64 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

После

Ширина:  |  Высота:  |  Размер: 29 KiB

Просмотреть файл

@ -89,7 +89,6 @@ Contents
:maxdepth: 2
:caption: Examples
examples/multi_agent_dqn_cim.rst
examples/greedy_policy_citi_bike.rst
.. toctree::

Просмотреть файл

@ -43,7 +43,7 @@ The main attributes of a message instance include:
message = Message(tag="check_in",
source="worker_001",
destination="master",
payload="")
body="")
Session Message
^^^^^^^^^^^^^^^
@ -71,13 +71,13 @@ The stages of each session are maintained internally by the proxy.
task_message = SessionMessage(tag="sum",
source="master",
destination="worker_001",
payload=[0, 1, 2, ...],
body=[0, 1, 2, ...],
session_type=SessionType.TASK)
notification_message = SessionMessage(tag="check_out",
source="worker_001",
destination="master",
payload="",
body="",
session_type=SessionType.NOTIFICATION)
Communication Primitives

Просмотреть файл

@ -259,3 +259,799 @@ For better data access, we also provide some advanced features, including:
# Also with dynamic implementation, we can get the const attributes which is shared between snapshot list, even without
# any snapshot (need to provided one tick for padding).
states = test_nodes_snapshots[0: [0, 1]: ["const_attribute", "const_attribute_2"]]
States in built-in scenarios' snapshot list
-------------------------------------------
.. TODO: move to environment part?
Currently there are 3 ways to expose states in built-in scenarios:
Summary
~~~~~~~~~~~
Summary(env.summary) is used to expose static states to outside, it provide 3 items by default:
node_mapping, node_detail and event payload.
The "node_mapping" item usually contains node name and related index, but the structure may be different
for different scenario.
The "node_detail" usually used to expose node definitions, like node name, attribute name and slot number,
this is useful if you want to know what attributes are support for a scenario.
The "event_payload" used show that payload attributes of event in scenario, like "RETURN_FULL" event in
CIM scenario, it contains "src_port_idx", "dest_port_idx" and "quantity".
Metrics
~~~~~~~
Metrics(env.metrics) is designed that used to expose raw states of reward since we have removed reward
support in v0.2 version, and it also can be used to export states that not supported by snapshot list, like dictionary or complex
structures. Currently there are 2 ways to get the metrics from environment: env.metrics, or 1st result from env.step.
This metrics usually is a dictionary with several keys, but this is determined by business engine.
Snapshot_list
~~~~~~~~~~~~~
Snapshot list is the history of nodes (or data model) for a scenario, it only support numberic data types now.
It supported slicing query with a numpy array, so it support batch operations, make it much faster than
using raw python objects.
Nodes and attributes may different for different scenarios, following we will introduce about those in
built-in scenarios.
NOTE:
Per tick state means that the attribute value will be reset to 0 after each step.
CIM
---
Default settings for snapshot list
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Snapshot resolution: 1
Max snapshot number: same as durations
Nodes and attributes in scenario
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In CIM scenario, there are 3 node types:
port
++++
capacity
********
type: int
slots: 1
The capacity of port for stocking containers.
empty
*****
type: int
slots: 1
Empty container volume on the port.
full
****
type: int
slots: 1
Laden container volume on the port.
on_shipper
**********
type: int
slots: 1
Empty containers, which are released to the shipper.
on_consignee
************
type: int
slots: 1
Laden containers, which are delivered to the consignee.
shortage
********
type: int
slots: 1
Per tick state. Shortage of empty container at current tick.
acc_storage
***********
type: int
slots: 1
Accumulated shortage number to the current tick.
booking
*******
type: int
slots: 1
Per tick state. Order booking number of a port at the current tick.
acc_booking
***********
type: int
slots: 1
Accumulated order booking number of a port to the current tick.
fulfillment
***********
type: int
slots: 1
Fulfilled order number of a port at the current tick.
acc_fulfillment
***************
type: int
slots: 1
Accumulated fulfilled order number of a port to the current tick.
transfer_cost
*************
type: float
slots: 1
Cost of transferring container, which also covers loading and discharging cost.
vessel
++++++
capacity
********
type: int
slots: 1
The capacity of vessel for transferring containers.
NOTE:
This attribute is ignored in current implementation.
empty
*****
type: int
slots: 1
Empty container volume on the vessel.
full
****
type: int
slots: 1
Laden container volume on the vessel.
remaining_space
***************
type: int
slots: 1
Remaining space of the vessel.
early_discharge
***************
type: int
slots: 1
Discharged empty container number for loading laden containers.
route_idx
*********
type: int
slots: 1
Which route current vessel belongs to.
last_loc_idx
************
type: int
slots: 1
Last stop port index in route, it is used to identify where is current vessel.
next_loc_idx
************
type: int
slots: 1
Next stop port index in route, it is used to identify where is current vessel.
past_stop_list
**************
type: int
slots: dynamic
NOTE:
This and following attribute are special, that its slot number is determined by configuration,
but different with a list attribute, its slot number is fixed at runtime.
Stop indices that we have stopped in the past.
past_stop_tick_list
*******************
type: int
slots: dynamic
Ticks that we stopped at the port in the past.
future_stop_list
****************
type: int
slots: dynamic
Stop indices that we will stop in the future.
future_stop_tick_list
*********************
type: int
slots: dynamic
Ticks that we will stop in the future.
matrices
++++++++
Matrices node is used to store big matrix for ports, vessels and containers.
full_on_ports
*************
type: int
slots: port number * port number
Distribution of full from port to port.
full_on_vessels
***************
type: int
slots: vessel number * port number
Distribution of full from vessel to port.
vessel_plans
************
type: int
slots: vessel number * port number
Planed route info for vessels.
How to
~~~~~~
How to use the matrix(s)
++++++++++++++++++++++++
Matrix is special that it only have one instance (index 0), and the value is saved as a flat 1 dim array, we can reshape it after querying.
.. code-block:: python
# assuming that we want to use full_on_ports attribute.
tick = 0
# we can get the instance number of a node by calling the len method
port_number = len(env.snapshot_list["port"])
# this is a 1 dim numpy array
full_on_ports = env.snapshot_list["matrices"][tick::"full_on_ports"]
# reshape it, then this is a 2 dim array that from port to port.
full_on_ports = full_on_ports.reshape(port_number, port_number)
Citi-Bike
---------
Default settings for snapshot list
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Snapshot resolution: 60
Max snapshot number: same as durations
Nodes and attributes in scenario
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
station
+++++++
bikes
*****
type: int
slots: 1
How many bikes avaiable in current station.
shortage
********
type: int
slots: 1
Per tick state. Lack number of bikes in current station.
trip_requirement
****************
type: int
slots: 1
Per tick states. How many requirements in current station.
fulfillment
***********
type: int
slots: 1
How many requirement is fit in current station.
capacity
********
type: int
slots: 1
Max number of bikes this station can take.
id
+++
type: int
slots: 1
Id of current station.
weekday
*******
type: short
slots: 1
Weekday at current tick.
temperature
***********
type: short
slots: 1
Temperature at current tick.
weather
*******
type: short
slots: 1
Weather at current tick.
0: sunny, 1: rainy, 2: snowy 3: sleet.
holiday
*******
type: short
slots: 1
If it is holidy at current tick.
0: holiday, 1: not holiday
extra_cost
**********
type: int
slots: 1
Cost after we reach the capacity after executing action, we have to move extra bikes
to other stations.
transfer_cost
*************
type: int
slots: 1
Cost to execute action to transfer bikes to other station.
failed_return
*************
type: int
slots: 1
Per tick state. How many bikes failed to return to current station.
min_bikes
*********
type: int
slots: 1
Min bikes number in a frame.
matrices
++++++++
trips_adj
*********
type: int
slots: station number * station number
Used to store trip requirement number between 2 stations.
VM-scheduling
-------------
Default settings for snapshot list
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Snapshot resolution: 1
Max snapshot number: same as durations
Nodes and attributes in scenario
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Cluster
+++++++
id
***
type: short
slots: 1
Id of the cluster.
region_id
*********
type: short
slots: 1
Region is of current cluster.
data_center_id
**************
type: short
slots: 1
Data center id of current cluster.
total_machine_num
******************
type: int
slots: 1
Total number of machines in the cluster.
empty_machine_num
******************
type: int
slots: 1
The number of empty machines in this cluster. A empty machine means that its allocated CPU cores are 0.
data_centers
++++++++++++
id
***
type: short
slots: 1
Id of current data center.
region_id
*********
type: short
slots: 1
Region id of current data center.
zone_id
*******
type: short
slots: 1
Zone id of current data center.
total_machine_num
*****************
type: int
slots: 1
Total number of machine in current data center.
empty_machine_num
*****************
type: int
slots: 1
The number of empty machines in current data center.
pms
+++
Physical machine node.
id
***
type: int
slots: 1
Id of current machine.
cpu_cores_capacity
******************
type: short
slots: 1
Max number of cpu core can be used for current machine.
memory_capacity
***************
type: short
slots: 1
Max number of memory can be used for current machine.
pm_type
*******
type: short
slots: 1
Type of current machine.
cpu_cores_allocated
*******************
type: short
slots: 1
How many cpu core is allocated.
memory_allocated
****************
type: short
slots: 1
How many memory is allocated.
cpu_utilization
***************
type: float
slots: 1
CPU utilization of current machine.
energy_consumption
******************
type: float
slots: 1
Energy consumption of current machine.
oversubscribable
****************
type: short
slots: 1
Physical machine type: non-oversubscribable is -1, empty: 0, oversubscribable is 1.
region_id
*********
type: short
slots: 1
Region id of current machine.
zone_id
*******
type: short
slots: 1
Zone id of current machine.
data_center_id
**************
type: short
slots: 1
Data center id of current machine.
cluster_id
**********
type: short
slots: 1
Cluster id of current machine.
rack_id
*******
type: short
slots: 1
Rack id of current machine.
Rack
++++
id
***
type: int
slots: 1
Id of current rack.
region_id
*********
type: short
slots: 1
Region id of current rack.
zone_id
*******
type: short
slots: 1
Zone id of current rack.
data_center_id
**************
type: short
slots: 1
Data center id of current rack.
cluster_id
**********
type: short
slots: 1
Cluster id of current rack.
total_machine_num
*****************
type: int
slots: 1
Total number of machines on this rack.
empty_machine_num
*****************
type: int
slots: 1
Number of machines that not in use on this rack.
regions
+++++++
id
***
type: short
slots: 1
Id of curent region.
total_machine_num
*****************
type: int
slots: 1
Total number of machines in this region.
empty_machine_num
*****************
type: int
slots: 1
Number of machines that not in use in this region.
zones
+++++
id
***
type: short
slots: 1
Id of this zone.
total_machine_num
*****************
type: int
slots: 1
Total number of machines in this zone.
empty_machine_num
*****************
type: int
slots: 1
Number of machines that not in use in this zone.

Просмотреть файл

@ -1,121 +1,198 @@
RL Toolkit
==========
MARO provides a full-stack abstraction for reinforcement learning (RL), which enables users to
apply predefined and customized components to various scenarios. The main abstractions include
fundamental components such as `Agent <#agent>`_\ and `Shaper <#shaper>`_\ , and training routine
controllers such as `Actor <#actor>` and `Learner <#learner>`.
MARO provides a full-stack abstraction for reinforcement learning (RL) which includes various customizable
components. In order to provide a gentle introduction for the RL toolkit, we cover the components in a top-down
manner, starting from the learning workflow.
Agent
-----
The Agent is the kernel abstraction of the RL formulation for a real-world problem.
Our abstraction decouples agent and its underlying model so that an agent can exist
as an RL paradigm independent of the inner workings of the models it uses to generate
actions or estimate values. For example, the actor-critic algorithm does not need to
concern itself with the structures and optimizing schemes of the actor and critic models.
This decoupling is achieved by the Core Model abstraction described below.
.. image:: ../images/rl/agent.svg
:target: ../images/rl/agent.svg
:alt: Agent
.. code-block:: python
class AbsAgent(ABC):
def __init__(self, model: AbsCoreModel, config, experience_pool=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device)
self.config = config
self._experience_pool = experience_pool
Core Model
----------
MARO provides an abstraction for the underlying models used by agents to form policies and estimate values.
The abstraction consists of ``AbsBlock`` and ``AbsCoreModel``, both of which subclass torch's nn.Module.
The ``AbsBlock`` represents the smallest structural unit of an NN-based model. For instance, the ``FullyConnectedBlock``
provided in the toolkit is a stack of fully connected layers with features like batch normalization,
drop-out and skip connection. The ``AbsCoreModel`` is a collection of network components with
embedded optimizers and serves as an agent's "brain" by providing a unified interface to it. regardless of how many individual models it requires and how
complex the model architecture might be.
As an example, the initialization of the actor-critic algorithm may look like this:
.. code-block:: python
actor_stack = FullyConnectedBlock(...)
critic_stack = FullyConnectedBlock(...)
model = SimpleMultiHeadModel(
{"actor": actor_stack, "critic": critic_stack},
optim_option={
"actor": OptimizerOption(cls=Adam, params={"lr": 0.001})
"critic": OptimizerOption(cls=RMSprop, params={"lr": 0.0001})
}
)
agent = ActorCritic("actor_critic", learning_model, config)
Choosing an action is simply:
.. code-block:: python
model(state, task_name="actor", training=False)
And performing one gradient step is simply:
.. code-block:: python
model.learn(critic_loss + actor_loss)
Explorer
Workflow
--------
MARO provides an abstraction for exploration in RL. Some RL algorithms such as DQN and DDPG require
explicit exploration governed by a set of parameters. The ``AbsExplorer`` class is designed to cater
to these needs. Simple exploration schemes, such as ``EpsilonGreedyExplorer`` for discrete action space
and ``UniformNoiseExplorer`` and ``GaussianNoiseExplorer`` for continuous action space, are provided in
the toolkit.
The nice thing about MARO's RL workflows is that it is abstracted neatly from business logic, policies and learning algorithms,
making it applicable to practically any scenario that utilizes standard reinforcement learning paradigms. The workflow is
controlled by a main process that executes 2-phase learning cycles: roll-out and training (:numref:`1`). The roll-out phase
collects data from one or more environment simulators for training. There can be a single environment simulator located in the same thread as the main
loop, or multiple environment simulators running in parallel on a set of remote workers (:numref:`2`) if you need to collect large amounts of data
fast. The training phase uses the data collected during the roll-out phase to train models involved in RL policies and algorithms.
In the case of multiple large models, this phase can be made faster by having the computationally intensive gradient-related tasks
sent to a set of remote workers for parallel processing (:numref:`3`).
As an example, the exploration for DQN may be carried out with the aid of an ``EpsilonGreedyExplorer``:
.. _1:
.. figure:: ../images/rl/learning_workflow.svg
:alt: Overview
:align: center
Learning Workflow
.. _2:
.. figure:: ../images/rl/parallel_rollout.svg
:alt: Overview
:align: center
Parallel Roll-out
.. _3:
.. figure:: ../images/rl/distributed_training.svg
:alt: Overview
:align: center
Distributed Training
Environment Sampler
-------------------
An environment sampler is an entity that contains an environment simulator and a set of policies used by agents to
interact with the environment (:numref:`4`). When creating an RL formulation for a scenario, it is necessary to define an environment
sampler class that includes these key elements:
- how observations / snapshots of the environment are encoded into state vectors as input to the policy models. This
is sometimes referred to as state shaping in applied reinforcement learning;
- how model outputs are converted to action objects defined by the environment simulator;
- how rewards / penalties are evaluated. This is sometimes referred to as reward shaping.
In parallel roll-out, each roll-out worker should have its own environment sampler instance.
.. _4:
.. figure:: ../images/rl/env_sampler.svg
:alt: Overview
:align: center
Environment Sampler
Policy
------
``Policy`` is the most important concept in reinforcement learning. In MARO, the highest level abstraction of a policy
object is ``AbsPolicy``. It defines the interface ``get_actions()`` which takes a batch of states as input and returns
corresponding actions.
The action is defined by the policy itself. It could be a scalar or a vector or any other types.
Env sampler should take responsibility for parsing the action to the acceptable format before passing it to the
environment.
The simplest type of policy is ``RuleBasedPolicy`` which generates actions by pre-defined rules. ``RuleBasedPolicy``
is mostly used in naive scenarios. However, in most cases where we need to train the policy by interacting with the
environment, we need to use ``RLPolicy``. In MARO's design, a policy cannot train itself. Instead,
polices could only be trained by :ref:`trainer` (we will introduce trainer later in this page). Therefore, in addition
to ``get_actions()``, ``RLPolicy`` also has a set of training-related interfaces, such as ``step()``, ``get_gradients()``
and ``set_gradients()``. These interfaces will be called by trainers for training. As you may have noticed, currently
we assume policies are built upon deep learning models, so the training-related interfaces are specifically
designed for gradient descent.
``RLPolicy`` is further divided into three types:
- ``ValueBasedPolicy``: For valued-based policies.
- ``DiscretePolicyGradient``: For gradient-based policies that generate discrete actions.
- ``ContinuousPolicyGradient``: For gradient-based policies that generate continuous actions.
The above classes are all concrete classes. Users do not need to implement any new classes, but can directly
create a policy object by configuring parameters. Here is a simple example:
.. code-block:: python
explorer = EpsilonGreedyExplorer(num_actions=10)
greedy_action = learning_model(state, training=False).argmax(dim=1).data
exploration_action = explorer(greedy_action)
ValueBasedPolicy(
name="policy",
q_net=MyQNet(state_dim=128, action_num=64),
)
Tools for Training
------------------------------
For now, you may have no idea about the ``q_net`` parameter, but don't worry, we will introduce it in the next section.
.. image:: ../images/rl/learner_actor.svg
:target: ../images/rl/learner_actor.svg
:alt: RL Overview
Model
-----
The RL toolkit provides tools that make local and distributed training easy:
* Learner, the central controller of the learning process, which consists of collecting simulation data from
remote actors and training the agents with them. The training data collection can be done in local or
distributed fashion by loading an ``Actor`` or ``ActorProxy`` instance, respectively.
* Actor, which implements the ``roll_out`` method where the agent interacts with the environment for one
episode. It consists of an environment instance and an agent (a single agent or multiple agents wrapped by
``MultiAgentWrapper``). The class provides the as_worker() method which turns it to an event loop where roll-outs
are performed on the learner's demand. In distributed RL, there are typically many actor processes running
simultaneously to parallelize training data collection.
* Actor proxy, which also implements the ``roll_out`` method with the same signature, but manages a set of remote
actors for parallel data collection.
* Trajectory, which is primarily responsible for translating between scenario-specific information and model
input / output. It implements the following methods which are used as callbacks in the actor's roll-out loop:
* ``get_state``, which converts observations of an environment into model input. For example, the observation
may be represented by a multi-level data structure, which gets encoded by a state shaper to a one-dimensional
vector as input to a neural network. The state shaper usually goes hand in hand with the underlying policy
or value models.
* ``get_action``, which provides model output with necessary context so that it can be executed by the
environment simulator.
* ``get_reward``, which computes a reward for a given action.
* ``on_env_feedback``, which defines things to do upon getting feedback from the environment.
* ``on_finish``, which defines things to do upon completion of a roll-out episode.
The above code snippet creates a ``ValueBasedPolicy`` object. Let's pay attention to the parameter ``q_net``.
``q_net`` accepts a ``DiscreteQNet`` object, and it serves as the core part of a ``ValueBasedPolicy`` object. In
other words, ``q_net`` defines the model structure of the Q-network in the value-based policy, and further determines
the policy's behavior. ``DiscreteQNet`` is an abstract class, and ``MyQNet`` is a user-defined implementation
of ``DiscreteQNet``. It can be a simple MLP, a multi-head transformer, or any other structure that the user wants.
MARO provides a set of abstractions of basic & commonly used PyTorch models like ``DiscereteQNet``, which enables
users to implement their own deep learning models in a handy way. They are:
- ``DiscreteQNet``: For ``ValueBasedPolicy``.
- ``DiscretePolicyNet``: For ``DiscretePolicyGradient``.
- ``ContinuousPolicyNet``: For ``ContinuousPolicyGradient``.
Users should choose the proper types of models according to the type of policies, and then implement their own
models by inheriting the abstract ones (just like ``MyQNet``).
There are also some other models for training purposes. For example:
- ``VNet``: Used in the critic part in the actor-critic algorithm.
- ``MultiQNet``: Used in the critic part in the MADDPG algorithm.
- ...
The way to use these models is exactly the same as the way to use the policy models.
.. _trainer:
Algorithm (Trainer)
-------
When introducing policies, we mentioned that policies cannot train themselves. Instead, they have to be trained
by external algorithms, which are also called trainers.
In MARO, a trainer represents an RL algorithm, such as DQN, actor-critic,
and so on. These two concepts are equivalent in the MARO context.
Trainers take interaction experiences and store them in the internal memory, and then use the experiences
in the memory to train the policies. Like ``RLPolicy``, trainers are also concrete classes, which means they could
be used by configuring parameters. Currently, we have 4 trainers (algorithms) in MARO:
- ``DiscreteActorCriticTrainer``: Actor-critic algorithm for policies that generate discrete actions.
- ``DiscretePPOTrainer``: PPO algorithm for policies that generate discrete actions.
- ``DDPGTrainer``: DDPG algorithm for policies that generate continuous actions.
- ``DQNTrainer``: DQN algorithm for policies that generate discrete actions.
- ``DiscreteMADDPGTrainer``: MADDPG algorithm for policies that generate discrete actions.
Each trainer has a corresponding ``Param`` class to manage all related parameters. For example,
``DiscreteActorCriticParams`` contains all parameters used in ``DiscreteActorCriticTrainer``:
.. code-block:: python
@dataclass
class DiscreteActorCriticParams(TrainerParams):
get_v_critic_net_func: Callable[[], VNet] = None
reward_discount: float = 0.9
grad_iters: int = 1
critic_loss_cls: Callable = None
clip_ratio: float = None
lam: float = 0.9
min_logp: Optional[float] = None
An example of creating an actor-critic trainer:
.. code-block:: python
DiscreteActorCriticTrainer(
name='ac',
params=DiscreteActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim=128),
reward_discount=.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
lam=.0
)
)
In order to indicate which trainer each policy is trained by, in MARO, we require that the name of the policy
start with the name of the trainer responsible for training it. For example, policy ``ac_1.policy_1`` is trained
by the trainer named ``ac_1``. Violating this provision will make MARO unable to correctly establish the
corresponding relationship between policy and trainer.
More details and examples can be found in the code base (`link`_).
.. _link: https://github.com/microsoft/maro/blob/master/examples/rl/cim/policy_trainer.py
As a summary, the relationship among policy, model, and trainer is demonstrated in :numref:`5`:
.. _5:
.. figure:: ../images/rl/policy_model_trainer.svg
:alt: Overview
:align: center
Summary of policy, model, and trainer

Просмотреть файл

@ -1,11 +0,0 @@
# Container Inventory Management
Container inventory management (CIM) is a scenario where reinforcement learning (RL) can potentially prove useful. Three algorithms are used to learn the multi-agent policy in given environments. Each algorithm has a ``config`` folder which contains ``agent_config.py`` and ``training_config.py``. The former contains parameters for the underlying models and algorithm specific hyper-parameters. The latter contains parameters for the environment and the main training loop. The file ``common.py`` contains parameters and utility functions shared by some or all of these algorithms.
In the ``ac`` folder, , the policy is trained using the Actor-Critc algorithm in single-threaded fashion. The example can be run by simply executing ``python3 main.py``. Logs will be saved in a file named ``cim-ac.CURRENT_TIME_STAMP.log`` under the ``ac/logs`` folder, where ``CURRENT_TIME_STAMP`` is the time of executing the script.
In the ``dqn`` folder, the policy is trained using the DQN algorithm in multi-process / distributed mode. This example can be run in three ways.
* ``python3 main.py`` or ``python3 main.py -w 0`` runs the example in multi-process mode, in which a main process spawns one learner process and a number of actor processes as specified in ``config/training_config.py``.
* ``python3 main.py -w 1`` launches the learner process only. This is for distributed training and expects a number of actor processes (as specified in ``config/training_config.py``) running on some other node(s).
* ``python3 main.py -w 2`` launches the actor process only. This is for distributed training and expects a learner process running on some other node.
Logs will be saved in a file named ``GROUP_NAME.log`` under the ``{ac_gnn, dqn}/logs`` folder, where ``GROUP_NAME`` is specified in the "group" field in ``config/training_config.py``.

Просмотреть файл

@ -1,7 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .agent_config import agent_config
from .training_config import training_config
__all__ = ["agent_config", "training_config"]

Просмотреть файл

@ -1,52 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from torch import nn
from torch.optim import Adam, RMSprop
from maro.rl import OptimOption
from examples.cim.common import common_config
input_dim = (
(common_config["look_back"] + 1) *
(common_config["max_ports_downstream"] + 1) *
len(common_config["port_attributes"]) +
len(common_config["vessel_attributes"])
)
agent_config = {
"model": {
"actor": {
"input_dim": input_dim,
"output_dim": len(common_config["action_space"]),
"hidden_dims": [256, 128, 64],
"activation": nn.Tanh,
"softmax": True,
"batch_norm": False,
"head": True
},
"critic": {
"input_dim": input_dim,
"output_dim": 1,
"hidden_dims": [256, 128, 64],
"activation": nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"head": True
}
},
"optimization": {
"actor": OptimOption(optim_cls=Adam, optim_params={"lr": 0.001}),
"critic": OptimOption(optim_cls=RMSprop, optim_params={"lr": 0.001})
},
"hyper_params": {
"reward_discount": .0,
"critic_loss_func": nn.SmoothL1Loss(),
"train_iters": 10,
"actor_loss_coefficient": 0.1,
"k": 1,
"lam": 0.0
# "clip_ratio": 0.8
}
}

Просмотреть файл

@ -1,11 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
training_config = {
"env": {
"scenario": "cim",
"topology": "toy.4p_ssdd_l0.0",
"durations": 1120,
},
"max_episode": 50
}

Просмотреть файл

@ -1,53 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
from maro.rl import (
Actor, ActorCritic, ActorCriticConfig, FullyConnectedBlock, MultiAgentWrapper, SimpleMultiHeadModel,
Scheduler, OnPolicyLearner
)
from maro.simulator import Env
from maro.utils import set_seeds
from examples.cim.ac.config import agent_config, training_config
from examples.cim.common import CIMTrajectory, common_config
def get_ac_agent():
actor_net = FullyConnectedBlock(**agent_config["model"]["actor"])
critic_net = FullyConnectedBlock(**agent_config["model"]["critic"])
ac_model = SimpleMultiHeadModel(
{"actor": actor_net, "critic": critic_net}, optim_option=agent_config["optimization"],
)
return ActorCritic(ac_model, ActorCriticConfig(**agent_config["hyper_params"]))
class CIMTrajectoryForAC(CIMTrajectory):
def on_finish(self):
training_data = {}
for event, state, action in zip(self.trajectory["event"], self.trajectory["state"], self.trajectory["action"]):
agent_id = list(state.keys())[0]
data = training_data.setdefault(agent_id, {"args": [[] for _ in range(4)]})
data["args"][0].append(state[agent_id]) # state
data["args"][1].append(action[agent_id][0]) # action
data["args"][2].append(action[agent_id][1]) # log_p
data["args"][3].append(self.get_offline_reward(event)) # reward
for agent_id in training_data:
training_data[agent_id]["args"] = [
np.asarray(vals, dtype=np.float32 if i == 3 else None)
for i, vals in enumerate(training_data[agent_id]["args"])
]
return training_data
# Single-threaded launcher
if __name__ == "__main__":
set_seeds(1024) # for reproducibility
env = Env(**training_config["env"])
agent = MultiAgentWrapper({name: get_ac_agent() for name in env.agent_idx_list})
actor = Actor(env, agent, CIMTrajectoryForAC, trajectory_kwargs=common_config) # local actor
learner = OnPolicyLearner(actor, training_config["max_episode"])
learner.run()

Просмотреть файл

@ -1,99 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import defaultdict
import numpy as np
from maro.rl import Trajectory
from maro.simulator.scenarios.cim.common import Action, ActionType
common_config = {
"port_attributes": ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"],
"vessel_attributes": ["empty", "full", "remaining_space"],
"action_space": list(np.linspace(-1.0, 1.0, 21)),
# Parameters for computing states
"look_back": 7,
"max_ports_downstream": 2,
# Parameters for computing actions
"finite_vessel_space": True,
"has_early_discharge": True,
# Parameters for computing rewards
"reward_time_window": 99,
"fulfillment_factor": 1.0,
"shortage_factor": 1.0,
"time_decay": 0.97
}
class CIMTrajectory(Trajectory):
def __init__(
self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream,
reward_time_window, fulfillment_factor, shortage_factor, time_decay,
finite_vessel_space=True, has_early_discharge=True
):
super().__init__(env)
self.port_attributes = port_attributes
self.vessel_attributes = vessel_attributes
self.action_space = action_space
self.look_back = look_back
self.max_ports_downstream = max_ports_downstream
self.reward_time_window = reward_time_window
self.fulfillment_factor = fulfillment_factor
self.shortage_factor = shortage_factor
self.time_decay = time_decay
self.finite_vessel_space = finite_vessel_space
self.has_early_discharge = has_early_discharge
def get_state(self, event):
vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
tick, port_idx, vessel_idx = event.tick, event.port_idx, event.vessel_idx
ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)]
future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes]
vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes]
return {port_idx: np.concatenate((port_features, vessel_features))}
def get_action(self, action_by_agent, event):
vessel_snapshots = self.env.snapshot_list["vessels"]
action_info = list(action_by_agent.values())[0]
model_action = action_info[0] if isinstance(action_info, tuple) else action_info
scope, tick, port, vessel = event.action_scope, event.tick, event.port_idx, event.vessel_idx
zero_action_idx = len(self.action_space) / 2 # index corresponding to value zero.
vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float("inf")
early_discharge = vessel_snapshots[tick:vessel:"early_discharge"][0] if self.has_early_discharge else 0
percent = abs(self.action_space[model_action])
if model_action < zero_action_idx:
action_type = ActionType.LOAD
actual_action = min(round(percent * scope.load), vessel_space)
elif model_action > zero_action_idx:
action_type = ActionType.DISCHARGE
plan_action = percent * (scope.discharge + early_discharge) - early_discharge
actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
else:
actual_action, action_type = 0, ActionType.LOAD
return {port: Action(vessel, port, actual_action, action_type)}
def get_offline_reward(self, event):
port_snapshots = self.env.snapshot_list["ports"]
start_tick = event.tick + 1
ticks = list(range(start_tick, start_tick + self.reward_time_window))
future_fulfillment = port_snapshots[ticks::"fulfillment"]
future_shortage = port_snapshots[ticks::"shortage"]
decay_list = [
self.time_decay ** i for i in range(self.reward_time_window)
for _ in range(future_fulfillment.shape[0] // self.reward_time_window)
]
tot_fulfillment = np.dot(future_fulfillment, decay_list)
tot_shortage = np.dot(future_shortage, decay_list)
return np.float32(self.fulfillment_factor * tot_fulfillment - self.shortage_factor * tot_shortage)
def on_env_feedback(self, event, state_by_agent, action_by_agent, reward):
self.trajectory["event"].append(event)
self.trajectory["state"].append(state_by_agent)
self.trajectory["action"].append(action_by_agent)

Просмотреть файл

@ -1,7 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .agent_config import agent_config
from .training_config import training_config
__all__ = ["agent_config", "training_config"]

Просмотреть файл

@ -1,38 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from torch import nn
from torch.optim import RMSprop
from maro.rl import DQN, DQNConfig, FullyConnectedBlock, OptimOption, PolicyGradient, SimpleMultiHeadModel
from examples.cim.common import common_config
input_dim = (
(common_config["look_back"] + 1) *
(common_config["max_ports_downstream"] + 1) *
len(common_config["port_attributes"]) +
len(common_config["vessel_attributes"])
)
agent_config = {
"model": {
"input_dim": input_dim,
"output_dim": len(common_config["action_space"]), # number of possible actions
"hidden_dims": [256, 128, 64],
"activation": nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0
},
"optimization": OptimOption(optim_cls=RMSprop, optim_params={"lr": 0.05}),
"hyper_params": {
"reward_discount": .0,
"loss_cls": nn.SmoothL1Loss,
"target_update_freq": 5,
"tau": 0.1,
"double": False
}
}

Просмотреть файл

@ -1,27 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
training_config = {
"env": {
"scenario": "cim",
"topology": "toy.4p_ssdd_l0.0",
"durations": 1120,
},
"max_episode": 100,
"exploration": {
"parameter_names": ["epsilon"],
"split": 0.5,
"start": 0.4,
"mid": 0.32,
"end": 0.0
},
"training": {
"min_experiences_to_train": 1024,
"train_iter": 10,
"batch_size": 128,
"prioritized_sampling_by_loss": True
},
"group": "cim-dqn",
"learner_update_trigger": 2,
"num_actors": 2
}

Просмотреть файл

@ -1,96 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import sys
from collections import defaultdict
from multiprocessing import Process
from os import makedirs
from os.path import dirname, join, realpath
from maro.rl import (
Actor, ActorProxy, DQN, DQNConfig, FullyConnectedBlock, MultiAgentWrapper, OffPolicyLearner,
SimpleMultiHeadModel, TwoPhaseLinearParameterScheduler
)
from maro.simulator import Env
from maro.utils import set_seeds
cim_dqn_path = dirname(realpath(__file__))
cim_example_path = dirname(cim_dqn_path)
sys.path.insert(0, cim_example_path)
from common import CIMTrajectory, common_config
from dqn.config import agent_config, training_config
log_dir = join(cim_dqn_path, "log")
makedirs(log_dir, exist_ok=True)
def get_dqn_agent():
q_model = SimpleMultiHeadModel(
FullyConnectedBlock(**agent_config["model"]), optim_option=agent_config["optimization"]
)
return DQN(q_model, DQNConfig(**agent_config["hyper_params"]))
class CIMTrajectoryForDQN(CIMTrajectory):
def on_finish(self):
exp_by_agent = defaultdict(lambda: defaultdict(list))
for i in range(len(self.trajectory["state"]) - 1):
agent_id = list(self.trajectory["state"][i].keys())[0]
exp = exp_by_agent[agent_id]
exp["S"].append(self.trajectory["state"][i][agent_id])
exp["A"].append(self.trajectory["action"][i][agent_id])
exp["R"].append(self.get_offline_reward(self.trajectory["event"][i]))
exp["S_"].append(list(self.trajectory["state"][i + 1].values())[0])
return dict(exp_by_agent)
def cim_dqn_learner():
env = Env(**training_config["env"])
agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list})
scheduler = TwoPhaseLinearParameterScheduler(training_config["max_episode"], **training_config["exploration"])
actor = ActorProxy(
training_config["group"], training_config["num_actors"],
update_trigger=training_config["learner_update_trigger"],
log_dir=log_dir
)
learner = OffPolicyLearner(actor, scheduler, agent, **training_config["training"], log_dir=log_dir)
learner.run()
def cim_dqn_actor():
env = Env(**training_config["env"])
agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list})
actor = Actor(env, agent, CIMTrajectoryForDQN, trajectory_kwargs=common_config)
actor.as_worker(training_config["group"], log_dir=log_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-w", "--whoami", type=int, choices=[0, 1, 2], default=0,
help="Identity of this process: 0 - multi-process mode, 1 - learner, 2 - actor"
)
args = parser.parse_args()
if args.whoami == 0:
actor_processes = [Process(target=cim_dqn_actor) for _ in range(training_config["num_actors"])]
learner_process = Process(target=cim_dqn_learner)
for i, actor_process in enumerate(actor_processes):
set_seeds(i) # this is to ensure that the actors explore differently.
actor_process.start()
learner_process.start()
for actor_process in actor_processes:
actor_process.join()
learner_process.join()
elif args.whoami == 1:
cim_dqn_learner()
elif args.whoami == 2:
cim_dqn_actor()

Просмотреть файл

@ -0,0 +1,9 @@
# Container Inventory Management
This example demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time. In this folder you can find:
* ``__init__.py``, the entrance of this example. You must expose a `rl_component_bundle_cls` interface in `__init__.py` (see the example file for details);
* ``config.py``, which contains general configurations for the scenario;
* ``algorithms/``, which contains configurations for the PPO, Actor-Critic, DQN and discrete-MADDPG algorithms, including network configurations;
* ``rl_componenet_bundle.py``, which defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read `CIMBundle` to learn its basic usage.
We recommend that you follow this example to write your own scenarios.

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .rl_component_bundle import CIMBundle as rl_component_bundle_cls
__all__ = [
"rl_component_bundle_cls",
]

Просмотреть файл

Просмотреть файл

@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict
import torch
from torch.optim import Adam, RMSprop
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
actor_net_conf = {
"hidden_dims": [256, 128, 64],
"activation": torch.nn.Tanh,
"softmax": True,
"batch_norm": False,
"head": True,
}
critic_net_conf = {
"hidden_dims": [256, 128, 64],
"output_dim": 1,
"activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"head": True,
}
actor_learning_rate = 0.001
critic_learning_rate = 0.001
class MyActorNet(DiscreteACBasedNet):
def __init__(self, state_dim: int, action_num: int) -> None:
super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
return self._actor(states)
class MyCriticNet(VNet):
def __init__(self, state_dim: int) -> None:
super(MyCriticNet, self).__init__(state_dim=state_dim)
self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf)
self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate)
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
return self._critic(states).squeeze(-1)
def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))
def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
lam=.0,
),
)

Просмотреть файл

@ -0,0 +1,66 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict
import torch
from torch.optim import RMSprop
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNTrainer, DQNParams
q_net_conf = {
"hidden_dims": [256, 128, 64, 32],
"activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}
learning_rate = 0.05
class MyQNet(DiscreteQNet):
def __init__(self, state_dim: int, action_num: int) -> None:
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
return self._fc(states)
def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num),
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
exploration_scheduling_options=[(
"epsilon", MultiLinearExplorationScheduler, {
"splits": [(2, 0.32)],
"initial_value": 0.4,
"last_ep": 5,
"final_value": 0.0,
}
)],
warmup=100,
)
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
params=DQNParams(
reward_discount=.0,
update_target_every=5,
num_epochs=10,
soft_update_coef=0.1,
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
),
)

Просмотреть файл

@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Dict, List
import torch
from torch.optim import Adam, RMSprop
from maro.rl.model import DiscreteACBasedNet, FullyConnected, MultiQNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteMADDPGTrainer, DiscreteMADDPGParams
actor_net_conf = {
"hidden_dims": [256, 128, 64],
"activation": torch.nn.Tanh,
"softmax": True,
"batch_norm": False,
"head": True
}
critic_net_conf = {
"hidden_dims": [256, 128, 64],
"output_dim": 1,
"activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"head": True
}
actor_learning_rate = 0.001
critic_learning_rate = 0.001
# #####################################################################################################################
class MyActorNet(DiscreteACBasedNet):
def __init__(self, state_dim: int, action_num: int) -> None:
super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
return self._actor(states)
class MyMultiCriticNet(MultiQNet):
def __init__(self, state_dim: int, action_dims: List[int]) -> None:
super(MyMultiCriticNet, self).__init__(state_dim=state_dim, action_dims=action_dims)
self._critic = FullyConnected(input_dim=state_dim + sum(action_dims), **critic_net_conf)
self._optim = RMSprop(self._critic.parameters(), critic_learning_rate)
def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor:
return self._critic(torch.cat([states] + actions, dim=1)).squeeze(-1)
def get_multi_critic_net(state_dim: int, action_dims: List[int]) -> MyMultiCriticNet:
return MyMultiCriticNet(state_dim, action_dims)
def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))
def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer:
return DiscreteMADDPGTrainer(
name=name,
params=DiscreteMADDPGParams(
reward_discount=.0,
num_epoch=10,
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
shared_critic=False
)
)

Просмотреть файл

@ -0,0 +1,25 @@
import torch
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import PPOParams, PPOTrainer
from .ac import MyActorNet, MyCriticNet
def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))
def get_ppo(state_dim: int, name: str) -> PPOTrainer:
return PPOTrainer(
name=name,
params=PPOParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
lam=.0,
clip_ratio=0.1,
),
)

44
examples/cim/rl/config.py Normal file
Просмотреть файл

@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
env_conf = {
"scenario": "cim",
"topology": "toy.4p_ssdd_l0.0",
"durations": 560
}
if env_conf["topology"].startswith("toy"):
num_agents = int(env_conf["topology"].split(".")[1][0])
else:
num_agents = int(env_conf["topology"].split(".")[1][:2])
port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
vessel_attributes = ["empty", "full", "remaining_space"]
state_shaping_conf = {
"look_back": 7,
"max_ports_downstream": 2
}
action_shaping_conf = {
"action_space": [(i - 10) / 10 for i in range(21)],
"finite_vessel_space": True,
"has_early_discharge": True
}
reward_shaping_conf = {
"time_window": 99,
"fulfillment_factor": 1.0,
"shortage_factor": 1.0,
"time_decay": 0.97
}
# obtain state dimension from a temporary env_wrapper instance
state_dim = (
(state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
+ len(vessel_attributes)
)
action_num = len(action_shaping_conf["action_space"])
algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg

Просмотреть файл

@ -0,0 +1,95 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, List, Tuple, Union
import numpy as np
from maro.rl.rollout import AbsEnvSampler, CacheElement
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent
from .config import (
action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf,
vessel_attributes,
)
class CIMEnvSampler(AbsEnvSampler):
def _get_global_and_agent_state_impl(
self, event: DecisionEvent, tick: int = None,
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
tick = self._env.tick
vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
port_idx, vessel_idx = event.port_idx, event.vessel_idx
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
state = np.concatenate([
port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes],
vessel_snapshots[tick: vessel_idx: vessel_attributes]
])
return state, {port_idx: state}
def _translate_to_env_action(
self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionEvent,
) -> Dict[Any, object]:
action_space = action_shaping_conf["action_space"]
finite_vsl_space = action_shaping_conf["finite_vessel_space"]
has_early_discharge = action_shaping_conf["has_early_discharge"]
port_idx, model_action = list(action_dict.items()).pop()
vsl_idx, action_scope = event.vessel_idx, event.action_scope
vsl_snapshots = self._env.snapshot_list["vessels"]
vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")
percent = abs(action_space[model_action[0]])
zero_action_idx = len(action_space) / 2 # index corresponding to value zero.
if model_action < zero_action_idx:
action_type = ActionType.LOAD
actual_action = min(round(percent * action_scope.load), vsl_space)
elif model_action > zero_action_idx:
action_type = ActionType.DISCHARGE
early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
else:
actual_action, action_type = 0, None
return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)}
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:
start_tick = tick + 1
ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))
# Get the ports that took actions at the given tick
ports = [int(port) for port in list(env_action_dict.keys())]
port_snapshots = self._env.snapshot_list["ports"]
future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)
decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
rewards = np.float32(
reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
)
return {agent_id: reward for agent_id, reward in zip(ports, rewards)}
def _post_step(self, cache_element: CacheElement) -> None:
self._info["env_metric"] = self._env.metrics
def _post_eval_step(self, cache_element: CacheElement) -> None:
self._post_step(cache_element)
def post_collect(self, info_list: list, ep: int) -> None:
# print the env metric from each rollout worker
for info in info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")
# print the average env metric
if len(info_list) > 1:
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
print(f"average env summary (episode {ep}): {avg_metric}")
def post_evaluate(self, info_list: list, ep: int) -> None:
self.post_collect(info_list, ep)

Просмотреть файл

@ -0,0 +1,84 @@
from functools import partial
from typing import Any, Callable, Dict, Optional
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
from examples.cim.rl.env_sampler import CIMEnvSampler
from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
from .algorithms.ac import get_ac_policy
from .algorithms.dqn import get_dqn_policy
from .algorithms.maddpg import get_maddpg_policy
from .algorithms.ppo import get_ppo_policy
from .algorithms.ac import get_ac
from .algorithms.ppo import get_ppo
from .algorithms.dqn import get_dqn
from .algorithms.maddpg import get_maddpg
class CIMBundle(RLComponentBundle):
def get_env_config(self) -> dict:
return env_conf
def get_test_env_config(self) -> Optional[dict]:
return None
def get_env_sampler(self) -> AbsEnvSampler:
return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"])
def get_agent2policy(self) -> Dict[Any, str]:
return {agent: f"{algorithm}_{agent}.policy"for agent in self.env.agent_idx_list}
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
if algorithm == "ac":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_ac_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "ppo":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_ppo_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "dqn":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_dqn_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "discrete_maddpg":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_maddpg_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return policy_creator
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
if algorithm == "ac":
trainer_creator = {
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}")
for i in range(num_agents)
}
elif algorithm == "ppo":
trainer_creator = {
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}")
for i in range(num_agents)
}
elif algorithm == "dqn":
trainer_creator = {
f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}")
for i in range(num_agents)
}
elif algorithm == "discrete_maddpg":
trainer_creator = {
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}")
for i in range(num_agents)
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return trainer_creator

Просмотреть файл

Просмотреть файл

@ -99,7 +99,7 @@ demand is 34 (at a specific station, during a time interval of 20 minutes), the
corresponding demand distribution shows that demand exceeding 10 bikes per time
interval (20 minutes) is only 2%.
![Demand Distribution Between Tick 2400 ~ Tick 2519](./LogDemand.ny201910.2400.png)
![Demand Distribution Between Tick 2400 ~ Tick 2519](LogDemand.ny201910.2400.png)
Besides, we can also find that the percentage of forecasting results that differ
to the data extracted from trip log is not low. To dive deeper in the practical
@ -110,9 +110,9 @@ show the distribution of the forecasting difference to the trip log. One for the
interval with the *Max Diff* (16:00-18:00), one for the interval with the highest
percentage of *Diff > 5* (10:00-12:00).
![Demand Distribution Between Tick 2400 ~ Tick 2519](./DemandDiff.ny201910.2400.png)
![Demand Distribution Between Tick 2400 ~ Tick 2519](DemandDiff.ny201910.2400.png)
![Demand Distribution Between Tick 2040 ~ Tick 2159](./DemandDiff.ny201910.2040.png)
![Demand Distribution Between Tick 2040 ~ Tick 2159](DemandDiff.ny201910.2040.png)
Maybe due to the *sparse* and *small* trip demand, and the *small* difference
between the forecasting results and data extracted from the trip log data, the

Просмотреть файл

Просмотреть файл

@ -75,10 +75,10 @@ class MaIlpAgent():
event_type = finished_events[self._next_event_idx].event_type
if event_type == CitiBikeEvents.RequireBike:
# TODO: Replace it with a pre-defined PayLoad.
payload = finished_events[self._next_event_idx].payload
payload = finished_events[self._next_event_idx].body
demand_history[interval_idx, payload.src_station] += 1
elif event_type == CitiBikeEvents.ReturnBike:
payload: BikeReturnPayload = finished_events[self._next_event_idx].payload
payload: BikeReturnPayload = finished_events[self._next_event_idx].body
supply_history[interval_idx, payload.to_station_idx] += payload.number
# Update the index to the finished event that has not been processed.
@ -129,7 +129,7 @@ class MaIlpAgent():
# Process to get the future supply from Pending Events.
for pending_event in ENV.get_pending_events(tick=tick):
if pending_event.event_type == CitiBikeEvents.ReturnBike:
payload: BikeReturnPayload = pending_event.payload
payload: BikeReturnPayload = pending_event.body
supply[interval_idx, payload.to_station_idx] += payload.number
return demand, supply

Просмотреть файл

@ -21,13 +21,13 @@ def worker(group_name):
print(f"{proxy.name}'s counter is {counter}.")
# Nonrecurring receive the message from the proxy.
for msg in proxy.receive(is_continuous=False):
print(f"{proxy.name} receive message from {msg.source}.")
msg = proxy.receive_once()
print(f"{proxy.name} received message from {msg.source}.")
if msg.tag == "INC":
counter += 1
print(f"{proxy.name} receive INC request, {proxy.name}'s count is {counter}.")
proxy.reply(message=msg, tag="done")
if msg.tag == "INC":
counter += 1
print(f"{proxy.name} receive INC request, {proxy.name}'s count is {counter}.")
proxy.reply(message=msg, tag="done")
def master(group_name: str, worker_num: int, is_immediate: bool = False):

Просмотреть файл

@ -21,12 +21,12 @@ def summation_worker(group_name):
expected_peers={"master": 1})
# Nonrecurring receive the message from the proxy.
for msg in proxy.receive(is_continuous=False):
print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.")
msg = proxy.receive_once()
print(f"{proxy.name} received message from {msg.source}. the payload is {msg.body}.")
if msg.tag == "job":
replied_payload = sum(msg.payload)
proxy.reply(message=msg, tag="sum", payload=replied_payload)
if msg.tag == "job":
replied_payload = sum(msg.body)
proxy.reply(message=msg, tag="sum", body=replied_payload)
def multiplication_worker(group_name):
@ -41,12 +41,12 @@ def multiplication_worker(group_name):
expected_peers={"master": 1})
# Nonrecurring receive the message from the proxy.
for msg in proxy.receive(is_continuous=False):
print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.")
msg = proxy.receive_once()
print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.body}.")
if msg.tag == "job":
replied_payload = np.prod(msg.payload)
proxy.reply(message=msg, tag="multiply", payload=replied_payload)
if msg.tag == "job":
replied_payload = np.prod(msg.body)
proxy.reply(message=msg, tag="multiply", body=replied_payload)
def master(group_name: str, sum_worker_number: int, multiply_worker_number: int, is_immediate: bool = False):
@ -73,13 +73,13 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
# Assign sum tasks for summation workers.
destination_payload_list = []
for idx, peer in enumerate(proxy.peers_name["sum_worker"]):
data_length_per_peer = int(len(sum_list) / len(proxy.peers_name["sum_worker"]))
for idx, peer in enumerate(proxy.peers["sum_worker"]):
data_length_per_peer = int(len(sum_list) / len(proxy.peers["sum_worker"]))
destination_payload_list.append((peer, sum_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer]))
# Assign multiply tasks for multiplication workers.
for idx, peer in enumerate(proxy.peers_name["multiply_worker"]):
data_length_per_peer = int(len(multiple_list) / len(proxy.peers_name["multiply_worker"]))
for idx, peer in enumerate(proxy.peers["multiply_worker"]):
data_length_per_peer = int(len(multiple_list) / len(proxy.peers["multiply_worker"]))
destination_payload_list.append(
(peer, multiple_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer]))
@ -98,11 +98,11 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
sum_result, multiply_result = 0, 1
for msg in replied_msgs:
if msg.tag == "sum":
print(f"{proxy.name} receive message from {msg.source} with the sum result {msg.payload}.")
sum_result += msg.payload
print(f"{proxy.name} receive message from {msg.source} with the sum result {msg.body}.")
sum_result += msg.body
elif msg.tag == "multiply":
print(f"{proxy.name} receive message from {msg.source} with the multiply result {msg.payload}.")
multiply_result *= msg.payload
print(f"{proxy.name} receive message from {msg.source} with the multiply result {msg.body}.")
multiply_result *= msg.body
# Check task result correction.
assert(sum(sum_list) == sum_result)

Просмотреть файл

@ -21,12 +21,12 @@ def worker(group_name):
expected_peers={"master": 1})
# Nonrecurring receive the message from the proxy.
for msg in proxy.receive(is_continuous=False):
print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.")
msg = proxy.receive_once()
print(f"{proxy.name} received message from {msg.source}. the payload is {msg.body}.")
if msg.tag == "sum":
replied_payload = sum(msg.payload)
proxy.reply(message=msg, tag="sum", payload=replied_payload)
if msg.tag == "sum":
replied_payload = sum(msg.body)
proxy.reply(message=msg, tag="sum", body=replied_payload)
def master(group_name: str, is_immediate: bool = False):
@ -47,11 +47,11 @@ def master(group_name: str, is_immediate: bool = False):
random_integer_list = np.random.randint(0, 100, 5)
print(f"generate random integer list: {random_integer_list}.")
for peer in proxy.peers_name["worker"]:
for peer in proxy.peers["worker"]:
message = SessionMessage(tag="sum",
source=proxy.name,
destination=peer,
payload=random_integer_list,
body=random_integer_list,
session_type=SessionType.TASK)
if is_immediate:
session_id = proxy.isend(message)
@ -61,7 +61,7 @@ def master(group_name: str, is_immediate: bool = False):
replied_msgs = proxy.send(message, timeout=-1)
for msg in replied_msgs:
print(f"{proxy.name} receive {msg.source}, replied payload is {msg.payload}.")
print(f"{proxy.name} receive {msg.source}, replied payload is {msg.body}.")
if __name__ == "__main__":

19
examples/rl/README.md Normal file
Просмотреть файл

@ -0,0 +1,19 @@
# Reinforcement Learning (RL) Examples
This folder contains scenarios that employ reinforcement learning. MARO's RL toolkit provides scenario-agnostic workflows to run a variety of scenarios in single-thread, multi-process or distributed modes.
## How to Run
The entrance of a RL workflow is a YAML config file. For readers' convenience, we call this config file `config.yml` in the rest part of this doc. `config.yml` specifies the path of all necessary resources, definitions, and configurations to run the job. MARO provides a comprehensive template of the config file with detailed explanations (`maro/maro/rl/workflows/config/template.yml`). Meanwhile, MARO also provides several simple examples of `config.yml` under the current folder.
There are two ways to start the RL job:
- If you only need to have a quick look and try to start an out-of-box workflow, just run `python .\examples\rl\run_rl_example.py PATH_TO_CONFIG_YAML`. For example, `python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml` will run the complete example RL training workflow of CIM scenario. If you only want to run the evaluation workflow, you could start the job with `--evaluate_only`.
- (**Require install MARO from source**) You could also start the job through MARO CLI. Use the command `maro local run [-c] path/to/your/config` to run in containerized (with `-c`) or non-containerized (without `-c`) environments. Similar, you could add `--evaluate_only` if you only need to run the evaluation workflow.
## Create Your Own Scenarios
You can create your own scenarios by supplying the necessary ingredients without worrying about putting them together in a workflow. It is necessary to create an ``__init__.py`` under your scenario folder (so that it can be treated as a package) and expose a `rl_component_bundle_cls` interface. The MARO's RL workflow will use this interface to create a `RLComponentBundle` instance and start the RL workflow based on it. a `RLComponentBundle` instance defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read one of the examples to learn its basic usage.
## Example
For a complete example, please check `examples/cim/rl`.

34
examples/rl/cim.yml Normal file
Просмотреть файл

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for CIM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
# Run this workflow by executing one of the following commands:
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml
job: cim_rl_workflow
scenario_path: "examples/cim/rl"
log_path: "log/rl_job/cim.txt"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
eval_schedule: 5
logging:
stdout: INFO
file: DEBUG
rollout:
logging:
stdout: INFO
file: DEBUG
training:
mode: simple
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/cim"
interval: 5
logging:
stdout: INFO
file: DEBUG

Просмотреть файл

@ -0,0 +1,15 @@
import argparse
from maro.cli.local.commands import run
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("conf_path", help='Path of the job deployment')
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
run(conf_path=args.conf_path, containerize=False, evaluate_only=args.evaluate_only)

Просмотреть файл

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Example RL config file for VM scheduling scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
# Run this workflow by executing one of the following commands:
# - python .\examples\rl\run_rl_example.py .\examples\rl\vm_scheduling.yml
# - (Requires installing MARO from source) maro local run .\examples\rl\vm_scheduling.yml
job: vm_scheduling_rl_workflow
scenario_path: "examples/vm_scheduling/rl"
log_path: "log/rl_job/vm_scheduling.txt"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
eval_schedule: 5
logging:
stdout: INFO
file: DEBUG
rollout:
logging:
stdout: INFO
file: DEBUG
training:
mode: simple
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/vm_scheduling"
interval: 5
logging:
stdout: INFO
file: DEBUG

Просмотреть файл

Просмотреть файл

@ -22,14 +22,11 @@ with io.open(CONFIG_PATH, "r") as in_file:
config = convert_dottable(raw_config)
LOG_PATH = os.path.join(FILE_PATH, "log", config.experiment_name)
if not os.path.exists(LOG_PATH):
os.makedirs(LOG_PATH)
simulation_logger = Logger(tag="simulation", format_=LogFormat.none, dump_folder=LOG_PATH, dump_mode="w", auto_timestamp=False)
ilp_logger = Logger(tag="ilp", format_=LogFormat.none, dump_folder=LOG_PATH, dump_mode="w", auto_timestamp=False)
simulation_logger = Logger(tag="simulation", format_=LogFormat.none, dump_path=LOG_PATH, dump_mode="w")
ilp_logger = Logger(tag="ilp", format_=LogFormat.none, dump_path=LOG_PATH, dump_mode="w")
if __name__ == "__main__":
start_time = timeit.default_timer()
env = Env(
scenario=config.env.scenario,
topology=config.env.topology,

Просмотреть файл

@ -0,0 +1,24 @@
# Virtual Machine Scheduling
A virtual machine (VM) scheduler is a cloud computing service component responsible for providing compute resources to satisfy user demands. A good resource allocation policy should aim to optimize several metrics at the same time, such as user wait time, profit, energy consumption and physical machine (PM) overload. Many commercial cloud providers use rule-based policies. Alternatively, the policy can also be optimized using reinforcement learning (RL) techniques, which involves simulating with historical data. This example demonstrates how DQN and Actor-Critic algorithms can be applied to this scenario. In this folder, you can find:
* ``__init__.py``, the entrance of this example. You must expose a `rl_component_bundle_cls` interface in `__init__.py` (see the example file for details);
* ``config.py``, which contains general configurations for the scenario;
* ``algorithms/``, which contains configurations for the algorithms, including network configurations;
* ``rl_componenet_bundle.py``, which defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read `VMBundle` to learn its basic usage.
We recommend that you follow this example to write your own scenarios.
# Some Comments About the Results
This example is meant to serve as a demonstration of using MARO's RL toolkit in a real-life scenario. In fact, we have yet to find a configuration that makes the policy learned by either DQN or Actor-Critic perform reasonably well in our experimental settings.
For reference, the best results have been achieved by the ``Best Fit`` algorithm (see ``examples/vm_scheduling/rule_based_algorithm/best_fit.py`` for details). The over-subscription rate is 115% in the over-subscription settings.
|Topology | PM Setting | Time Spent(s) | Total VM Requests |Successful Allocation| Energy Consumption| Total Oversubscriptions | Total Overload PMs
|:----:|-----|:--------:|:---:|:-------:|:----:|:---:|:---:|
|10k| 100 PMs, 32 Cores, 128 GB | 104.98|10,000| 10,000| 2,399,610 | 0 | 0|
|10k.oversubscription| 100 PMs, 32 Cores, 128 GB| 101.00 |10,000 |10,000| 2,386,371| 279,331 | 0|
|336k| 880 PMs, 16 Cores, 112 GB | 7,896.37 |335,985| 109,249 |26,425,878 | 0 | 0 |
|336k.oversubscription| 880 PMs, 16 Cores, 112 GB | 7,903.33| 335,985| 115,008 | 27,440,946 | 3,868,475 | 0

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .rl_component_bundle import VMBundle as rl_component_bundle_cls
__all__ = [
"rl_component_bundle_cls",
]

Просмотреть файл

Просмотреть файл

@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict
import torch
from torch.optim import Adam, SGD
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
actor_net_conf = {
"hidden_dims": [64, 32, 32],
"activation": torch.nn.LeakyReLU,
"softmax": True,
"batch_norm": False,
"head": True,
}
critic_net_conf = {
"hidden_dims": [256, 128, 64],
"activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": False,
"head": True,
}
actor_learning_rate = 0.0001
critic_learning_rate = 0.001
class MyActorNet(DiscreteACBasedNet):
def __init__(self, state_dim: int, action_num: int, num_features: int) -> None:
super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._num_features = num_features
self._actor = FullyConnected(input_dim=num_features, output_dim=action_num, **actor_net_conf)
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
features, masks = states[:, :self._num_features], states[:, self._num_features:]
masks += 1e-8 # this is to prevent zero probability and infinite logP.
return self._actor(features) * masks
class MyCriticNet(VNet):
def __init__(self, state_dim: int, num_features: int) -> None:
super(MyCriticNet, self).__init__(state_dim=state_dim)
self._num_features = num_features
self._critic = FullyConnected(input_dim=num_features, output_dim=1, **critic_net_conf)
self._optim = SGD(self._critic.parameters(), lr=critic_learning_rate)
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
features, masks = states[:, :self._num_features], states[:, self._num_features:]
masks += 1e-8 # this is to prevent zero probability and infinite logP.
return self._critic(features).squeeze(-1)
def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str) -> DiscretePolicyGradient:
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num, num_features))
def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features),
reward_discount=0.9,
grad_iters=100,
critic_loss_cls=torch.nn.MSELoss,
min_logp=-20,
lam=.0,
),
)

Просмотреть файл

@ -0,0 +1,85 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from maro.rl.exploration import MultiLinearExplorationScheduler
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
q_net_conf = {
"hidden_dims": [64, 128, 256],
"activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": False,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}
q_net_learning_rate = 0.0005
q_net_lr_scheduler_params = {"T_0": 500, "T_mult": 2}
class MyQNet(DiscreteQNet):
def __init__(self, state_dim: int, action_num: int, num_features: int) -> None:
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._num_features = num_features
self._fc = FullyConnected(input_dim=num_features, output_dim=action_num, **q_net_conf)
self._optim = SGD(self._fc.parameters(), lr=q_net_learning_rate)
self._lr_scheduler = CosineAnnealingWarmRestarts(self._optim, **q_net_lr_scheduler_params)
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
masks = states[:, self._num_features:]
q_for_all_actions = self._fc(states[:, :self._num_features])
return q_for_all_actions + (masks - 1) * 1e8
class MaskedEpsGreedy:
def __init__(self, state_dim: int, num_features: int) -> None:
self._state_dim = state_dim
self._num_features = num_features
def __call__(self, states, actions, num_actions, *, epsilon):
masks = states[:, self._num_features:]
return np.array([
action if np.random.random() > epsilon else np.random.choice(np.where(mask == 1)[0])
for action, mask in zip(actions, masks)
])
def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str) -> ValueBasedPolicy:
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num, num_features),
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
exploration_scheduling_options=[(
"epsilon", MultiLinearExplorationScheduler, {
"splits": [(100, 0.32)],
"initial_value": 0.4,
"last_ep": 400,
"final_value": 0.0,
}
)],
warmup=100,
)
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
params=DQNParams(
reward_discount=0.9,
update_target_every=5,
num_epochs=100,
soft_update_coef=0.1,
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
data_parallelism=2,
),
)

Просмотреть файл

@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from maro.simulator import Env
env_conf = {
"scenario": "vm_scheduling",
"topology": "azure.2019.10k",
"start_tick": 0,
"durations": 300, # 8638
"snapshot_resolution": 1,
}
num_pms = Env(**env_conf).business_engine.pm_amount
pm_window_size = 1
num_features = 2 * num_pms * pm_window_size + 4
state_dim = num_features + num_pms + 1
pm_attributes = ["cpu_cores_capacity", "memory_capacity", "cpu_cores_allocated", "memory_allocated"]
# vm_attributes = ["cpu_cores_requirement", "memory_requirement", "lifetime", "remain_time", "total_income"]
reward_shaping_conf = {
"alpha": 0.0,
"beta": 1.0,
}
seed = 666
test_env_conf = {
"scenario": "vm_scheduling",
"topology": "azure.2019.10k.oversubscription",
"start_tick": 0,
"durations": 300,
"snapshot_resolution": 1,
}
test_reward_shaping_conf = {
"alpha": 0.0,
"beta": 1.0,
}
test_seed = 1024
algorithm = "ac" # "dqn" or "ac"

Просмотреть файл

@ -0,0 +1,200 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
from collections import defaultdict
from os import makedirs
from os.path import dirname, join, realpath
from typing import Any, Dict, List, Tuple, Union
import numpy as np
from matplotlib import pyplot as plt
from maro.rl.rollout import AbsEnvSampler, CacheElement
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
from .config import (
num_features, pm_attributes, pm_window_size, reward_shaping_conf, seed, test_reward_shaping_conf, test_seed,
)
timestamp = str(time.time())
plt_path = join(dirname(realpath(__file__)), "plots", timestamp)
makedirs(plt_path, exist_ok=True)
class VMEnvSampler(AbsEnvSampler):
def __init__(self, learn_env: Env, test_env: Env) -> None:
super(VMEnvSampler, self).__init__(learn_env, test_env)
self._learn_env.set_seed(seed)
self._test_env.set_seed(test_seed)
# adjust the ratio of the success allocation and the total income when computing the reward
self.num_pms = self._learn_env.business_engine._pm_amount # the number of pms
self._durations = self._learn_env.business_engine._max_tick
self._pm_state_history = np.zeros((pm_window_size - 1, self.num_pms, 2))
self._legal_pm_mask = None
def _get_global_and_agent_state_impl(
self, event: DecisionPayload, tick: int = None,
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
pm_state, vm_state = self._get_pm_state(), self._get_vm_state(event)
# get the legal number of PM.
legal_pm_mask = np.zeros(self.num_pms + 1)
if len(event.valid_pms) <= 0:
# no pm available
legal_pm_mask[self.num_pms] = 1
else:
legal_pm_mask[self.num_pms] = 1
remain_cpu_dict = dict()
for pm in event.valid_pms:
# If two pms have the same remaining cpu, choose the one with the smaller id
if pm_state[-1, pm, 0] not in remain_cpu_dict:
remain_cpu_dict[pm_state[-1, pm, 0]] = 1
legal_pm_mask[pm] = 1
else:
legal_pm_mask[pm] = 0
self._legal_pm_mask = legal_pm_mask
state = np.concatenate((pm_state.flatten(), vm_state.flatten(), legal_pm_mask)).astype(np.float32)
return None, {"AGENT": state}
def _translate_to_env_action(
self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionPayload,
) -> Dict[Any, object]:
if action_dict["AGENT"] == self.num_pms:
return {"AGENT": PostponeAction(vm_id=event.vm_id, postpone_step=1)}
else:
return {"AGENT": AllocateAction(vm_id=event.vm_id, pm_id=action_dict["AGENT"][0])}
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionPayload, tick: int) -> Dict[Any, float]:
action = env_action_dict["AGENT"]
conf = reward_shaping_conf if self._env == self._learn_env else test_reward_shaping_conf
if isinstance(action, PostponeAction): # postponement
if np.sum(self._legal_pm_mask) != 1:
reward = -0.1 * conf["alpha"] + 0.0 * conf["beta"]
else:
reward = 0.0 * conf["alpha"] + 0.0 * conf["beta"]
else:
reward = self._get_allocation_reward(event, conf["alpha"], conf["beta"]) if event else .0
return {"AGENT": np.float32(reward)}
def _get_pm_state(self):
total_pm_info = self._env.snapshot_list["pms"][self._env.frame_index::pm_attributes]
total_pm_info = total_pm_info.reshape(self.num_pms, len(pm_attributes))
# normalize the attributes of pms' cpu and memory
self._max_cpu_capacity = np.max(total_pm_info[:, 0])
self._max_memory_capacity = np.max(total_pm_info[:, 1])
total_pm_info[:, 2] /= self._max_cpu_capacity
total_pm_info[:, 3] /= self._max_memory_capacity
# get the remaining cpu and memory of the pms
remain_cpu = (1 - total_pm_info[:, 2]).reshape(1, self.num_pms, 1)
remain_memory = (1 - total_pm_info[:, 3]).reshape(1, self.num_pms, 1)
# get the pms' information
total_pm_info = np.concatenate((remain_cpu, remain_memory), axis=2) # (1, num_pms, 2)
# get the sequence pms' information
self._pm_state_history = np.concatenate((self._pm_state_history, total_pm_info), axis=0)
return self._pm_state_history[-pm_window_size:, :, :] # (win_size, num_pms, 2)
def _get_vm_state(self, event):
return np.array([
event.vm_cpu_cores_requirement / self._max_cpu_capacity,
event.vm_memory_requirement / self._max_memory_capacity,
(self._durations - self._env.tick) * 1.0 / 200, # TODO: CHANGE 200 TO SOMETHING CONFIGURABLE
self._env.business_engine._get_unit_price(event.vm_cpu_cores_requirement, event.vm_memory_requirement)
])
def _get_allocation_reward(self, event: DecisionPayload, alpha: float, beta: float):
vm_unit_price = self._env.business_engine._get_unit_price(
event.vm_cpu_cores_requirement, event.vm_memory_requirement
)
return (alpha + beta * vm_unit_price * min(self._durations - event.frame_index, event.remaining_buffer_time))
def _post_step(self, cache_element: CacheElement) -> None:
self._info["env_metric"] = {k: v for k, v in self._env.metrics.items() if k != "total_latency"}
self._info["env_metric"]["latency_due_to_agent"] = self._env.metrics["total_latency"].due_to_agent
self._info["env_metric"]["latency_due_to_resource"] = self._env.metrics["total_latency"].due_to_resource
if "actions_by_core_requirement" not in self._info:
self._info["actions_by_core_requirement"] = defaultdict(list)
if "action_sequence" not in self._info:
self._info["action_sequence"] = []
action = cache_element.action_dict["AGENT"]
if cache_element.state:
mask = cache_element.state[num_features:]
self._info["actions_by_core_requirement"][cache_element.event.vm_cpu_cores_requirement].append([action, mask])
self._info["action_sequence"].append(action)
def _post_eval_step(self, cache_element: CacheElement) -> None:
self._post_step(cache_element)
def post_collect(self, info_list: list, ep: int) -> None:
# print the env metric from each rollout worker
for info in info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")
# print the average env metric
if len(info_list) > 1:
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys}
print(f"average env metric (episode {ep}): {avg_metric}")
def post_evaluate(self, info_list: list, ep: int) -> None:
# print the env metric from each rollout worker
for info in info_list:
print(f"env summary (evaluation episode {ep}): {info['env_metric']}")
# print the average env metric
if len(info_list) > 1:
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys}
print(f"average env metric (evaluation episode {ep}): {avg_metric}")
for info in info_list:
core_requirement = info["actions_by_core_requirement"]
action_sequence = info["action_sequence"]
# plot action sequence
fig = plt.figure(figsize=(40, 32))
ax = fig.add_subplot(1, 1, 1)
ax.plot(action_sequence)
fig.savefig(f"{plt_path}/action_sequence_{ep}")
plt.cla()
plt.close("all")
# plot with legal action mask
fig = plt.figure(figsize=(40, 32))
for idx, key in enumerate(core_requirement.keys()):
ax = fig.add_subplot(len(core_requirement.keys()), 1, idx + 1)
for i in range(len(core_requirement[key])):
if i == 0:
ax.plot(core_requirement[key][i][0] * core_requirement[key][i][1], label=str(key))
ax.legend()
else:
ax.plot(core_requirement[key][i][0] * core_requirement[key][i][1])
fig.savefig(f"{plt_path}/values_with_legal_action_{ep}")
plt.cla()
plt.close("all")
# plot without legal actin mask
fig = plt.figure(figsize=(40, 32))
for idx, key in enumerate(core_requirement.keys()):
ax = fig.add_subplot(len(core_requirement.keys()), 1, idx + 1)
for i in range(len(core_requirement[key])):
if i == 0:
ax.plot(core_requirement[key][i][0], label=str(key))
ax.legend()
else:
ax.plot(core_requirement[key][i][0])
fig.savefig(f"{plt_path}/values_without_legal_action_{ep}")
plt.cla()
plt.close("all")

Просмотреть файл

@ -0,0 +1,57 @@
from functools import partial
from typing import Any, Callable, Dict, Optional
from examples.vm_scheduling.rl.algorithms.ac import get_ac_policy
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn_policy
from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf
from examples.vm_scheduling.rl.env_sampler import VMEnvSampler
from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
class VMBundle(RLComponentBundle):
def get_env_config(self) -> dict:
return env_conf
def get_test_env_config(self) -> Optional[dict]:
return test_env_conf
def get_env_sampler(self) -> AbsEnvSampler:
return VMEnvSampler(self.env, self.test_env)
def get_agent2policy(self) -> Dict[Any, str]:
return {"AGENT": f"{algorithm}.policy"}
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
action_num = num_pms + 1 # action could be any PM or postponement, hence the plus 1
if algorithm == "ac":
policy_creator = {
f"{algorithm}.policy": partial(
get_ac_policy, state_dim, action_num, num_features, f"{algorithm}.policy",
)
}
elif algorithm == "dqn":
policy_creator = {
f"{algorithm}.policy": partial(
get_dqn_policy, state_dim, action_num, num_features, f"{algorithm}.policy",
)
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return policy_creator
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
if algorithm == "ac":
from .algorithms.ac import get_ac, get_ac_policy
trainer_creator = {algorithm: partial(get_ac, state_dim, num_features, algorithm)}
elif algorithm == "dqn":
from .algorithms.dqn import get_dqn, get_dqn_policy
trainer_creator = {algorithm: partial(get_dqn, algorithm)}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return trainer_creator

Просмотреть файл

Просмотреть файл

@ -41,13 +41,13 @@
.. image:: https://github.com/microsoft/maro/workflows/test/badge.svg
:target: https://github.com/microsoft/maro/actions?query=workflow%3Atest
:alt: test
:target: https://github.com/microsoft/maro/actions?query=workflow%3Atest
:alt: test
.. image:: https://github.com/microsoft/maro/workflows/build/badge.svg
:target: https://github.com/microsoft/maro/actions?query=workflow%3Abuild
:alt: build
:target: https://github.com/microsoft/maro/actions?query=workflow%3Abuild
:alt: build
.. image:: https://github.com/microsoft/maro/workflows/docker/badge.svg
@ -56,8 +56,8 @@
.. image:: https://readthedocs.org/projects/maro/badge/?version=latest
:target: https://maro.readthedocs.io/
:alt: docs
:target: https://maro.readthedocs.io/
:alt: docs
.. image:: https://img.shields.io/pypi/v/pymaro
@ -142,6 +142,69 @@
================================================================================================================
.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/vm_scheduling.svg
:target: https://maro.readthedocs.io/en/latest/scenarios/vm_scheduling.html
:alt: VM Scheduling
.. image:: https://img.shields.io/gitter/room/microsoft/maro
:target: https://gitter.im/Microsoft/MARO#
:alt: Gitter
.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/stack_overflow.svg
:target: https://stackoverflow.com/questions/ask?tags=maro
:alt: Stack Overflow
.. image:: https://img.shields.io/github/release-date-pre/microsoft/maro
:target: https://github.com/microsoft/maro/releases
:alt: Releases
.. image:: https://img.shields.io/github/commits-since/microsoft/maro/latest/master
:target: https://github.com/microsoft/maro/commits/master
:alt: Commits
.. image:: https://github.com/microsoft/maro/workflows/vulnerability%20scan/badge.svg
:target: https://github.com/microsoft/maro/actions?query=workflow%3A%22vulnerability+scan%22
:alt: Vulnerability Scan
.. image:: https://github.com/microsoft/maro/workflows/lint/badge.svg
:target: https://github.com/microsoft/maro/actions?query=workflow%3Alint
:alt: Lint
.. image:: https://img.shields.io/codecov/c/github/microsoft/maro
:target: https://codecov.io/gh/microsoft/maro
:alt: Coverage
.. image:: https://img.shields.io/pypi/dm/pymaro
:target: https://pypi.org/project/pymaro/#files
:alt: Downloads
.. image:: https://img.shields.io/docker/pulls/maro2020/maro
:target: https://hub.docker.com/repository/docker/maro2020/maro
:alt: Docker Pulls
.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/play_with_maro.svg
:target: https://hub.docker.com/r/maro2020/maro
:alt: Play with MARO
.. image:: https://github.com/microsoft/maro/blob/master/docs/source/images/logo.svg
:target: https://maro.readthedocs.io/en/latest/
:alt: MARO LOGO
================================================================================================================
Multi-Agent Resource Optimization (MARO) platform is an instance of Reinforcement
learning as a Service (RaaS) for real-world resource optimization. It can be
applied to many important industrial domains, such as `container inventory
@ -172,18 +235,18 @@ Contents
--------
.. list-table::
:header-rows: 1
:header-rows: 1
* - File/folder
- Description
* - ``maro``
- MARO source code.
* - ``docs``
- MARO docs, it is host on `readthedocs <https://maro.readthedocs.io/en/latest/>`_.
* - ``examples``
- Showcase of MARO.
* - ``notebooks``
- MARO quick-start notebooks.
* - File/folder
- Description
* - ``maro``
- MARO source code.
* - ``docs``
- MARO docs, it is host on `readthedocs <https://maro.readthedocs.io/en/latest/>`_.
* - ``examples``
- Showcase of MARO.
* - ``notebooks``
- MARO quick-start notebooks.
*Try `MARO playground <#run-playground>`_ to have a quick experience.*
@ -199,17 +262,17 @@ Install MARO from `PyPI <https://pypi.org/project/pymaro/#files>`_
.. code-block:: sh
pip install pymaro
pip install pymaro
*
Windows
.. code-block:: powershell
# Install torch first, if you don't have one.
pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
# Install torch first, if you don't have one.
pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pymaro
pip install pymaro
Install MARO from Source
------------------------
@ -235,9 +298,9 @@ Install MARO from Source
.. code-block:: sh
# If your environment is not clean, create a virtual environment firstly.
python -m venv maro_venv
source ./maro_venv/bin/activate
# If your environment is not clean, create a virtual environment firstly.
python -m venv maro_venv
source ./maro_venv/bin/activate
*
Windows
@ -267,16 +330,16 @@ Install MARO from Source
.. code-block:: sh
# Install MARO from source.
bash scripts/install_maro.sh
# Install MARO from source.
bash scripts/install_maro.sh
*
Windows
.. code-block:: powershell
# Install MARO from source.
.\scripts\install_maro.bat
# Install MARO from source.
.\scripts\install_maro.bat
*
*Notes: If your package is not found, remember to set your PYTHONPATH*
@ -300,16 +363,16 @@ Quick Example
.. code-block:: python
from maro.simulator import Env
from maro.simulator import Env
env = Env(scenario="cim", topology="toy.5p_ssddd_l0.0", start_tick=0, durations=100)
env = Env(scenario="cim", topology="toy.5p_ssddd_l0.0", start_tick=0, durations=100)
metrics, decision_event, is_done = env.step(None)
metrics, decision_event, is_done = env.step(None)
while not is_done:
metrics, decision_event, is_done = env.step(None)
while not is_done:
metrics, decision_event, is_done = env.step(None)
print(f"environment metrics: {env.metrics}")
print(f"environment metrics: {env.metrics}")
`Environment Visualization <https://maro.readthedocs.io/en/latest/>`_
-------------------------------------------------------------------------
@ -382,8 +445,8 @@ Run Playground
.. code-block:: sh
# Build playground image.
bash ./scripts/build_playground.sh
# Build playground image.
bash ./scripts/build_playground.sh
# Run playground container.
# Redis commander (GUI for redis) -> http://127.0.0.1:40009
@ -395,8 +458,8 @@ Run Playground
.. code-block:: powershell
# Build playground image.
.\scripts\build_playground.bat
# Build playground image.
.\scripts\build_playground.bat
# Run playground container.
# Redis commander (GUI for redis) -> http://127.0.0.1:40009

Просмотреть файл

@ -74,6 +74,15 @@ def node(name: str):
return node_dec
def try_get_attribute(target, name, default=None):
try:
attr = object.__getattribute__(target, name)
return attr
except:
return default
cdef class NodeAttribute:
def __cinit__(self, object dtype = None, SLOT_INDEX slot_num = 1, is_const = False, is_list = False):
# Check the type of dtype, used to compact with old version
@ -532,6 +541,8 @@ cdef class FrameBase:
else:
node._is_deleted = False
# Also
cpdef void take_snapshot(self, INT tick) except *:
"""Take snapshot for specified point (tick) for current frame.

Просмотреть файл

@ -0,0 +1,317 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import base64
import json
import os
import shutil
from os.path import abspath, dirname, expanduser, join
import yaml
from maro.cli.utils import docker as docker_utils
from maro.cli.utils.azure import storage as azure_storage_utils
from maro.cli.utils.azure.aks import attach_acr
from maro.cli.utils.azure.deployment import create_deployment
from maro.cli.utils.azure.general import connect_to_aks, get_acr_push_permissions, set_env_credentials
from maro.cli.utils.azure.resource_group import create_resource_group, delete_resource_group
from maro.cli.utils.common import show_log
from maro.rl.workflows.config import ConfigParser
from maro.utils.logger import CliLogger
from maro.utils.utils import LOCAL_MARO_ROOT
from ..utils import k8s_manifest_generator, k8s_ops
# metadata
CLI_AKS_PATH = dirname(abspath(__file__))
TEMPLATE_PATH = join(CLI_AKS_PATH, "template.json")
NVIDIA_PLUGIN_PATH = join(CLI_AKS_PATH, "create_nvidia_plugin", "nvidia-device-plugin.yml")
LOCAL_ROOT = expanduser("~/.maro/aks")
DEPLOYMENT_CONF_PATH = os.path.join(LOCAL_ROOT, "conf.json")
DOCKER_FILE_PATH = join(LOCAL_MARO_ROOT, "docker_files", "dev.df")
DOCKER_IMAGE_NAME = "maro-aks"
REDIS_HOST = "maro-redis"
REDIS_PORT = 6379
ADDRESS_REGISTRY_NAME = "address-registry"
ADDRESS_REGISTRY_PORT = 6379
K8S_SECRET_NAME = "azure-secret"
# display
NO_DEPLOYMENT_MSG = "No Kubernetes deployment on Azure found. Use 'maro aks init' to create a deployment first"
NO_JOB_MSG = "No job named {} has been scheduled. Use 'maro aks job add' to add the job first."
JOB_EXISTS_MSG = "A job named {} has already been scheduled."
logger = CliLogger(name=__name__)
# helper functions
def get_resource_group_name(deployment_name: str):
return f"rg-{deployment_name}"
def get_acr_name(deployment_name: str):
return f"crmaro{deployment_name}"
def get_acr_server_name(acr_name: str):
return f"{acr_name}.azurecr.io"
def get_docker_image_name_in_acr(acr_name: str, docker_image_name: str):
return f"{get_acr_server_name(acr_name)}/{docker_image_name}"
def get_aks_name(deployment_name: str):
return f"aks-maro-{deployment_name}"
def get_agentpool_name(deployment_name: str):
return f"ap{deployment_name}"
def get_fileshare_name(deployment_name: str):
return f"fs-{deployment_name}"
def get_storage_account_name(deployment_name: str):
return f"stscenario{deployment_name}"
def get_virtual_network_name(location: str, deployment_name: str):
return f"vnet-prod-{location}-{deployment_name}"
def get_local_job_path(job_name: str):
return os.path.join(LOCAL_ROOT, job_name)
def get_storage_account_secret(resource_group_name: str, storage_account_name: str, namespace: str):
storage_account_keys = azure_storage_utils.get_storage_account_keys(resource_group_name, storage_account_name)
storage_key = storage_account_keys[0]["value"]
secret_data = {
"azurestorageaccountname": base64.b64encode(storage_account_name.encode()).decode(),
"azurestorageaccountkey": base64.b64encode(bytes(storage_key.encode())).decode()
}
k8s_ops.create_secret(K8S_SECRET_NAME, secret_data, namespace)
def get_resource_params(deployment_conf: dict) -> dict:
"""Create ARM parameters for Azure resource deployment ().
See https://docs.microsoft.com/en-us/azure/azure-resource-manager/templates/overview for details.
Args:
deployment_conf (dict): Configuration dict for deployment on Azure.
Returns:
dict: parameter dict, should be exported to json.
"""
name = deployment_conf["name"]
return {
"acrName": get_acr_name(name),
"acrSku": deployment_conf["container_registry_service_tier"],
"systemPoolVMCount": deployment_conf["resources"]["k8s"]["vm_count"],
"systemPoolVMSize": deployment_conf["resources"]["k8s"]["vm_size"],
"userPoolName": get_agentpool_name(name),
"userPoolVMCount": deployment_conf["resources"]["app"]["vm_count"],
"userPoolVMSize": deployment_conf["resources"]["app"]["vm_size"],
"aksName": get_aks_name(name),
"location": deployment_conf["location"],
"storageAccountName": get_storage_account_name(name),
"fileShareName": get_fileshare_name(name)
# "virtualNetworkName": get_virtual_network_name(deployment_conf["location"], name)
}
def prepare_docker_image_and_push_to_acr(image_name: str, context: str, docker_file_path: str, acr_name: str):
# build and tag docker image locally and push to the Azure Container Registry
if not docker_utils.image_exists(image_name):
docker_utils.build_image(context, docker_file_path, image_name)
get_acr_push_permissions(os.environ["AZURE_CLIENT_ID"], acr_name)
docker_utils.push(image_name, get_acr_server_name(acr_name))
def start_redis_service_in_aks(host: str, port: int, namespace: str):
k8s_ops.load_config()
k8s_ops.create_namespace(namespace)
k8s_ops.create_deployment(k8s_manifest_generator.get_redis_deployment_manifest(host, port), namespace)
k8s_ops.create_service(k8s_manifest_generator.get_redis_service_manifest(host, port), namespace)
# CLI command functions
def init(deployment_conf_path: str, **kwargs):
"""Prepare Azure resources needed for an AKS cluster using a YAML configuration file.
The configuration file template can be found in cli/k8s/aks/conf.yml. Use the Azure CLI to log into
your Azure account (az login ...) and the the Azure Container Registry (az acr login ...) first.
Args:
deployment_conf_path (str): Path to the deployment configuration file.
"""
with open(deployment_conf_path, "r") as fp:
deployment_conf = yaml.safe_load(fp)
subscription = deployment_conf["azure_subscription"]
name = deployment_conf["name"]
if os.path.isfile(DEPLOYMENT_CONF_PATH):
logger.warning(f"Deployment {name} has already been created")
return
os.makedirs(LOCAL_ROOT, exist_ok=True)
resource_group_name = get_resource_group_name(name)
try:
# Set credentials as environment variables
set_env_credentials(LOCAL_ROOT, f"sp-{name}")
# create resource group
resource_group = create_resource_group(subscription, resource_group_name, deployment_conf["location"])
logger.info_green(f"Provisioned resource group {resource_group.name} in {resource_group.location}")
# Create ARM parameters and start deployment
logger.info("Creating Azure resources...")
resource_params = get_resource_params(deployment_conf)
with open(TEMPLATE_PATH, 'r') as fp:
template = json.load(fp)
create_deployment(subscription, resource_group_name, name, template, resource_params)
# Attach ACR to AKS
aks_name, acr_name = resource_params["aksName"], resource_params["acrName"]
attach_acr(resource_group_name, aks_name, acr_name)
connect_to_aks(resource_group_name, aks_name)
# build and tag docker image locally and push to the Azure Container Registry
logger.info("Preparing docker image...")
prepare_docker_image_and_push_to_acr(DOCKER_IMAGE_NAME, LOCAL_MARO_ROOT, DOCKER_FILE_PATH, acr_name)
# start the Redis service in the k8s cluster in the deployment namespace and expose it
logger.info("Starting Redis service in the k8s cluster...")
start_redis_service_in_aks(REDIS_HOST, REDIS_PORT, name)
# Dump the deployment configuration
with open(DEPLOYMENT_CONF_PATH, "w") as fp:
json.dump({
"name": name,
"subscription": subscription,
"resource_group": resource_group_name,
"resources": resource_params
}, fp)
logger.info_green(f"Cluster '{name}' is created")
except Exception as e:
# If failed, remove details folder, then raise
shutil.rmtree(LOCAL_ROOT)
logger.error_red(f"Deployment {name} failed due to {e}, rolling back...")
delete_resource_group(subscription, resource_group_name)
except KeyboardInterrupt:
shutil.rmtree(LOCAL_ROOT)
logger.error_red(f"Deployment {name} aborted, rolling back...")
delete_resource_group(subscription, resource_group_name)
def add_job(conf_path: dict, **kwargs):
if not os.path.isfile(DEPLOYMENT_CONF_PATH):
logger.error_red(NO_DEPLOYMENT_MSG)
return
parser = ConfigParser(conf_path)
job_name = parser.config["job"]
local_job_path = get_local_job_path(job_name)
if os.path.isdir(local_job_path):
logger.error_red(JOB_EXISTS_MSG.format(job_name))
return
os.makedirs(local_job_path)
with open(DEPLOYMENT_CONF_PATH, "r") as fp:
deployment_conf = json.load(fp)
resource_group_name, resource_name = deployment_conf["resource_group"], deployment_conf["resources"]
fileshare = azure_storage_utils.get_fileshare(resource_name["storageAccountName"], resource_name["fileShareName"])
job_dir = azure_storage_utils.get_directory(fileshare, job_name)
scenario_path = parser.config["scenario_path"]
logger.info(f"Uploading local directory {scenario_path}...")
azure_storage_utils.upload_to_fileshare(job_dir, scenario_path, name="scenario")
azure_storage_utils.get_directory(job_dir, "checkpoints")
azure_storage_utils.get_directory(job_dir, "logs")
# Define mount volumes, i.e., scenario code, checkpoints, logs and load point
job_path_in_share = f"{resource_name['fileShareName']}/{job_name}"
volumes = [
k8s_manifest_generator.get_azurefile_volume_spec(name, f"{job_path_in_share}/{name}", K8S_SECRET_NAME)
for name in ["scenario", "logs", "checkpoints"]
]
if "load_path" in parser.config["training"]:
load_path = parser.config["training"]["load_path"]
logger.info(f"Uploading local model directory {load_path}...")
azure_storage_utils.upload_to_fileshare(job_dir, load_path, name="loadpoint")
volumes.append(
k8s_manifest_generator.get_azurefile_volume_spec(
"loadpoint", f"{job_path_in_share}/loadpoint", K8S_SECRET_NAME)
)
# Start k8s jobs
k8s_ops.load_config()
k8s_ops.create_namespace(job_name)
get_storage_account_secret(resource_group_name, resource_name["storageAccountName"], job_name)
k8s_ops.create_service(
k8s_manifest_generator.get_cross_namespace_service_access_manifest(
ADDRESS_REGISTRY_NAME, REDIS_HOST, deployment_conf["name"], ADDRESS_REGISTRY_PORT
), job_name
)
for component_name, (script, env) in parser.get_job_spec(containerize=True).items():
container_spec = k8s_manifest_generator.get_container_spec(
get_docker_image_name_in_acr(resource_name["acrName"], DOCKER_IMAGE_NAME),
component_name,
script,
env,
volumes
)
manifest = k8s_manifest_generator.get_job_manifest(
resource_name["userPoolName"],
component_name,
container_spec,
volumes
)
k8s_ops.create_job(manifest, job_name)
def remove_jobs(job_names: str, **kwargs):
if not os.path.isfile(DEPLOYMENT_CONF_PATH):
logger.error_red(NO_DEPLOYMENT_MSG)
return
k8s_ops.load_config()
for job_name in job_names:
local_job_path = get_local_job_path(job_name)
if not os.path.isdir(local_job_path):
logger.error_red(NO_JOB_MSG.format(job_name))
return
k8s_ops.delete_job(job_name)
def get_job_logs(job_name: str, tail: int = -1, **kwargs):
with open(DEPLOYMENT_CONF_PATH, "r") as fp:
deployment_conf = json.load(fp)
local_log_path = os.path.join(get_local_job_path(job_name), "log")
resource_name = deployment_conf["resources"]
fileshare = azure_storage_utils.get_fileshare(resource_name["storageAccountName"], resource_name["fileShareName"])
job_dir = azure_storage_utils.get_directory(fileshare, job_name)
log_dir = azure_storage_utils.get_directory(job_dir, "logs")
azure_storage_utils.download_from_fileshare(log_dir, f"{job_name}.log", local_log_path)
show_log(local_log_path, tail=tail)
def exit(**kwargs):
try:
with open(DEPLOYMENT_CONF_PATH, "r") as fp:
deployment_conf = json.load(fp)
except FileNotFoundError:
logger.error(NO_DEPLOYMENT_MSG)
return
name = deployment_conf["name"]
set_env_credentials(LOCAL_ROOT, f"sp-{name}")
delete_resource_group(deployment_conf["subscription"], deployment_conf["resource_group"])

12
maro/cli/k8s/aks/conf.yml Normal file
Просмотреть файл

@ -0,0 +1,12 @@
mode: ""
azure_subscription: your_azure_subscription_id
name: your_deployment_name
location: your_azure_service_location
container_registry_service_tier: Standard # "Basic", "Standard", "Premium", see https://docs.microsoft.com/en-us/azure/container-registry/container-registry-skus for details
resources:
k8s:
vm_size: Standard_DS2_v2 # https://docs.microsoft.com/en-us/azure/virtual-machines/sizes, https://docs.microsoft.com/en-us/azure/aks/quotas-skus-regions
vm_count: 1 # must be at least 2 for k8s to function properly.
app:
vm_size: Standard_DS2_v2 # https://docs.microsoft.com/en-us/azure/virtual-machines/sizes, https://docs.microsoft.com/en-us/azure/aks/quotas-skus-regions
vm_count: 1

Просмотреть файл

@ -0,0 +1,33 @@
{
"$schema": "https://schema.management.azure.com/schemas/2015-01-01/deploymentParameters.json#",
"contentVersion": "1.1.0.0",
"parameters": {
"acrName": {
"value": "myacr"
},
"acrSku": {
"value": "Basic"
},
"agentCount": {
"value": 1
},
"agentVMSize": {
"value": "standard_a2_v2"
},
"clusterName": {
"value": "myaks"
},
"fileShareName": {
"value": "myfileshare"
},
"location": {
"value": "East US"
},
"storageAccountName": {
"value": "mystorage"
},
"virtualNetworkName": {
"value": "myvnet"
}
}
}

Просмотреть файл

@ -0,0 +1,157 @@
{
"$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
"contentVersion": "1.1.0.0",
"parameters": {
"acrName": {
"type": "string",
"minLength": 5,
"maxLength": 50,
"metadata": {
"description": "Name of your Azure Container Registry"
}
},
"acrSku": {
"type": "string",
"metadata": {
"description": "Tier of your Azure Container Registry."
},
"defaultValue": "Standard",
"allowedValues": [
"Basic",
"Standard",
"Premium"
]
},
"systemPoolVMCount": {
"type": "int",
"metadata": {
"description": "The number of VMs allocated for running the k8s system components."
},
"minValue": 1,
"maxValue": 50
},
"systemPoolVMSize": {
"type": "string",
"metadata": {
"description": "Virtual Machine size for running the k8s system components."
}
},
"userPoolName": {
"type": "string",
"metadata": {
"description": "Name of the user node pool."
}
},
"userPoolVMCount": {
"type": "int",
"metadata": {
"description": "The number of VMs allocated for running the user appplication."
},
"minValue": 1,
"maxValue": 50
},
"userPoolVMSize": {
"type": "string",
"metadata": {
"description": "Virtual Machine size for running the user application."
}
},
"aksName": {
"type": "string",
"metadata": {
"description": "Name of the Managed Cluster resource."
}
},
"location": {
"type": "string",
"metadata": {
"description": "Location of the Managed Cluster resource."
}
},
"storageAccountName": {
"type": "string",
"metadata": {
"description": "Azure storage account name."
}
},
"fileShareName": {
"type": "string",
"metadata": {
"description": "Azure file share name."
}
}
},
"resources": [
{
"name": "[parameters('acrName')]",
"type": "Microsoft.ContainerRegistry/registries",
"apiVersion": "2021-09-01",
"location": "[parameters('location')]",
"sku": {
"name": "[parameters('acrSku')]"
},
"properties": {
}
},
{
"name": "[parameters('aksName')]",
"type": "Microsoft.ContainerService/managedClusters",
"apiVersion": "2021-10-01",
"location": "[parameters('location')]",
"properties": {
"dnsPrefix": "maro",
"agentPoolProfiles": [
{
"name": "system",
"osDiskSizeGB": 0,
"count": "[parameters('systemPoolVMCount')]",
"vmSize": "[parameters('systemPoolVMSize')]",
"osType": "Linux",
"storageProfile": "ManagedDisks",
"mode": "System",
"type": "VirtualMachineScaleSets"
},
{
"name": "[parameters('userPoolName')]",
"osDiskSizeGB": 0,
"count": "[parameters('userPoolVMCount')]",
"vmSize": "[parameters('userPoolVMSize')]",
"osType": "Linux",
"storageProfile": "ManagedDisks",
"mode": "User",
"type": "VirtualMachineScaleSets"
}
],
"networkProfile": {
"networkPlugin": "azure",
"loadBalancerSku": "standard"
}
},
"identity": {
"type": "SystemAssigned"
}
},
{
"type": "Microsoft.Storage/storageAccounts",
"apiVersion": "2021-08-01",
"name": "[parameters('storageAccountName')]",
"location": "[parameters('location')]",
"kind": "StorageV2",
"sku": {
"name": "Standard_LRS",
"tier": "Standard"
},
"properties": {
"accessTier": "Hot"
}
},
{
"type": "Microsoft.Storage/storageAccounts/fileServices/shares",
"apiVersion": "2021-04-01",
"name": "[concat(parameters('storageAccountName'), '/default/', parameters('fileShareName'))]",
"dependsOn": [
"[resourceId('Microsoft.Storage/storageAccounts', parameters('storageAccountName'))]"
]
}
]
}

Просмотреть файл

@ -22,18 +22,6 @@
"Premium"
]
},
"adminPublicKey": {
"type": "string",
"metadata": {
"description": "Configure all linux machines with the SSH RSA public key string. Your key should include three parts, for example 'ssh-rsa AAAAB...snip...UcyupgH azureuser@linuxvm'"
}
},
"adminUsername": {
"type": "string",
"metadata": {
"description": "User name for the Linux Virtual Machines."
}
},
"agentCount": {
"type": "int",
"metadata": {
@ -87,7 +75,7 @@
"resources": [
{
"type": "Microsoft.Storage/storageAccounts/fileServices/shares",
"apiVersion": "2020-08-01-preview",
"apiVersion": "2021-04-01",
"name": "[concat(parameters('storageAccountName'), '/default/', parameters('fileShareName'))]",
"dependsOn": [
"[variables('stvmId')]"
@ -96,7 +84,7 @@
{
"name": "[parameters('acrName')]",
"type": "Microsoft.ContainerRegistry/registries",
"apiVersion": "2020-11-01-preview",
"apiVersion": "2021-09-01",
"location": "[parameters('location')]",
"sku": {
"name": "[parameters('acrSku')]"
@ -107,7 +95,7 @@
{
"name": "[parameters('clusterName')]",
"type": "Microsoft.ContainerService/managedClusters",
"apiVersion": "2020-03-01",
"apiVersion": "2021-10-01",
"location": "[parameters('location')]",
"dependsOn": [
"[variables('vnetId')]"
@ -127,16 +115,6 @@
"type": "VirtualMachineScaleSets"
}
],
"linuxProfile": {
"adminUsername": "[parameters('adminUsername')]",
"ssh": {
"publicKeys": [
{
"keyData": "[parameters('adminPublicKey')]"
}
]
}
},
"networkProfile": {
"networkPlugin": "azure",
"loadBalancerSku": "standard"
@ -148,7 +126,7 @@
},
{
"type": "Microsoft.Storage/storageAccounts",
"apiVersion": "2020-08-01-preview",
"apiVersion": "2021-08-01",
"name": "[parameters('storageAccountName')]",
"location": "[parameters('location')]",
"kind": "StorageV2",
@ -163,7 +141,7 @@
{
"name": "[parameters('virtualNetworkName')]",
"type": "Microsoft.Network/virtualNetworks",
"apiVersion": "2020-04-01",
"apiVersion": "2020-11-01",
"location": "[parameters('location')]",
"properties": {
"addressSpace": {

Просмотреть файл

@ -0,0 +1,106 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List
from maro.cli.utils.common import format_env_vars
def get_job_manifest(agent_pool_name: str, component_name: str, container_spec: dict, volumes: List[dict]):
return {
"metadata": {"name": component_name},
"spec": {
"template": {
"spec": {
"nodeSelector": {"agentpool": agent_pool_name},
"restartPolicy": "Never",
"volumes": volumes,
"containers": [container_spec]
}
}
}
}
def get_azurefile_volume_spec(name: str, share_name: str, secret_name: str):
return {
"name": name,
"azureFile": {
"secretName": secret_name,
"shareName": share_name,
"readOnly": False
}
}
def get_container_spec(image_name: str, component_name: str, script: str, env: dict, volumes):
common_container_spec = {
"image": image_name,
"imagePullPolicy": "Always",
"volumeMounts": [{"name": vol["name"], "mountPath": f"/{vol['name']}"} for vol in volumes]
}
return {
**common_container_spec,
**{
"name": component_name,
"command": ["python3", script],
"env": format_env_vars(env, mode="k8s")
}
}
def get_redis_deployment_manifest(host: str, port: int):
return {
"metadata": {
"name": host,
"labels": {"app": "redis"}
},
"spec": {
"selector": {
"matchLabels": {"app": "redis"}
},
"replicas": 1,
"template": {
"metadata": {
"labels": {"app": "redis"}
},
"spec": {
"containers": [
{
"name": "master",
"image": "redis:6",
"ports": [{"containerPort": port}]
}
]
}
}
}
}
def get_redis_service_manifest(host: str, port: int):
return {
"metadata": {
"name": host,
"labels": {"app": "redis"}
},
"spec": {
"ports": [{"port": port, "targetPort": port}],
"selector": {"app": "redis"}
}
}
def get_cross_namespace_service_access_manifest(
service_name: str, target_service_name: str, target_service_namespace: str, target_service_port: int
):
return {
"metadata": {
"name": service_name,
},
"spec": {
"type": "ExternalName",
"externalName": f"{target_service_name}.{target_service_namespace}.svc.cluster.local",
"ports": [{"port": target_service_port}]
}
}

Просмотреть файл

@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import kubernetes
from kubernetes import client, config
def load_config():
config.load_kube_config()
def create_namespace(namespace: str):
body = client.V1Namespace()
body.metadata = client.V1ObjectMeta(name=namespace)
try:
client.CoreV1Api().create_namespace(body)
except kubernetes.client.exceptions.ApiException:
pass
def create_deployment(conf: dict, namespace: str):
client.AppsV1Api().create_namespaced_deployment(namespace, conf)
def create_service(conf: dict, namespace: str):
client.CoreV1Api().create_namespaced_service(namespace, conf)
def create_job(conf: dict, namespace: str):
client.BatchV1Api().create_namespaced_job(namespace, conf)
def create_secret(name: str, data: dict, namespace: str):
client.CoreV1Api().create_namespaced_secret(
body=client.V1Secret(metadata=client.V1ObjectMeta(name=name), data=data),
namespace=namespace
)
def delete_job(namespace: str):
client.BatchV1Api().delete_collection_namespaced_job(namespace)
client.CoreV1Api().delete_namespace(namespace)
def describe_job(namespace: str):
client.CoreV1Api().read_namespace(namespace)

Просмотреть файл

253
maro/cli/local/commands.py Normal file
Просмотреть файл

@ -0,0 +1,253 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import shutil
import subprocess
import sys
import time
from os import makedirs
from os.path import abspath, dirname, exists, expanduser, join
import redis
import yaml
from maro.cli.utils.common import close_by_pid, show_log
from maro.rl.workflows.config import ConfigParser
from maro.utils.logger import CliLogger
from maro.utils.utils import LOCAL_MARO_ROOT
from .utils import (
JobStatus, RedisHashKey, start_redis, start_rl_job, start_rl_job_with_docker_compose, stop_redis,
stop_rl_job_with_docker_compose
)
# metadata
LOCAL_ROOT = expanduser("~/.maro/local")
SESSION_STATE_PATH = join(LOCAL_ROOT, "session.json")
DOCKERFILE_PATH = join(LOCAL_MARO_ROOT, "docker_files", "dev.df")
DOCKER_IMAGE_NAME = "maro-local"
DOCKER_NETWORK = "MAROLOCAL"
# display
NO_JOB_MANAGER_MSG = """No job manager found. Run "maro local init" to start the job manager first."""
NO_JOB_MSG = """No job named {} found. Run "maro local job ls" to view existing jobs."""
JOB_LS_TEMPLATE = "{JOB:12}{STATUS:15}{STARTED:20}"
logger = CliLogger(name="MARO-LOCAL")
# helper functions
def get_redis_conn(port=None):
if port is None:
try:
with open(SESSION_STATE_PATH, "r") as fp:
port = json.load(fp)["port"]
except FileNotFoundError:
logger.error(NO_JOB_MANAGER_MSG)
return
try:
redis_conn = redis.Redis(host="localhost", port=port)
redis_conn.ping()
return redis_conn
except redis.exceptions.ConnectionError:
logger.error(NO_JOB_MANAGER_MSG)
# Functions executed on CLI commands
def run(conf_path: str, containerize: bool = False, evaluate_only: bool = False, **kwargs):
# Load job configuration file
parser = ConfigParser(conf_path)
if containerize:
try:
start_rl_job_with_docker_compose(
parser, LOCAL_MARO_ROOT, DOCKERFILE_PATH, DOCKER_IMAGE_NAME, evaluate_only=evaluate_only,
)
except KeyboardInterrupt:
stop_rl_job_with_docker_compose(parser.config["job"], LOCAL_MARO_ROOT)
else:
try:
start_rl_job(parser, LOCAL_MARO_ROOT, evaluate_only=evaluate_only)
except KeyboardInterrupt:
sys.exit(1)
def init(
port: int = 19999,
max_running: int = 3,
query_every: int = 5,
timeout: int = 3,
containerize: bool = False,
**kwargs
):
if exists(SESSION_STATE_PATH):
with open(SESSION_STATE_PATH, "r") as fp:
session_state = json.load(fp)
logger.warning(
f"Local job manager is already running at port {session_state['port']}. "
f"Run 'maro local job add/rm' to add / remove jobs."
)
return
start_redis(port)
# Start job manager
command = ["python", join(dirname(abspath(__file__)), 'job_manager.py')]
job_manager = subprocess.Popen(
command,
env={
"PYTHONPATH": LOCAL_MARO_ROOT,
"MAX_RUNNING": str(max_running),
"QUERY_EVERY": str(query_every),
"SIGTERM_TIMEOUT": str(timeout),
"CONTAINERIZE": str(containerize),
"REDIS_PORT": str(port),
"LOCAL_MARO_ROOT": LOCAL_MARO_ROOT,
"DOCKER_IMAGE_NAME": DOCKER_IMAGE_NAME,
"DOCKERFILE_PATH": DOCKERFILE_PATH
}
)
# Dump environment setting
makedirs(LOCAL_ROOT, exist_ok=True)
with open(SESSION_STATE_PATH, "w") as fp:
json.dump({"port": port, "job_manager_pid": job_manager.pid, "containerized": containerize}, fp)
# Create log folder
logger.info("Local job manager started")
def exit(**kwargs):
try:
with open(SESSION_STATE_PATH, "r") as fp:
session_state = json.load(fp)
except FileNotFoundError:
logger.error(NO_JOB_MANAGER_MSG)
return
redis_conn = get_redis_conn()
# Mark all jobs as REMOVED and let the job manager terminate them properly.
job_details = redis_conn.hgetall(RedisHashKey.JOB_DETAILS)
if job_details:
for job_name, details in job_details.items():
details = json.loads(details)
details["status"] = JobStatus.REMOVED
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
logger.info(f"Gracefully terminating job {job_name.decode()}")
# Stop job manager
close_by_pid(int(session_state["job_manager_pid"]))
# Stop Redis
stop_redis(session_state["port"])
# Remove dump folder.
shutil.rmtree(LOCAL_ROOT, True)
logger.info("Local job manager terminated.")
def add_job(conf_path: str, **kwargs):
redis_conn = get_redis_conn()
if not redis_conn:
return
# Load job configuration file
with open(conf_path, "r") as fr:
conf = yaml.safe_load(fr)
job_name = conf["job"]
if redis_conn.hexists(RedisHashKey.JOB_DETAILS, job_name):
logger.error(f"A job named '{job_name}' has already been added.")
return
# Push job config to redis
redis_conn.hset(RedisHashKey.JOB_CONF, job_name, json.dumps(conf))
details = {
"status": JobStatus.PENDING,
"added": time.time()
}
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
def remove_jobs(job_names, **kwargs):
redis_conn = get_redis_conn()
if not redis_conn:
return
for job_name in job_names:
details = redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name)
if not details:
logger.error(f"No job named '{job_name}' has been scheduled or started.")
else:
details = json.loads(details)
details["status"] = JobStatus.REMOVED
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
logger.info(f"Removed job {job_name}")
def describe_job(job_name, **kwargs):
redis_conn = get_redis_conn()
if not redis_conn:
return
details = redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name)
if not details:
logger.error(NO_JOB_MSG.format(job_name))
return
details = json.loads(details)
err = "error_message" in details
if err:
err_msg = details["error_message"].split('\n')
del details["error_message"]
logger.info(details)
if err:
for line in err_msg:
logger.info(line)
def get_job_logs(job_name: str, tail: int = -1, **kwargs):
redis_conn = get_redis_conn()
if not redis_conn.hexists(RedisHashKey.JOB_CONF, job_name):
logger.error(NO_JOB_MSG.format(job_name))
return
conf = json.loads(redis_conn.hget(RedisHashKey.JOB_CONF, job_name))
show_log(conf["log_path"], tail=tail)
def list_jobs(**kwargs):
redis_conn = get_redis_conn()
if not redis_conn:
return
def get_time_diff_string(time_diff):
time_diff = int(time_diff)
days = time_diff // (3600 * 24)
if days:
return f"{days} days"
hours = time_diff // 3600
if hours:
return f"{hours} hours"
minutes = time_diff // 60
if minutes:
return f"{minutes} minutes"
return f"{time_diff} seconds"
# Header
logger.info(JOB_LS_TEMPLATE.format(JOB="JOB", STATUS="STATUS", STARTED="STARTED"))
for job_name, details in redis_conn.hgetall(RedisHashKey.JOB_DETAILS).items():
job_name = job_name.decode()
details = json.loads(details)
if "start_time" in details:
time_diff = f"{get_time_diff_string(time.time() - details['start_time'])} ago"
logger.info(JOB_LS_TEMPLATE.format(JOB=job_name, STATUS=details["status"], STARTED=time_diff))
else:
logger.info(JOB_LS_TEMPLATE.format(JOB=job_name, STATUS=details["status"], STARTED=JobStatus.PENDING))

Просмотреть файл

@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
import threading
import time
import redis
from maro.cli.local.utils import JobStatus, RedisHashKey, poll, start_rl_job, start_rl_job_in_containers, term
from maro.cli.utils.docker import build_image, image_exists
from maro.rl.workflows.config import ConfigParser
if __name__ == "__main__":
redis_port = int(os.getenv("REDIS_PORT", default=19999))
redis_conn = redis.Redis(host="localhost", port=redis_port)
started, max_running = {}, int(os.getenv("MAX_RUNNING", default=1))
query_every = int(os.getenv("QUERY_EVERY", default=5))
sigterm_timeout = int(os.getenv("SIGTERM_TIMEOUT", default=3))
containerize = os.getenv("CONTAINERIZE", default="False") == "True"
local_maro_root = os.getenv("LOCAL_MARO_ROOT")
docker_file_path = os.getenv("DOCKERFILE_PATH")
docker_image_name = os.getenv("DOCKER_IMAGE_NAME")
# thread to monitor a job
def monitor(job_name):
removed, error, err_out, running = False, False, None, started[job_name]
while running:
error, err_out, running = poll(running)
# check if the job has been marked as REMOVED before termination
details = json.loads(redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name))
if details["status"] == JobStatus.REMOVED:
removed = True
break
if error:
break
if removed:
term(started[job_name], job_name, timeout=sigterm_timeout)
redis_conn.hdel(RedisHashKey.JOB_DETAILS, job_name)
redis_conn.hdel(RedisHashKey.JOB_CONF, job_name)
return
if error:
term(started[job_name], job_name, timeout=sigterm_timeout)
details["status"] = JobStatus.ERROR
details["error_message"] = err_out
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
else: # all job processes terminated normally
details["status"] = JobStatus.FINISHED
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
# Continue to monitor if the job is marked as REMOVED
while json.loads(redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name))["status"] != JobStatus.REMOVED:
time.sleep(query_every)
term(started[job_name], job_name, timeout=sigterm_timeout)
redis_conn.hdel(RedisHashKey.JOB_DETAILS, job_name)
redis_conn.hdel(RedisHashKey.JOB_CONF, job_name)
while True:
# check for pending jobs
job_details = redis_conn.hgetall(RedisHashKey.JOB_DETAILS)
if job_details:
num_running, pending = 0, []
for job_name, details in job_details.items():
job_name, details = job_name.decode(), json.loads(details)
if details["status"] == JobStatus.RUNNING:
num_running += 1
elif details["status"] == JobStatus.PENDING:
pending.append((job_name, json.loads(redis_conn.hget(RedisHashKey.JOB_CONF, job_name))))
for job_name, conf in pending[:max(0, max_running - num_running)]:
if containerize and not image_exists(docker_image_name):
redis_conn.hset(
RedisHashKey.JOB_DETAILS, job_name, json.dumps({"status": JobStatus.IMAGE_BUILDING})
)
build_image(local_maro_root, docker_file_path, docker_image_name)
parser = ConfigParser(conf)
if containerize:
path_mapping = parser.get_path_mapping(containerize=True)
started[job_name] = start_rl_job_in_containers(parser, docker_image_name)
details["containers"] = started[job_name]
else:
started[job_name] = start_rl_job(parser, local_maro_root, background=True)
details["pids"] = [proc.pid for proc in started[job_name]]
details = {"status": JobStatus.RUNNING, "start_time": time.time()}
redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details))
threading.Thread(target=monitor, args=(job_name,)).start() # start job monitoring thread
time.sleep(query_every)

195
maro/cli/local/utils.py Normal file
Просмотреть файл

@ -0,0 +1,195 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import subprocess
from copy import deepcopy
from typing import List
import docker
import yaml
from maro.cli.utils.common import format_env_vars
from maro.rl.workflows.config.parser import ConfigParser
class RedisHashKey:
"""Record Redis elements name, and only for maro process"""
JOB_CONF = "job_conf"
JOB_DETAILS = "job_details"
class JobStatus:
PENDING = "pending"
IMAGE_BUILDING = "image_building"
RUNNING = "running"
ERROR = "error"
REMOVED = "removed"
FINISHED = "finished"
def start_redis(port: int):
subprocess.Popen(["redis-server", "--port", str(port)], stdout=subprocess.DEVNULL)
def stop_redis(port: int):
subprocess.Popen(["redis-cli", "-p", str(port), "shutdown"], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
def extract_error_msg_from_docker_log(container: docker.models.containers.Container):
logs = container.logs().decode().splitlines()
for i, log in enumerate(logs):
if "Traceback (most recent call last):" in log:
return "\n".join(logs[i:])
return logs
def check_proc_status(proc):
if isinstance(proc, subprocess.Popen):
if proc.poll() is None:
return True, 0, None
_, err_out = proc.communicate()
return False, proc.returncode, err_out
else:
client = docker.from_env()
container_state = client.api.inspect_container(proc.id)["State"]
return container_state["Running"], container_state["ExitCode"], extract_error_msg_from_docker_log(proc)
def poll(procs):
error, running = False, []
for proc in procs:
is_running, exit_code, err_out = check_proc_status(proc)
if is_running:
running.append(proc)
elif exit_code:
error = True
break
return error, err_out, running
def term(procs, job_name: str, timeout: int = 3):
if isinstance(procs[0], subprocess.Popen):
for proc in procs:
if proc.poll() is None:
try:
proc.terminate()
proc.wait(timeout=timeout)
except subprocess.TimeoutExpired:
proc.kill()
else:
for proc in procs:
try:
proc.stop(timeout=timeout)
proc.remove()
except Exception:
pass
client = docker.from_env()
try:
job_network = client.networks.get(job_name)
job_network.remove()
except Exception:
pass
def exec(cmd: str, env: dict, debug: bool = False) -> subprocess.Popen:
stream = None if debug else subprocess.PIPE
return subprocess.Popen(
cmd.split(), env={**os.environ.copy(), **env}, stdout=stream, stderr=stream, encoding="utf8"
)
def start_rl_job(
parser: ConfigParser, maro_root: str, evaluate_only: bool, background: bool = False,
) -> List[subprocess.Popen]:
procs = [
exec(
f"python {script}" + ("" if not evaluate_only else " --evaluate_only"),
format_env_vars({**env, "PYTHONPATH": maro_root}, mode="proc"),
debug=not background
)
for script, env in parser.get_job_spec().values()
]
if not background:
for proc in procs:
proc.communicate()
return procs
def start_rl_job_in_containers(parser: ConfigParser, image_name: str) -> list:
job_name = parser.config["job"]
client, containers = docker.from_env(), []
training_mode = parser.config["training"]["mode"]
if "parallelism" in parser.config["rollout"]:
rollout_parallelism = max(
parser.config["rollout"]["parallelism"]["sampling"],
parser.config["rollout"]["parallelism"].get("eval", 1)
)
else:
rollout_parallelism = 1
if training_mode != "simple" or rollout_parallelism > 1:
# create the exclusive network for the job
client.networks.create(job_name, driver="bridge")
for component, (script, env) in parser.get_job_spec(containerize=True).items():
# volume mounts for scenario folder, policy loading, checkpointing and logging
container = client.containers.run(
image_name,
command=f"python3 {script}",
detach=True,
name=component,
environment=env,
volumes=[f"{src}:{dst}" for src, dst in parser.get_path_mapping(containerize=True).items()],
network=job_name
)
containers.append(container)
return containers
def get_docker_compose_yml_path(maro_root: str) -> str:
return os.path.join(maro_root, ".tmp", "docker-compose.yml")
def start_rl_job_with_docker_compose(
parser: ConfigParser, context: str, dockerfile_path: str, image_name: str, evaluate_only: bool,
) -> None:
common_spec = {
"build": {"context": context, "dockerfile": dockerfile_path},
"image": image_name,
"volumes": [f"./{src}:{dst}" for src, dst in parser.get_path_mapping(containerize=True).items()]
}
job_name = parser.config["job"]
manifest = {
"version": "3.9",
"services": {
component: {
**deepcopy(common_spec),
**{
"container_name": component,
"command": f"python3 {script}" + ("" if not evaluate_only else " --evaluate_only"),
"environment": format_env_vars(env, mode="docker-compose")
}
}
for component, (script, env) in parser.get_job_spec(containerize=True).items()
},
}
docker_compose_file_path = get_docker_compose_yml_path(maro_root=context)
with open(docker_compose_file_path, "w") as fp:
yaml.safe_dump(manifest, fp)
subprocess.run(
["docker-compose", "--project-name", job_name, "-f", docker_compose_file_path, "up", "--remove-orphans"]
)
def stop_rl_job_with_docker_compose(job_name: str, context: str):
subprocess.run(["docker-compose", "--project-name", job_name, "down"])
os.remove(get_docker_compose_yml_path(maro_root=context))

Просмотреть файл

@ -90,6 +90,15 @@ def main():
parser_k8s.set_defaults(func=_help_func(parser=parser_k8s))
load_parser_k8s(prev_parser=parser_k8s, global_parser=global_parser)
# maro aks
parser_aks = subparsers.add_parser(
"aks",
help="Manage distributed cluster with Kubernetes.",
parents=[global_parser]
)
parser_aks.set_defaults(func=_help_func(parser=parser_aks))
load_parser_aks(prev_parser=parser_aks, global_parser=global_parser)
# maro inspector
parser_inspector = subparsers.add_parser(
'inspector',
@ -99,13 +108,13 @@ def main():
parser_inspector.set_defaults(func=_help_func(parser=parser_inspector))
load_parser_inspector(parser_inspector, global_parser)
# maro process
parser_process = subparsers.add_parser(
"process",
help="Run application by mulit-process to simulate distributed mode."
# maro local
parser_local = subparsers.add_parser(
"local",
help="Run jobs locally."
)
parser_process.set_defaults(func=_help_func(parser=parser_process))
load_parser_process(prev_parser=parser_process, global_parser=global_parser)
parser_local.set_defaults(func=_help_func(parser=parser_local))
load_parser_local(prev_parser=parser_local, global_parser=global_parser)
# maro project
parser_project = subparsers.add_parser(
@ -151,152 +160,128 @@ def main():
logger.error_red(f"{e.__class__.__name__}: {e.get_message()}")
def load_parser_process(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None:
def load_parser_local(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None:
subparsers = prev_parser.add_subparsers()
# maro process create
from maro.cli.process.create import create
parser_setup = subparsers.add_parser(
"create",
help="Create local process environment.",
# maro local run
from maro.cli.local.commands import run
parser = subparsers.add_parser(
"run",
help="Run a job in debug mode.",
examples=CliExamples.MARO_PROCESS_SETUP,
parents=[global_parser]
)
parser_setup.add_argument(
'deployment_path',
help='Path of the local process setting deployment.',
nargs='?',
default=None)
parser_setup.set_defaults(func=create)
parser.add_argument("conf_path", help='Path of the job deployment')
parser.add_argument("-c", "--containerize", action="store_true", help="Whether to run jobs in containers")
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
parser.add_argument("-p", "--port", type=int, default=20000, help="")
parser.set_defaults(func=run)
# maro process delete
from maro.cli.process.delete import delete
parser_setup = subparsers.add_parser(
"delete",
help="Delete the local process environment. Including closing agents and maro Redis.",
# maro local init
from maro.cli.local.commands import init
parser = subparsers.add_parser(
"init",
help="Initialize local job manager.",
examples=CliExamples.MARO_PROCESS_SETUP,
parents=[global_parser]
)
parser_setup.set_defaults(func=delete)
parser.add_argument(
"-p", "--port", type=int, default=19999,
help="Port on local machine to launch the Redis server at. Defaults to 19999."
)
parser.add_argument(
"-m", "--max-running", type=int, default=3,
help="Maximum number of jobs to allow running at the same time. Defaults to 3."
)
parser.add_argument(
"-q", "--query-every", type=int, default=5,
help="Number of seconds to wait between queries to the Redis server for pending or removed jobs. Defaults to 5."
)
parser.add_argument(
"-t", "--timeout", type=int, default=3,
help="""
Number of seconds to wait after sending SIGTERM to a process. If the process does not terminate
during this time, the process will be force-killed through SIGKILL. Defaults to 3.
"""
)
parser.add_argument("-c", "--containerize", action="store_true", help="Whether to run jobs in containers")
parser.set_defaults(func=init)
# maro process job
parser_job = subparsers.add_parser(
# maro local exit
from maro.cli.local.commands import exit
parser = subparsers.add_parser(
"exit",
help="Terminate the local job manager",
parents=[global_parser]
)
parser.set_defaults(func=exit)
# maro local job
parser = subparsers.add_parser(
"job",
help="Manage jobs",
parents=[global_parser]
)
parser_job.set_defaults(func=_help_func(parser=parser_job))
parser_job_subparsers = parser_job.add_subparsers()
parser.set_defaults(func=_help_func(parser=parser))
job_subparsers = parser.add_subparsers()
# maro process job start
from maro.cli.process.job import start_job
parser_job_start = parser_job_subparsers.add_parser(
'start',
help='Start a training job',
# maro local job add
from maro.cli.local.commands import add_job
job_add_parser = job_subparsers.add_parser(
"add",
help="Start an RL job",
examples=CliExamples.MARO_PROCESS_JOB_START,
parents=[global_parser]
)
parser_job_start.add_argument(
'deployment_path', help='Path of the job deployment')
parser_job_start.set_defaults(func=start_job)
job_add_parser.add_argument("conf_path", help='Path of the job deployment')
job_add_parser.set_defaults(func=add_job)
# maro process job stop
from maro.cli.process.job import stop_job
parser_job_stop = parser_job_subparsers.add_parser(
'stop',
help='Stop a training job',
# maro local job rm
from maro.cli.local.commands import remove_jobs
job_stop_parser = job_subparsers.add_parser(
"rm",
help='Stop an RL job',
examples=CliExamples.MARO_PROCESS_JOB_STOP,
parents=[global_parser]
)
parser_job_stop.add_argument(
'job_name', help='Name of the job')
parser_job_stop.set_defaults(func=stop_job)
job_stop_parser.add_argument('job_names', help="Job names", nargs="*")
job_stop_parser.set_defaults(func=remove_jobs)
# maro process job delete
from maro.cli.process.job import delete_job
parser_job_delete = parser_job_subparsers.add_parser(
'delete',
help='delete a stopped job',
examples=CliExamples.MARO_PROCESS_JOB_DELETE,
# maro local job describe
from maro.cli.local.commands import describe_job
job_stop_parser = job_subparsers.add_parser(
"describe",
help="Get the status of an RL job and the error information if the job fails due to some error",
examples=CliExamples.MARO_PROCESS_JOB_STOP,
parents=[global_parser]
)
parser_job_delete.add_argument(
'job_name', help='Name of the job or the schedule')
parser_job_delete.set_defaults(func=delete_job)
job_stop_parser.add_argument('job_name', help='Job name')
job_stop_parser.set_defaults(func=describe_job)
# maro process job list
from maro.cli.process.job import list_jobs
parser_job_list = parser_job_subparsers.add_parser(
'list',
# maro local job ls
from maro.cli.local.commands import list_jobs
job_list_parser = job_subparsers.add_parser(
"ls",
help='List all jobs',
examples=CliExamples.MARO_PROCESS_JOB_LIST,
parents=[global_parser]
)
parser_job_list.set_defaults(func=list_jobs)
job_list_parser.set_defaults(func=list_jobs)
# maro process job logs
from maro.cli.process.job import get_job_logs
parser_job_logs = parser_job_subparsers.add_parser(
'logs',
help='Get logs of the job',
# maro local job logs
from maro.cli.local.commands import get_job_logs
job_logs_parser = job_subparsers.add_parser(
"logs",
help="Get job logs",
examples=CliExamples.MARO_PROCESS_JOB_LOGS,
parents=[global_parser]
)
parser_job_logs.add_argument(
'job_name', help='Name of the job')
parser_job_logs.set_defaults(func=get_job_logs)
# maro process schedule
parser_schedule = subparsers.add_parser(
'schedule',
help='Manage schedules',
parents=[global_parser]
job_logs_parser.add_argument("job_name", help="job name")
job_logs_parser.add_argument(
"-n", "--tail", type=int, default=-1,
help="Number of lines to show from the end of the given job's logs"
)
parser_schedule.set_defaults(func=_help_func(parser=parser_schedule))
parser_schedule_subparsers = parser_schedule.add_subparsers()
# maro process schedule start
from maro.cli.process.schedule import start_schedule
parser_schedule_start = parser_schedule_subparsers.add_parser(
'start',
help='Start a schedule',
examples=CliExamples.MARO_PROCESS_SCHEDULE_START,
parents=[global_parser]
)
parser_schedule_start.add_argument(
'deployment_path', help='Path of the schedule deployment')
parser_schedule_start.set_defaults(func=start_schedule)
# maro process schedule stop
from maro.cli.process.schedule import stop_schedule
parser_schedule_stop = parser_schedule_subparsers.add_parser(
'stop',
help='Stop a schedule',
examples=CliExamples.MARO_PROCESS_SCHEDULE_STOP,
parents=[global_parser]
)
parser_schedule_stop.add_argument(
'schedule_name', help='Name of the schedule')
parser_schedule_stop.set_defaults(func=stop_schedule)
# maro process template
from maro.cli.process.template import template
parser_template = subparsers.add_parser(
"template",
help="Get deployment templates",
examples=CliExamples.MARO_PROCESS_TEMPLATE,
parents=[global_parser]
)
parser_template.add_argument(
"--setting_deploy",
action="store_true",
help="Get environment setting templates"
)
parser_template.add_argument(
"export_path",
default="./",
nargs='?',
help="Path of the export directory")
parser_template.set_defaults(func=template)
job_logs_parser.set_defaults(func=get_job_logs)
def load_parser_grass(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None:
@ -922,6 +907,81 @@ def load_parser_k8s(prev_parser: ArgumentParser, global_parser: ArgumentParser)
parser_template.set_defaults(func=template)
def load_parser_aks(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None:
subparsers = prev_parser.add_subparsers()
# maro aks create
from maro.cli.k8s.aks.aks_commands import init
parser_create = subparsers.add_parser(
"init",
help="""
Deploy resources and start required services on Azure. The configuration file template can be found
in cli/k8s/aks/conf.yml. Use the Azure CLI to log into your Azure account (az login ...) and the the
Azure Container Registry (az acr login ...) first.
""",
examples=CliExamples.MARO_K8S_CREATE,
parents=[global_parser]
)
parser_create.add_argument("deployment_conf_path", help="Path of the deployment configuration file")
parser_create.set_defaults(func=init)
# maro aks exit
from maro.cli.k8s.aks.aks_commands import exit
parser_create = subparsers.add_parser(
"exit",
help="Delete deployed resources",
examples=CliExamples.MARO_K8S_DELETE,
parents=[global_parser]
)
parser_create.set_defaults(func=exit)
# maro aks job
parser_job = subparsers.add_parser(
"job",
help="AKS job-related commands",
parents=[global_parser]
)
parser_job.set_defaults(func=_help_func(parser=parser_job))
job_subparsers = parser_job.add_subparsers()
# maro aks job add
from maro.cli.k8s.aks.aks_commands import add_job
parser_job_start = job_subparsers.add_parser(
"add",
help="Add an RL job to the AKS cluster",
examples=CliExamples.MARO_K8S_JOB_START,
parents=[global_parser]
)
parser_job_start.add_argument("conf_path", help="Path to the job configuration file")
parser_job_start.set_defaults(func=add_job)
# maro aks job rm
from maro.cli.k8s.aks.aks_commands import remove_jobs
parser_job_start = job_subparsers.add_parser(
"rm",
help="Remove previously scheduled RL jobs from the AKS cluster",
examples=CliExamples.MARO_K8S_JOB_START,
parents=[global_parser]
)
parser_job_start.add_argument("job_names", help="Name of job to be removed", nargs="*")
parser_job_start.set_defaults(func=remove_jobs)
# maro aks job logs
from maro.cli.k8s.aks.aks_commands import get_job_logs
job_logs_parser = job_subparsers.add_parser(
"logs",
help="Get job logs",
examples=CliExamples.MARO_PROCESS_JOB_LOGS,
parents=[global_parser]
)
job_logs_parser.add_argument("job_name", help="job name")
job_logs_parser.add_argument(
"-n", "--tail", type=int, default=-1,
help="Number of lines to show from the end of the given job's logs"
)
job_logs_parser.set_defaults(func=get_job_logs)
def load_parser_data(prev_parser: ArgumentParser, global_parser: ArgumentParser):
data_cmd_sub_parsers = prev_parser.add_subparsers()

Просмотреть файл

@ -1,206 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import multiprocessing as mp
import os
import subprocess
import time
import psutil
import redis
from maro.cli.grass.lib.services.utils.params import JobStatus
from maro.cli.process.utils.details import close_by_pid, get_child_pid
from maro.cli.utils.details_reader import DetailsReader
from maro.cli.utils.params import LocalPaths, ProcessRedisName
class PendingJobAgent(mp.Process):
def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60):
super().__init__()
self.cluster_detail = cluster_detail
self.redis_connection = redis_connection
self.check_interval = check_interval
def run(self):
while True:
self._check_pending_ticket()
time.sleep(self.check_interval)
def _check_pending_ticket(self):
# Check pending job ticket
pending_jobs = self.redis_connection.lrange(ProcessRedisName.PENDING_JOB_TICKETS, 0, -1)
running_jobs_length = len(JobTrackingAgent.get_running_jobs(
self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)
))
parallel_level = self.cluster_detail["parallel_level"]
for job_name in pending_jobs:
job_detail = json.loads(self.redis_connection.hget(ProcessRedisName.JOB_DETAILS, job_name))
# Start pending job only if current running job's number less than parallel level.
if int(parallel_level) > running_jobs_length:
self._start_job(job_detail)
self.redis_connection.lrem(ProcessRedisName.PENDING_JOB_TICKETS, 0, job_name)
running_jobs_length += 1
def _start_job(self, job_details: dict):
command_pid_list = []
for component_type, command_info in job_details["components"].items():
component_number = command_info["num"]
component_command = f"JOB_NAME={job_details['name']} " + command_info["command"]
for number in range(component_number):
job_local_path = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_details['name']}")
if not os.path.exists(job_local_path):
os.makedirs(job_local_path)
with open(f"{job_local_path}/{component_type}_{number}.log", "w") as log_file:
proc = subprocess.Popen(component_command, shell=True, stdout=log_file)
command_pid = get_child_pid(proc.pid)
if not command_pid:
command_pid_list.append(proc.pid)
else:
command_pid_list.append(command_pid)
job_details["status"] = JobStatus.RUNNING
job_details["pid_list"] = command_pid_list
self.redis_connection.hset(ProcessRedisName.JOB_DETAILS, job_details["name"], json.dumps(job_details))
class JobTrackingAgent(mp.Process):
def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60):
super().__init__()
self.cluster_detail = cluster_detail
self.redis_connection = redis_connection
self.check_interval = check_interval
self._shutdown_count = 0
self._countdown = cluster_detail["agent_countdown"]
def run(self):
while True:
self._check_job_status()
time.sleep(self.check_interval)
keep_alive = self.cluster_detail["keep_agent_alive"]
if not keep_alive:
self._close_agents()
def _check_job_status(self):
running_jobs = self.get_running_jobs(self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS))
for running_job_name, running_job_detail in running_jobs.items():
# Check pid status
still_alive = False
for pid in running_job_detail["pid_list"]:
if psutil.pid_exists(pid):
still_alive = True
# Update if no pid exists
if not still_alive:
running_job_detail["status"] = JobStatus.FINISH
del running_job_detail["pid_list"]
self.redis_connection.hset(
ProcessRedisName.JOB_DETAILS,
running_job_name,
json.dumps(running_job_detail)
)
@staticmethod
def get_running_jobs(job_details: dict):
running_jobs = {}
for job_name, job_detail in job_details.items():
job_detail = json.loads(job_detail)
if job_detail["status"] == JobStatus.RUNNING:
running_jobs[job_name.decode()] = job_detail
return running_jobs
def _close_agents(self):
if (
not len(
JobTrackingAgent.get_running_jobs(self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS))
) and
not self.redis_connection.llen(ProcessRedisName.PENDING_JOB_TICKETS)
):
self._shutdown_count += 1
else:
self._shutdown_count = 0
if self._shutdown_count >= self._countdown:
agent_pid = int(self.redis_connection.hget(ProcessRedisName.SETTING, "agent_pid"))
# close agent
close_by_pid(pid=agent_pid, recursive=True)
# Set agent status to 0
self.redis_connection.hset(ProcessRedisName.SETTING, "agent_status", 0)
class KilledJobAgent(mp.Process):
def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60):
super().__init__()
self.cluster_detail = cluster_detail
self.redis_connection = redis_connection
self.check_interval = check_interval
def run(self):
while True:
self._check_killed_tickets()
time.sleep(self.check_interval)
def _check_killed_tickets(self):
# Check pending job ticket
killed_job_names = self.redis_connection.lrange(ProcessRedisName.KILLED_JOB_TICKETS, 0, -1)
for job_name in killed_job_names:
job_detail = json.loads(self.redis_connection.hget(ProcessRedisName.JOB_DETAILS, job_name))
if job_detail["status"] == JobStatus.RUNNING:
close_by_pid(pid=job_detail["pid_list"], recursive=False)
del job_detail["pid_list"]
elif job_detail["status"] == JobStatus.PENDING:
self.redis_connection.lrem(ProcessRedisName.PENDING_JOB_TICKETS, 0, job_name)
elif job_detail["status"] == JobStatus.FINISH:
continue
job_detail["status"] = JobStatus.KILLED
self.redis_connection.hset(ProcessRedisName.JOB_DETAILS, job_name, json.dumps(job_detail))
self.redis_connection.lrem(ProcessRedisName.KILLED_JOB_TICKETS, 0, job_name)
class MasterAgent:
def __init__(self):
self.cluster_detail = DetailsReader.load_cluster_details("process")
self.check_interval = self.cluster_detail["check_interval"]
self.redis_connection = redis.Redis(
host=self.cluster_detail["redis_info"]["host"],
port=self.cluster_detail["redis_info"]["port"]
)
self.redis_connection.hset(ProcessRedisName.SETTING, "agent_pid", os.getpid())
def start(self) -> None:
"""Start agents."""
pending_job_agent = PendingJobAgent(
cluster_detail=self.cluster_detail,
redis_connection=self.redis_connection,
check_interval=self.check_interval
)
pending_job_agent.start()
killed_job_agent = KilledJobAgent(
cluster_detail=self.cluster_detail,
redis_connection=self.redis_connection,
check_interval=self.check_interval
)
killed_job_agent.start()
job_tracking_agent = JobTrackingAgent(
cluster_detail=self.cluster_detail,
redis_connection=self.redis_connection,
check_interval=self.check_interval
)
job_tracking_agent.start()
if __name__ == "__main__":
master_agent = MasterAgent()
master_agent.start()

Просмотреть файл

@ -1,93 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import multiprocessing as mp
import os
import time
import redis
from maro.cli.utils.params import LocalParams
from maro.cli.utils.resource_executor import ResourceInfo
from maro.utils.exception.cli_exception import BadRequestError
class ResourceTrackingAgent(mp.Process):
def __init__(
self,
check_interval: int = 30
):
super().__init__()
self._redis_connection = redis.Redis(host="localhost", port=LocalParams.RESOURCE_REDIS_PORT)
try:
if self._redis_connection.hexists(LocalParams.RESOURCE_INFO, "check_interval"):
self._check_interval = int(self._redis_connection.hget(LocalParams.RESOURCE_INFO, "check_interval"))
else:
self._check_interval = check_interval
except Exception:
raise BadRequestError(
"Failure to connect to Resource Redis."
"Please make sure at least one cluster running."
)
self._set_resource_info()
def _set_resource_info(self):
# Set resource agent pid.
self._redis_connection.hset(
LocalParams.RESOURCE_INFO,
"agent_pid",
os.getpid()
)
# Set resource agent check interval.
self._redis_connection.hset(
LocalParams.RESOURCE_INFO,
"check_interval",
json.dumps(self._check_interval)
)
# Push static resource information into Redis.
resource = ResourceInfo.get_static_info()
self._redis_connection.hset(
LocalParams.RESOURCE_INFO,
"resource",
json.dumps(resource)
)
def run(self) -> None:
"""Start tracking node status and updating details.
Returns:
None.
"""
while True:
start_time = time.time()
self.push_local_resource_usage()
time.sleep(max(self._check_interval - (time.time() - start_time), 0))
self._check_interval = int(self._redis_connection.hget(LocalParams.RESOURCE_INFO, "check_interval"))
def push_local_resource_usage(self):
resource_usage = ResourceInfo.get_dynamic_info(self._check_interval)
self._redis_connection.rpush(
LocalParams.CPU_USAGE,
json.dumps(resource_usage["cpu_usage_per_core"])
)
self._redis_connection.rpush(
LocalParams.MEMORY_USAGE,
json.dumps(resource_usage["memory_usage"])
)
self._redis_connection.rpush(
LocalParams.GPU_USAGE,
json.dumps(resource_usage["gpu_memory_usage"])
)
if __name__ == "__main__":
resource_agent = ResourceTrackingAgent()
resource_agent.start()

Просмотреть файл

@ -1,18 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import yaml
from maro.cli.process.executor import ProcessExecutor
from maro.cli.process.utils.default_param import process_setting
def create(deployment_path: str, **kwargs):
if deployment_path is not None:
with open(deployment_path, "r") as fr:
create_deployment = yaml.safe_load(fr)
else:
create_deployment = process_setting
executor = ProcessExecutor(create_deployment)
executor.create()

Просмотреть файл

@ -1,9 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from maro.cli.process.executor import ProcessExecutor
def delete(**kwargs):
executor = ProcessExecutor()
executor.delete()

Просмотреть файл

@ -1,10 +0,0 @@
mode: process
name: MyJobName # str: name of the training job
components: # component config
actor:
num: 5 # int: number of this component
command: "python /target/path/run_actor.py" # str: command to be executed
learner:
num: 1
command: "python /target/path/run_learner.py"

Просмотреть файл

@ -1,16 +0,0 @@
mode: process
name: MyScheduleName # str: name of the training schedule
job_names: # list: names of the training job
- MyJobName2
- MyJobName3
- MyJobName4
- MyJobName5
components: # component config
actor:
num: 5 # int: number of this component
command: "python /target/path/run_actor.py" # str: command to be executed
learner:
num: 1
command: "python /target/path/run_learner.py"

Просмотреть файл

@ -1,8 +0,0 @@
redis_info:
host: "localhost"
port: 19999
redis_mode: MARO # one of MARO, customized. customized Redis won't be exited after maro process clear.
parallel_level: 1 # Represented the maximum number of running jobs in the same times.
keep_agent_alive: True # If True represented the agents won't exit until the environment delete; otherwise, False.
agent_countdown: 5 # After agent_countdown times checks, still no jobs will close agents. Available only if keep_agent_alive is 0.
check_interval: 60 # The time interval (seconds) of agents check with Redis

Просмотреть файл

@ -1,248 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import json
import os
import shutil
import subprocess
import redis
import yaml
from maro.cli.grass.lib.services.utils.params import JobStatus
from maro.cli.process.utils.details import close_by_pid, get_redis_pid_by_port
from maro.cli.utils.abs_visible_executor import AbsVisibleExecutor
from maro.cli.utils.details_reader import DetailsReader
from maro.cli.utils.details_writer import DetailsWriter
from maro.cli.utils.params import GlobalPaths, LocalPaths, ProcessRedisName
from maro.cli.utils.resource_executor import LocalResourceExecutor
from maro.utils.logger import CliLogger
logger = CliLogger(name=__name__)
class ProcessExecutor(AbsVisibleExecutor):
def __init__(self, details: dict = None):
self.details = details if details else \
DetailsReader.load_cluster_details("process")
# Connection with Redis
redis_port = self.details["redis_info"]["port"]
self._redis_connection = redis.Redis(host="localhost", port=redis_port)
try:
self._redis_connection.ping()
except Exception:
redis_process = subprocess.Popen(
["redis-server", "--port", str(redis_port), "--daemonize yes"]
)
redis_process.wait(timeout=2)
# Connection with Resource Redis
self._resource_redis = LocalResourceExecutor()
def create(self):
logger.info("Starting MARO Multi-Process Mode.")
if os.path.isdir(f"{GlobalPaths.ABS_MARO_CLUSTERS}/process"):
logger.warning("Process mode has been created.")
# Get environment setting
DetailsWriter.save_cluster_details(
cluster_name="process",
cluster_details=self.details
)
# Start agents
command = f"python {LocalPaths.MARO_PROCESS_AGENT}"
_ = subprocess.Popen(command, shell=True)
self._redis_connection.hset(ProcessRedisName.SETTING, "agent_status", 1)
# Add connection to resource Redis.
self._resource_redis.add_cluster()
logger.info(f"MARO process mode setting: {self.details}")
def delete(self):
process_setting = self._redis_connection.hgetall(ProcessRedisName.SETTING)
process_setting = {
key.decode(): json.loads(value) for key, value in process_setting.items()
}
# Stop running jobs
jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)
if jobs:
for job_name, job_detail in jobs.items():
job_detail = json.loads(job_detail)
if job_detail["status"] == JobStatus.RUNNING:
close_by_pid(pid=job_detail["pid_list"], recursive=False)
logger.info(f"Stop running job {job_name.decode()}.")
# Stop agents
agent_status = int(process_setting["agent_status"])
if agent_status:
agent_pid = int(process_setting["agent_pid"])
close_by_pid(pid=agent_pid, recursive=True)
logger.info("Close agents.")
else:
logger.info("Agents is already closed.")
# Stop Redis or clear Redis
redis_mode = self.details["redis_mode"]
if redis_mode == "MARO":
redis_pid = get_redis_pid_by_port(self.details["redis_info"]["port"])
close_by_pid(pid=redis_pid, recursive=False)
else:
self._redis_clear()
# Rm connection from resource redis.
self._resource_redis.sub_cluster()
logger.info("Redis cleared.")
# Remove local process file.
shutil.rmtree(f"{GlobalPaths.ABS_MARO_CLUSTERS}/process", True)
logger.info("Process mode has been deleted.")
def _redis_clear(self):
redis_keys = self._redis_connection.keys("process:*")
for key in redis_keys:
self._redis_connection.delete(key)
def start_job(self, deployment_path: str):
# Load start_job_deployment
with open(deployment_path, "r") as fr:
start_job_deployment = yaml.safe_load(fr)
job_name = start_job_deployment["name"]
start_job_deployment["status"] = JobStatus.PENDING
# Push job details to redis
self._redis_connection.hset(
ProcessRedisName.JOB_DETAILS,
job_name,
json.dumps(start_job_deployment)
)
self._push_pending_job(job_name)
def _push_pending_job(self, job_name: str):
# Push job name to pending_job_tickets
self._redis_connection.lpush(
ProcessRedisName.PENDING_JOB_TICKETS,
job_name
)
logger.info(f"Sending {job_name} into pending job tickets.")
def stop_job(self, job_name: str):
if not self._redis_connection.hexists(ProcessRedisName.JOB_DETAILS, job_name):
logger.error(f"No such job '{job_name}' in Redis.")
return
# push job_name into kill_job_tickets
self._redis_connection.lpush(
ProcessRedisName.KILLED_JOB_TICKETS,
job_name
)
logger.info(f"Sending {job_name} into killed job tickets.")
def delete_job(self, job_name: str):
# Stop job for running and pending job.
self.stop_job(job_name)
# Rm job details in Redis
self._redis_connection.hdel(ProcessRedisName.JOB_DETAILS, job_name)
# Rm job's log folder
job_folder = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_name}")
shutil.rmtree(job_folder, True)
logger.info(f"Remove local temporary log folder {job_folder}.")
def get_job_logs(self, job_name):
source_path = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_name}")
if not os.path.exists(source_path):
logger.error(f"Cannot find the logs of {job_name}.")
destination = os.path.join(os.getcwd(), job_name)
if os.path.exists(destination):
shutil.rmtree(destination)
shutil.copytree(source_path, destination)
logger.info(f"Dump logs in path: {destination}.")
def list_job(self):
# Get all jobs
jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)
for job_name, job_detail in jobs.items():
job_name = job_name.decode()
job_detail = json.loads(job_detail)
logger.info(job_detail)
def start_schedule(self, deployment_path: str):
with open(deployment_path, "r") as fr:
schedule_detail = yaml.safe_load(fr)
# push schedule details to Redis
self._redis_connection.hset(
ProcessRedisName.JOB_DETAILS,
schedule_detail["name"],
json.dumps(schedule_detail)
)
job_list = schedule_detail["job_names"]
# switch schedule details into job details
job_detail = copy.deepcopy(schedule_detail)
del job_detail["job_names"]
for job_name in job_list:
job_detail["name"] = job_name
# Push job details to redis
self._redis_connection.hset(
ProcessRedisName.JOB_DETAILS,
job_name,
json.dumps(job_detail)
)
self._push_pending_job(job_name)
def stop_schedule(self, schedule_name: str):
if self._redis_connection.hexists(ProcessRedisName.JOB_DETAILS, schedule_name):
schedule_details = json.loads(self._redis_connection.hget(ProcessRedisName.JOB_DETAILS, schedule_name))
else:
logger.error(f"Cannot find {schedule_name} in Redis. Please check schedule name.")
return
if "job_names" not in schedule_details.keys():
logger.error(f"'{schedule_name}' is not a schedule.")
return
job_list = schedule_details["job_names"]
for job_name in job_list:
self.stop_job(job_name)
def get_job_details(self):
jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)
for job_name, job_details_str in jobs.items():
jobs[job_name] = json.loads(job_details_str)
return list(jobs.values())
def get_job_queue(self):
pending_job_queue = self._redis_connection.lrange(
ProcessRedisName.PENDING_JOB_TICKETS,
0, -1
)
killed_job_queue = self._redis_connection.lrange(
ProcessRedisName.KILLED_JOB_TICKETS,
0, -1
)
return {
"pending_jobs": pending_job_queue,
"killed_jobs": killed_job_queue
}
def get_resource(self):
return self._resource_redis.get_local_resource()
def get_resource_usage(self, previous_length: int):
return self._resource_redis.get_local_resource_usage(previous_length)

Просмотреть файл

@ -1,30 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from maro.cli.process.executor import ProcessExecutor
def start_job(deployment_path: str, **kwargs):
executor = ProcessExecutor()
executor.start_job(deployment_path=deployment_path)
def stop_job(job_name: str, **kwargs):
executor = ProcessExecutor()
executor.stop_job(job_name=job_name)
def delete_job(job_name: str, **kwargs):
executor = ProcessExecutor()
executor.delete_job(job_name=job_name)
def list_jobs(**kwargs):
executor = ProcessExecutor()
executor.list_job()
def get_job_logs(job_name: str, **kwargs):
executor = ProcessExecutor()
executor.get_job_logs(job_name=job_name)

Просмотреть файл

@ -1,15 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from maro.cli.process.executor import ProcessExecutor
def start_schedule(deployment_path: str, **kwargs):
executor = ProcessExecutor()
executor.start_schedule(deployment_path=deployment_path)
def stop_schedule(schedule_name: str, **kwargs):
executor = ProcessExecutor()
executor.stop_schedule(schedule_name=schedule_name)

Просмотреть файл

@ -1,17 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import shutil
from maro.cli.utils.params import LocalPaths
def template(setting_deploy, export_path, **kwargs):
deploy_files = os.listdir(LocalPaths.MARO_PROCESS_DEPLOYMENT)
if not setting_deploy:
deploy_files.remove("process_setting_deployment.yml")
export_path = os.path.abspath(export_path)
for file_name in deploy_files:
if os.path.isfile(f"{LocalPaths.MARO_PROCESS_DEPLOYMENT}/{file_name}"):
shutil.copy(f"{LocalPaths.MARO_PROCESS_DEPLOYMENT}/{file_name}", export_path)

Просмотреть файл

@ -1,15 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
process_setting = {
"redis_info": {
"host": "localhost",
"port": 19999
},
"redis_mode": "MARO", # one of MARO, customized. customized Redis won't exit after maro process clear.
"parallel_level": 1,
"keep_agent_alive": 1, # If 0 (False), agents will exit after 5 minutes of no pending jobs and running jobs.
"check_interval": 60, # seconds
"agent_countdown": 5 # how many times to shutdown agents about finding no job in Redis.
}

Просмотреть файл

@ -1,54 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import signal
import subprocess
from typing import Union
import psutil
def close_by_pid(pid: Union[int, list], recursive: bool = False):
if isinstance(pid, int):
if not psutil.pid_exists(pid):
return
if recursive:
current_process = psutil.Process(pid)
children_process = current_process.children(recursive=False)
# May launch by JobTrackingAgent which is child process, so need close parent process first.
current_process.kill()
for child_process in children_process:
child_process.kill()
else:
os.kill(pid, signal.SIGKILL)
else:
for p in pid:
if psutil.pid_exists(p):
os.kill(p, signal.SIGKILL)
def get_child_pid(parent_pid):
command = f"ps -o pid --ppid {parent_pid} --noheaders"
get_children_pid_process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)
children_pids = get_children_pid_process.stdout.read()
get_children_pid_process.wait(timeout=2)
# Convert into list or int
try:
children_pids = int(children_pids)
except ValueError:
children_pids = children_pids.decode().split("\n")
children_pids = [int(pid) for pid in children_pids[:-1]]
return children_pids
def get_redis_pid_by_port(port: int):
get_redis_pid_command = f"pidof 'redis-server *:{port}'"
get_redis_pid_process = subprocess.Popen(get_redis_pid_command, shell=True, stdout=subprocess.PIPE)
redis_pid = int(get_redis_pid_process.stdout.read())
get_redis_pid_process.wait()
return redis_pid

Просмотреть файл

@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from maro.cli.utils.subprocess import Subprocess
def login_acr(acr_name: str) -> None:
command = f"az acr login --name {acr_name}"
_ = Subprocess.run(command=command)
def list_acr_repositories(acr_name: str) -> list:
command = f"az acr repository list -n {acr_name}"
return_str = Subprocess.run(command=command)
return json.loads(return_str)

Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import subprocess
from azure.identity import DefaultAzureCredential
from azure.mgmt.authorization import AuthorizationManagementClient
from azure.mgmt.containerservice import ContainerServiceClient
from maro.cli.utils.subprocess import Subprocess
def get_container_service_client(subscription: str):
return ContainerServiceClient(DefaultAzureCredential(), subscription)
def get_authorization_client(subscription: str):
return AuthorizationManagementClient()
def load_aks_context(resource_group: str, aks_name: str) -> None:
command = f"az aks get-credentials -g {resource_group} --name {aks_name}"
_ = Subprocess.run(command=command)
def get_aks(subscription: str, resource_group: str, aks_name: str) -> dict:
container_service_client = get_container_service_client(subscription)
return container_service_client.managed_clusters.get(resource_group, aks_name)
def attach_acr(resource_group: str, aks_name: str, acr_name: str) -> None:
subprocess.run(f"az aks update -g {resource_group} -n {aks_name} --attach-acr {acr_name}".split())
def add_nodepool(resource_group: str, aks_name: str, nodepool_name: str, node_count: int, node_size: str) -> None:
command = (
f"az aks nodepool add "
f"-g {resource_group} "
f"--cluster-name {aks_name} "
f"--name {nodepool_name} "
f"--node-count {node_count} "
f"--node-vm-size {node_size}"
)
_ = Subprocess.run(command=command)
def scale_nodepool(resource_group: str, aks_name: str, nodepool_name: str, node_count: int) -> None:
command = (
f"az aks nodepool scale "
f"-g {resource_group} "
f"--cluster-name {aks_name} "
f"--name {nodepool_name} "
f"--node-count {node_count}"
)
_ = Subprocess.run(command=command)

Просмотреть файл

@ -0,0 +1,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .general import get_resource_client
def create_deployment(
subscription: str,
resource_group: str,
deployment_name: str,
template: dict,
params: dict,
sync: bool = True
) -> None:
params = {k: {"value": v} for k, v in params.items()}
resource_client = get_resource_client(subscription)
deployment_params = {"mode": "Incremental", "template": template, "parameters": params}
result = resource_client.deployments.begin_create_or_update(
resource_group, deployment_name, {"properties": deployment_params}
)
if sync:
result.result()
def delete_deployment(subscription: str, resource_group: str, deployment_name: str) -> None:
resource_client = get_resource_client(subscription)
resource_client.deployments.begin_delete(resource_group, deployment_name)

Просмотреть файл

@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
import subprocess
from azure.identity import DefaultAzureCredential
from azure.mgmt.resource import ResourceManagementClient
from maro.cli.utils.subprocess import Subprocess
def set_subscription(subscription: str) -> None:
command = f"az account set --subscription {subscription}"
_ = Subprocess.run(command=command)
def get_version() -> dict:
command = "az version"
return_str = Subprocess.run(command=command)
return json.loads(return_str)
def get_resource_client(subscription: str):
return ResourceManagementClient(DefaultAzureCredential(), subscription)
def set_env_credentials(dump_path: str, service_principal_name: str):
os.makedirs(dump_path, exist_ok=True)
service_principal_file_path = os.path.join(dump_path, f"{service_principal_name}.json")
# If the service principal file does not exist, create one using the az CLI command.
# For details on service principals, refer to
# https://docs.microsoft.com/en-us/azure/active-directory/develop/app-objects-and-service-principals
if not os.path.exists(service_principal_file_path):
with open(service_principal_file_path, 'w') as fp:
subprocess.run(
f"az ad sp create-for-rbac --name {service_principal_name} --sdk-auth --role contributor".split(),
stdout=fp
)
with open(service_principal_file_path, 'r') as fp:
service_principal = json.load(fp)
os.environ["AZURE_TENANT_ID"] = service_principal["tenantId"]
os.environ["AZURE_CLIENT_ID"] = service_principal["clientId"]
os.environ["AZURE_CLIENT_SECRET"] = service_principal["clientSecret"]
os.environ["AZURE_SUBSCRIPTION_ID"] = service_principal["subscriptionId"]
def connect_to_aks(resource_group: str, aks: str):
subprocess.run(f"az aks get-credentials --resource-group {resource_group} --name {aks}".split())
def get_acr_push_permissions(service_principal_id: str, acr: str):
acr_id = json.loads(
subprocess.run(f"az acr show --name {acr} --query id".split(), stdout=subprocess.PIPE).stdout
)
subprocess.run(
f"az role assignment create --assignee {service_principal_id} --scope {acr_id} --role acrpush".split()
)
subprocess.run(f"az acr login --name {acr}".split())

Просмотреть файл

@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from maro.cli.utils.subprocess import Subprocess
from maro.utils.exception.cli_exception import CommandExecutionError
from .general import get_resource_client
def get_resource_group(resource_group: str) -> dict:
command = f"az group show --name {resource_group}"
try:
return_str = Subprocess.run(command=command)
return json.loads(return_str)
except CommandExecutionError:
return {}
def delete_resource_group(resource_group: str) -> None:
command = f"az group delete --yes --name {resource_group}"
_ = Subprocess.run(command=command)
# Chained Azure resource group operations
def create_resource_group(subscription: str, resource_group: str, location: str):
"""Create the resource group if it does not exist.
Args:
subscription (str): Azure subscription name.
resource group (str): Resource group name.
location (str): Reousrce group location.
Returns:
None.
"""
resource_client = get_resource_client(subscription)
return resource_client.resource_groups.create_or_update(resource_group, {"location": location})
def delete_resource_group_under_subscription(subscription: str, resource_group: str):
resource_client = get_resource_client(subscription)
return resource_client.resource_groups.begin_delete(resource_group)

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from maro.cli.utils.subprocess import Subprocess
def list_resources(resource_group: str) -> list:
command = f"az resource list -g {resource_group}"
return_str = Subprocess.run(command=command)
return json.loads(return_str)
def delete_resources(resource_ids: list) -> None:
command = f"az resource delete --ids {' '.join(resource_ids)}"
_ = Subprocess.run(command=command)
def cleanup(cluster_name: str, resource_group: str) -> None:
# Get resource list
resource_list = list_resources(resource_group)
# Filter resources
deletable_ids = []
for resource in resource_list:
if resource["name"].startswith(cluster_name):
deletable_ids.append(resource["id"])
# Delete resources
if deletable_ids:
delete_resources(resource_ids=deletable_ids)

Просмотреть файл

@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import datetime
import json
import os
from typing import Union
from azure.core.exceptions import ResourceExistsError
from azure.storage.fileshare import ShareClient, ShareDirectoryClient
from maro.cli.utils.subprocess import Subprocess
def get_storage_account_keys(resource_group: str, storage_account_name: str) -> dict:
command = f"az storage account keys list -g {resource_group} --account-name {storage_account_name}"
return_str = Subprocess.run(command=command)
return json.loads(return_str)
def get_storage_account_sas(
account_name: str,
services: str = "bqtf",
resource_types: str = "sco",
permissions: str = "rwdlacup",
expiry: str = (datetime.datetime.utcnow() + datetime.timedelta(days=365)).strftime("%Y-%m-%dT%H:%M:%S") + "Z"
) -> str:
command = (
f"az storage account generate-sas --account-name {account_name} --services {services} "
f"--resource-types {resource_types} --permissions {permissions} --expiry {expiry}"
)
sas_str = Subprocess.run(command=command).strip("\n").replace('"', "")
# logger.debug(sas_str)
return sas_str
def get_connection_string(storage_account_name: str) -> str:
"""Get the connection string for a storage account.
Args:
storage_account_name: The storage account name.
Returns:
str: Connection string.
"""
command = f"az storage account show-connection-string --name {storage_account_name}"
return_str = Subprocess.run(command=command)
return json.loads(return_str)["connectionString"]
def get_fileshare(storage_account_name: str, fileshare_name: str):
connection_string = get_connection_string(storage_account_name)
share = ShareClient.from_connection_string(connection_string, fileshare_name)
try:
share.create_share()
except ResourceExistsError:
pass
return share
def get_directory(share: Union[ShareClient, ShareDirectoryClient], name: str):
if isinstance(share, ShareClient):
directory = share.get_directory_client(directory_path=name)
try:
directory.create_directory()
except ResourceExistsError:
pass
return directory
elif isinstance(share, ShareDirectoryClient):
try:
return share.create_subdirectory(name)
except ResourceExistsError:
return share.get_subdirectory_client(name)
def upload_to_fileshare(share: Union[ShareClient, ShareDirectoryClient], source_path: str, name: str = None):
if os.path.isdir(source_path):
if not name:
name = os.path.basename(source_path)
directory = get_directory(share, name)
for file in os.listdir(source_path):
upload_to_fileshare(directory, os.path.join(source_path, file))
else:
with open(source_path, "rb") as fp:
share.upload_file(file_name=os.path.basename(source_path), data=fp)
def download_from_fileshare(share: ShareDirectoryClient, file_name: str, local_path: str):
file = share.get_file_client(file_name=file_name)
with open(local_path, "wb") as fp:
fp.write(file.download_file().readall())
def delete_directory(share: Union[ShareClient, ShareDirectoryClient], name: str, recursive: bool = True):
share.delete_directory(directory_name=name)

Просмотреть файл

@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from maro.cli.utils.subprocess import Subprocess
def list_ip_addresses(resource_group: str, vm_name: str) -> list:
command = f"az vm list-ip-addresses -g {resource_group} --name {vm_name}"
return_str = Subprocess.run(command=command)
return json.loads(return_str)
def start_vm(resource_group: str, vm_name: str) -> None:
command = f"az vm start -g {resource_group} --name {vm_name}"
_ = Subprocess.run(command=command)
def stop_vm(resource_group: str, vm_name: str) -> None:
command = f"az vm stop -g {resource_group} --name {vm_name}"
_ = Subprocess.run(command=command)
def list_vm_sizes(location: str) -> list:
command = f"az vm list-sizes -l {location}"
return_str = Subprocess.run(command=command)
return json.loads(return_str)
def deallocate_vm(resource_group: str, vm_name: str) -> None:
command = f"az vm deallocate --resource-group {resource_group} --name {vm_name}"
_ = Subprocess.run(command=command)
def generalize_vm(resource_group: str, vm_name: str) -> None:
command = f"az vm generalize --resource-group {resource_group} --name {vm_name}"
_ = Subprocess.run(command=command)
def create_image_from_vm(resource_group: str, image_name: str, vm_name: str) -> None:
command = f"az image create --resource-group {resource_group} --name {image_name} --source {vm_name}"
_ = Subprocess.run(command=command)
def get_image_resource_id(resource_group: str, image_name: str) -> str:
command = f"az image show --resource-group {resource_group} --name {image_name}"
return_str = Subprocess.run(command=command)
return json.loads(return_str)["id"]

Просмотреть файл

@ -1,7 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import subprocess
import sys
from collections import deque
import psutil
from maro.utils import Logger
def close_by_pid(pid: int, recursive: bool = True):
if not psutil.pid_exists(pid):
return
proc = psutil.Process(pid)
if recursive:
for child in proc.children(recursive=recursive):
child.kill()
proc.kill()
def get_child_pids(parent_pid):
# command = f"ps -o pid --ppid {parent_pid} --noheaders"
# get_children_pid_process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)
# children_pids = get_children_pid_process.stdout.read()
# get_children_pid_process.wait(timeout=2)
# # Convert into list or int
# try:
# children_pids = int(children_pids)
# except ValueError:
# children_pids = children_pids.decode().split("\n")
# children_pids = [int(pid) for pid in children_pids[:-1]]
# return children_pids
try:
return [child.pid for child in psutil.Process(parent_pid).children(recursive=True)]
except psutil.NoSuchProcess:
print(f"No process with PID {parent_pid} found")
return
def get_redis_pid_by_port(port: int):
get_redis_pid_command = f"pidof 'redis-server *:{port}'"
get_redis_pid_process = subprocess.Popen(get_redis_pid_command, shell=True, stdout=subprocess.PIPE)
redis_pid = int(get_redis_pid_process.stdout.read())
get_redis_pid_process.wait()
return redis_pid
def exit(state: int = 0, msg: str = None):
@ -10,3 +58,75 @@ def exit(state: int = 0, msg: str = None):
sys.stderr.write(msg)
sys.exit(state)
def get_last_k_lines(file_name: str, k: int):
"""
Helper function to retrieve the last K lines from a file in a memory-efficient way.
Code slightly adapted from https://thispointer.com/python-get-last-n-lines-of-a-text-file-like-tail-command/
"""
# Create an empty list to keep the track of last k lines
lines = deque()
# Open file for reading in binary mode
with open(file_name, 'rb') as fp:
# Move the cursor to the end of the file
fp.seek(0, os.SEEK_END)
# Create a buffer to keep the last read line
buffer = bytearray()
# Get the current position of pointer i.e eof
ptr = fp.tell()
# Loop till pointer reaches the top of the file
while ptr >= 0:
# Move the file pointer to the location pointed by ptr
fp.seek(ptr)
# Shift pointer location by -1
ptr -= 1
# read that byte / character
new_byte = fp.read(1)
# If the read byte is new line character then it means one line is read
if new_byte != b'\n':
# If last read character is not eol then add it in buffer
buffer.extend(new_byte)
elif buffer:
lines.appendleft(buffer.decode()[::-1])
if len(lines) == k:
return lines
# Reinitialize the byte array to save next line
buffer.clear()
# As file is read completely, if there is still data in buffer, then it's the first of the last K lines.
if buffer:
lines.appendleft(buffer.decode()[::-1])
return lines
def show_log(log_path: str, tail: int = -1, logger: Logger = None):
print_fn = logger.info if logger else print
if tail == -1:
with open(log_path, "r") as fp:
for line in fp:
print_fn(line.rstrip('\n'))
else:
for line in get_last_k_lines(log_path, tail):
print_fn(line)
def format_env_vars(env: dict, mode: str = "proc"):
if mode == "proc":
return env
if mode == "docker":
env_opt_list = []
for key, val in env.items():
env_opt_list.extend(["--env", f"{key}={val}"])
return env_opt_list
if mode == "docker-compose":
return [f"{key}={val}" for key, val in env.items()]
if mode == "k8s":
return [{"name": key, "value": val} for key, val in env.items()]
raise ValueError(f"'mode' should be one of 'proc', 'docker', 'docker-compose', 'k8s', got {mode}")

36
maro/cli/utils/docker.py Normal file
Просмотреть файл

@ -0,0 +1,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import docker
def image_exists(image_name: str):
try:
client = docker.from_env()
client.images.get(image_name)
return True
except docker.errors.ImageNotFound:
return False
def build_image(context: str, docker_file_path: str, image_name: str):
client = docker.from_env()
with open(docker_file_path, "r"):
client.images.build(
path=context,
tag=image_name,
quiet=False,
rm=True,
custom_context=False,
dockerfile=docker_file_path
)
def push(local_image_name: str, repository: str):
client = docker.from_env()
image = client.images.get(local_image_name)
acr_tag = f"{repository}/{local_image_name}"
image.tag(acr_tag)
# subprocess.run(f"docker push {acr_tag}".split())
client.images.push(acr_tag)
print(f"Pushed image to {acr_tag}")

Просмотреть файл

@ -38,23 +38,3 @@ class LocalParams:
CPU_USAGE = "local_resource:cpu_usage_per_core"
MEMORY_USAGE = "local_resource:memory_usage"
GPU_USAGE = "local_resource:gpu_memory_usage"
class LocalPaths:
"""Only use by maro process cli"""
MARO_PROCESS = "~/.maro/clusters/process"
MARO_PROCESS_AGENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/agent/job_agent.py")
MARO_RESOURCE_AGENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/agent/resource_agent.py")
MARO_PROCESS_DEPLOYMENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/deployment")
MARO_GRASS_LOCAL_AGENT = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../grass/lib/services/master_agent/local_agent.py"
)
class ProcessRedisName:
"""Record Redis elements name, and only for maro process"""
PENDING_JOB_TICKETS = "process:pending_job_tickets"
KILLED_JOB_TICKETS = "process:killed_job_tickets"
JOB_DETAILS = "process:job_details"
SETTING = "process:setting"

Просмотреть файл

@ -26,7 +26,7 @@ def dist(proxy: Proxy, handler_dict: {object: Callable}):
self.local_instance = cls(*args, **kwargs)
self.proxy = proxy
self._handler_function = {}
self._registry_table = RegisterTable(self.proxy.peers_name)
self._registry_table = RegisterTable(self.proxy.peers)
# Use functools.partial to freeze handling function's local_instance and proxy
# arguments to self.local_instance and self.proxy.
for constraint, handler_fun in handler_dict.items():

Просмотреть файл

@ -69,7 +69,7 @@ class ZmqDriver(AbsDriver):
"""
self._unicast_receiver = self._zmq_context.socket(zmq.PULL)
unicast_receiver_port = self._unicast_receiver.bind_to_random_port(f"{self._protocol}://*")
self._logger.info(f"Receive message via unicasting at {self._ip_address}:{unicast_receiver_port}.")
self._logger.debug(f"Receive message via unicasting at {self._ip_address}:{unicast_receiver_port}.")
# Dict about zmq.PUSH sockets, fulfills in self.connect.
self._unicast_sender_dict = {}
@ -80,7 +80,7 @@ class ZmqDriver(AbsDriver):
self._broadcast_receiver = self._zmq_context.socket(zmq.SUB)
self._broadcast_receiver.setsockopt(zmq.SUBSCRIBE, self._component_type.encode())
broadcast_receiver_port = self._broadcast_receiver.bind_to_random_port(f"{self._protocol}://*")
self._logger.info(f"Subscriber message at {self._ip_address}:{broadcast_receiver_port}.")
self._logger.debug(f"Subscriber message at {self._ip_address}:{broadcast_receiver_port}.")
# Record own sockets' address.
self._address = {
@ -122,10 +122,10 @@ class ZmqDriver(AbsDriver):
self._unicast_sender_dict[peer_name] = self._zmq_context.socket(zmq.PUSH)
self._unicast_sender_dict[peer_name].setsockopt(zmq.SNDTIMEO, self._send_timeout)
self._unicast_sender_dict[peer_name].connect(address)
self._logger.info(f"Connects to {peer_name} via unicasting.")
self._logger.debug(f"Connects to {peer_name} via unicasting.")
elif int(socket_type) == zmq.SUB:
self._broadcast_sender.connect(address)
self._logger.info(f"Connects to {peer_name} via broadcasting.")
self._logger.debug(f"Connects to {peer_name} via broadcasting.")
else:
raise SocketTypeError(f"Unrecognized socket type {socket_type}.")
except Exception as e:
@ -158,13 +158,13 @@ class ZmqDriver(AbsDriver):
raise PeersDisconnectionError(f"Driver cannot disconnect to {peer_name}! Due to {str(e)}")
self._disconnected_peer_name_list.append(peer_name)
self._logger.info(f"Disconnected with {peer_name}.")
self._logger.debug(f"Disconnected with {peer_name}.")
def receive(self, is_continuous: bool = True, timeout: int = None):
def receive(self, timeout: int = None):
"""Receive message from ``zmq.POLLER``.
Args:
is_continuous (bool): Continuously receive message or not. Defaults to True.
timeout (int): Timeout for polling. If the first poll times out, the function returns None.
Yields:
recv_message (Message): The received message from the poller.
@ -184,13 +184,38 @@ class ZmqDriver(AbsDriver):
recv_message = pickle.loads(recv_message)
self._logger.debug(f"Receive a message from {recv_message.source} through broadcast receiver.")
else:
self._logger.debug(f"Cannot receive any message within {receive_timeout}.")
self._logger.debug(f"No message received within {receive_timeout}.")
return
yield recv_message
if not is_continuous:
break
def receive_once(self, timeout: int = None):
"""Receive a single message from ``zmq.POLLER``.
Args:
timeout (int): Time-out for ZMQ polling. If the first poll times out, the function returns None.
Returns:
recv_message (Message): The received message from the poller or None if the poller times out.
"""
receive_timeout = timeout if timeout else self._receive_timeout
try:
sockets = dict(self._poller.poll(receive_timeout))
except Exception as e:
raise DriverReceiveError(f"Driver cannot receive message as {e}")
if self._unicast_receiver in sockets:
recv_message = self._unicast_receiver.recv_pyobj()
self._logger.debug(f"Receive a message from {recv_message.source} through unicast receiver.")
elif self._broadcast_receiver in sockets:
_, recv_message = self._broadcast_receiver.recv_multipart()
recv_message = pickle.loads(recv_message)
self._logger.debug(f"Receive a message from {recv_message.source} through broadcast receiver.")
else:
self._logger.debug(f"No message received within {receive_timeout}.")
return
return recv_message
def send(self, message: Message):
"""Send message.

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше