V0.2 update (#262)
* refine readme * feat: refine data push/pull (#138) * feat: refine data push/pull * test: add cli provision testing * fix: style fix * fix: add necessary comments * fix: from code review * add fall back function in weather download (#112) * fix deployment issue in multi envs * fix typo * fix ~/.maro not exist issue in build * skip deploy when build * update for comments * temporarily disable weather info * replace ecr with cim in setup.py * replace ecr in manifest * remove weather check when read data * fix station id issue * fix format * add TODO in comments * add noaa weather source * fix weather reset and weather comment * add comment for weather data url * some format update * add fall back function in weather download * update comment * update for comments * update comment * add period * fix for pylint * update for pylint check * added example docs (#136) * added example docs * added citibike greedy example doc * modified citibike doc * fixed PR comments * fixed more PR comments * fixed small formatting issue Co-authored-by: ysqyang <v-yangqi@microsoft.com> * switch the key and value of handler_dict in decorator (#144) * switch the key and value of handler_dict in decorator * add dist decorator UT and fixed multithreading conflict in maro test suite * pr comments update. * resolved comments about decorator UT * rename handler_fun in dist decorator * change self.attr into class_name.attr * update UT tests comments * V0.1 annotation (#147) * refine the annotation of simulator core * remove reward from env(be) * format refined * white spaces test * left-padding spaces refined * format modifed * update the left-padding spaces of docstrings * code format updated * update according to comments * update according to PR comments Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Event payload details for env.summary (#156) * key_list of events added for env.summary * code refined according to lint * 2 kinds of Payload added for CIM scenario; citi bike summary refined according to comments * code format refined * try trigger the git tests * update github workflow Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * V0.2 online lp for citi bike (#159) * key_list of events added for env.summary * code refined according to lint * 2 kinds of Payload added for CIM scenario; citi bike summary refined according to comments * code format refined * try trigger the git tests * update github workflow * online LP example added for citi bike * infeasible solution * infeasible solution fixed: call snapshot before any env.step() * experiment results of toy topos added * experiment results of toy topos added * experiment result update: better than naive baseline * PuLP version added * greedy experiment results update * citibike result update * modified according to PR comments * update experiment results and forecasting comparison * citi bike lp README updated * README updated * modified according to PR comments * update according to PR comments Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * V0.2 rl toolkit refinement (#165) * refined rl abstractions * fixed formattin issues * checked out error-code related code from v0.2_pg * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * renamed save_models to dump_models * 1. set default batch_norm_enabled to True; 2. used state_dict in dqn model saving * renamed dump_experience_store to dump_experience_pool * fixed a bug in the dump_experience_pool method * fixed some PR comments * fixed more PR comments * 1.fixed some PR comments; 2.added early_stopping_checker; 3.revised explorer class * fixed cim example according to rl toolkit changes * fixed some more PR comments * rewrote multi_process_launcher to eliminate the distributed section in config * 1. fixed a typo; 2. added logging before early stopping * fixed a bug * fixed a bug * fixed a bug * added early stopping feature to CIM exmaple * fixed a typo * fixed some issues with early stopping * changed early stopping metric func * fixed a bug * fixed a bug * added early stopping to dist mode cim * added experience collecting func * edited notebook according to changes in CIM example * fixed bugs in nb * fixed lint formatting issues * fixed a typo * fixed some PR comments * fixed more PR comments * revised docs * removed nb output * fixed a bug in simple_learner * fixed a typo in nb * fixed a bug * fixed a bug * fixed a bug * removed unused import * fixed a bug * 1. changed early stopping default config; 2. renamed param in early stopping checker and added typing * fixed some doc issues * added output to nb Co-authored-by: ysqyang <v-yangqi@microsoft.com> * update according to flake8 * V0.2 Logical operator overloading for EarlyStoppingChecker (#178) * 1. added logical operator overloading for early stopping checker; 2. added mean value checker * fixed PR comments * removed learner.exit() in single_process_launcher * added another early stopping checker in example * fixed PR comments and lint issues * lint issue fix * fixed lint issues * fixed a bug * fixed a bug Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 skip connection (#176) * replaced IdentityLayers with nn.Identity * 1. added skip connection option in FC_net; 2. generalized learning model * added skip_connection option in config * removed type casting in fc_net * fixed lint formatting issues * refined docstring * added multi-head functionality to LearningModel * refined learning model docstring * added head_key param in learningModel forward * fixed PR comments * added top layer logic and is_top option in fc_net * fixed a bug * fixed a bug * reverted some changes in learning model * reverted some changes in learning model * added members to learning model to fix the mode issue * fixed a bug * fixed mode setting issue in learning model * removed learner.exit() in single_process_launcher * fixed PR comments * fixed rl/__init__ * fixed issues in example * fixed a bug * fixed a bug * fixed lint formatting issues * moved reward type casting to exp shaper Co-authored-by: ysqyang <v-yangqi@microsoft.com> * fixed a bug in learner's test() (#193) Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 double dqn (#188) * added dueling action value model * renamed params in dueling_action_value_model * renamed shared_features to features * replaced IdentityLayers with nn.Identity * 1. added skip connection option in FC_net; 2. generalized learning model * added skip_connection option in config * removed type casting in fc_net * fixed lint formatting issues * refined docstring * mv dueling_actiovalue_model and fixed some bugs * added multi-head functionality to LearningModel * refined learning model docstring * added head_key param in learningModel forward * added double DQN and dueling features to DQN * fixed a bug * added DuelingQModelHead enum * fixed a bug * removed unwanted file * fixed PR comments * added top layer logic and is_top option in fc_net * fixed a bug * fixed a bug * reverted some changes in learning model * reverted some changes in learning model * added members to learning model to fix the mode issue * fixed a bug * fixed mode setting issue in learning model * fixed PR comments * revised cim example according to DQN changes * renamed eval_model to q_value_model in cim example * more fixes * fixed a bug * fixed a bug * added doc per PR comments * removed learner.exit() in single_process_launcher * removed learner.exit() in single_process_launcher * fixed PR comments * fixed rl/__init__ * fixed issues in example * fixed a bug * fixed a bug * fixed lint formatting issues * double DQN feature * fixed a bug * fixed a bug * fixed PR comments * fixed lint issue * 1. fixed PR comments related to load/dump; 2. removed abstract load/dump methods from AbsAlgorithm * added load_models in simple_learner * minor docstring edits * minor docstring edits * set is_double to true in DQN config Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Arthur Jiang <ArthurSJiang@gmail.com> * V0.2 feature predefined image (#183) * feat: support predefined image provision * style: fix linting errors * style: fix linting errors * style: fix linting errors * style: fix linting errors * fix: error scripts invocation after using relative import * fix: missing init.py * fixed a bug in learner's test() * feat: add distributed_config for dqn example * test: update test for grass * test: update test for k8s * feat: add promptings for steps * fix: change relative imports to absolute imports Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Arthur Jiang <ArthurSJiang@gmail.com> * V0.2 feature proxy rejoin (#158) * update dist decorator * replace proxy.get_peers by proxy.peers * update proxy rejoin (draft, not runable for proxy rejoin) * fix bugs in proxy * add message cache, and redesign rejoin parameter * feat: add checkpoint with test * update proxy.rejoin * fixed rejoin bug, rename func * add test example(temp) * feat: add FaultToleranceAgent, refine other MasterAgents and NodeAgents. * capital env vari name * rm json.dumps; change retries to 10; temp add warning level for rejoin * fix: unable to load FaultToleranceAgent, missing params * fix: delete mapping in StopJob if FaultTolerance is activated, add exception handler for FaultToleranceAgent * feat: add node_id to node_details * fix: add a new dependency for tests * style: meet linting requirements * style: remaining linting problems * lint fixed; rm temp test folder. * fixed lint f-string without placeholder * fix: add a flag for "remove_container", refine restart logic and Redis keys naming * proxy rejoin update. * variable rename. * fixed lint issues * fixed lint issues * add exit code for different error * feat: add special errors handler * add max rejoin times * remove unused import * add rejoin UT; resolve rejoin comments * lint fixed * fixed UT import problem * rm MessageCache in proxy * fix: refine key naming * update proxy rejoin; add topic for broadcast * feat: support predefined image provision * update UT for communication * add docstring for rejoin * fixed isort and zmq driver import * fixed isort and UT test * fix isort issue * proxy rejoin update (comments v2) * fixed isort error * style: fix linting errors * style: fix linting errors * style: fix linting errors * style: fix linting errors * feat: add exists method for checkpoint * fix: error scripts invocation after using relative import * fix: missing init.py * fixed a bug in learner's test() * add driver close and socket SUB disconnect for rejoin * feat: add distributed_config for dqn example * test: update test for grass * test: update test for k8s * feat: add promptings for steps * fix: change relative imports to absolute imports * fixed comments and update logger level * mv driver in proxy.__init__ for issue temp fixed. * Update docstring and comments * style: fix code reviews problems * fix code format Co-authored-by: Lyuchun Huang <romic.kid@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 feature cli windows (#203) * fix: change local mkdir to os.makedirs * fix: add utf8 encoding for logger * fix: add powershell.exe prefix to subprocess functions * feat: add debug_green * fix: use fsutil to create fix-size files in Windows * fix: use universal_newlines=True to handle encoding problem in different operating systems * fix: use temp file to do copy when the operating system is not Linux * fix: linting error * fix: use fsutil in test_k8s.py * feat: dynamic init ABS_PATH in GlobalParams * fix: use -Command to execute Powershell command * fix: refine code style in k8s_azure_executor.py, add Windows support for k8s mode * fix: problems in code review * EventBuffer refine (#197) * merge uniform event changes back * 1st step: move executing events into stack for better removing performance * flush event pool * typo * add option for env to enable event pool * refine stack functions * fix comment issues, add typings * lint fixing * lint fix * add missing fix * linting * lint * use linked list instead original event list and execute stack * add missing file * linting, and fixes * add missing file * linting fix * fixing comments * add missing file * rename event_list to event_linked_list * correct import path * change enable_event_pool to disable_finished_events * add missing file * V0.2 merge master (#214) * fix the visualization of docs/key_components/distributed_toolkit * add examples into isort ignore * refine import path for examples (#195) * refine import path for examples * refine indents * fixed formatting issues * update code style * add editorconfig-checker, add editorconfig path into lint, change super-linter version * change path for code saving in cim.gnn Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> * fix issue that sometimes there is conflict between distutils and setuptools (#208) * fix issue that cython and setuptools conflict * follow the accepted temp workaround * update comment, it should be conflict between setuptools and distutils * fixed bugs related to proxy interface changes Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> * typo fix * Bug fix: event buffer issue that cause Actions cannot be passed into business engine (#215) * bug fix * clear the reference after extract sub events, update ut to cover this issue Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * fix flake8 style problem * V0.2 feature refine mode namings (#212) * feat: refine cli exception * feat: refine mode namings * EventBuffer refine (#197) * merge uniform event changes back * 1st step: move executing events into stack for better removing performance * flush event pool * typo * add option for env to enable event pool * refine stack functions * fix comment issues, add typings * lint fixing * lint fix * add missing fix * linting * lint * use linked list instead original event list and execute stack * add missing file * linting, and fixes * add missing file * linting fix * fixing comments * add missing file * rename event_list to event_linked_list * correct import path * change enable_event_pool to disable_finished_events * add missing file * fixed bugs in dist rl * feat: rename files * tests: set longer gracefully wait time * style: fix linting errors * style: fix linting errors * style: fix linting errors * fix: rm redundant variables * fix: refine error message Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 vis new (#210) Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> * V0.2 local host process (#221) * Update local process (not ready) * update cli process mode * add setup/clear/template for maro process * fix process stop * add logger and rename parameters * add logger for setup/clear * fixed close not exist pid when given pid list. * Fixed comments and rename setup/clear with create/delete * update ProcessInternalError * V0.2 grass on premises (#220) * feat: refine cli exception * commit on v0.2_grass_on_premises Co-authored-by: Lyuchun Huang <romic.kid@gmail.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 vm scheduling scenario (#189) * Initialize * Data center scenario init * Code style modification * V0.2 event buffer subevents expand (#180) * V0.2 rl toolkit refinement (#165) * refined rl abstractions * fixed formattin issues * checked out error-code related code from v0.2_pg * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * renamed save_models to dump_models * 1. set default batch_norm_enabled to True; 2. used state_dict in dqn model saving * renamed dump_experience_store to dump_experience_pool * fixed a bug in the dump_experience_pool method * fixed some PR comments * fixed more PR comments * 1.fixed some PR comments; 2.added early_stopping_checker; 3.revised explorer class * fixed cim example according to rl toolkit changes * fixed some more PR comments * rewrote multi_process_launcher to eliminate the distributed section in config * 1. fixed a typo; 2. added logging before early stopping * fixed a bug * fixed a bug * fixed a bug * added early stopping feature to CIM exmaple * fixed a typo * fixed some issues with early stopping * changed early stopping metric func * fixed a bug * fixed a bug * added early stopping to dist mode cim * added experience collecting func * edited notebook according to changes in CIM example * fixed bugs in nb * fixed lint formatting issues * fixed a typo * fixed some PR comments * fixed more PR comments * revised docs * removed nb output * fixed a bug in simple_learner * fixed a typo in nb * fixed a bug * fixed a bug * fixed a bug * removed unused import * fixed a bug * 1. changed early stopping default config; 2. renamed param in early stopping checker and added typing * fixed some doc issues * added output to nb Co-authored-by: ysqyang <v-yangqi@microsoft.com> * unfold sub-events, insert after parent * remove event category, use different class instead, add helper functions to gen decision and action event * add a method to support add immediate event to cascade event with tick validation * fix ut issue * add action as 1st sub event to ensure the executing order Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Data center scenario update * Code style update * Data scenario business engine update * Isort update * Fix lint code check * Fix based on PR comments. * Update based on PR comments. * Add decision payload * Add config file * Update utilization series logic * Update based on PR comment * Update based on PR * Update * Update * Add the ValidPm class * Update docs string and naming * Add energy consumption * Lint code fixed * Refining postpone function * Lint style update * Init data pipeline * Update based on PR comment * Add data pipeline download * Lint style update * Code style fix * Temp update * Data pipeline update * Add aria2p download function * Update based on PR comment * Update based on PR comment * Update based on PR comment * Update naming of variables * Rename topology * Renaming * Fix valid pm list * Pylint fix * Update comment * Update docstring and comment * Fix init import * Update tick issue * fix merge problem * update style * V0.2 datacenter data pipeline (#199) * Data pipeline update * Data pipeline update * Lint update * Update pipeline * Add vmid mapping * Update lint style * Add VM data analytics * Update notebook * Add binary converter * Modift vmtable yaml * Update binary meta file * Add cpu reader * random example added for data center * Fix bugs * Fix pylint * Add launcher * Fix pylint * best fit policy added * Add reset * Add config * Add config * Modify action object * Modify config * Fix naming * Modify config * Add snapshot list * Modify a spelling typo * Update based on PR comments. * Rename scenario to vm scheduling * Rename scenario * Update print messages * Lint fix * Lint fix * Rename scenario * Modify the calculation of cpu utilization * Add comment * Modify data pipeline path * Fix typo * Modify naming * Add unittest * Add comment * Unify naming * Fix data path typo * Update comments * Update snapshot features * Add take snapshot * Add summary keys * Update cpu reader * Update naming * Add unit test * Rename snapshot node * Add processed data pipeline * Modify config * Add comment * Lint style fix Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Add package used in vm_scheduling * add aria2p to test requirement * best fit example: update the usage of snapshot * Add aria2p to test requriement * Remove finish event * Fix unittest * Add test dataset * Update based on PR comment * Refine cpu reader and unittest * Lint update * Refine based on PR comment * Add agent index * Add node maping * Refine based on PR comments * Renaming postpone_step * Renaming and refine based on PR comments * Rename config * Update Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * Resolve none action problem (#224) * V0.2 vm_scheduling notebook (#223) * Initialize * Data center scenario init * Code style modification * V0.2 event buffer subevents expand (#180) * V0.2 rl toolkit refinement (#165) * refined rl abstractions * fixed formattin issues * checked out error-code related code from v0.2_pg * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * renamed save_models to dump_models * 1. set default batch_norm_enabled to True; 2. used state_dict in dqn model saving * renamed dump_experience_store to dump_experience_pool * fixed a bug in the dump_experience_pool method * fixed some PR comments * fixed more PR comments * 1.fixed some PR comments; 2.added early_stopping_checker; 3.revised explorer class * fixed cim example according to rl toolkit changes * fixed some more PR comments * rewrote multi_process_launcher to eliminate the distributed section in config * 1. fixed a typo; 2. added logging before early stopping * fixed a bug * fixed a bug * fixed a bug * added early stopping feature to CIM exmaple * fixed a typo * fixed some issues with early stopping * changed early stopping metric func * fixed a bug * fixed a bug * added early stopping to dist mode cim * added experience collecting func * edited notebook according to changes in CIM example * fixed bugs in nb * fixed lint formatting issues * fixed a typo * fixed some PR comments * fixed more PR comments * revised docs * removed nb output * fixed a bug in simple_learner * fixed a typo in nb * fixed a bug * fixed a bug * fixed a bug * removed unused import * fixed a bug * 1. changed early stopping default config; 2. renamed param in early stopping checker and added typing * fixed some doc issues * added output to nb Co-authored-by: ysqyang <v-yangqi@microsoft.com> * unfold sub-events, insert after parent * remove event category, use different class instead, add helper functions to gen decision and action event * add a method to support add immediate event to cascade event with tick validation * fix ut issue * add action as 1st sub event to ensure the executing order Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Data center scenario update * Code style update * Data scenario business engine update * Isort update * Fix lint code check * Fix based on PR comments. * Update based on PR comments. * Add decision payload * Add config file * Update utilization series logic * Update based on PR comment * Update based on PR * Update * Update * Add the ValidPm class * Update docs string and naming * Add energy consumption * Lint code fixed * Refining postpone function * Lint style update * Init data pipeline * Update based on PR comment * Add data pipeline download * Lint style update * Code style fix * Temp update * Data pipeline update * Add aria2p download function * Update based on PR comment * Update based on PR comment * Update based on PR comment * Update naming of variables * Rename topology * Renaming * Fix valid pm list * Pylint fix * Update comment * Update docstring and comment * Fix init import * Update tick issue * fix merge problem * update style * V0.2 datacenter data pipeline (#199) * Data pipeline update * Data pipeline update * Lint update * Update pipeline * Add vmid mapping * Update lint style * Add VM data analytics * Update notebook * Add binary converter * Modift vmtable yaml * Update binary meta file * Add cpu reader * random example added for data center * Fix bugs * Fix pylint * Add launcher * Fix pylint * best fit policy added * Add reset * Add config * Add config * Modify action object * Modify config * Fix naming * Modify config * Add snapshot list * Modify a spelling typo * Update based on PR comments. * Rename scenario to vm scheduling * Rename scenario * Update print messages * Lint fix * Lint fix * Rename scenario * Modify the calculation of cpu utilization * Add comment * Modify data pipeline path * Fix typo * Modify naming * Add unittest * Add comment * Unify naming * Fix data path typo * Update comments * Update snapshot features * Add take snapshot * Add summary keys * Update cpu reader * Update naming * Add unit test * Rename snapshot node * Add processed data pipeline * Modify config * Add comment * Lint style fix Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Add package used in vm_scheduling * add aria2p to test requirement * best fit example: update the usage of snapshot * Add aria2p to test requriement * Remove finish event * Fix unittest * Add test dataset * Update based on PR comment * Refine cpu reader and unittest * Lint update * Refine based on PR comment * Add agent index * Add node maping * Init vm shceduling notebook * Add notebook * Refine based on PR comments * Renaming postpone_step * Renaming and refine based on PR comments * Rename config * Update based on the v0.2_datacenter * Update notebook * Update * update filepath * notebook updated Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * Update process mode docs and fixed on premises (#226) * V0.2 Add github workflow integration (#222) * test: add github workflow integration * fix: split procedures && bug fixed * test: add training only restriction * fix: add 'approved' restriction * fix: change default ssh port to 22 * style: in one line * feat: add timeout for Subprocess.run * test: change default node_size to Standard_D2s_v3 * style: refine style * fix: add ssh_port param to on-premises mode * fix: add missing init.py * V0.2 explorer (#198) * overhauled exploration abstraction * fixed a bug * fixed a bug * fixed a bug * added exploration related methods to abs_agent * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * separated learning with exploration schedule and without * small fixes * moved explorer logic to actor side * fixed a bug * fixed a bug * fixed a bug * fixed a bug * removed unwanted param from simple agent manager * added noise explorer * fixed formatting * removed unnecessary comma * fixed PR comments * removed unwanted exception and imports * fixed a bug * fixed PR comments * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issue * fixed a bug * fixed lint issue * fixed naming * combined exploration param generation and early stopping in scheduler * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * fixed lint issue * moved logger inside scheduler * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * removed epsilon parameter from choose_action * fixed some PR comments * fixed some PR comments * bug fix * bug fix * bug fix * removed explorer abstraction from agent * refined dqn example * fixed lint issues * simplified scheduler * removed early stopping from CIM dqn example * removed early stopping from cim example config * renamed early_stopping_callback to early_stopping_checker * removed action_dim from noise explorer classes and added some shape checks * modified NoiseExplorer's __call__ logic to batch processing * made NoiseExplorer's __call__ return type np array * renamed update to set_parameters in explorer * fixed old naming in test_grass Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 embedded optim (#191) * added dueling action value model * renamed params in dueling_action_value_model * renamed shared_features to features * replaced IdentityLayers with nn.Identity * 1. added skip connection option in FC_net; 2. generalized learning model * added skip_connection option in config * removed type casting in fc_net * fixed lint formatting issues * refined docstring * mv dueling_actiovalue_model and fixed some bugs * added multi-head functionality to LearningModel * refined learning model docstring * added head_key param in learningModel forward * added double DQN and dueling features to DQN * fixed a bug * added DuelingQModelHead enum * fixed a bug * removed unwanted file * fixed PR comments * added top layer logic and is_top option in fc_net * fixed a bug * fixed a bug * reverted some changes in learning model * reverted some changes in learning model * added members to learning model to fix the mode issue * fixed a bug * fixed mode setting issue in learning model * fixed PR comments * revised cim example according to DQN changes * renamed eval_model to q_value_model in cim example * more fixes * fixed a bug * fixed a bug * added doc per PR comments * removed learner.exit() in single_process_launcher * removed learner.exit() in single_process_launcher * fixed PR comments * fixed rl/__init__ * fixed issues in example * fixed a bug * fixed a bug * fixed lint formatting issues * double DQN feature * fixed a bug * fixed a bug * fixed PR comments * fixed lint issue * embedded optimizer into SingleHeadLearningModel * 1. fixed PR comments related to load/dump; 2. removed abstract load/dump methods from AbsAlgorithm * added load_models in simple_learner * minor docstring edits * minor docstring edits * minor docstring edits * mv optimizer options inside LearningMode * modified example accordingly * fixed a bug * fixed a bug * fixed a bug * added dueling DQN feature * revised and refined docstrings * fixed a bug * fixed lint issues * added load/dump functions to LearningModel * fixed a bug * fixed a bug * fixed lint issues * refined DQN docstrings * removed load/dump functions from DQN * added task validator * fixed decorator use * fixed a typo * fixed a bug * fixed lint issues * changed LearningModel's step() to take a single loss * revised learning model design * revised example * fixed a bug * fixed a bug * fixed a bug * fixed a bug * added decorator utils to algorithm * fixed a bug * renamed core_model to model * fixed a bug * 1. fixed lint formatting issues; 2. refined learning model docstrings * rm trailing whitespaces * added decorator for choose_action * fixed a bug * fixed a bug * fixed version-related issues * renamed add_zeroth_dim decorator to expand_dim * overhauled exploration abstraction * fixed a bug * fixed a bug * fixed a bug * added exploration related methods to abs_agent * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * separated learning with exploration schedule and without * small fixes * moved explorer logic to actor side * fixed a bug * fixed a bug * fixed a bug * fixed a bug * removed unwanted param from simple agent manager * small fixes * added shared_module property to LearningModel * added shared_module property to LearningModel * revised __getstate__ for LearningModel * fixed a bug * added soft_update function to learningModel * fixed a bug * revised learningModel * rm __getstate__ and __setstate__ from LearningModel * added noise explorer * fixed formatting * removed unnecessary comma * removed unnecessary comma * fixed PR comments * removed unwanted exception and imports * removed unwanted exception and imports * fixed a bug * fixed PR comments * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issue * fixed a bug * fixed lint issue * fixed naming * combined exploration param generation and early stopping in scheduler * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * fixed lint issue * moved logger inside scheduler * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * fixed lint issue * removed epsilon parameter from choose_action * removed epsilon parameter from choose_action * changed agent manager's train parameter to experience_by_agent * fixed some PR comments * renamed zero_grad to zero_gradients in LearningModule * fixed some PR comments * bug fix * bug fix * bug fix * removed explorer abstraction from agent * added DEVICE env variable as first choice for torch device * refined dqn example * fixed lint issues * removed unwanted import in cim example * updated cim-dqn notebook * simplified scheduler * edited notebook according to merged scheduler changes * refined dimension check for learning module manager and removed num_actions from DQNConfig * bug fix for cim example * added notebook output * removed early stopping from CIM dqn example * removed early stopping from cim example config * moved decorator logic inside algorithms * renamed early_stopping_callback to early_stopping_checker * removed action_dim from noise explorer classes and added some shape checks * modified NoiseExplorer's __call__ logic to batch processing * made NoiseExplorer's __call__ return type np array * renamed update to set_parameters in explorer * fixed old naming in test_grass Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 VM scheduling docs (#228) * Initialize * Data center scenario init * Code style modification * V0.2 event buffer subevents expand (#180) * V0.2 rl toolkit refinement (#165) * refined rl abstractions * fixed formattin issues * checked out error-code related code from v0.2_pg * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * renamed save_models to dump_models * 1. set default batch_norm_enabled to True; 2. used state_dict in dqn model saving * renamed dump_experience_store to dump_experience_pool * fixed a bug in the dump_experience_pool method * fixed some PR comments * fixed more PR comments * 1.fixed some PR comments; 2.added early_stopping_checker; 3.revised explorer class * fixed cim example according to rl toolkit changes * fixed some more PR comments * rewrote multi_process_launcher to eliminate the distributed section in config * 1. fixed a typo; 2. added logging before early stopping * fixed a bug * fixed a bug * fixed a bug * added early stopping feature to CIM exmaple * fixed a typo * fixed some issues with early stopping * changed early stopping metric func * fixed a bug * fixed a bug * added early stopping to dist mode cim * added experience collecting func * edited notebook according to changes in CIM example * fixed bugs in nb * fixed lint formatting issues * fixed a typo * fixed some PR comments * fixed more PR comments * revised docs * removed nb output * fixed a bug in simple_learner * fixed a typo in nb * fixed a bug * fixed a bug * fixed a bug * removed unused import * fixed a bug * 1. changed early stopping default config; 2. renamed param in early stopping checker and added typing * fixed some doc issues * added output to nb Co-authored-by: ysqyang <v-yangqi@microsoft.com> * unfold sub-events, insert after parent * remove event category, use different class instead, add helper functions to gen decision and action event * add a method to support add immediate event to cascade event with tick validation * fix ut issue * add action as 1st sub event to ensure the executing order Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Data center scenario update * Code style update * Data scenario business engine update * Isort update * Fix lint code check * Fix based on PR comments. * Update based on PR comments. * Add decision payload * Add config file * Update utilization series logic * Update based on PR comment * Update based on PR * Update * Update * Add the ValidPm class * Update docs string and naming * Add energy consumption * Lint code fixed * Refining postpone function * Lint style update * Init data pipeline * Update based on PR comment * Add data pipeline download * Lint style update * Code style fix * Temp update * Data pipeline update * Add aria2p download function * Update based on PR comment * Update based on PR comment * Update based on PR comment * Update naming of variables * Rename topology * Renaming * Fix valid pm list * Pylint fix * Update comment * Update docstring and comment * Fix init import * Update tick issue * fix merge problem * update style * V0.2 datacenter data pipeline (#199) * Data pipeline update * Data pipeline update * Lint update * Update pipeline * Add vmid mapping * Update lint style * Add VM data analytics * Update notebook * Add binary converter * Modift vmtable yaml * Update binary meta file * Add cpu reader * random example added for data center * Fix bugs * Fix pylint * Add launcher * Fix pylint * best fit policy added * Add reset * Add config * Add config * Modify action object * Modify config * Fix naming * Modify config * Add snapshot list * Modify a spelling typo * Update based on PR comments. * Rename scenario to vm scheduling * Rename scenario * Update print messages * Lint fix * Lint fix * Rename scenario * Modify the calculation of cpu utilization * Add comment * Modify data pipeline path * Fix typo * Modify naming * Add unittest * Add comment * Unify naming * Fix data path typo * Update comments * Update snapshot features * Add take snapshot * Add summary keys * Update cpu reader * Update naming * Add unit test * Rename snapshot node * Add processed data pipeline * Modify config * Add comment * Lint style fix Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Add package used in vm_scheduling * add aria2p to test requirement * best fit example: update the usage of snapshot * Add aria2p to test requriement * Remove finish event * Fix unittest * Add test dataset * Update based on PR comment * vm doc init * Update docs * Update docs * Update docs * Update docs * Remove old notebook * Update docs * Update docs * Add figure * Update docs Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * v0.2 VM Scheduling docs refinement (#231) * Fix typo * Refining vm scheduling docs * V0.2 store refinement (#234) * updated docs and images for rl toolkit * 1. fixed import formats for maro/rl; 2. changed decorators to hypers in store * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Fix bug (#237) vm scenario: fix the event type bug of the postpone event * V0.2 rl toolkit doc (#235) * updated docs and images for rl toolkit * updated cim example doc * updated cim exmaple docs * updated cim example rst * updated rl_toolkit and cim example docs * replaced q_module with q_net in example rst * refined doc * refined doc * updated figures * updated figures Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Merge V0.2 vis into V0.2 (#233) * Implemented dump snapshots and convert to CSV. * Let BE supports params when dump snapshot. * Refactor dump code to core.py * Implemented decision event dump. * replace is not '' with !='' * Fixed issues that code review mentioned. * removed path from hello.py * Changed import sort. * Fix import sorting in citi_bike/business_engine * visualization 0.1 * Updated lint configurations. * Fixed formatting error that caused lint errors. * render html title function * Try to fix lint errors. * flake-8 style fix * remove space around 18,35 * dump_csv_converter.py re-formatting. * files re-formatting. * style fixed * tab delete * white space fix * white space fix-2 * vis redundant function delete * refine * re-formatting after merged upstream. * Updated import section. * Updated import section. * pr refine * isort fix * white space * lint error * \n error * test continuation * indent * continuation of indent * indent 0.3 * comment update * comment update 0.2 * f-string update * f-string 0.2 * lint 0.3 * lint 0.4 * lint 0.4 * lint 0.5 * lint 0.6 * docstring update * data version deploy update * condition update * add whitespace * V0.2 vis dump feature enhancement. (#190) * Dumps added manifest file. * Code updated format by flake8 * Changed manifest file format for easy reading. * deploy info update; docs update * weird white space * Update dashboard_visualization.md * new endline? * delete dependency * delete irrelevant file * change scenario to enum, divide file path into a separated class * doc refine * doc update * params type * data structure update * doc&enum, formula refine * refine * add ut, refine doc * style refine * isort * strong type fix * os._exit delete * revert datalib * import new line * change test case * change file name & doc * change deploy path * delete params * revert file * delete duplicate file * delete single process * update naming * manually change import order * delete blank * edit error * requirement txt * style fix & refine * comments&docstring refine * add parameter name * test & dump * comments update * Added manifest file. (#201) Only a few changes that need to meet requirements of manifest file format. * comments fix * delete toolkit change * doc update * citi bike update * deploy path * datalib update * revert datalib * revert * maro file format * comments update * doc update * update param name * doc update * new link * image update * V0.2 visualization-0.1 (#181) * visualization 0.1 * render html title function * flake-8 style fix * style fixed * tab delete * white space fix * white space fix-2 * vis redundant function delete * refine * pr refine * isort fix * white space * lint error * \n error * test continuation * indent * continuation of indent * indent 0.3 * comment update * comment update 0.2 * f-string update * f-string 0.2 * lint 0.3 * lint 0.4 * lint 0.4 * lint 0.5 * lint 0.6 * docstring update * data version deploy update * condition update * add whitespace * deploy info update; docs update * weird white space * Update dashboard_visualization.md * new endline? * delete dependency * delete irrelevant file * change scenario to enum, divide file path into a separated class * fix the visualization of docs/key_components/distributed_toolkit * doc refine * doc update * params type * add examples into isort ignore * data structure update * doc&enum, formula refine * refine * add ut, refine doc * style refine * isort * strong type fix * os._exit delete * revert datalib * import new line * change test case * change file name & doc * change deploy path * delete params * revert file * delete duplicate file * delete single process * update naming * manually change import order * delete blank * edit error * requirement txt * style fix & refine * comments&docstring refine * add parameter name * test & dump * comments update * comments fix * delete toolkit change * doc update * citi bike update * deploy path * datalib update * revert datalib * revert * maro file format * comments update * doc update * update param name * doc update * new link * image update Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Miaoran Chen (Wicresoft) <v-miaorc@microsoft.com> * image change * add reset snapshot * delete dump * add new line * add next steps * import change * relative import * add init file * import change * change utils file * change cliexpcetion to clierror * dashboard test * change result * change assertation * move not * unit test change * core change * unit test delete name_mapping_file * update cim business engine * doc update * change relative path * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * duc update * duc update * duc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * change import sequence * comments update * doc add pic * add dependency * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * Update dashboard_visualization.rst * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * delete white space * doc update * doc update * update doc * update doc * update doc Co-authored-by: Michael Li <mic_lee2000@hotmail.com> Co-authored-by: Miaoran Chen (Wicresoft) <v-miaorc@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * V0.2 docs process mode (#230) * Update process mode docs and fixed on premises * Update orchestration docs * Update process mode docs add JOB_NAME as env variable * fixed bugs * fixed isort issue * update docs index Co-authored-by: kaiqli <v-kaiqli@microsoft.com> * V0.2 learning model refinement (#236) * moved optimizer options to LearningModel * typo fix * fixed lint issues * updated notebook * misc edits * 1. renamed CIMAgent to DQNAgent; 2. moved create_dqn_agents to Agent section in notebook * renamed single_host_cim_learner ot cim_learner in notebook * updated notebook output * typo fix * removed dimension check in absence of shared stack * fixed a typo * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Update vm docs (#241) Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * V0.2 info update (#240) * update readme * update version * refine reademe format * add vis gif * add citation * update citation * update badge Co-authored-by: Arthur Jiang <sjian@microsoft.com> * Fix typo (#242) * Fix typo * fix typo * fix * syntax fix (#253) * syntax fix * syntax fix * syntax fix * rm unwanted import Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 vm oversubscription (#246) * Remove topology * Update pipeline * Update pipeline * Update pipeline * Modify metafile * Add two attributes of VM * Update pipeline * Add vm category * Add todo * Add oversub config * Add oversubscription feature * Lint fix * Update based on PR comment. * Update pipeline * Update pipeline * Update config. * Update based on PR comment * Update * Add pm sku feature * Add sku setting * Add sku feature * Lint fix * Lint style * Update sku, overloading * Lint fix * Lint style * Fix bug * Modify config * Remove sky and replaced it by pm stype * Add and refactor vm category * Comment out cofig * Unify the enum format * Fix lint style * Fix import order * Update based on PR comment Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * V0.2 vm scheduling decision event (#257) * Fix data preparation bug * Add frame index * V0.2 PG, K-step and lambda return utils (#155) * fixed a bug * fixed lint issues * added load/dump functions to LearningModel * fixed a bug * fixed a bug * fixed lint issues * merged with v0.2_embedded_optims * refined DQN docstrings * removed load/dump functions from DQN * added task validator * fixed decorator use * fixed a typo * fixed a bug * revised * fixed lint issues * changed LearningModel's step() to take a single loss * revised learning model design * revised example * fixed a bug * fixed a bug * fixed a bug * fixed a bug * added decorator utils to algorithm * fixed a bug * renamed core_model to model * fixed a bug * 1. fixed lint formatting issues; 2. refined learning model docstrings * rm trailing whitespaces * added decorator for choose_action * fixed a bug * fixed a bug * fixed version-related issues * renamed add_zeroth_dim decorator to expand_dim * overhauled exploration abstraction * fixed a bug * fixed a bug * fixed a bug * added exploration related methods to abs_agent * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * separated learning with exploration schedule and without * small fixes * moved explorer logic to actor side * fixed a bug * fixed a bug * fixed a bug * fixed a bug * removed unwanted param from simple agent manager * small fixes * revised code based on revised abstractions * fixed some bugs * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * added shared_module property to LearningModel * added shared_module property to LearningModel * fixed a bug with k-step return in AC * fixed a bug * fixed a bug * merged pg, ac and ppo examples * fixed a bug * fixed a bug * fixed naming for ppo * renamed some variables in PPO * added ActionWithLogProbability return type for PO-type algorithms * fixed a bug * fixed a bug * fixed lint issues * revised __getstate__ for LearningModel * fixed a bug * added soft_update function to learningModel * fixed a bug * revised learningModel * rm __getstate__ and __setstate__ from LearningModel * added noise explorer * formatting * fixed formatting * removed unnecessary comma * removed unnecessary comma * removed unnecessary comma * fixed PR comments * removed unwanted exception and imports * removed unwanted exception and imports * fixed a bug * fixed PR comments * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issue * fixed a bug * fixed lint issue * fixed naming * combined exploration param generation and early stopping in scheduler * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * fixed lint issue * moved logger inside scheduler * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * fixed lint issue * removed epsilon parameter from choose_action * removed epsilon parameter from choose_action * changed agent manager's train parameter to experience_by_agent * fixed some PR comments * renamed zero_grad to zero_gradients in LearningModule * fixed some PR comments * bug fix * bug fix * bug fix * removed explorer abstraction from agent * added DEVICE env variable as first choice for torch device * refined dqn example * fixed lint issues * removed unwanted import in cim example * updated cim-dqn notebook * simplified scheduler * edited notebook according to merged scheduler changes * refined dimension check for learning module manager and removed num_actions from DQNConfig * bug fix for cim example * added notebook output * updated cim PO example code according to changes in maro/rl * removed early stopping from CIM dqn example * combined ac and ppo and simplified example code and config * removed early stopping from cim example config * moved decorator logic inside algorithms * renamed early_stopping_callback to early_stopping_checker * put PG and AC under PolicyOptimization class and refined examples accordingly * fixed lint issues * removed action_dim from noise explorer classes and added some shape checks * modified NoiseExplorer's __call__ logic to batch processing * made NoiseExplorer's __call__ return type np array * renamed update to set_parameters in explorer * fixed old naming in test_grass * moved optimizer options to LearningModel * typo fix * fixed lint issues * updated notebook * updated cim example for policy optimization * typo fix * typo fix * typo fix * typo fix * misc edits * minor edits to rl_toolkit.rst * checked out docs from master * fixed typo in k-step shaper * fixed lint issues * bug fix in store * lint issue fix * changed default max_ep to 100 for policy_optimization algos * vis doc update to master (#244) * refine readme * feat: refine data push/pull (#138) * feat: refine data push/pull * test: add cli provision testing * fix: style fix * fix: add necessary comments * fix: from code review * add fall back function in weather download (#112) * fix deployment issue in multi envs * fix typo * fix ~/.maro not exist issue in build * skip deploy when build * update for comments * temporarily disable weather info * replace ecr with cim in setup.py * replace ecr in manifest * remove weather check when read data * fix station id issue * fix format * add TODO in comments * add noaa weather source * fix weather reset and weather comment * add comment for weather data url * some format update * add fall back function in weather download * update comment * update for comments * update comment * add period * fix for pylint * update for pylint check * added example docs (#136) * added example docs * added citibike greedy example doc * modified citibike doc * fixed PR comments * fixed more PR comments * fixed small formatting issue Co-authored-by: ysqyang <v-yangqi@microsoft.com> * switch the key and value of handler_dict in decorator (#144) * switch the key and value of handler_dict in decorator * add dist decorator UT and fixed multithreading conflict in maro test suite * pr comments update. * resolved comments about decorator UT * rename handler_fun in dist decorator * change self.attr into class_name.attr * update UT tests comments * V0.1 annotation (#147) * refine the annotation of simulator core * remove reward from env(be) * format refined * white spaces test * left-padding spaces refined * format modifed * update the left-padding spaces of docstrings * code format updated * update according to comments * update according to PR comments Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Event payload details for env.summary (#156) * key_list of events added for env.summary * code refined according to lint * 2 kinds of Payload added for CIM scenario; citi bike summary refined according to comments * code format refined * try trigger the git tests * update github workflow Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Implemented dump snapshots and convert to CSV. * Let BE supports params when dump snapshot. * Refactor dump code to core.py * Implemented decision event dump. * V0.2 online lp for citi bike (#159) * key_list of events added for env.summary * code refined according to lint * 2 kinds of Payload added for CIM scenario; citi bike summary refined according to comments * code format refined * try trigger the git tests * update github workflow * online LP example added for citi bike * infeasible solution * infeasible solution fixed: call snapshot before any env.step() * experiment results of toy topos added * experiment results of toy topos added * experiment result update: better than naive baseline * PuLP version added * greedy experiment results update * citibike result update * modified according to PR comments * update experiment results and forecasting comparison * citi bike lp README updated * README updated * modified according to PR comments * update according to PR comments Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> * V0.2 rl toolkit refinement (#165) * refined rl abstractions * fixed formattin issues * checked out error-code related code from v0.2_pg * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * renamed save_models to dump_models * 1. set default batch_norm_enabled to True; 2. used state_dict in dqn model saving * renamed dump_experience_store to dump_experience_pool * fixed a bug in the dump_experience_pool method * fixed some PR comments * fixed more PR comments * 1.fixed some PR comments; 2.added early_stopping_checker; 3.revised explorer class * fixed cim example according to rl toolkit changes * fixed some more PR comments * rewrote multi_process_launcher to eliminate the distributed section in config * 1. fixed a typo; 2. added logging before early stopping * fixed a bug * fixed a bug * fixed a bug * added early stopping feature to CIM exmaple * fixed a typo * fixed some issues with early stopping * changed early stopping metric func * fixed a bug * fixed a bug * added early stopping to dist mode cim * added experience collecting func * edited notebook according to changes in CIM example * fixed bugs in nb * fixed lint formatting issues * fixed a typo * fixed some PR comments * fixed more PR comments * revised docs * removed nb output * fixed a bug in simple_learner * fixed a typo in nb * fixed a bug * fixed a bug * fixed a bug * removed unused import * fixed a bug * 1. changed early stopping default config; 2. renamed param in early stopping checker and added typing * fixed some doc issues * added output to nb Co-authored-by: ysqyang <v-yangqi@microsoft.com> * replace is not '' with !='' * Fixed issues that code review mentioned. * removed path from hello.py * Changed import sort. * Fix import sorting in citi_bike/business_engine * visualization 0.1 * Updated lint configurations. * Fixed formatting error that caused lint errors. * render html title function * Try to fix lint errors. * flake-8 style fix * remove space around 18,35 * dump_csv_converter.py re-formatting. * files re-formatting. * style fixed * tab delete * white space fix * white space fix-2 * vis redundant function delete * refine * update according to flake8 * re-formatting after merged upstream. * Updated import section. * Updated import section. * V0.2 Logical operator overloading for EarlyStoppingChecker (#178) * 1. added logical operator overloading for early stopping checker; 2. added mean value checker * fixed PR comments * removed learner.exit() in single_process_launcher * added another early stopping checker in example * fixed PR comments and lint issues * lint issue fix * fixed lint issues * fixed a bug * fixed a bug Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 skip connection (#176) * replaced IdentityLayers with nn.Identity * 1. added skip connection option in FC_net; 2. generalized learning model * added skip_connection option in config * removed type casting in fc_net * fixed lint formatting issues * refined docstring * added multi-head functionality to LearningModel * refined learning model docstring * added head_key param in learningModel forward * fixed PR comments * added top layer logic and is_top option in fc_net * fixed a bug * fixed a bug * reverted some changes in learning model * reverted some changes in learning model * added members to learning model to fix the mode issue * fixed a bug * fixed mode setting issue in learning model * removed learner.exit() in single_process_launcher * fixed PR comments * fixed rl/__init__ * fixed issues in example * fixed a bug * fixed a bug * fixed lint formatting issues * moved reward type casting to exp shaper Co-authored-by: ysqyang <v-yangqi@microsoft.com> * pr refine * isort fix * white space * lint error * \n error * test continuation * indent * continuation of indent * indent 0.3 * comment update * comment update 0.2 * f-string update * f-string 0.2 * lint 0.3 * lint 0.4 * lint 0.4 * lint 0.5 * lint 0.6 * docstring update * data version deploy update * condition update * add whitespace * V0.2 vis dump feature enhancement. (#190) * Dumps added manifest file. * Code updated format by flake8 * Changed manifest file format for easy reading. * deploy info update; docs update * weird white space * Update dashboard_visualization.md * new endline? * delete dependency * delete irrelevant file * change scenario to enum, divide file path into a separated class * fixed a bug in learner's test() (#193) Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 double dqn (#188) * added dueling action value model * renamed params in dueling_action_value_model * renamed shared_features to features * replaced IdentityLayers with nn.Identity * 1. added skip connection option in FC_net; 2. generalized learning model * added skip_connection option in config * removed type casting in fc_net * fixed lint formatting issues * refined docstring * mv dueling_actiovalue_model and fixed some bugs * added multi-head functionality to LearningModel * refined learning model docstring * added head_key param in learningModel forward * added double DQN and dueling features to DQN * fixed a bug * added DuelingQModelHead enum * fixed a bug * removed unwanted file * fixed PR comments * added top layer logic and is_top option in fc_net * fixed a bug * fixed a bug * reverted some changes in learning model * reverted some changes in learning model * added members to learning model to fix the mode issue * fixed a bug * fixed mode setting issue in learning model * fixed PR comments * revised cim example according to DQN changes * renamed eval_model to q_value_model in cim example * more fixes * fixed a bug * fixed a bug * added doc per PR comments * removed learner.exit() in single_process_launcher * removed learner.exit() in single_process_launcher * fixed PR comments * fixed rl/__init__ * fixed issues in example * fixed a bug * fixed a bug * fixed lint formatting issues * double DQN feature * fixed a bug * fixed a bug * fixed PR comments * fixed lint issue * 1. fixed PR comments related to load/dump; 2. removed abstract load/dump methods from AbsAlgorithm * added load_models in simple_learner * minor docstring edits * minor docstring edits * set is_double to true in DQN config Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Arthur Jiang <ArthurSJiang@gmail.com> * V0.2 feature predefined image (#183) * feat: support predefined image provision * style: fix linting errors * style: fix linting errors * style: fix linting errors * style: fix linting errors * fix: error scripts invocation after using relative import * fix: missing init.py * fixed a bug in learner's test() * feat: add distributed_config for dqn example * test: update test for grass * test: update test for k8s * feat: add promptings for steps * fix: change relative imports to absolute imports Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Arthur Jiang <ArthurSJiang@gmail.com> * doc refine * doc update * params type * data structure update * doc&enum, formula refine * refine * add ut, refine doc * style refine * isort * strong type fix * os._exit delete * revert datalib * import new line * change test case * change file name & doc * change deploy path * delete params * revert file * delete duplicate file * delete single process * update naming * manually change import order * delete blank * edit error * requirement txt * style fix & refine * comments&docstring refine * add parameter name * test & dump * comments update * V0.2 feature proxy rejoin (#158) * update dist decorator * replace proxy.get_peers by proxy.peers * update proxy rejoin (draft, not runable for proxy rejoin) * fix bugs in proxy * add message cache, and redesign rejoin parameter * feat: add checkpoint with test * update proxy.rejoin * fixed rejoin bug, rename func * add test example(temp) * feat: add FaultToleranceAgent, refine other MasterAgents and NodeAgents. * capital env vari name * rm json.dumps; change retries to 10; temp add warning level for rejoin * fix: unable to load FaultToleranceAgent, missing params * fix: delete mapping in StopJob if FaultTolerance is activated, add exception handler for FaultToleranceAgent * feat: add node_id to node_details * fix: add a new dependency for tests * style: meet linting requirements * style: remaining linting problems * lint fixed; rm temp test folder. * fixed lint f-string without placeholder * fix: add a flag for "remove_container", refine restart logic and Redis keys naming * proxy rejoin update. * variable rename. * fixed lint issues * fixed lint issues * add exit code for different error * feat: add special errors handler * add max rejoin times * remove unused import * add rejoin UT; resolve rejoin comments * lint fixed * fixed UT import problem * rm MessageCache in proxy * fix: refine key naming * update proxy rejoin; add topic for broadcast * feat: support predefined image provision * update UT for communication * add docstring for rejoin * fixed isort and zmq driver import * fixed isort and UT test * fix isort issue * proxy rejoin update (comments v2) * fixed isort error * style: fix linting errors * style: fix linting errors * style: fix linting errors * style: fix linting errors * feat: add exists method for checkpoint * fix: error scripts invocation after using relative import * fix: missing init.py * fixed a bug in learner's test() * add driver close and socket SUB disconnect for rejoin * feat: add distributed_config for dqn example * test: update test for grass * test: update test for k8s * feat: add promptings for steps * fix: change relative imports to absolute imports * fixed comments and update logger level * mv driver in proxy.__init__ for issue temp fixed. * Update docstring and comments * style: fix code reviews problems * fix code format Co-authored-by: Lyuchun Huang <romic.kid@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 feature cli windows (#203) * fix: change local mkdir to os.makedirs * fix: add utf8 encoding for logger * fix: add powershell.exe prefix to subprocess functions * feat: add debug_green * fix: use fsutil to create fix-size files in Windows * fix: use universal_newlines=True to handle encoding problem in different operating systems * fix: use temp file to do copy when the operating system is not Linux * fix: linting error * fix: use fsutil in test_k8s.py * feat: dynamic init ABS_PATH in GlobalParams * fix: use -Command to execute Powershell command * fix: refine code style in k8s_azure_executor.py, add Windows support for k8s mode * fix: problems in code review * EventBuffer refine (#197) * merge uniform event changes back * 1st step: move executing events into stack for better removing performance * flush event pool * typo * add option for env to enable event pool * refine stack functions * fix comment issues, add typings * lint fixing * lint fix * add missing fix * linting * lint * use linked list instead original event list and execute stack * add missing file * linting, and fixes * add missing file * linting fix * fixing comments * add missing file * rename event_list to event_linked_list * correct import path * change enable_event_pool to disable_finished_events * add missing file * V0.2 merge master (#214) * fix the visualization of docs/key_components/distributed_toolkit * add examples into isort ignore * refine import path for examples (#195) * refine import path for examples * refine indents * fixed formatting issues * update code style * add editorconfig-checker, add editorconfig path into lint, change super-linter version * change path for code saving in cim.gnn Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> * fix issue that sometimes there is conflict between distutils and setuptools (#208) * fix issue that cython and setuptools conflict * follow the accepted temp workaround * update comment, it should be conflict between setuptools and distutils * fixed bugs related to proxy interface changes Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> * typo fix * Bug fix: event buffer issue that cause Actions cannot be passed into business engine (#215) * bug fix * clear the reference after extract sub events, update ut to cover this issue Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * fix flake8 style problem * V0.2 feature refine mode namings (#212) * feat: refine cli exception * feat: refine mode namings * EventBuffer refine (#197) * merge uniform event changes back * 1st step: move executing events into stack for better removing performance * flush event pool * typo * add option for env to enable event pool * refine stack functions * fix comment issues, add typings * lint fixing * lint fix * add missing fix * linting * lint * use linked list instead original event list and execute stack * add missing file * linting, and fixes * add missing file * linting fix * fixing comments * add missing file * rename event_list to event_linked_list * correct import path * change enable_event_pool to disable_finished_events * add missing file * fixed bugs in dist rl * feat: rename files * tests: set longer gracefully wait time * style: fix linting errors * style: fix linting errors * style: fix linting errors * fix: rm redundant variables * fix: refine error message Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 vis new (#210) Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> * V0.2 local host process (#221) * Update local process (not ready) * update cli process mode * add setup/clear/template for maro process * fix process stop * add logger and rename parameters * add logger for setup/clear * fixed close not exist pid when given pid list. * Fixed comments and rename setup/clear with create/delete * update ProcessInternalError * comments fix * delete toolkit change * doc update * citi bike update * deploy path * datalib update * revert datalib * revert * maro file format * comments update * doc update * V0.2 grass on premises (#220) * feat: refine cli exception * commit on v0.2_grass_on_premises Co-authored-by: Lyuchun Huang <romic.kid@gmail.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 vm scheduling scenario (#189) * Initialize * Data center scenario init * Code style modification * V0.2 event buffer subevents expand (#180) * V0.2 rl toolkit refinement (#165) * refined rl abstractions * fixed formattin issues * checked out error-code related code from v0.2_pg * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * renamed save_models to dump_models * 1. set default batch_norm_enabled to True; 2. used state_dict in dqn model saving * renamed dump_experience_store to dump_experience_pool * fixed a bug in the dump_experience_pool method * fixed some PR comments * fixed more PR comments * 1.fixed some PR comments; 2.added early_stopping_checker; 3.revised explorer class * fixed cim example according to rl toolkit changes * fixed some more PR comments * rewrote multi_process_launcher to eliminate the distributed section in config * 1. fixed a typo; 2. added logging before early stopping * fixed a bug * fixed a bug * fixed a bug * added early stopping feature to CIM exmaple * fixed a typo * fixed some issues with early stopping * changed early stopping metric func * fixed a bug * fixed a bug * added early stopping to dist mode cim * added experience collecting func * edited notebook according to changes in CIM example * fixed bugs in nb * fixed lint formatting issues * fixed a typo * fixed some PR comments * fixed more PR comments * revised docs * removed nb output * fixed a bug in simple_learner * fixed a typo in nb * fixed a bug * fixed a bug * fixed a bug * removed unused import * fixed a bug * 1. changed early stopping default config; 2. renamed param in early stopping checker and added typing * fixed some doc issues * added output to nb Co-authored-by: ysqyang <v-yangqi@microsoft.com> * unfold sub-events, insert after parent * remove event category, use different class instead, add helper functions to gen decision and action event * add a method to support add immediate event to cascade event with tick validation * fix ut issue * add action as 1st sub event to ensure the executing order Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Data center scenario update * Code style update * Data scenario business engine update * Isort update * Fix lint code check * Fix based on PR comments. * Update based on PR comments. * Add decision payload * Add config file * Update utilization series logic * Update based on PR comment * Update based on PR * Update * Update * Add the ValidPm class * Update docs string and naming * Add energy consumption * Lint code fixed * Refining postpone function * Lint style update * Init data pipeline * Update based on PR comment * Add data pipeline download * Lint style update * Code style fix * Temp update * Data pipeline update * Add aria2p download function * Update based on PR comment * Update based on PR comment * Update based on PR comment * Update naming of variables * Rename topology * Renaming * Fix valid pm list * Pylint fix * Update comment * Update docstring and comment * Fix init import * Update tick issue * fix merge problem * update style * V0.2 datacenter data pipeline (#199) * Data pipeline update * Data pipeline update * Lint update * Update pipeline * Add vmid mapping * Update lint style * Add VM data analytics * Update notebook * Add binary converter * Modift vmtable yaml * Update binary meta file * Add cpu reader * random example added for data center * Fix bugs * Fix pylint * Add launcher * Fix pylint * best fit policy added * Add reset * Add config * Add config * Modify action object * Modify config * Fix naming * Modify config * Add snapshot list * Modify a spelling typo * Update based on PR comments. * Rename scenario to vm scheduling * Rename scenario * Update print messages * Lint fix * Lint fix * Rename scenario * Modify the calculation of cpu utilization * Add comment * Modify data pipeline path * Fix typo * Modify naming * Add unittest * Add comment * Unify naming * Fix data path typo * Update comments * Update snapshot features * Add take snapshot * Add summary keys * Update cpu reader * Update naming * Add unit test * Rename snapshot node * Add processed data pipeline * Modify config * Add comment * Lint style fix Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Add package used in vm_scheduling * add aria2p to test requirement * best fit example: update the usage of snapshot * Add aria2p to test requriement * Remove finish event * Fix unittest * Add test dataset * Update based on PR comment * Refine cpu reader and unittest * Lint update * Refine based on PR comment * Add agent index * Add node maping * Refine based on PR comments * Renaming postpone_step * Renaming and refine based on PR comments * Rename config * Update Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * Resolve none action problem (#224) * V0.2 vm_scheduling notebook (#223) * Initialize * Data center scenario init * Code style modification * V0.2 event buffer subevents expand (#180) * V0.2 rl toolkit refinement (#165) * refined rl abstractions * fixed formattin issues * checked out error-code related code from v0.2_pg * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * renamed save_models to dump_models * 1. set default batch_norm_enabled to True; 2. used state_dict in dqn model saving * renamed dump_experience_store to dump_experience_pool * fixed a bug in the dump_experience_pool method * fixed some PR comments * fixed more PR comments * 1.fixed some PR comments; 2.added early_stopping_checker; 3.revised explorer class * fixed cim example according to rl toolkit changes * fixed some more PR comments * rewrote multi_process_launcher to eliminate the distributed section in config * 1. fixed a typo; 2. added logging before early stopping * fixed a bug * fixed a bug * fixed a bug * added early stopping feature to CIM exmaple * fixed a typo * fixed some issues with early stopping * changed early stopping metric func * fixed a bug * fixed a bug * added early stopping to dist mode cim * added experience collecting func * edited notebook according to changes in CIM example * fixed bugs in nb * fixed lint formatting issues * fixed a typo * fixed some PR comments * fixed more PR comments * revised docs * removed nb output * fixed a bug in simple_learner * fixed a typo in nb * fixed a bug * fixed a bug * fixed a bug * removed unused import * fixed a bug * 1. changed early stopping default config; 2. renamed param in early stopping checker and added typing * fixed some doc issues * added output to nb Co-authored-by: ysqyang <v-yangqi@microsoft.com> * unfold sub-events, insert after parent * remove event category, use different class instead, add helper functions to gen decision and action event * add a method to support add immediate event to cascade event with tick validation * fix ut issue * add action as 1st sub event to ensure the executing order Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Data center scenario update * Code style update * Data scenario business engine update * Isort update * Fix lint code check * Fix based on PR comments. * Update based on PR comments. * Add decision payload * Add config file * Update utilization series logic * Update based on PR comment * Update based on PR * Update * Update * Add the ValidPm class * Update docs string and naming * Add energy consumption * Lint code fixed * Refining postpone function * Lint style update * Init data pipeline * Update based on PR comment * Add data pipeline download * Lint style update * Code style fix * Temp update * Data pipeline update * Add aria2p download function * Update based on PR comment * Update based on PR comment * Update based on PR comment * Update naming of variables * Rename topology * Renaming * Fix valid pm list * Pylint fix * Update comment * Update docstring and comment * Fix init import * Update tick issue * fix merge problem * update style * V0.2 datacenter data pipeline (#199) * Data pipeline update * Data pipeline update * Lint update * Update pipeline * Add vmid mapping * Update lint style * Add VM data analytics * Update notebook * Add binary converter * Modift vmtable yaml * Update binary meta file * Add cpu reader * random example added for data center * Fix bugs * Fix pylint * Add launcher * Fix pylint * best fit policy added * Add reset * Add config * Add config * Modify action object * Modify config * Fix naming * Modify config * Add snapshot list * Modify a spelling typo * Update based on PR comments. * Rename scenario to vm scheduling * Rename scenario * Update print messages * Lint fix * Lint fix * Rename scenario * Modify the calculation of cpu utilization * Add comment * Modify data pipeline path * Fix typo * Modify naming * Add unittest * Add comment * Unify naming * Fix data path typo * Update comments * Update snapshot features * Add take snapshot * Add summary keys * Update cpu reader * Update naming * Add unit test * Rename snapshot node * Add processed data pipeline * Modify config * Add comment * Lint style fix Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Add package used in vm_scheduling * add aria2p to test requirement * best fit example: update the usage of snapshot * Add aria2p to test requriement * Remove finish event * Fix unittest * Add test dataset * Update based on PR comment * Refine cpu reader and unittest * Lint update * Refine based on PR comment * Add agent index * Add node maping * Init vm shceduling notebook * Add notebook * Refine based on PR comments * Renaming postpone_step * Renaming and refine based on PR comments * Rename config * Update based on the v0.2_datacenter * Update notebook * Update * update filepath * notebook updated Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * Update process mode docs and fixed on premises (#226) * V0.2 Add github workflow integration (#222) * test: add github workflow integration * fix: split procedures && bug fixed * test: add training only restriction * fix: add 'approved' restriction * fix: change default ssh port to 22 * style: in one line * feat: add timeout for Subprocess.run * test: change default node_size to Standard_D2s_v3 * style: refine style * fix: add ssh_port param to on-premises mode * fix: add missing init.py * update param name * V0.2 explorer (#198) * overhauled exploration abstraction * fixed a bug * fixed a bug * fixed a bug * added exploration related methods to abs_agent * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * separated learning with exploration schedule and without * small fixes * moved explorer logic to actor side * fixed a bug * fixed a bug * fixed a bug * fixed a bug * removed unwanted param from simple agent manager * added noise explorer * fixed formatting * removed unnecessary comma * fixed PR comments * removed unwanted exception and imports * fixed a bug * fixed PR comments * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issue * fixed a bug * fixed lint issue * fixed naming * combined exploration param generation and early stopping in scheduler * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * fixed lint issue * moved logger inside scheduler * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * removed epsilon parameter from choose_action * fixed some PR comments * fixed some PR comments * bug fix * bug fix * bug fix * removed explorer abstraction from agent * refined dqn example * fixed lint issues * simplified scheduler * removed early stopping from CIM dqn example * removed early stopping from cim example config * renamed early_stopping_callback to early_stopping_checker * removed action_dim from noise explorer classes and added some shape checks * modified NoiseExplorer's __call__ logic to batch processing * made NoiseExplorer's __call__ return type np array * renamed update to set_parameters in explorer * fixed old naming in test_grass Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 embedded optim (#191) * added dueling action value model * renamed params in dueling_action_value_model * renamed shared_features to features * replaced IdentityLayers with nn.Identity * 1. added skip connection option in FC_net; 2. generalized learning model * added skip_connection option in config * removed type casting in fc_net * fixed lint formatting issues * refined docstring * mv dueling_actiovalue_model and fixed some bugs * added multi-head functionality to LearningModel * refined learning model docstring * added head_key param in learningModel forward * added double DQN and dueling features to DQN * fixed a bug * added DuelingQModelHead enum * fixed a bug * removed unwanted file * fixed PR comments * added top layer logic and is_top option in fc_net * fixed a bug * fixed a bug * reverted some changes in learning model * reverted some changes in learning model * added members to learning model to fix the mode issue * fixed a bug * fixed mode setting issue in learning model * fixed PR comments * revised cim example according to DQN changes * renamed eval_model to q_value_model in cim example * more fixes * fixed a bug * fixed a bug * added doc per PR comments * removed learner.exit() in single_process_launcher * removed learner.exit() in single_process_launcher * fixed PR comments * fixed rl/__init__ * fixed issues in example * fixed a bug * fixed a bug * fixed lint formatting issues * double DQN feature * fixed a bug * fixed a bug * fixed PR comments * fixed lint issue * embedded optimizer into SingleHeadLearningModel * 1. fixed PR comments related to load/dump; 2. removed abstract load/dump methods from AbsAlgorithm * added load_models in simple_learner * minor docstring edits * minor docstring edits * minor docstring edits * mv optimizer options inside LearningMode * modified example accordingly * fixed a bug * fixed a bug * fixed a bug * added dueling DQN feature * revised and refined docstrings * fixed a bug * fixed lint issues * added load/dump functions to LearningModel * fixed a bug * fixed a bug * fixed lint issues * refined DQN docstrings * removed load/dump functions from DQN * added task validator * fixed decorator use * fixed a typo * fixed a bug * fixed lint issues * changed LearningModel's step() to take a single loss * revised learning model design * revised example * fixed a bug * fixed a bug * fixed a bug * fixed a bug * added decorator utils to algorithm * fixed a bug * renamed core_model to model * fixed a bug * 1. fixed lint formatting issues; 2. refined learning model docstrings * rm trailing whitespaces * added decorator for choose_action * fixed a bug * fixed a bug * fixed version-related issues * renamed add_zeroth_dim decorator to expand_dim * overhauled exploration abstraction * fixed a bug * fixed a bug * fixed a bug * added exploration related methods to abs_agent * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * separated learning with exploration schedule and without * small fixes * moved explorer logic to actor side * fixed a bug * fixed a bug * fixed a bug * fixed a bug * removed unwanted param from simple agent manager * small fixes * added shared_module property to LearningModel * added shared_module property to LearningModel * revised __getstate__ for LearningModel * fixed a bug * added soft_update function to learningModel * fixed a bug * revised learningModel * rm __getstate__ and __setstate__ from LearningModel * added noise explorer * fixed formatting * removed unnecessary comma * removed unnecessary comma * fixed PR comments * removed unwanted exception and imports * removed unwanted exception and imports * fixed a bug * fixed PR comments * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issue * fixed a bug * fixed lint issue * fixed naming * combined exploration param generation and early stopping in scheduler * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * fixed lint issue * moved logger inside scheduler * fixed a bug * fixed a bug * fixed a bug * fixed lint issues * fixed lint issue * removed epsilon parameter from choose_action * removed epsilon parameter from choose_action * changed agent manager's train parameter to experience_by_agent * fixed some PR comments * renamed zero_grad to zero_gradients in LearningModule * fixed some PR comments * bug fix * bug fix * bug fix * removed explorer abstraction from agent * added DEVICE env variable as first choice for torch device * refined dqn example * fixed lint issues * removed unwanted import in cim example * updated cim-dqn notebook * simplified scheduler * edited notebook according to merged scheduler changes * refined dimension check for learning module manager and removed num_actions from DQNConfig * bug fix for cim example * added notebook output * removed early stopping from CIM dqn example * removed early stopping from cim example config * moved decorator logic inside algorithms * renamed early_stopping_callback to early_stopping_checker * removed action_dim from noise explorer classes and added some shape checks * modified NoiseExplorer's __call__ logic to batch processing * made NoiseExplorer's __call__ return type np array * renamed update to set_parameters in explorer * fixed old naming in test_grass Co-authored-by: ysqyang <v-yangqi@microsoft.com> * V0.2 VM scheduling docs (#228) * Initialize * Data center scenario init * Code style modification * V0.2 event buffer subevents expand (#180) * V0.2 rl toolkit refinement (#165) * refined rl abstractions * fixed formattin issues * checked out error-code related code from v0.2_pg * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * fixed a bug * renamed save_models to dump_models * 1. set default batch_norm_enabled to True; 2. used state_dict in dqn model saving * renamed dump_experience_store to dump_experience_pool * fixed a bug in the dump_experience_pool method * fixed some PR comments * fixed more PR comments * 1.fixed some PR comments; 2.added early_stopping_checker; 3.revised explorer class * fixed cim example according to rl toolkit changes * fixed some more PR comments * rewrote multi_process_launcher to eliminate the distributed section in config * 1. fixed a typo; 2. added logging before early stopping * fixed a bug * fixed a bug * fixed a bug * added early stopping feature to CIM exmaple * fixed a typo * fixed some issues with early stopping * changed early stopping metric func * fixed a bug * fixed a bug * added early stopping to dist mode cim * added experience collecting func * edited notebook according to changes in CIM example * fixed bugs in nb * fixed lint formatting issues * fixed a typo * fixed some PR comments * fixed more PR comments * revised docs * removed nb output * fixed a bug in simple_learner * fixed a typo in nb * fixed a bug * fixed a bug * fixed a bug * removed unused import * fixed a bug * 1. changed early stopping default config; 2. renamed param in early stopping checker and added typing * fixed some doc issues * added output to nb Co-authored-by: ysqyang <v-yangqi@microsoft.com> * unfold sub-events, insert after parent * remove event category, use different class instead, add helper functions to gen decision and action event * add a method to support add immediate event to cascade event with tick validation * fix ut issue * add action as 1st sub event to ensure the executing order Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Data center scenario update * Code style update * Data scenario business engine update * Isort update * Fix lint code check * Fix based on PR comments. * Update based on PR comments. * Add decision payload * Add config file * Update utilization series logic * Update based on PR comment * Update based on PR * Update * Update * Add the ValidPm class * Update docs string and naming * Add energy consumption * Lint code fixed * Refining postpone function * Lint style update * Init data pipeline * Update based on PR comment * Add data pipeline download * Lint style update * Code style fix * Temp update * Data pipeline update * Add aria2p download function * Update based on PR comment * Update based on PR comment * Update based on PR comment * Update naming of variables * Rename topology * Renaming * Fix valid pm list * Pylint fix * Update comment * Update docstring and comment * Fix init import * Update tick issue * fix merge problem * update style * V0.2 datacenter data pipeline (#199) * Data pipeline update * Data pipeline update * Lint update * Update pipeline * Add vmid mapping * Update lint style * Add VM data analytics * Update notebook * Add binary converter * Modift vmtable yaml * Update binary meta file * Add cpu reader * random example added for data center * Fix bugs * Fix pylint * Add launcher * Fix pylint * best fit policy added * Add reset * Add config * Add config * Modify action object * Modify config * Fix naming * Modify config * Add snapshot list * Modify a spelling typo * Update based on PR comments. * Rename scenario to vm scheduling * Rename scenario * Update print messages * Lint fix * Lint fix * Rename scenario * Modify the calculation of cpu utilization * Add comment * Modify data pipeline path * Fix typo * Modify naming * Add unittest * Add comment * Unify naming * Fix data path typo * Update comments * Update snapshot features * Add take snapshot * Add summary keys * Update cpu reader * Update naming * Add unit test * Rename snapshot node * Add processed data pipeline * Modify config * Add comment * Lint style fix Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> * Add package used in vm_scheduling * add aria2p to test requirement * best fit example: update the usage of snapshot * Add aria2p to test requriement * Remove finish event * Fix unittest * Add test dataset * Update based on PR comment * vm doc init * Update docs * Update docs * Update docs * Update docs * Remove old notebook * Update docs * Update docs * Add figure * Update docs Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * doc update * new link * image update * v0.2 VM Scheduling docs refinement (#231) * Fix typo * Refining vm scheduling docs * image change * V0.2 store refinement (#234) * updated docs and images for rl toolkit * 1. fixed import formats for maro/rl; 2. changed decorators to hypers in store * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Fix bug (#237) vm scenario: fix the event type bug of the postpone event * V0.2 rl toolkit doc (#235) * updated docs and images for rl toolkit * updated cim example doc * updated cim exmaple docs * updated cim example rst * updated rl_toolkit and cim example docs * replaced q_module with q_net in example rst * refined doc * refined doc * updated figures * updated figures Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Merge V0.2 vis into V0.2 (#233) * Implemented dump snapshots and convert to CSV. * Let BE supports params when dump snapshot. * Refactor dump code to core.py * Implemented decision event dump. * replace is not '' with !='' * Fixed issues that code review mentioned. * removed path from hello.py * Changed import sort. * Fix import sorting in citi_bike/business_engine * visualization 0.1 * Updated lint configurations. * Fixed formatting error that caused lint errors. * render html title function * Try to fix lint errors. * flake-8 style fix * remove space around 18,35 * dump_csv_converter.py re-formatting. * files re-formatting. * style fixed * tab delete * white space fix * white space fix-2 * vis redundant function delete * refine * re-formatting after merged upstream. * Updated import section. * Updated import section. * pr refine * isort fix * white space * lint error * \n error * test continuation * indent * continuation of indent * indent 0.3 * comment update * comment update 0.2 * f-string update * f-string 0.2 * lint 0.3 * lint 0.4 * lint 0.4 * lint 0.5 * lint 0.6 * docstring update * data version deploy update * condition update * add whitespace * V0.2 vis dump feature enhancement. (#190) * Dumps added manifest file. * Code updated format by flake8 * Changed manifest file format for easy reading. * deploy info update; docs update * weird white space * Update dashboard_visualization.md * new endline? * delete dependency * delete irrelevant file * change scenario to enum, divide file path into a separated class * doc refine * doc update * params type * data structure update * doc&enum, formula refine * refine * add ut, refine doc * style refine * isort * strong type fix * os._exit delete * revert datalib * import new line * change test case * change file name & doc * change deploy path * delete params * revert file * delete duplicate file * delete single process * update naming * manually change import order * delete blank * edit error * requirement txt * style fix & refine * comments&docstring refine * add parameter name * test & dump * comments update * Added manifest file. (#201) Only a few changes that need to meet requirements of manifest file format. * comments fix * delete toolkit change * doc update * citi bike update * deploy path * datalib update * revert datalib * revert * maro file format * comments update * doc update * update param name * doc update * new link * image update * V0.2 visualization-0.1 (#181) * visualization 0.1 * render html title function * flake-8 style fix * style fixed * tab delete * white space fix * white space fix-2 * vis redundant function delete * refine * pr refine * isort fix * white space * lint error * \n error * test continuation * indent * continuation of indent * indent 0.3 * comment update * comment update 0.2 * f-string update * f-string 0.2 * lint 0.3 * lint 0.4 * lint 0.4 * lint 0.5 * lint 0.6 * docstring update * data version deploy update * condition update * add whitespace * deploy info update; docs update * weird white space * Update dashboard_visualization.md * new endline? * delete dependency * delete irrelevant file * change scenario to enum, divide file path into a separated class * fix the visualization of docs/key_components/distributed_toolkit * doc refine * doc update * params type * add examples into isort ignore * data structure update * doc&enum, formula refine * refine * add ut, refine doc * style refine * isort * strong type fix * os._exit delete * revert datalib * import new line * change test case * change file name & doc * change deploy path * delete params * revert file * delete duplicate file * delete single process * update naming * manually change import order * delete blank * edit error * requirement txt * style fix & refine * comments&docstring refine * add parameter name * test & dump * comments update * comments fix * delete toolkit change * doc update * citi bike update * deploy path * datalib update * revert datalib * revert * maro file format * comments update * doc update * update param name * doc update * new link * image update Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Miaoran Chen (Wicresoft) <v-miaorc@microsoft.com> * image change * add reset snapshot * delete dump * add new line * add next steps * import change * relative import * add init file * import change * change utils file * change cliexpcetion to clierror * dashboard test * change result * change assertation * move not * unit test change * core change * unit test delete name_mapping_file * update cim business engine * doc update * change relative path * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * duc update * duc update * duc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * change import sequence * comments update * doc add pic * add dependency * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * Update dashboard_visualization.rst * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * delete white space * doc update * doc update * update doc * update doc * update doc Co-authored-by: Michael Li <mic_lee2000@hotmail.com> Co-authored-by: Miaoran Chen (Wicresoft) <v-miaorc@microsoft.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * V0.2 docs process mode (#230) * Update process mode docs and fixed on premises * Update orchestration docs * Update process mode docs add JOB_NAME as env variable * fixed bugs * fixed isort issue * update docs index Co-authored-by: kaiqli <v-kaiqli@microsoft.com> * V0.2 learning model refinement (#236) * moved optimizer options to LearningModel * typo fix * fixed lint issues * updated notebook * misc edits * 1. renamed CIMAgent to DQNAgent; 2. moved create_dqn_agents to Agent section in notebook * renamed single_host_cim_learner ot cim_learner in notebook * updated notebook output * typo fix * removed dimension check in absence of shared stack * fixed a typo * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Update vm docs (#241) Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * V0.2 info update (#240) * update readme * update version * refine reademe format * add vis gif * add citation * update citation * update badge Co-authored-by: Arthur Jiang <sjian@microsoft.com> * Fix typo (#242) * Fix typo * fix typo * fix * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update * doc update Co-authored-by: Arthur Jiang <sjian@microsoft.com> Co-authored-by: Arthur Jiang <ArthurSJiang@gmail.com> Co-authored-by: Romic Huang <romic.kid@gmail.com> Co-authored-by: zhanyu wang <pocket_2001@163.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: kaiqli <59279714+kaiqli@users.noreply.github.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Michael Li <mic_lee2000@hotmail.com> Co-authored-by: Miaoran Chen (Wicresoft) <v-miaorc@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> Co-authored-by: kyu-kuanwei <72911362+kyu-kuanwei@users.noreply.github.com> Co-authored-by: kaiqli <v-kaiqli@microsoft.com> * bug fix related to np array divide (#245) Co-authored-by: ysqyang <v-yangqi@microsoft.com> * Master.simple bike (#250) * notebook for simple bike repositioning added * add simple rule-based algorithms * unify input * add policy based on statistics * update be for simple bike scenario to fit latest event buffer changes (#247) * change rendered graph * figures updated * change notebook * matplot updated * figures updated Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: wesley <Wenlei.Shi@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> * simple bike repositioning article: formula updated * checked out docs/source from v0.2 * aligned with v0.2 * rm unwanted import * added references in policy_optimization.py * fixed lint issues Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: Meroy Chen <39452768+Meroy9819@users.noreply.github.com> Co-authored-by: Arthur Jiang <sjian@microsoft.com> Co-authored-by: Arthur Jiang <ArthurSJiang@gmail.com> Co-authored-by: Romic Huang <romic.kid@gmail.com> Co-authored-by: zhanyu wang <pocket_2001@163.com> Co-authored-by: kaiqli <59279714+kaiqli@users.noreply.github.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Michael Li <mic_lee2000@hotmail.com> Co-authored-by: Miaoran Chen (Wicresoft) <v-miaorc@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> Co-authored-by: kyu-kuanwei <72911362+kyu-kuanwei@users.noreply.github.com> Co-authored-by: kaiqli <v-kaiqli@microsoft.com> * V0.2 backend dynamic node support (#172) * update lint workflow * fix workflow issue * Update lint.yml * Create tox.ini * Update lint.yml * Update lint.yml * Update tox.ini * Update lint.yml * Delete tox.ini from root folder, move it to .github/linters * Update CONTRIBUTING.md * add more comments * update lint conf to ignore cli banner issue * change extension implementation from c to cpp * update script to gen cpp files * backend base interface redefine * interface revamp for np backend * 1st step for revamp * bug fix * draft * implementation of attribute * implementation of backend * remove backend switching * draft raw backend wrapper * correct function parameter type * 1st runable version * bug fix for types * ut passed * change CRLF to LF * fix get_node_info interface * add raw test in frame ut * return np.array for all query result * use ticks from backend * set init value * snapshot ut passed * support set default backend by environemnt variable * env ut with different backend * fix take snapshot index bug * test under both backends * ignore generated cpp file * fix lint isues * more lint fix * use ordered map to store ticks to keep the order * remove test code * refine dup code * refine code to avoid too much if/else * handle and raise exception for attr getter * change the way to handle cpp exception, use cython runtimeerror instead * add missing function, and fix bug in np impl * fix lint issue * specify c++11 flag for compilers * use normal field assignment instead initializer list, as linux gcc will complain it * add np ignore macro * try to refine token pasting operator to avoid error on linux * more pasting operator issue fix * remove un-used options * update workflow files to fit new backend * 1st version of dynamic backend structure * setup ut for cpp using lest * bitset complete * attributestore and ut * arrange * copy_to * current frame * ut for frame * bug fix and ut correct * fix issue that value not correct after arrange * fix bug in test case * frame update * change the way to add nodes, support add node from middle * frame in backend * snapshotlist code complete * add size method for snapshotlist, add ut template * make sure snapshot max size not be 0 * add max size * fix query parameters * fix attribute store extend error * add function to retrieve attribute from snapshotlist * return nan for invalid index * add function to check if nan for float attribute only * fix bug that not update _last_tick for snapshot list, that cause take snapshot for same tick crash * add functions to expose internal state under debug mode, make it easy to do unit test * fix issue that cause overlap logic skiped * ut passed for all implemented functions * remove query in ut, as it not completed yet * refine querying interfaces, use 2 functions for 1 querying * snapshot query, * use pointer instead weak_ptr * backend impl * set default parameters value * query bug fix, * bug fix: new_attr should return attr id not node id * use macro to create attribute getters * add reset support * change the way to reset, avoid allocation time * test reset for attributestore * use Bitset instead vector<bool> to make it easy to reset * refine backend interfaces to make it compact with old one * correct quering interface, cython compile passed * bug fix: get_ticks not set correct index * correct cpp backend binding, add type for frame * correct ut for snapshot * bug fix: query cause crash after snapshot reset * fix env test * bug fix: is_nan should check data type first * fix cim ut issues with raw backend * fix citibike ut issues for raw backend * add interfaces to support dynamic nodes, not tested * bug fix: access cpp object without cdef * bug fix: missing impl for dynamic methods * ut for append nodes * return node number dynamiclly * remove unused parameters for snapshot * remove unused code * allow get attribute for deleted node * ut for delete and resume node * function to set attribute slot * bug fix: set attribute will cause crash * bug fix: remove append node when reset cause exception * bug fix: frame.backend_type return incorrect name * backends performance comparison * correct internal type * correct warnings * missing ; * formating * fix lint issue * simple the way to copy mapping * add dump interfaces * frame dump * ignore if dump path is not exist * bug fix: use max slots instead of current slots for padding in snapshot querying * use max slot number in history instead of current for padding * dump for snapshot * close file at the end * refine snapshot dump function * fix lint issue * avoid too much allocate operation * use pointer instead reference for furthure changes * avoid 2 times map copy * add comments for missing functions * performance optimize * use emplace instead push * use emplace instead push * remove cpp files * add missing lisence * ignore .vs folder * add lest lisence for cpp unittest * Delete CMakeLists.txt * add error msg for exception, make it easy to identify error at python side * remove old codes * replace with new code * change IDENTIER to NODE_TYPE and ATTR_TYPE * build pass * fix attr type not correct bug * reomve unused comment * make frame ut pass * correct the max snapshots checking * fix test case * add missing file * correct performance test * refine attribute code * refine bitset code * update FrameBase doc about switch backend * correct the exception name * refine frame code * refine node code * refine snapshot list code * add is_const and is_list when adding attribute * support query const attribute without tick exist * add operations for list attribute * remove cache as we have list attribute * add remove and insert for list attribute * add for-loop support for list attribute * fix bug that not update list attribute slot number after operations * test for dynamic features * frame dump * dump for snapshot list * fix issue on gcc compiler * add missing file * fix lint issues * refine the exception, more comments * fix lint issue * fix lint issue * use simulate enum instead of str * Use new type instead old in tests * using mapping instead if-else * remove generated code * use mapping to reduce too much if-else * add default attribute type int if not provided or invalid provided * remove generated code * update workflow with code gen * more frame test * add missing files * test: cover maro.simulator.utils.common * update test with new scenario * comments * tests * update doc * fix lint and comments * CRLF to LF * fix lint issue Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> * V0.2 vm oversub docs (#256) * Remove topology * Update pipeline * Update pipeline * Update pipeline * Modify metafile * Add two attributes of VM * Update pipeline * Add vm category * Add todo * Add oversub config * Add oversubscription feature * Lint fix * Update based on PR comment. * Update pipeline * Update pipeline * Update config. * Update based on PR comment * Update * Add pm sku feature * Add sku setting * Add sku feature * Lint fix * Lint style * Update sku, overloading * Lint fix * Lint style * Fix bug * Modify config * Remove sky and replaced it by pm stype * Add and refactor vm category * Comment out cofig * Unify the enum format * Fix lint style * Fix import order * Update based on PR comment * Update overload to the VM docs * Update docs * Update vm docs Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Arthur Jiang <sjian@microsoft.com> Co-authored-by: Arthur Jiang <ArthurSJiang@gmail.com> Co-authored-by: Romic Huang <romic.kid@gmail.com> Co-authored-by: zhanyu wang <pocket_2001@163.com> Co-authored-by: ysqyang <ysqyang@gmail.com> Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: kaiqli <59279714+kaiqli@users.noreply.github.com> Co-authored-by: Jinyu Wang <Wang.Jinyu@microsoft.com> Co-authored-by: Jinyu Wang <jinywan@microsoft.com> Co-authored-by: Chaos Yu <chaos.you@gmail.com> Co-authored-by: Wenlei Shi <Wenlei.Shi@microsoft.com> Co-authored-by: Michael Li <mic_lee2000@hotmail.com> Co-authored-by: kyu-kuanwei <72911362+kyu-kuanwei@users.noreply.github.com> Co-authored-by: Meroy Chen <39452768+Meroy9819@users.noreply.github.com> Co-authored-by: Miaoran Chen (Wicresoft) <v-miaorc@microsoft.com> Co-authored-by: kaiqli <v-kaiqli@microsoft.com> Co-authored-by: Kuan Wei Yu <v-kyu@microsoft.com>
This commit is contained in:
Родитель
376d573f71
Коммит
fa092f35b1
|
@ -19,6 +19,7 @@ exclude =
|
|||
.github,
|
||||
scripts,
|
||||
tests,
|
||||
maro/backends/*.cpp
|
||||
setup.py
|
||||
|
||||
max-line-length = 120
|
||||
|
|
|
@ -34,7 +34,8 @@ jobs:
|
|||
|
||||
- name: Compile cython files
|
||||
run: |
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx -3 -E FRAME_BACKEND=NUMPY,NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
python ./scripts/code_gen.py
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx --cplus -3 -E NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
|
||||
- name: Build wheel on Windows and macOS
|
||||
if: runner.os == 'Windows' || runner.os == 'macOS'
|
||||
|
|
|
@ -29,7 +29,8 @@ jobs:
|
|||
- name: Build image
|
||||
run: |
|
||||
pip install -r ./maro/requirements.build.txt
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx -3 -E FRAME_BACKEND=NUMPY,NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
python ./scripts/code_gen.py
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx --cplus -3 -E NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
cat ./maro/__misc__.py | grep __version__ | egrep -o [0-9].[0-9].[0-9,a-z]+ | { read version; docker build -f ./docker_files/cpu.play.df . -t ${{ secrets.DOCKER_HUB_USERNAME }}/maro:cpu -t ${{ secrets.DOCKER_HUB_USERNAME }}/maro:latest -t ${{ secrets.DOCKER_HUB_USERNAME }}/maro:cpu-$version; }
|
||||
|
||||
- name: Login docker hub
|
||||
|
|
|
@ -32,7 +32,8 @@ jobs:
|
|||
|
||||
- name: Compile cython files
|
||||
run: |
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx -3 -E FRAME_BACKEND=NUMPY,NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
python ./scripts/code_gen.py
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx --cplus -3 -E NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
|
||||
- name: Build maro inplace
|
||||
run: |
|
||||
|
|
|
@ -30,7 +30,8 @@ jobs:
|
|||
|
||||
- name: Compile cython files
|
||||
run: |
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx -3 -E FRAME_BACKEND=NUMPY,NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
python ./scripts/code_gen.py
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx --cplus -3 -E NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
|
||||
- name: Build maro inplace
|
||||
run: |
|
||||
|
|
|
@ -8,6 +8,14 @@ the backend language for improving the execution reference. What's more,
|
|||
the backend store is a pluggable design, user can choose different backend
|
||||
implementation based on their real performance requirement and device limitation.
|
||||
|
||||
Currenty there are two data model backend implementation: static and dynamic.
|
||||
Static implementation used Numpy as its data store, do not support dynamic
|
||||
attribute length, the advance of this version is that its memory size is same as its
|
||||
declaration.
|
||||
Dynamic implementation is hand-craft c++.
|
||||
It supports dynamic attribute (list) which will take more memory than the static implementation
|
||||
but is faster for querying snapshot states and accessing attributes.
|
||||
|
||||
Key Concepts
|
||||
------------
|
||||
|
||||
|
@ -28,6 +36,12 @@ As shown in the figure above, there are some key concepts in the data model:
|
|||
The ``slot`` number can indicate the attribute values (e.g. the three different
|
||||
container types in CIM scenario) or the detailed categories (e.g. the ten specific
|
||||
products in the `Use Case <#use-case>`_ below). By default, the ``slot`` value is one.
|
||||
As for the dynamic backend implementation, an attribute can be marked as is_list or is_const to identify
|
||||
it is a list attribute or a const attribute respectively.
|
||||
A list attribute's default slot number is 0, and can be increased as demand, max number is 2^32.
|
||||
A const attribute is designed for the value that will not change after initialization,
|
||||
e.g. the capacity of a port/station. The value is shared between frames and will not be copied
|
||||
when taking a snapshot.
|
||||
* **Frame** is the collection of all nodes in the environment. The historical frames
|
||||
present the aggregated state of the environment during a specific period, while
|
||||
the current frame hosts the latest state of the environment at the current time point.
|
||||
|
@ -41,6 +55,7 @@ Use Case
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from maro.backends.backend import AttributeType
|
||||
from maro.backends.frame import node, NodeAttribute, NodeBase, FrameNode, FrameBase
|
||||
|
||||
TOTAL_PRODUCT_CATEGORIES = 10
|
||||
|
@ -51,8 +66,8 @@ Use Case
|
|||
|
||||
@node("warehouse")
|
||||
class Warehouse(NodeBase):
|
||||
inventories = NodeAttribute("i", TOTAL_PRODUCT_CATEGORIES)
|
||||
shortages = NodeAttribute("i", TOTAL_PRODUCT_CATEGORIES)
|
||||
inventories = NodeAttribute(AttributeType.Int, TOTAL_PRODUCT_CATEGORIES)
|
||||
shortages = NodeAttribute(AttributeType.Int, TOTAL_PRODUCT_CATEGORIES)
|
||||
|
||||
def __init__(self):
|
||||
self._init_inventories = [100 * (i + 1) for i in range(TOTAL_PRODUCT_CATEGORIES)]
|
||||
|
@ -65,9 +80,9 @@ Use Case
|
|||
|
||||
@node("store")
|
||||
class Store(NodeBase):
|
||||
inventories = NodeAttribute("i", TOTAL_PRODUCT_CATEGORIES)
|
||||
shortages = NodeAttribute("i", TOTAL_PRODUCT_CATEGORIES)
|
||||
sales = NodeAttribute("i", TOTAL_PRODUCT_CATEGORIES)
|
||||
inventories = NodeAttribute(AttributeType.Int, TOTAL_PRODUCT_CATEGORIES)
|
||||
shortages = NodeAttribute(AttributeType.Int, TOTAL_PRODUCT_CATEGORIES)
|
||||
sales = NodeAttribute(AttributeType.Int, TOTAL_PRODUCT_CATEGORIES)
|
||||
|
||||
def __init__(self):
|
||||
self._init_inventories = [10 * (i + 1) for i in range(TOTAL_PRODUCT_CATEGORIES)]
|
||||
|
@ -86,7 +101,8 @@ Use Case
|
|||
|
||||
def __init__(self):
|
||||
# If your actual frame number was more than the total snapshot number, the old snapshots would be rolling replaced.
|
||||
super().__init__(enable_snapshot=True, total_snapshot=TOTAL_SNAPSHOT)
|
||||
# You can select a backend implementation that will fit your requirement.
|
||||
super().__init__(enable_snapshot=True, total_snapshot=TOTAL_SNAPSHOT, backend_name="static/dynamic")
|
||||
|
||||
* The operations on the retail frame.
|
||||
|
||||
|
@ -139,19 +155,34 @@ All supported data types for the attribute of the node:
|
|||
* - Attribute Data Type
|
||||
- C Type
|
||||
- Range
|
||||
* - i2
|
||||
- int16_t
|
||||
* - Attribute.Byte
|
||||
- char
|
||||
- -128 .. 127
|
||||
* - Attribute.UByte
|
||||
- unsigned char
|
||||
- 0 .. 255
|
||||
* - Attribute.Short (i2)
|
||||
- short
|
||||
- -32,768 .. 32,767
|
||||
* - i, i4
|
||||
* - Attribute.UShort
|
||||
- unsigned short
|
||||
- 0 .. 65,535
|
||||
* - Attribute.Int (i4)
|
||||
- int32_t
|
||||
- -2,147,483,648 .. 2,147,483,647
|
||||
* - i8
|
||||
* - Attribute.UInt (i4)
|
||||
- uint32_t
|
||||
- 0 .. 4,294,967,295
|
||||
* - Attribute.Long (i8)
|
||||
- int64_t
|
||||
- -9,223,372,036,854,775,808 .. 9,223,372,036,854,775,807
|
||||
* - f
|
||||
* - Attribute.ULong (i8)
|
||||
- uint64_t
|
||||
- 0 .. 18,446,744,073,709,551,615
|
||||
* - Attribute.Float (f)
|
||||
- float
|
||||
- -3.4E38 .. 3.4E38
|
||||
* - d
|
||||
* - Attribute.Double (d)
|
||||
- double
|
||||
- -1.7E308 .. 1.7E308
|
||||
|
||||
|
@ -216,3 +247,15 @@ For better data access, we also provide some advanced features, including:
|
|||
|
||||
# Query attribute by frame index list.
|
||||
states = test_nodes_snapshots[[0, 1, 2]: 0: "int_attribute"]
|
||||
|
||||
# The querying states is different between static and dynamic implementation
|
||||
# Static implementation will return a 1-dim numpy array, as the shape is known according to the parameters.
|
||||
# Dynamic implementation will return a 4-dim numpy array, that shape is (ticks, node_indices, attributes, slots).
|
||||
# Usually we can just flatten the state from dynamic implementation, then it will be same as static implementation,
|
||||
# except for list attributes.
|
||||
# List attribute only support one tick, one node index and one attribute name to query, cannot mix with normal attributes
|
||||
states = test_nodes_snapshots[0: 0: "list_attribute"]
|
||||
|
||||
# Also with dynamic implementation, we can get the const attributes which is shared between snapshot list, even without
|
||||
# any snapshot (need to provided one tick for padding).
|
||||
states = test_nodes_snapshots[0: [0, 1]: ["const_attribute", "const_attribute_2"]]
|
||||
|
|
|
@ -63,19 +63,19 @@ Learner and Actor
|
|||
Scheduler
|
||||
---------
|
||||
|
||||
A ``Scheduler`` is the driver of an episodic learning process. The learner uses the scheduler to repeat the
|
||||
rollout-training cycle a set number of episodes. For algorithms that require explicit exploration (e.g.,
|
||||
A ``Scheduler`` is the driver of an episodic learning process. The learner uses the scheduler to repeat the
|
||||
rollout-training cycle a set number of episodes. For algorithms that require explicit exploration (e.g.,
|
||||
DQN and DDPG), there are two types of schedules that a learner may follow:
|
||||
|
||||
* Static schedule, where the exploration parameters are generated using a pre-defined function of episode
|
||||
number. See ``LinearParameterScheduler`` and ``TwoPhaseLinearParameterScheduler`` provided in the toolkit
|
||||
for example.
|
||||
* Static schedule, where the exploration parameters are generated using a pre-defined function of episode
|
||||
number. See ``LinearParameterScheduler`` and ``TwoPhaseLinearParameterScheduler`` provided in the toolkit
|
||||
for example.
|
||||
* Dynamic schedule, where the exploration parameters for the next episode are determined based on the performance
|
||||
history. Such a mechanism is possible in our abstraction because the scheduler provides a ``record_performance``
|
||||
interface that allows it to keep track of roll-out performances.
|
||||
interface that allows it to keep track of roll-out performances.
|
||||
|
||||
Optionally, an early stopping checker may be registered if one wishes to terminate training when certain performance
|
||||
requirements are satisfied, possibly before reaching the prescribed number of episodes.
|
||||
Optionally, an early stopping checker may be registered if one wishes to terminate training when certain performance
|
||||
requirements are satisfied, possibly before reaching the prescribed number of episodes.
|
||||
|
||||
Agent Manager
|
||||
-------------
|
||||
|
@ -125,11 +125,11 @@ scenario agnostic.
|
|||
Algorithm
|
||||
---------
|
||||
|
||||
The algorithm is the kernel abstraction of the RL formulation for a real-world problem. Our abstraction
|
||||
decouples algorithm and model so that an algorithm can exist as an RL paradigm independent of the inner
|
||||
workings of the models it uses to generate actions or estimate values. For example, the actor-critic
|
||||
The algorithm is the kernel abstraction of the RL formulation for a real-world problem. Our abstraction
|
||||
decouples algorithm and model so that an algorithm can exist as an RL paradigm independent of the inner
|
||||
workings of the models it uses to generate actions or estimate values. For example, the actor-critic
|
||||
algorithm does not need to concern itself with the structures and optimizing schemes of the actor and
|
||||
critic models. This decoupling is achieved by the ``LearningModel`` abstraction described below.
|
||||
critic models. This decoupling is achieved by the ``LearningModel`` abstraction described below.
|
||||
|
||||
|
||||
.. image:: ../images/rl/algorithm.svg
|
||||
|
@ -153,18 +153,18 @@ Block, NNStack and LearningModel
|
|||
--------------------------------
|
||||
|
||||
MARO provides an abstraction for the underlying models used by agents to form policies and estimate values.
|
||||
The abstraction consists of a 3-level hierachy formed by ``AbsBlock``, ``NNStack`` and ``LearningModel`` from
|
||||
The abstraction consists of a 3-level hierachy formed by ``AbsBlock``, ``NNStack`` and ``LearningModel`` from
|
||||
the bottom up, all of which subclass torch's nn.Module. An ``AbsBlock`` is the smallest structural
|
||||
unit of an NN-based model. For instance, the ``FullyConnectedBlock`` provided in the toolkit represents a stack
|
||||
of fully connected layers with features like batch normalization, drop-out and skip connection. An ``NNStack`` is
|
||||
a composite network that consists of one or more such blocks, each with its own set of network features.
|
||||
The complete model as used directly by an ``Algorithm`` is represented by a ``LearningModel``, which consists of
|
||||
one or more task stacks as "heads" and an optional shared stack at the bottom (which serves to produce representations
|
||||
as input to all task stacks). It also contains one or more optimizers responsible for applying gradient steps to the
|
||||
trainable parameters within each stack, which is the smallest trainable unit from the perspective of a ``LearningModel``.
|
||||
The assignment of optimizers is flexible: it is possible to freeze certain stacks while optimizing others. Such an
|
||||
abstraction presents a unified interface to the algorithm, regardless of how many individual models it requires and how
|
||||
complex the model architecture might be.
|
||||
unit of an NN-based model. For instance, the ``FullyConnectedBlock`` provided in the toolkit represents a stack
|
||||
of fully connected layers with features like batch normalization, drop-out and skip connection. An ``NNStack`` is
|
||||
a composite network that consists of one or more such blocks, each with its own set of network features.
|
||||
The complete model as used directly by an ``Algorithm`` is represented by a ``LearningModel``, which consists of
|
||||
one or more task stacks as "heads" and an optional shared stack at the bottom (which serves to produce representations
|
||||
as input to all task stacks). It also contains one or more optimizers responsible for applying gradient steps to the
|
||||
trainable parameters within each stack, which is the smallest trainable unit from the perspective of a ``LearningModel``.
|
||||
The assignment of optimizers is flexible: it is possible to freeze certain stacks while optimizing others. Such an
|
||||
abstraction presents a unified interface to the algorithm, regardless of how many individual models it requires and how
|
||||
complex the model architecture might be.
|
||||
|
||||
.. image:: ../images/rl/learning_model.svg
|
||||
:target: ../images/rl/learning_model.svg
|
||||
|
@ -196,11 +196,11 @@ And performing one gradient step is simply:
|
|||
Explorer
|
||||
-------
|
||||
|
||||
MARO provides an abstraction for exploration in RL. Some RL algorithms such as DQN and DDPG require
|
||||
explicit exploration, the extent of which is usually determined by a set of parameters whose values
|
||||
MARO provides an abstraction for exploration in RL. Some RL algorithms such as DQN and DDPG require
|
||||
explicit exploration, the extent of which is usually determined by a set of parameters whose values
|
||||
are generated by the scheduler. The ``AbsExplorer`` class is designed to cater to these needs. Simple
|
||||
exploration schemes, such as ``EpsilonGreedyExplorer`` for discrete action space and ``UniformNoiseExplorer``
|
||||
and ``GaussianNoiseExplorer`` for continuous action space, are provided in the toolkit.
|
||||
exploration schemes, such as ``EpsilonGreedyExplorer`` for discrete action space and ``UniformNoiseExplorer``
|
||||
and ``GaussianNoiseExplorer`` for continuous action space, are provided in the toolkit.
|
||||
|
||||
As an example, the exploration for DQN may be carried out with the aid of an ``EpsilonGreedyExplorer``:
|
||||
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
Virtual Machine Scheduling (VM Scheduling)
|
||||
===========================================
|
||||
|
||||
In 21th century, the business needs of cloud computing are dramatically increasing.
|
||||
In 21th century, the business needs of cloud computing are dramatically increasing.
|
||||
During the cloud service, users request Virtual Machine (VM) with a certain amount of resources (eg. CPU, memory, etc).
|
||||
The following important issue is how to allocate the physical resources for these VMs?
|
||||
The VM Scheduling scenario aims to find a better solution of the VM scheduling problem
|
||||
in cloud data centers.
|
||||
Now, consider a specific time, the number of VM
|
||||
in cloud data centers.
|
||||
Now, consider a specific time, the number of VM
|
||||
requests and arrival pattern is fixed. Given a cluster of limited physical
|
||||
machines(PM) with limited physical resources, different VM allocation strategeies result in
|
||||
machines(PM) with limited physical resources, different VM allocation strategies result in
|
||||
different amount of
|
||||
successful completion and different operating cost of the data center. For cloud providers, a
|
||||
good VM allocation strategy can maximize the resource utilization and thus can increase the profit by
|
||||
providing more VMs to users. For cloud users, a good VM allocation strategy can
|
||||
minimize the VM response time and have a better using experience. We hope this scenario can meet
|
||||
successful completion and different operating cost of the data center. For cloud providers, a
|
||||
good VM allocation strategy can maximize the resource utilization and thus can increase the profit by
|
||||
providing more VMs to users. For cloud users, a good VM allocation strategy can
|
||||
minimize the VM response time and have a better using experience. We hope this scenario can meet
|
||||
the real needs and provide you with a demand simulation that is closest to the real situation.
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ resource life cycle always contains the steps below:
|
|||
- The VM's resource utilization changes dynamically and the PM's real-time energy consumption
|
||||
will be simulated in the runtime simulation.
|
||||
- After a period of execution, the VM completes its tasks. The simulator will release the resources
|
||||
allocated to this VM and deallocate this VM from the PM.
|
||||
allocated to this VM and deallocate this VM from the PM.
|
||||
Finally, the resource is free and is ready to serve the next VM request.
|
||||
|
||||
VM Request
|
||||
|
@ -40,24 +40,33 @@ VM Request
|
|||
In the VM scheduling scenario, the VM requests are uniformly sampled from real
|
||||
workloads. As long as the original dataset is large enough and the sample ratio
|
||||
is not too small, the sampled VM requests can follow a similar distribution to the
|
||||
original ones.
|
||||
original ones.
|
||||
|
||||
Given a fixed time interval, a VM request will arise according to the real VM workload data.
|
||||
The request contains the VM information of the required resources, including the required CPU cores,
|
||||
the required memory, and the remaining buffer time.
|
||||
Given a fixed time interval, a VM request will arise according to the real VM workload data.
|
||||
The request contains the VM information, such as the subscription id, the deployment id, and the
|
||||
VM category, VM's required resources, including the required CPU cores and
|
||||
the required memory, and the remaining buffer time.
|
||||
|
||||
* Whenever receive a VM request, the MARO simulator will first calculate the
|
||||
remaining resources of each PM, filtering out the valid PMs (valid PMs means that the remaining
|
||||
* Whenever receive a VM request, the MARO simulator will first calculate the
|
||||
remaining resources of each PM, filtering out the valid PMs (valid PMs means that the remaining
|
||||
resources of PM are enough for the required resources of the VM).
|
||||
* Then, the simulator delivers all valid PMs and the required resources of the awaiting VM
|
||||
* Then, the simulator delivers all valid PMs and the required resources of the awaiting VM
|
||||
to the VM scheduler (agent) for a decision.
|
||||
|
||||
We have two categories of VM. One is interactive, and the other one is
|
||||
delay-insensitive.
|
||||
|
||||
* Interactive: The interactive VMs usually require low response time, so we set this kind of VMs can
|
||||
only be allocated to the non-oversubscribable PM server.
|
||||
* Delay-insensitive: The delay-insensitive VMs usually serve for batch-tasks or development workload. This kind of VMs can
|
||||
be allocated to the over-subscribable PM server.
|
||||
|
||||
VM Allocation
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Based on the valid PM list, the histortical information recorded by the simulator, and the detailed
|
||||
required resources of the VM, the VM scheduler (decision agent) will make the decision according to its
|
||||
allocation strategy.
|
||||
Based on the valid PM list, the historical information recorded by the simulator, and the detailed
|
||||
required resources of the VM, the VM scheduler (decision agent) will make the decision according to its
|
||||
allocation strategy.
|
||||
|
||||
There are two types of meaningful actions:
|
||||
|
||||
|
@ -67,21 +76,39 @@ There are two types of meaningful actions:
|
|||
|
||||
See the detailed attributes of `Action <#id1>`_.
|
||||
|
||||
|
||||
Oversubscription
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
To maximize each PM's utilization, cloud providers will oversubscribe the physical resource.
|
||||
Considering the various service level, the physical machines are then divided into the over-subscribable ones and non-oversubscribable ones.
|
||||
For the over-subscription, there are several parameters can be set in the config.yml.
|
||||
In our scenario, there are two resources could be oversubscribed, CPU and memory, so we have two maximum oversubscription rate can be set.
|
||||
|
||||
* ``MAX_CPU_OVERSUBSCRIPTION_RATE``: The oversubscription rate of CPU. For example, the default setting
|
||||
is 1.15, that means each PM can be allocated at most 1.15 times of its resource capacity.
|
||||
* ``MAX_MEM_OVERSUBSCRIPTION_RATE``: The oversubscription rate of memory. Similar to the CPU rate.
|
||||
|
||||
To protect the PM from the overloading, we need to consider the CPU utilization. The ``MAX_UTILIZATION_RATE``
|
||||
is used as the security mechanism, that can be set in the config.yml.
|
||||
|
||||
* ``MAX_UTILIZATION_RATE``: The default setting is 1, which means that when filtering the valid PMs,
|
||||
the maximum allowed physical CPU utilization is 100%.
|
||||
|
||||
Runtime Simulation
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Dynamic Utilization
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To make the simulated environment closest to the real situation. We also simulate the resource utilization
|
||||
(currently only CPU utilization) of each VM. The CPU utilization of the VM varies every tick based on
|
||||
the real VM workload readings. We will also regularly update the real-time resource utilization of
|
||||
To make the simulated environment closest to the real situation. We also simulate the resource utilization
|
||||
(currently only CPU utilization) of each VM. The CPU utilization of the VM varies every tick based on
|
||||
the real VM workload readings. We will also regularly update the real-time resource utilization of
|
||||
each PM based on the live VMs in it.
|
||||
|
||||
Real-time Energy Consumption
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
One of the most important characteristics that cloud providers concern is the enery consumption of the
|
||||
One of the most important characteristics that cloud providers concern is the energy consumption of the
|
||||
data center. The different VM allocation can result in different energy consumption of the PM cluster,
|
||||
we also simulate the energy usage based on the CPU utilization.
|
||||
|
||||
|
@ -92,19 +119,37 @@ Energy Curve
|
|||
:target: ../images/scenario/vm.energy_curve.svg
|
||||
:alt: data center energy curve
|
||||
|
||||
As we mention before, the lower energy consumption of the PMs, the lower cost to maintain the physical
|
||||
As we mention before, the lower energy consumption of the PMs, the lower cost to maintain the physical
|
||||
servers. In our simulation, we currently use a non-linear energy curve like the one in the above
|
||||
`figure <https://dl.acm.org/doi/10.1145/1273440.1250665>`_ to
|
||||
simulate the energy based on the CPU utilization.
|
||||
simulate the energy based on the CPU utilization.
|
||||
|
||||
Overload
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Since the VM's CPU utilization varies along the time, when enabling the oversubscription, it might
|
||||
happen that the sum of VM's CPU usage exceed the capacity of the physical resource. This situation called
|
||||
overload.
|
||||
|
||||
Overloading may lead to VM's performance degradation or service level agreements (SLAs) violations
|
||||
in real production (We will support these features in the future).
|
||||
Currently, for the situation of overloading, we only support quiescing (killing) all VMs or just recording
|
||||
the times of overloading, which can also be set in config.yml.
|
||||
|
||||
* ``KILL_ALL_VMS_IF_OVERLOAD``: If this action is enable,
|
||||
once overloading happens, all VMs located at the overloading PMs will be deallocated. To consider the
|
||||
effect of overloading, we will still count the energy consumed by the high utilization.
|
||||
The impact of the quiescing action on the PM's utilization will be reflected in the next tick.
|
||||
|
||||
No matter enable killing all VMs or not, we will calculate the number of overload PMs and the number
|
||||
of overload VMs. These two metrics are cumulative values and will be recorded as the environment metrics.
|
||||
|
||||
VM Deallocation
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
The MARO simulator regularly checks the finished VMs in every tick.
|
||||
The MARO simulator regularly checks the finished VMs in every tick.
|
||||
A finished VM means that it goes through a complete life cycle, is ready to be terminated, and
|
||||
the resources it occupies will be available again in the end.
|
||||
The simulator will then release the finished VM's resources, and finally remove the VM from the PM.
|
||||
The simulator will then release the finished VM's resources, and finally remove the VM from the PM.
|
||||
|
||||
Topologies
|
||||
-----------
|
||||
|
@ -112,22 +157,23 @@ Topologies
|
|||
Azure Topologies
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
The original data comes from `Azure public dataset <https://github.com/Azure/AzurePublicDataset>`_.
|
||||
The dataset contains real Azure VM workloads, including the information of VMs and their
|
||||
The original data comes from `Azure public dataset <https://github.com/Azure/AzurePublicDataset>`_.
|
||||
The dataset contains real Azure VM workloads, including the information of VMs and their
|
||||
utilization readings in 2019 lasting for 30 days. Total number of VM recorded is 2,695,548.
|
||||
|
||||
In our scenario, we pre-processed the AzurePublicDatasetV2.
|
||||
In our scenario, we pre-processed the AzurePublicDatasetV2.
|
||||
The detailed information of the data schema can be found
|
||||
`here <https://github.com/Azure/AzurePublicDataset/blob/master/AzurePublicDatasetV2.md>`_.
|
||||
`here <https://github.com/Azure/AzurePublicDataset/blob/master/AzurePublicDatasetV2.md>`_.
|
||||
After pre-processed, the data contains
|
||||
|
||||
* Renumbered VM ID
|
||||
* VM cores and memory(GB) requirements
|
||||
* Real VM creation and deletion time (converted to the tick, 1 tick means 5 minutes in real time)
|
||||
|
||||
As for the utilization readings part, we sort the renumbered VM ID and CPU utilization pairs by the timestamp (tick).
|
||||
|
||||
To provide system workloads from light to heavy, two kinds of simple topologies are designed and
|
||||
provided in VM Scheduling scenario.
|
||||
To provide system workloads from light to heavy, two kinds of simple topologies are designed and
|
||||
provided in VM Scheduling scenario.
|
||||
|
||||
azure.2019.10k
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
@ -166,10 +212,10 @@ PM setting (Given by the /[topologies]/config.yml):
|
|||
Naive Baseline
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
Belows are the final environment metrics of the method **Random Allocation** and
|
||||
**Best-Fit Allocation** in different topologies.
|
||||
Belows are the final environment metrics of the method **Random Allocation** and
|
||||
**Best-Fit Allocation** in different topologies.
|
||||
For each experiment, we setup the environment and test for a duration of 30 days.
|
||||
Besides, we use several settings of PM capacity to test performance under different
|
||||
Besides, we use several settings of PM capacity to test performance under different
|
||||
initial resources.
|
||||
|
||||
|
||||
|
@ -188,14 +234,14 @@ Randomly allocate to a valid PM.
|
|||
- Successful Allocation
|
||||
- Successful completion
|
||||
- Failed Allocation
|
||||
* - Azure.2019.10k
|
||||
* - Azure.2019.10k
|
||||
- 100 PMs, 32 Cores, 128 GB
|
||||
- 10,000
|
||||
- 2,430,651.6
|
||||
- 9,850
|
||||
- 9,030
|
||||
- 150
|
||||
* -
|
||||
* -
|
||||
- 100 PMs, 16 Cores, 112 GB
|
||||
- 10,000
|
||||
- 2,978,445.0
|
||||
|
@ -209,7 +255,7 @@ Randomly allocate to a valid PM.
|
|||
- 176,468
|
||||
- 165,715
|
||||
- 159,517
|
||||
* -
|
||||
* -
|
||||
- 880 PMs, 16 Cores, 112 GB
|
||||
- 335,985
|
||||
- 26,367,238.7
|
||||
|
@ -232,28 +278,28 @@ Choose the valid PM with the least remaining resources (only consider CPU cores
|
|||
- Successful Allocation
|
||||
- Successful completion
|
||||
- Failed Allocation
|
||||
* - Azure.2019.10k
|
||||
* - Azure.2019.10k
|
||||
- 100 PMs, 32 Cores, 128 GB
|
||||
- 10,000
|
||||
- 2,395,328.7
|
||||
- 10,000
|
||||
- 9,180
|
||||
- 0
|
||||
* -
|
||||
* -
|
||||
- 100 PMs, 16 Cores, 112 GB
|
||||
- 10,000
|
||||
- 2,987,086.6
|
||||
- 7,917
|
||||
- 7,313
|
||||
- 2,083
|
||||
* - Azure.2019.336k
|
||||
* - Azure.2019.336k
|
||||
- 880 PMs, 32 Cores, 128 GB
|
||||
- 335,985
|
||||
- 26,695,470.8
|
||||
- 171,044
|
||||
- 160,495
|
||||
- 164,941
|
||||
* -
|
||||
* -
|
||||
- 880 PMs, 16 Cores, 112 GB
|
||||
- 335,985
|
||||
- 26,390,972.9
|
||||
|
@ -269,25 +315,25 @@ Quick Start
|
|||
Data Preparation
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
When the environment is first created, the system will automatically trigger the pipeline to download
|
||||
When the environment is first created, the system will automatically trigger the pipeline to download
|
||||
and process the data files. Afterwards, if you want to run multiple simulations, the system will detect
|
||||
whether the processed data files exist or not. If not, it will then trigger the pipeline again. Otherwise,
|
||||
the system will reuse the processed data files.
|
||||
the system will reuse the processed data files.
|
||||
|
||||
|
||||
Environment Interface
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Before starting interaction with the environment, we need to know the definition of ``DecisionPayload`` and
|
||||
``Action`` in VM Scheduling scenario first. Besides, you can query the environment
|
||||
`snapshot list <../key_components/data_model.html#advanced-features>`_ to get more
|
||||
Before starting interaction with the environment, we need to know the definition of ``DecisionPayload`` and
|
||||
``Action`` in VM Scheduling scenario first. Besides, you can query the environment
|
||||
`snapshot list <../key_components/data_model.html#advanced-features>`_ to get more
|
||||
detailed information for the decision making.
|
||||
|
||||
DecisionPayload
|
||||
~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
Once the environment need the agent's response to promote the simulation, it will throw an ``PendingDecision``
|
||||
event with the ``DecisionPayload``. In the scenario of VM Scheduling, the information of ``DecisionPayload`` is
|
||||
event with the ``DecisionPayload``. In the scenario of VM Scheduling, the information of ``DecisionPayload`` is
|
||||
listed as below:
|
||||
|
||||
* **valid_pms** (List[int]): The list of the PM ID that is considered as valid (Its CPU and memory resource is enough for the incoming VM request).
|
||||
|
@ -299,35 +345,36 @@ listed as below:
|
|||
Action
|
||||
~~~~~~~
|
||||
|
||||
Once get a ``PendingDecision`` event from the envirionment, the agent should respond with an Action. Valid
|
||||
Once get a ``PendingDecision`` event from the environment, the agent should respond with an Action. Valid
|
||||
``Action`` includes:
|
||||
|
||||
* **None**. It means do nothing but ignore this VM request.
|
||||
* ``AllocateAction``: If the MARO simulator receives the ``AllocateAction``, the VM's creation time will be
|
||||
fixed at the tick it receives. Besides, the simulator will update the workloads (the workloads include
|
||||
CPU cores, the memory, and the energy consumption) of the target PM.
|
||||
* ``AllocateAction``: If the MARO simulator receives the ``AllocateAction``, the VM's creation time will be
|
||||
fixed at the tick it receives. Besides, the simulator will update the workloads (the workloads include
|
||||
CPU cores, the memory, and the energy consumption) of the target PM.
|
||||
The ``AllocateAction`` includes:
|
||||
|
||||
* vm_id (int): The ID of the VM that is waiting for the allocation.
|
||||
* pm_id (int): The ID of the PM where the VM is scheduled to allocate to.
|
||||
* ``PostponeAction``: If the MARO simulator receives the ``PostponeAction``, it will calculate the
|
||||
remaining buffer time.
|
||||
* ``PostponeAction``: If the MARO simulator receives the ``PostponeAction``, it will calculate the
|
||||
remaining buffer time.
|
||||
|
||||
* If the time is still enough, the simulator will re-generate a new request
|
||||
event and insert it to the corresponding tick (based on the ``Postpone Step`` and ``DELAY_DURATION``).
|
||||
The ``DecisionPayload`` of the new requirement event only differs in the remaining buffer time from the
|
||||
event and insert it to the corresponding tick (based on the ``Postpone Step`` and ``DELAY_DURATION``).
|
||||
The ``DecisionPayload`` of the new requirement event only differs in the remaining buffer time from the
|
||||
old ones.
|
||||
* If the time is exhausted, the simulator will note it as a failed allocation.
|
||||
|
||||
The ``PostponeAction`` includes:
|
||||
|
||||
* vm_id (int): The ID of the VM that is waiting for the allocation.
|
||||
* postpone_step (int): The number of times that the allocation to be postponed. The unit
|
||||
* postpone_step (int): The number of times that the allocation to be postponed. The unit
|
||||
is ``DELAY_DURATION``. 1 means delay 1 ``DELAY_DURATION``, which can be set in the config.yml.
|
||||
|
||||
Example
|
||||
^^^^^^^^
|
||||
|
||||
Here we will show you a simple example of interaction with the environment in random mode, we
|
||||
Here we will show you a simple example of interaction with the environment in random mode, we
|
||||
hope this could help you learn how to use the environment interfaces:
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -350,7 +397,7 @@ hope this could help you learn how to use the environment interfaces:
|
|||
decision_event: DecisionPayload = None
|
||||
is_done: bool = False
|
||||
action: AllocateAction = None
|
||||
|
||||
|
||||
# Start the env with a None Action
|
||||
metrics, decision_event, is_done = env.step(None)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
This file is used to load config and convert it into a dotted dictionary.
|
||||
This file is used to load the configuration and convert it into a dotted dictionary.
|
||||
"""
|
||||
|
||||
import io
|
||||
|
|
|
@ -38,7 +38,6 @@ def launch(config):
|
|||
|
||||
# Step 4: Create an actor and a learner to start the training process.
|
||||
scheduler = TwoPhaseLinearParameterScheduler(config.main_loop.max_episode, **config.main_loop.exploration)
|
||||
|
||||
actor = SimpleActor(env, agent_manager)
|
||||
learner = SimpleLearner(
|
||||
agent_manager, actor, scheduler,
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Overview
|
||||
|
||||
The CIM problem is one of the quintessential use cases of MARO. The example can
|
||||
be run with a set of scenario configurations that can be found under
|
||||
maro/simulator/scenarios/cim. General experimental parameters (e.g., type of
|
||||
topology, type of algorithm to use, number of training episodes) can be configured
|
||||
through config.yml. Each RL formulation has a dedicated folder, e.g., dqn, and
|
||||
all algorithm-specific parameters can be configured through
|
||||
the config.py file in that folder.
|
||||
|
||||
## Single-host Single-process Mode
|
||||
|
||||
To run the CIM example using the DQN algorithm under single-host mode, go to
|
||||
examples/cim/dqn and run single_process_launcher.py. You may play around with
|
||||
the configuration if you want to try out different settings.
|
||||
|
||||
## Distributed Mode
|
||||
|
||||
The examples/cim/dqn/components folder contains dist_learner.py and dist_actor.py
|
||||
for distributed training. For debugging purposes, we provide a script that
|
||||
simulates distributed mode using multi-processing. Simply go to examples/cim/dqn
|
||||
and run multi_process_launcher.py to start the learner and actor processes.
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .action_shaper import CIMActionShaper
|
||||
from .agent_manager import POAgentManager, create_po_agents
|
||||
from .experience_shaper import TruncatedExperienceShaper
|
||||
from .state_shaper import CIMStateShaper
|
||||
|
||||
__all__ = [
|
||||
"CIMActionShaper",
|
||||
"POAgentManager", "create_po_agents",
|
||||
"TruncatedExperienceShaper",
|
||||
"CIMStateShaper"
|
||||
]
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from maro.rl import ActionShaper
|
||||
from maro.simulator.scenarios.cim.common import Action
|
||||
|
||||
|
||||
class CIMActionShaper(ActionShaper):
|
||||
def __init__(self, action_space):
|
||||
super().__init__()
|
||||
self._action_space = action_space
|
||||
self._zero_action_index = action_space.index(0)
|
||||
|
||||
def __call__(self, model_action, decision_event, snapshot_list):
|
||||
scope = decision_event.action_scope
|
||||
tick = decision_event.tick
|
||||
port_idx = decision_event.port_idx
|
||||
vessel_idx = decision_event.vessel_idx
|
||||
|
||||
port_empty = snapshot_list["ports"][tick: port_idx: ["empty", "full", "on_shipper", "on_consignee"]][0]
|
||||
vessel_remaining_space = snapshot_list["vessels"][tick: vessel_idx: ["empty", "full", "remaining_space"]][2]
|
||||
early_discharge = snapshot_list["vessels"][tick:vessel_idx: "early_discharge"][0]
|
||||
assert 0 <= model_action < len(self._action_space)
|
||||
|
||||
if model_action < self._zero_action_index:
|
||||
actual_action = max(round(self._action_space[model_action] * port_empty), -vessel_remaining_space)
|
||||
elif model_action > self._zero_action_index:
|
||||
plan_action = self._action_space[model_action] * (scope.discharge + early_discharge) - early_discharge
|
||||
actual_action = round(plan_action) if plan_action > 0 else round(self._action_space[model_action] * scope.discharge)
|
||||
else:
|
||||
actual_action = 0
|
||||
|
||||
return Action(vessel_idx, port_idx, actual_action)
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from maro.rl import (
|
||||
AbsAgent, ActorCritic, ActorCriticConfig, FullyConnectedBlock, LearningModel, NNStack,
|
||||
OptimizerOptions, PolicyGradient, PolicyOptimizationConfig, SimpleAgentManager
|
||||
)
|
||||
from maro.utils import set_seeds
|
||||
|
||||
|
||||
class POAgent(AbsAgent):
|
||||
def train(self, states: np.ndarray, actions: np.ndarray, log_action_prob: np.ndarray, rewards: np.ndarray):
|
||||
self._algorithm.train(states, actions, log_action_prob, rewards)
|
||||
|
||||
|
||||
def create_po_agents(agent_id_list, config):
|
||||
input_dim, num_actions = config.input_dim, config.num_actions
|
||||
set_seeds(config.seed)
|
||||
agent_dict = {}
|
||||
for agent_id in agent_id_list:
|
||||
actor_net = NNStack(
|
||||
"actor",
|
||||
FullyConnectedBlock(
|
||||
input_dim=input_dim,
|
||||
output_dim=num_actions,
|
||||
activation=nn.Tanh,
|
||||
is_head=True,
|
||||
**config.actor_model
|
||||
)
|
||||
)
|
||||
|
||||
if config.type == "actor_critic":
|
||||
critic_net = NNStack(
|
||||
"critic",
|
||||
FullyConnectedBlock(
|
||||
input_dim=config.input_dim,
|
||||
output_dim=1,
|
||||
activation=nn.LeakyReLU,
|
||||
is_head=True,
|
||||
**config.critic_model
|
||||
)
|
||||
)
|
||||
|
||||
hyper_params = config.actor_critic_hyper_parameters
|
||||
hyper_params.update({"reward_discount": config.reward_discount})
|
||||
learning_model = LearningModel(
|
||||
actor_net, critic_net,
|
||||
optimizer_options={
|
||||
"actor": OptimizerOptions(cls=Adam, params=config.actor_optimizer),
|
||||
"critic": OptimizerOptions(cls=RMSprop, params=config.critic_optimizer)
|
||||
}
|
||||
)
|
||||
algorithm = ActorCritic(
|
||||
learning_model, ActorCriticConfig(critic_loss_func=nn.SmoothL1Loss(), **hyper_params)
|
||||
)
|
||||
else:
|
||||
learning_model = LearningModel(
|
||||
actor_net,
|
||||
optimizer_options=OptimizerOptions(cls=Adam, params=config.actor_optimizer)
|
||||
)
|
||||
algorithm = PolicyGradient(learning_model, PolicyOptimizationConfig(config.reward_discount))
|
||||
|
||||
agent_dict[agent_id] = POAgent(name=agent_id, algorithm=algorithm)
|
||||
|
||||
return agent_dict
|
||||
|
||||
|
||||
class POAgentManager(SimpleAgentManager):
|
||||
def train(self, experiences_by_agent: dict):
|
||||
for agent_id, exp in experiences_by_agent.items():
|
||||
if not isinstance(exp, list):
|
||||
exp = [exp]
|
||||
for trajectory in exp:
|
||||
self.agent_dict[agent_id].train(
|
||||
trajectory["state"],
|
||||
trajectory["action"],
|
||||
trajectory["log_action_probability"],
|
||||
trajectory["reward"]
|
||||
)
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
This file is used to load the configuration and convert it into a dotted dictionary.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import yaml
|
||||
|
||||
|
||||
CONFIG_PATH = os.path.join(os.path.split(os.path.realpath(__file__))[0], "../config.yml")
|
||||
with io.open(CONFIG_PATH, "r") as in_file:
|
||||
config = yaml.safe_load(in_file)
|
||||
|
||||
DISTRIBUTED_CONFIG_PATH = os.path.join(os.path.split(os.path.realpath(__file__))[0], "../distributed_config.yml")
|
||||
with io.open(DISTRIBUTED_CONFIG_PATH, "r") as in_file:
|
||||
distributed_config = yaml.safe_load(in_file)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.rl import ExperienceShaper
|
||||
|
||||
|
||||
class TruncatedExperienceShaper(ExperienceShaper):
|
||||
def __init__(self, *, time_window: int, time_decay_factor: float, fulfillment_factor: float,
|
||||
shortage_factor: float):
|
||||
super().__init__(reward_func=None)
|
||||
self._time_window = time_window
|
||||
self._time_decay_factor = time_decay_factor
|
||||
self._fulfillment_factor = fulfillment_factor
|
||||
self._shortage_factor = shortage_factor
|
||||
|
||||
def __call__(self, trajectory, snapshot_list):
|
||||
agent_ids = np.asarray(trajectory.get_by_key("agent_id"))
|
||||
states = np.asarray(trajectory.get_by_key("state"))
|
||||
actions = np.asarray(trajectory.get_by_key("action"))
|
||||
log_action_probabilities = np.asarray(trajectory.get_by_key("log_action_probability"))
|
||||
rewards = np.fromiter(
|
||||
map(self._compute_reward, trajectory.get_by_key("event"), [snapshot_list] * len(trajectory)),
|
||||
dtype=np.float32
|
||||
)
|
||||
return {agent_id: {
|
||||
"state": states[agent_ids == agent_id],
|
||||
"action": actions[agent_ids == agent_id],
|
||||
"log_action_probability": log_action_probabilities[agent_ids == agent_id],
|
||||
"reward": rewards[agent_ids == agent_id],
|
||||
}
|
||||
for agent_id in set(agent_ids)}
|
||||
|
||||
def _compute_reward(self, decision_event, snapshot_list):
|
||||
start_tick = decision_event.tick + 1
|
||||
end_tick = decision_event.tick + self._time_window
|
||||
ticks = list(range(start_tick, end_tick))
|
||||
|
||||
# calculate tc reward
|
||||
future_fulfillment = snapshot_list["ports"][ticks::"fulfillment"]
|
||||
future_shortage = snapshot_list["ports"][ticks::"shortage"]
|
||||
decay_list = [self._time_decay_factor ** i for i in range(end_tick - start_tick)
|
||||
for _ in range(future_fulfillment.shape[0]//(end_tick-start_tick))]
|
||||
|
||||
tot_fulfillment = np.dot(future_fulfillment, decay_list)
|
||||
tot_shortage = np.dot(future_shortage, decay_list)
|
||||
|
||||
return np.float(self._fulfillment_factor * tot_fulfillment - self._shortage_factor * tot_shortage)
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.rl import StateShaper
|
||||
|
||||
PORT_ATTRIBUTES = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
|
||||
VESSEL_ATTRIBUTES = ["empty", "full", "remaining_space"]
|
||||
|
||||
|
||||
class CIMStateShaper(StateShaper):
|
||||
def __init__(self, *, look_back, max_ports_downstream):
|
||||
super().__init__()
|
||||
self._look_back = look_back
|
||||
self._max_ports_downstream = max_ports_downstream
|
||||
self._dim = (look_back + 1) * (max_ports_downstream + 1) * len(PORT_ATTRIBUTES) + len(VESSEL_ATTRIBUTES)
|
||||
|
||||
def __call__(self, decision_event, snapshot_list):
|
||||
tick, port_idx, vessel_idx = decision_event.tick, decision_event.port_idx, decision_event.vessel_idx
|
||||
ticks = [tick - rt for rt in range(self._look_back - 1)]
|
||||
future_port_idx_list = snapshot_list["vessels"][tick: vessel_idx: 'future_stop_list'].astype('int')
|
||||
port_features = snapshot_list["ports"][ticks: [port_idx] + list(future_port_idx_list): PORT_ATTRIBUTES]
|
||||
vessel_features = snapshot_list["vessels"][tick: vessel_idx: VESSEL_ATTRIBUTES]
|
||||
state = np.concatenate((port_features, vessel_features))
|
||||
return str(port_idx), state
|
||||
|
||||
@property
|
||||
def dim(self):
|
||||
return self._dim
|
|
@ -0,0 +1,50 @@
|
|||
env:
|
||||
scenario: "cim"
|
||||
topology: "toy.4p_ssdd_l0.0"
|
||||
durations: 1120
|
||||
state_shaping:
|
||||
look_back: 7
|
||||
max_ports_downstream: 2
|
||||
experience_shaping:
|
||||
time_window: 100
|
||||
fulfillment_factor: 1.0
|
||||
shortage_factor: 1.0
|
||||
time_decay_factor: 0.97
|
||||
main_loop:
|
||||
max_episode: 100
|
||||
early_stopping:
|
||||
warmup_ep: 20
|
||||
last_k: 5
|
||||
perf_threshold: 0.95 # minimum performance (fulfillment ratio) required to trigger early stopping
|
||||
perf_stability_threshold: 0.1 # stability is measured by the maximum of abs(perf_(i+1) - perf_i) / perf_i
|
||||
# over the last k episodes (where perf is short for performance). This value must
|
||||
# be below this threshold to trigger early stopping
|
||||
agents:
|
||||
seed: 1024 # for reproducibility
|
||||
type: "actor_critic" # "actor_critic" or "policy_gradient"
|
||||
num_actions: 21
|
||||
actor_model:
|
||||
hidden_dims:
|
||||
- 256
|
||||
- 128
|
||||
- 64
|
||||
softmax_enabled: true
|
||||
batch_norm_enabled: false
|
||||
actor_optimizer:
|
||||
lr: 0.001
|
||||
critic_model:
|
||||
hidden_dims:
|
||||
- 256
|
||||
- 128
|
||||
- 64
|
||||
softmax_enabled: false
|
||||
batch_norm_enabled: true
|
||||
critic_optimizer:
|
||||
lr: 0.001
|
||||
reward_discount: .0
|
||||
actor_critic_hyper_parameters:
|
||||
train_iters: 10
|
||||
actor_loss_coefficient: 0.1
|
||||
k: 1
|
||||
lam: 0.0
|
||||
# clip_ratio: 0.8
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.rl import AgentManagerMode, SimpleActor, ActorWorker
|
||||
from maro.utils import convert_dottable
|
||||
|
||||
from components import CIMActionShaper, CIMStateShaper, POAgentManager, TruncatedExperienceShaper, create_po_agents
|
||||
|
||||
|
||||
def launch(config):
|
||||
config = convert_dottable(config)
|
||||
env = Env(config.env.scenario, config.env.topology, durations=config.env.durations)
|
||||
agent_id_list = [str(agent_id) for agent_id in env.agent_idx_list]
|
||||
state_shaper = CIMStateShaper(**config.env.state_shaping)
|
||||
action_shaper = CIMActionShaper(action_space=list(np.linspace(-1.0, 1.0, config.agents.num_actions)))
|
||||
experience_shaper = TruncatedExperienceShaper(**config.env.experience_shaping)
|
||||
|
||||
config["agents"]["input_dim"] = state_shaper.dim
|
||||
agent_manager = POAgentManager(
|
||||
name="cim_actor",
|
||||
mode=AgentManagerMode.INFERENCE,
|
||||
agent_dict=create_po_agents(agent_id_list, config.agents),
|
||||
state_shaper=state_shaper,
|
||||
action_shaper=action_shaper,
|
||||
experience_shaper=experience_shaper,
|
||||
)
|
||||
proxy_params = {
|
||||
"group_name": os.environ["GROUP"],
|
||||
"expected_peers": {"learner": 1},
|
||||
"redis_address": ("localhost", 6379)
|
||||
}
|
||||
actor_worker = ActorWorker(
|
||||
local_actor=SimpleActor(env=env, agent_manager=agent_manager),
|
||||
proxy_params=proxy_params
|
||||
)
|
||||
actor_worker.launch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from components.config import config
|
||||
launch(config)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
|
||||
from maro.rl import ActorProxy, AgentManagerMode, Scheduler, SimpleLearner, merge_experiences_with_trajectory_boundaries
|
||||
from maro.simulator import Env
|
||||
from maro.utils import Logger, convert_dottable
|
||||
|
||||
from components import CIMStateShaper, POAgentManager, create_po_agents
|
||||
|
||||
|
||||
def launch(config):
|
||||
config = convert_dottable(config)
|
||||
env = Env(config.env.scenario, config.env.topology, durations=config.env.durations)
|
||||
agent_id_list = [str(agent_id) for agent_id in env.agent_idx_list]
|
||||
config["agents"]["input_dim"] = CIMStateShaper(**config.env.state_shaping).dim
|
||||
agent_manager = POAgentManager(
|
||||
name="cim_learner",
|
||||
mode=AgentManagerMode.TRAIN,
|
||||
agent_dict=create_po_agents(agent_id_list, config.agents)
|
||||
)
|
||||
|
||||
proxy_params = {
|
||||
"group_name": os.environ["GROUP"],
|
||||
"expected_peers": {"actor": int(os.environ["NUM_ACTORS"])},
|
||||
"redis_address": ("localhost", 6379)
|
||||
}
|
||||
|
||||
learner = SimpleLearner(
|
||||
agent_manager=agent_manager,
|
||||
actor=ActorProxy(
|
||||
proxy_params=proxy_params, experience_collecting_func=merge_experiences_with_trajectory_boundaries
|
||||
),
|
||||
scheduler=Scheduler(config.main_loop.max_episode),
|
||||
logger=Logger("cim_learner", auto_timestamp=False)
|
||||
)
|
||||
learner.learn()
|
||||
learner.test()
|
||||
learner.dump_models(os.path.join(os.getcwd(), "models"))
|
||||
learner.exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from components.config import config
|
||||
launch(config)
|
|
@ -0,0 +1,6 @@
|
|||
redis:
|
||||
hostname: "localhost"
|
||||
port: 6379
|
||||
group: test_group
|
||||
num_actors: 1
|
||||
num_learners: 1
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
This script is used to debug distributed algorithm in single host multi-process mode.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("group_name", help="group name")
|
||||
parser.add_argument("num_actors", type=int, help="number of actors")
|
||||
args = parser.parse_args()
|
||||
|
||||
learner_path = f"{os.path.split(os.path.realpath(__file__))[0]}/dist_learner.py &"
|
||||
actor_path = f"{os.path.split(os.path.realpath(__file__))[0]}/dist_actor.py &"
|
||||
|
||||
# Launch the learner process
|
||||
os.system(f"GROUP={args.group_name} NUM_ACTORS={args.num_actors} python " + learner_path)
|
||||
|
||||
# Launch the actor processes
|
||||
for _ in range(args.num_actors):
|
||||
os.system(f"GROUP={args.group_name} python " + actor_path)
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.rl import AgentManagerMode, Scheduler, SimpleActor, SimpleLearner
|
||||
from maro.utils import LogFormat, Logger, convert_dottable
|
||||
|
||||
from components import CIMActionShaper, CIMStateShaper, POAgentManager, TruncatedExperienceShaper, create_po_agents
|
||||
|
||||
|
||||
class EarlyStoppingChecker:
|
||||
"""Callable class that checks the performance history to determine early stopping.
|
||||
|
||||
Args:
|
||||
warmup_ep (int): Episode from which early stopping checking is initiated.
|
||||
last_k (int): Number of latest performance records to check for early stopping.
|
||||
perf_threshold (float): The mean of the ``last_k`` performance metric values must be above this value to
|
||||
trigger early stopping.
|
||||
perf_stability_threshold (float): The maximum one-step change over the ``last_k`` performance metrics must be
|
||||
below this value to trigger early stopping.
|
||||
"""
|
||||
def __init__(self, warmup_ep: int, last_k: int, perf_threshold: float, perf_stability_threshold: float):
|
||||
self._warmup_ep = warmup_ep
|
||||
self._last_k = last_k
|
||||
self._perf_threshold = perf_threshold
|
||||
self._perf_stability_threshold = perf_stability_threshold
|
||||
|
||||
def get_metric(record):
|
||||
return 1 - record["container_shortage"] / record["order_requirements"]
|
||||
self._metric_func = get_metric
|
||||
|
||||
def __call__(self, perf_history) -> bool:
|
||||
if len(perf_history) < max(self._last_k, self._warmup_ep):
|
||||
return False
|
||||
|
||||
metric_series = list(map(self._metric_func, perf_history[-self._last_k:]))
|
||||
max_delta = max(
|
||||
abs(metric_series[i] - metric_series[i - 1]) / metric_series[i - 1] for i in range(1, self._last_k)
|
||||
)
|
||||
print(f"mean_metric: {mean(metric_series)}, max_delta: {max_delta}")
|
||||
return mean(metric_series) > self._perf_threshold and max_delta < self._perf_stability_threshold
|
||||
|
||||
|
||||
def launch(config):
|
||||
# First determine the input dimension and add it to the config.
|
||||
config = convert_dottable(config)
|
||||
|
||||
# Step 1: initialize a CIM environment for using a toy dataset.
|
||||
env = Env(config.env.scenario, config.env.topology, durations=config.env.durations)
|
||||
agent_id_list = [str(agent_id) for agent_id in env.agent_idx_list]
|
||||
|
||||
# Step 2: create state, action and experience shapers. We also need to create an explorer here due to the
|
||||
# greedy nature of the DQN algorithm.
|
||||
state_shaper = CIMStateShaper(**config.env.state_shaping)
|
||||
action_shaper = CIMActionShaper(action_space=list(np.linspace(-1.0, 1.0, config.agents.num_actions)))
|
||||
experience_shaper = TruncatedExperienceShaper(**config.env.experience_shaping)
|
||||
|
||||
# Step 3: create an agent manager.
|
||||
config["agents"]["input_dim"] = state_shaper.dim
|
||||
agent_manager = POAgentManager(
|
||||
name="cim_learner",
|
||||
mode=AgentManagerMode.TRAIN_INFERENCE,
|
||||
agent_dict=create_po_agents(agent_id_list, config.agents),
|
||||
state_shaper=state_shaper,
|
||||
action_shaper=action_shaper,
|
||||
experience_shaper=experience_shaper,
|
||||
)
|
||||
|
||||
# Step 4: Create an actor and a learner to start the training process.
|
||||
scheduler = Scheduler(
|
||||
config.main_loop.max_episode,
|
||||
early_stopping_checker=EarlyStoppingChecker(**config.main_loop.early_stopping)
|
||||
)
|
||||
actor = SimpleActor(env, agent_manager)
|
||||
learner = SimpleLearner(
|
||||
agent_manager, actor, scheduler,
|
||||
logger=Logger("cim_learner", format_=LogFormat.simple, auto_timestamp=False)
|
||||
)
|
||||
learner.learn()
|
||||
learner.test()
|
||||
learner.dump_models(os.path.join(os.getcwd(), "models"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from components.config import config
|
||||
launch(config)
|
|
@ -0,0 +1,294 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
#distutils: language = c++
|
||||
#distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
|
||||
from cython cimport view
|
||||
from cython.operator cimport dereference as deref
|
||||
|
||||
from cpython cimport bool
|
||||
from libcpp cimport bool as cppbool
|
||||
from libcpp.map cimport map
|
||||
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX,
|
||||
ATTR_CHAR, ATTR_UCHAR, ATTR_SHORT, ATTR_USHORT, ATTR_INT, ATTR_UINT,
|
||||
ATTR_LONG, ATTR_ULONG, ATTR_FLOAT, ATTR_DOUBLE)
|
||||
|
||||
|
||||
# Ensure numpy will not crash, as we use numpy as query result
|
||||
np.import_array()
|
||||
|
||||
cdef dict attribute_accessors = {
|
||||
AttributeType.Byte: AttributeCharAccessor,
|
||||
AttributeType.UByte: AttributeUCharAccessor,
|
||||
AttributeType.Short: AttributeShortAccessor,
|
||||
AttributeType.UShort: AttributeUShortAccessor,
|
||||
AttributeType.Int: AttributeIntAccessor,
|
||||
AttributeType.UInt: AttributeUIntAccessor,
|
||||
AttributeType.Long: AttributeLongAccessor,
|
||||
AttributeType.ULong: AttributeULongAccessor,
|
||||
AttributeType.Float: AttributeFloatAccessor,
|
||||
AttributeType.Double: AttributeDoubleAccessor,
|
||||
}
|
||||
|
||||
cdef map[string, AttrDataType] attr_type_mapping
|
||||
|
||||
attr_type_mapping[AttributeType.Byte] = ACHAR
|
||||
attr_type_mapping[AttributeType.UByte] = AUCHAR
|
||||
attr_type_mapping[AttributeType.Short] = ASHORT
|
||||
attr_type_mapping[AttributeType.UShort] = AUSHORT
|
||||
attr_type_mapping[AttributeType.Int] = AINT
|
||||
attr_type_mapping[AttributeType.UInt] = AUINT
|
||||
attr_type_mapping[AttributeType.Long] = ALONG
|
||||
attr_type_mapping[AttributeType.ULong] = AULONG
|
||||
attr_type_mapping[AttributeType.Float] = AFLOAT
|
||||
attr_type_mapping[AttributeType.Double] = ADOUBLE
|
||||
|
||||
|
||||
# Helpers used to access attribute with different data type to avoid to much if-else.
|
||||
cdef class AttributeAccessor:
|
||||
cdef:
|
||||
ATTR_TYPE _attr_type
|
||||
RawBackend _backend
|
||||
|
||||
cdef void setup(self, RawBackend backend, ATTR_TYPE attr_type):
|
||||
self._backend = backend
|
||||
self._attr_type = attr_type
|
||||
|
||||
cdef void set_value(self, NODE_INDEX node_index, SLOT_INDEX slot_index, object value) except +:
|
||||
pass
|
||||
|
||||
cdef object get_value(self, NODE_INDEX node_index, SLOT_INDEX slot_index) except +:
|
||||
pass
|
||||
|
||||
cdef void append_value(self, NODE_INDEX node_index, object value) except +:
|
||||
pass
|
||||
|
||||
cdef void insert_value(self, NODE_INDEX node_index, SLOT_INDEX slot_index, object value) except +:
|
||||
pass
|
||||
|
||||
def __dealloc__(self):
|
||||
self._backend = None
|
||||
|
||||
cdef class RawBackend(BackendAbc):
|
||||
def __cinit__(self):
|
||||
self._node_info = {}
|
||||
self._attr_type_dict = {}
|
||||
|
||||
cdef bool is_support_dynamic_features(self):
|
||||
return True
|
||||
|
||||
cdef NODE_TYPE add_node(self, str name, NODE_INDEX number) except +:
|
||||
cdef NODE_TYPE type = self._frame.add_node(name.encode(), number)
|
||||
|
||||
self._node_info[type] = {"number": number, "name": name, "attrs":{}}
|
||||
|
||||
return type
|
||||
|
||||
cdef ATTR_TYPE add_attr(self, NODE_TYPE node_type, str attr_name, bytes dtype, SLOT_INDEX slot_num, bool is_const, bool is_list) except +:
|
||||
cdef AttrDataType dt = AINT
|
||||
|
||||
cdef map[string, AttrDataType].iterator attr_pair = attr_type_mapping.find(dtype)
|
||||
|
||||
if attr_pair != attr_type_mapping.end():
|
||||
dt = deref(attr_pair).second;
|
||||
|
||||
# Add attribute to frame.
|
||||
cdef ATTR_TYPE attr_type = self._frame.add_attr(node_type, attr_name.encode(), dt, slot_num, is_const, is_list)
|
||||
|
||||
# Initial an access wrapper to this attribute.
|
||||
cdef AttributeAccessor acc = attribute_accessors[dtype]()
|
||||
|
||||
acc.setup(self, attr_type)
|
||||
|
||||
self._attr_type_dict[attr_type] = acc
|
||||
|
||||
# Record the information for output.
|
||||
self._node_info[node_type]["attrs"][attr_type] = {"type": dtype.decode(), "slots": slot_num, "name": attr_name}
|
||||
|
||||
return attr_type
|
||||
|
||||
cdef void set_attr_value(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, object value) except +:
|
||||
cdef AttributeAccessor acc = self._attr_type_dict[attr_type]
|
||||
|
||||
acc.set_value(node_index, slot_index, value)
|
||||
|
||||
cdef object get_attr_value(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index) except +:
|
||||
cdef AttributeAccessor acc = self._attr_type_dict[attr_type]
|
||||
|
||||
return acc.get_value(node_index, slot_index)
|
||||
|
||||
cdef void set_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX[:] slot_index, list value) except +:
|
||||
cdef SLOT_INDEX slot
|
||||
cdef int index
|
||||
|
||||
for index, slot in enumerate(slot_index):
|
||||
self.set_attr_value(node_index, attr_type, slot, value[index])
|
||||
|
||||
cdef list get_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX[:] slot_indices) except +:
|
||||
cdef AttributeAccessor acc = self._attr_type_dict[attr_type]
|
||||
|
||||
cdef SLOT_INDEX slot
|
||||
|
||||
cdef list result = []
|
||||
|
||||
for slot in slot_indices:
|
||||
result.append(acc.get_value(node_index, slot))
|
||||
|
||||
return result
|
||||
|
||||
cdef void append_node(self, NODE_TYPE node_type, NODE_INDEX number) except +:
|
||||
self._frame.append_node(node_type, number)
|
||||
|
||||
cdef void delete_node(self, NODE_TYPE node_type, NODE_INDEX node_index) except +:
|
||||
self._frame.remove_node(node_type, node_index)
|
||||
|
||||
cdef void resume_node(self, NODE_TYPE node_type, NODE_INDEX node_index) except +:
|
||||
self._frame.resume_node(node_type, node_index)
|
||||
|
||||
cdef void append_to_list(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +:
|
||||
cdef AttributeAccessor acc = self._attr_type_dict[attr_type]
|
||||
|
||||
acc.append_value(index, value)
|
||||
|
||||
cdef void resize_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX new_size) except +:
|
||||
self._frame.resize_list(index, attr_type, new_size)
|
||||
|
||||
cdef void clear_list(self, NODE_INDEX index, ATTR_TYPE attr_type) except +:
|
||||
self._frame.clear_list(index, attr_type)
|
||||
|
||||
cdef void remove_from_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX slot_index) except +:
|
||||
self._frame.remove_from_list(index, attr_type, slot_index)
|
||||
|
||||
cdef void insert_to_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, object value) except +:
|
||||
cdef AttributeAccessor acc = self._attr_type_dict[attr_type]
|
||||
|
||||
acc.insert_value(index, slot_index, value)
|
||||
|
||||
cdef void reset(self) except +:
|
||||
self._frame.reset()
|
||||
|
||||
cdef void setup(self, bool enable_snapshot, USHORT total_snapshot, dict options) except +:
|
||||
self._frame.setup()
|
||||
|
||||
if enable_snapshot:
|
||||
self.snapshots = RawSnapshotList(self, total_snapshot)
|
||||
|
||||
cdef dict get_node_info(self) except +:
|
||||
cdef dict node_info = {}
|
||||
|
||||
for node_id, node in self._node_info.items():
|
||||
node_info[node["name"]] = {
|
||||
"number": node["number"],
|
||||
"attributes": {
|
||||
attr["name"]: {
|
||||
"type": attr["type"],
|
||||
"slots": attr["slots"]
|
||||
} for _, attr in node["attrs"].items()
|
||||
}
|
||||
}
|
||||
|
||||
return node_info
|
||||
|
||||
cdef void dump(self, str folder) except +:
|
||||
self._frame.dump(folder.encode())
|
||||
|
||||
|
||||
cdef class RawSnapshotList(SnapshotListAbc):
|
||||
def __cinit__(self, RawBackend backend, USHORT total_snapshots):
|
||||
self._snapshots.setup(&backend._frame)
|
||||
self._snapshots.set_max_size(total_snapshots)
|
||||
|
||||
# Query states from snapshot list
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef query(self, NODE_TYPE node_type, list ticks, list node_index_list, list attr_list) except +:
|
||||
cdef int index
|
||||
cdef ATTR_TYPE attr_type
|
||||
|
||||
# NOTE: format must be changed if NODE_INDEX type changed
|
||||
# Node indices parameters passed to raw backend
|
||||
cdef NODE_INDEX[:] node_indices = None
|
||||
# Tick parameter passed to raw backend
|
||||
cdef INT[:] tick_list = None
|
||||
# Attribute list cannot be empty, so we just use it to construct parameter
|
||||
cdef ATTR_TYPE[:] attr_type_list = view.array(shape=(len(attr_list),), itemsize=sizeof(ATTR_TYPE), format="I")
|
||||
|
||||
# Check and construct node indices list
|
||||
if node_index_list is not None and len(node_index_list) > 0:
|
||||
node_indices = view.array(shape=(len(node_index_list),), itemsize=sizeof(NODE_INDEX), format="I")
|
||||
|
||||
cdef USHORT ticks_length = len(ticks)
|
||||
|
||||
# Check ticks, and construct if has value
|
||||
if ticks is not None and ticks_length > 0:
|
||||
tick_list = view.array(shape=(ticks_length,), itemsize=sizeof(INT), format="i")
|
||||
|
||||
for index in range(ticks_length):
|
||||
tick_list[index] = ticks[index]
|
||||
else:
|
||||
ticks_length = self._snapshots.size()
|
||||
|
||||
for index in range(len(node_index_list)):
|
||||
node_indices[index] = node_index_list[index]
|
||||
|
||||
for index in range(len(attr_list)):
|
||||
attr_type_list[index] = attr_list[index]
|
||||
|
||||
# Calc 1 frame length
|
||||
cdef SnapshotQueryResultShape shape = self._snapshots.prepare(node_type, &tick_list[0], ticks_length, &node_indices[0], len(node_indices), &attr_type_list[0], len(attr_type_list))
|
||||
|
||||
cdef size_t result_size = shape.tick_number * shape.max_node_number * shape.attr_number * shape.max_slot_number
|
||||
|
||||
if result_size <= 0:
|
||||
self._snapshots.cancel_query()
|
||||
|
||||
return None
|
||||
|
||||
# Result holder
|
||||
cdef QUERY_FLOAT[:, :, :, :] result = view.array(shape=(shape.tick_number, shape.max_node_number, shape.attr_number, shape.max_slot_number), itemsize=sizeof(QUERY_FLOAT), format="f")
|
||||
|
||||
# Default result value
|
||||
result[:, :, :, :] = np.nan
|
||||
|
||||
# Do query
|
||||
self._snapshots.query(&result[0][0][0][0])
|
||||
|
||||
return np.array(result)
|
||||
|
||||
# Record current backend state into snapshot list
|
||||
cdef void take_snapshot(self, INT tick) except +:
|
||||
self._snapshots.take_snapshot(tick)
|
||||
|
||||
cdef NODE_INDEX get_node_number(self, NODE_TYPE node_type) except +:
|
||||
return self._snapshots.get_max_node_number(node_type)
|
||||
|
||||
# List of available frame index in snapshot list
|
||||
cdef list get_frame_index_list(self) except +:
|
||||
cdef USHORT number = self._snapshots.size()
|
||||
cdef INT[:] result = view.array(shape=(number,), itemsize=sizeof(INT), format="i")
|
||||
|
||||
self._snapshots.get_ticks(&result[0])
|
||||
|
||||
return list(result)
|
||||
|
||||
# Enable history, history will dump backend into files each time take_snapshot called
|
||||
cdef void enable_history(self, str history_folder) except +:
|
||||
pass
|
||||
|
||||
# Reset internal states
|
||||
cdef void reset(self) except +:
|
||||
self._snapshots.reset()
|
||||
|
||||
cdef void dump(self, str folder) except +:
|
||||
self._snapshots.dump(folder.encode())
|
||||
|
||||
def __len__(self):
|
||||
return self._snapshots.size()
|
|
@ -0,0 +1,14 @@
|
|||
|
||||
|
||||
cdef class Attribute{CLSNAME}Accessor(AttributeAccessor):
|
||||
cdef void set_value(self, NODE_INDEX node_index, SLOT_INDEX slot_index, object value) except +:
|
||||
self._backend._frame.set_value[{T}](node_index, self._attr_type, slot_index, value)
|
||||
|
||||
cdef object get_value(self, NODE_INDEX node_index, SLOT_INDEX slot_index) except +:
|
||||
return self._backend._frame.get_value[{T}](node_index, self._attr_type, slot_index)
|
||||
|
||||
cdef void append_value(self, NODE_INDEX node_index, object value) except +:
|
||||
self._backend._frame.append_to_list[{T}](node_index, self._attr_type, value)
|
||||
|
||||
cdef void insert_value(self, NODE_INDEX node_index, SLOT_INDEX slot_index, object value) except +:
|
||||
self._backend._frame.insert_to_list[{T}](node_index, self._attr_type, slot_index, value)
|
|
@ -2,43 +2,133 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
#distutils: language = c++
|
||||
|
||||
from cpython cimport bool
|
||||
|
||||
from libc.stdint cimport int32_t, int64_t, int16_t, int8_t, uint32_t, uint64_t
|
||||
|
||||
# common types
|
||||
|
||||
ctypedef int INT
|
||||
ctypedef unsigned int UINT
|
||||
ctypedef unsigned long long ULONG
|
||||
ctypedef unsigned short USHORT
|
||||
|
||||
ctypedef char ATTR_CHAR
|
||||
ctypedef unsigned char ATTR_UCHAR
|
||||
ctypedef short ATTR_SHORT
|
||||
ctypedef USHORT ATTR_USHORT
|
||||
ctypedef int32_t ATTR_INT
|
||||
ctypedef uint32_t ATTR_UINT
|
||||
ctypedef int64_t ATTR_LONG
|
||||
ctypedef uint64_t ATTR_ULONG
|
||||
ctypedef float ATTR_FLOAT
|
||||
ctypedef double ATTR_DOUBLE
|
||||
|
||||
# Type for snapshot querying.
|
||||
ctypedef float QUERY_FLOAT
|
||||
|
||||
# TYPE of node and attribute
|
||||
ctypedef unsigned short NODE_TYPE
|
||||
ctypedef uint32_t ATTR_TYPE
|
||||
|
||||
# Index type of node
|
||||
ctypedef ATTR_TYPE NODE_INDEX
|
||||
|
||||
# Index type of slot
|
||||
ctypedef ATTR_TYPE SLOT_INDEX
|
||||
|
||||
|
||||
cdef class AttributeType:
|
||||
pass
|
||||
|
||||
|
||||
# Base of all snapshot accessing implementation
|
||||
cdef class SnapshotListAbc:
|
||||
# query states from snapshot list
|
||||
cdef query(self, str node_name, list ticks, list node_index_list, list attr_name_list)
|
||||
# Query states from snapshot list
|
||||
cdef query(self, NODE_TYPE node_type, list ticks, list node_index_list, list attr_list) except +
|
||||
|
||||
# record specified backend state into snapshot list
|
||||
cdef void take_snapshot(self, int frame_index) except *
|
||||
# Record current backend state into snapshot list
|
||||
cdef void take_snapshot(self, INT tick) except +
|
||||
|
||||
# list of available frame index in snapshot list
|
||||
cdef list get_frame_index_list(self)
|
||||
# List of available frame index in snapshot list
|
||||
cdef list get_frame_index_list(self) except +
|
||||
|
||||
cdef void enable_history(self, str history_folder) except *
|
||||
# Get number of specified node
|
||||
cdef NODE_INDEX get_node_number(self, NODE_TYPE node_type) except +
|
||||
|
||||
cdef void reset(self) except *
|
||||
# Enable history, history will dump backend into files each time take_snapshot called
|
||||
cdef void enable_history(self, str history_folder) except +
|
||||
|
||||
# Reset internal states.
|
||||
cdef void reset(self) except +
|
||||
|
||||
# Dump Snapshot into target folder (without filename).
|
||||
cdef void dump(self, str folder) except +
|
||||
|
||||
|
||||
# Base of all backend implementation
|
||||
cdef class BackendAbc:
|
||||
cdef:
|
||||
public SnapshotListAbc snapshots
|
||||
|
||||
cdef void setup(self, bool enable_snapshot, int total_snapshot, dict options) except *
|
||||
# Is current backend support dynamic features.
|
||||
cdef bool is_support_dynamic_features(self)
|
||||
|
||||
cdef void reset(self) except *
|
||||
# Add a new node to current backend, with specified number (>=0).
|
||||
# Returns an ID of this new node in current backend.
|
||||
cdef NODE_TYPE add_node(self, str name, NODE_INDEX number) except +
|
||||
|
||||
cdef void add_node(self, str name, int number) except *
|
||||
# Add a new attribute to specified node (id).
|
||||
# Returns an ID of this new attribute for current node (id).
|
||||
cdef ATTR_TYPE add_attr(self, NODE_TYPE node_type, str attr_name, bytes dtype, SLOT_INDEX slot_num, bool is_const, bool is_list) except +
|
||||
|
||||
cdef void add_attr(self, str node_name, str attr_name, str dtype, int slot_num) except *
|
||||
# Set value of specified attribute slot.
|
||||
# NOTE: since we already know which node current attribute belongs to, so we just need to specify attribute id
|
||||
cdef void set_attr_value(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, object value) except +
|
||||
|
||||
cdef void set_attr_value(self, str node_name, int node_index, str attr_name, int slot_index, value) except *
|
||||
# Get value of specified attribute slot.
|
||||
cdef object get_attr_value(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index) except +
|
||||
|
||||
cdef void set_attr_values(self, str node_name, int node_index, str attr_name, int[:] slot_index, list value) except *
|
||||
# Set values of specified slots.
|
||||
cdef void set_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX[:] slot_index, list value) except +
|
||||
|
||||
cdef object get_attr_value(self, str node_name, int node_index, str attr_name, int slot_index)
|
||||
# Get values of specified slots.
|
||||
cdef list get_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX[:] slot_indices) except +
|
||||
|
||||
cdef object[object, ndim=1] get_attr_values(self, str node_name, int node_index, str attr_name, int[:] slot_indices)
|
||||
# Get node definition of backend.
|
||||
cdef dict get_node_info(self) except +
|
||||
|
||||
cdef dict get_node_info(self)
|
||||
# Setup backend with options.
|
||||
cdef void setup(self, bool enable_snapshot, USHORT total_snapshot, dict options) except +
|
||||
|
||||
cdef void dump(self, str filePath)
|
||||
# Reset internal states.
|
||||
cdef void reset(self) except +
|
||||
|
||||
# Append specified number of nodes.
|
||||
cdef void append_node(self, NODE_TYPE node_type, NODE_INDEX number) except +
|
||||
|
||||
# Delete a node by index.
|
||||
cdef void delete_node(self, NODE_TYPE node_type, NODE_INDEX node_index) except +
|
||||
|
||||
# Resume node that been deleted.
|
||||
cdef void resume_node(self, NODE_TYPE node_type, NODE_INDEX node_index) except +
|
||||
|
||||
# Append value to specified list attribute.
|
||||
cdef void append_to_list(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
|
||||
|
||||
# Resize specified list attribute.
|
||||
cdef void resize_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX new_size) except +
|
||||
|
||||
# Clear specified list attribute.
|
||||
cdef void clear_list(self, NODE_INDEX index, ATTR_TYPE attr_type) except +
|
||||
|
||||
# Remove a slot from list attribute.
|
||||
cdef void remove_from_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX slot_index) except +
|
||||
|
||||
# Insert a slot to list attribute.
|
||||
cdef void insert_to_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, object value) except +
|
||||
|
||||
# Dump Snapshot into target folder (without filename).
|
||||
cdef void dump(self, str folder) except +
|
||||
|
|
|
@ -2,53 +2,106 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
#distutils: language = c++
|
||||
|
||||
from enum import Enum
|
||||
from cpython cimport bool
|
||||
|
||||
cdef class AttributeType:
|
||||
Byte = b"byte"
|
||||
UByte = b"ubyte"
|
||||
Short = b"short"
|
||||
UShort = b"ushort"
|
||||
Int = b"int"
|
||||
UInt = b"uint"
|
||||
Long = b"long"
|
||||
ULong = b"ulong"
|
||||
Float = b"float"
|
||||
Double = b"double"
|
||||
|
||||
cdef int raise_get_attr_error() except +:
|
||||
raise Exception("Bad parameters to get attribute value.")
|
||||
|
||||
|
||||
cdef class SnapshotListAbc:
|
||||
cdef query(self, str node_name, list ticks, list node_index_list, list attr_name_list):
|
||||
cdef query(self, NODE_TYPE node_type, list ticks, list node_index_list, list attr_list) except +:
|
||||
pass
|
||||
|
||||
cdef void take_snapshot(self, int tick) except *:
|
||||
cdef void take_snapshot(self, INT tick) except +:
|
||||
pass
|
||||
|
||||
cdef void enable_history(self, str history_folder) except *:
|
||||
cdef NODE_INDEX get_node_number(self, NODE_TYPE node_type) except +:
|
||||
return 0
|
||||
|
||||
cdef void enable_history(self, str history_folder) except +:
|
||||
pass
|
||||
|
||||
cdef void reset(self) except *:
|
||||
cdef void reset(self) except +:
|
||||
pass
|
||||
|
||||
cdef list get_frame_index_list(self):
|
||||
cdef list get_frame_index_list(self) except +:
|
||||
return []
|
||||
|
||||
cdef void dump(self, str folder) except +:
|
||||
pass
|
||||
|
||||
cdef class BackendAbc:
|
||||
|
||||
cdef void add_node(self, str name, int number) except *:
|
||||
cdef bool is_support_dynamic_features(self):
|
||||
return False
|
||||
|
||||
cdef NODE_TYPE add_node(self, str name, NODE_INDEX number) except +:
|
||||
pass
|
||||
|
||||
cdef void add_attr(self, str node_name, str attr_name, str dtype, int slot_num) except *:
|
||||
cdef ATTR_TYPE add_attr(self, NODE_TYPE node_type, str attr_name, bytes dtype, SLOT_INDEX slot_num, bool is_const, bool is_list) except +:
|
||||
pass
|
||||
|
||||
cdef void set_attr_value(self, str node_name, int node_index, str attr_name, int slot_index, value) except *:
|
||||
cdef void set_attr_value(self, NODE_INDEX node_index, ATTR_TYPE attr_id, SLOT_INDEX slot_index, object value) except +:
|
||||
pass
|
||||
|
||||
cdef object get_attr_value(self, str node_name, int node_index, str attr_name, int slot_index):
|
||||
cdef object get_attr_value(self, NODE_INDEX node_index, ATTR_TYPE attr_id, SLOT_INDEX slot_index) except +:
|
||||
pass
|
||||
|
||||
cdef void set_attr_values(self, str node_name, int node_index, str attr_name, int[:] slot_index, list value) except *:
|
||||
cdef void set_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_id, SLOT_INDEX[:] slot_index, list value) except +:
|
||||
pass
|
||||
|
||||
cdef object[object, ndim=1] get_attr_values(self, str node_name, int node_index, str attr_name, int[:] slot_indices):
|
||||
cdef list get_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_id, SLOT_INDEX[:] slot_indices) except +:
|
||||
pass
|
||||
|
||||
cdef void reset(self) except *:
|
||||
cdef void reset(self) except +:
|
||||
pass
|
||||
|
||||
cdef void setup(self, bool enable_snapshot, int total_snapshot, dict options) except *:
|
||||
cdef void setup(self, bool enable_snapshot, USHORT total_snapshot, dict options) except +:
|
||||
pass
|
||||
|
||||
cdef dict get_node_info(self):
|
||||
cdef dict get_node_info(self) except +:
|
||||
return {}
|
||||
|
||||
cdef void dump(self, str filePath):
|
||||
cdef void append_node(self, NODE_TYPE node_type, NODE_INDEX number) except +:
|
||||
pass
|
||||
|
||||
cdef void delete_node(self, NODE_TYPE node_type, NODE_INDEX node_index) except +:
|
||||
pass
|
||||
|
||||
cdef void resume_node(self, NODE_TYPE node_type, NODE_INDEX node_index) except +:
|
||||
pass
|
||||
|
||||
cdef void append_to_list(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +:
|
||||
pass
|
||||
|
||||
# Resize specified list attribute.
|
||||
cdef void resize_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX new_size) except +:
|
||||
pass
|
||||
|
||||
# Clear specified list attribute.
|
||||
cdef void clear_list(self, NODE_INDEX index, ATTR_TYPE attr_type) except +:
|
||||
pass
|
||||
|
||||
cdef void remove_from_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX slot_index) except +:
|
||||
pass
|
||||
|
||||
cdef void insert_to_list(self, NODE_INDEX index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, object value) except +:
|
||||
pass
|
||||
|
||||
cdef void dump(self, str folder) except +:
|
||||
pass
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
#distutils: language = c++
|
||||
|
||||
from cpython cimport bool
|
||||
|
||||
from maro.backends.backend cimport BackendAbc, SnapshotListAbc
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, USHORT, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
|
||||
|
||||
|
||||
cdef class SnapshotList:
|
||||
|
@ -78,7 +80,10 @@ cdef class FrameBase:
|
|||
yournodes = FrameNode(YourNode, 12)
|
||||
|
||||
def __init__(self, enable_snapshot:bool=True, snapshot_number: int = 10):
|
||||
super().__init__(self, enable_snapshot, total_snapshots=snapshot_number)
|
||||
super().__init__(self, enable_snapshot, total_snapshots=snapshot_number, backend_name="static or dynamic")
|
||||
|
||||
Currently we support 2 kinds of backend implementation for frame: static and dynamic. Dynamic backend support list attribute
|
||||
which works list a normal python list, but only can hold decleared data type.
|
||||
|
||||
The snapshot list is used to hold snapshot of current frame at specified point (tick or frame index), it can be
|
||||
configured that how many snapshots should be kept in memory, latest snapshot will over-write oldest one if reach
|
||||
|
@ -130,17 +135,29 @@ cdef class FrameBase:
|
|||
|
||||
SnapshotList _snapshot_list
|
||||
|
||||
str _backend_name
|
||||
|
||||
dict _node_cls_dict
|
||||
dict _node_name2attrname_dict
|
||||
dict _node_origin_number_dict
|
||||
|
||||
# enable dynamic fields
|
||||
dict __dict__
|
||||
|
||||
|
||||
cpdef void reset(self) except *
|
||||
|
||||
cpdef void take_snapshot(self, int tick) except *
|
||||
cpdef void take_snapshot(self, INT tick) except *
|
||||
|
||||
cpdef void enable_history(self, str path) except *
|
||||
|
||||
cdef void _setup_backend(self, bool enable_snapshot, int total_snapshot, dict options) except *
|
||||
cpdef void append_node(self, str node_name, NODE_INDEX number) except +
|
||||
|
||||
cpdef void delete_node(self, NodeBase node) except +
|
||||
|
||||
cpdef void resume_node(self, NodeBase node) except +
|
||||
|
||||
cdef void _setup_backend(self, bool enable_snapshot, USHORT total_snapshot, dict options) except *
|
||||
|
||||
|
||||
cdef class FrameNode:
|
||||
|
@ -155,7 +172,7 @@ cdef class FrameNode:
|
|||
cdef:
|
||||
public type _node_cls
|
||||
|
||||
public int _number
|
||||
public NODE_INDEX _number
|
||||
|
||||
|
||||
cdef class NodeBase:
|
||||
|
@ -207,7 +224,8 @@ cdef class NodeBase:
|
|||
def gen_my_node_definition(float_attr_number: int):
|
||||
@node("my nodes")
|
||||
class MyNode(NodeBase):
|
||||
my_int_attr = NodeAttribute("i")
|
||||
# Default attribute type is AttributeType.Int, slot number is 1, so we can leave it empty here
|
||||
my_int_attr = NodeAttribute()
|
||||
my_float_array_attr = NodeAttribute("f", float_attr_number)
|
||||
|
||||
return MyNode
|
||||
|
@ -251,36 +269,56 @@ cdef class NodeBase:
|
|||
my_nodes.my_float_array_attr[(0, 1)] = (0.1, 0.2)
|
||||
"""
|
||||
cdef:
|
||||
# index of current node in frame memory,
|
||||
# Index of current node in frame memory,
|
||||
# all the node/frame operation will base on this property, so user should create a mapping that
|
||||
# map the business model id/name to node index
|
||||
int _index
|
||||
NODE_INDEX _index
|
||||
|
||||
# Node id, used to access backend
|
||||
NODE_TYPE _type
|
||||
|
||||
BackendAbc _backend
|
||||
|
||||
# enable dynamic attributes
|
||||
bool _is_deleted
|
||||
|
||||
# Attriubtes: name -> type.
|
||||
dict _attributes
|
||||
|
||||
# Enable dynamic attributes
|
||||
dict __dict__
|
||||
|
||||
# set up the node for using with frame, and index
|
||||
# Set up the node for using with frame, and index
|
||||
# this is called by Frame after the instance is initialized
|
||||
cdef void setup(self, BackendAbc backend, int index) except *
|
||||
cdef void setup(self, BackendAbc backend, NODE_INDEX index, NODE_TYPE type, dict attr_name_id_dict) except *
|
||||
|
||||
# internal functions, will be called after Frame's setup, used to bind attributes to instance
|
||||
# Internal functions, will be called after Frame's setup, used to bind attributes to instance
|
||||
cdef void _bind_attributes(self) except *
|
||||
|
||||
|
||||
cdef class NodeAttribute:
|
||||
"""Helper class used to declare an attribute in node that inherit from NodeBase.
|
||||
|
||||
Currently we only support these data types: 'i', 'i2', 'i4', 'i8', 'f' and 'd'.
|
||||
|
||||
Args:
|
||||
dtype(str): Type of this attribute, it support following data types.
|
||||
slots(int): If this number greater than 1, then it will be treat as an array, this will be the array size.
|
||||
dtype(str): Type of this attribute, use type from maro.backends.backend.AttributeType to specify valid type,
|
||||
default is AttributeType.Int if not provided, or invalid type provided.
|
||||
slots(int): If this number greater than 1, then it will be treat as an array, this will be the array size,
|
||||
this value cannot be changed after definition, max value is 2^32.
|
||||
is_const(bool): Is this is a const attribute, True means this attribute will not be copied into snapshot list,
|
||||
share between current frame and snapshots. Default is False.
|
||||
is_list(bool): Is this is a list attribute, True means this attribute works like a list (max size is 2^32),
|
||||
without a fixed size like normal attribute. NOTE: a list attribute cannot be const, it will cause exception,
|
||||
and its default slot number will be 0, but can be resized.
|
||||
Default is False.
|
||||
"""
|
||||
cdef:
|
||||
# data type of attribute, same as numpy string dtype
|
||||
public str _dtype
|
||||
# Data type of attribute, same as numpy string dtype.
|
||||
public bytes _dtype
|
||||
|
||||
# array size of tis attribute
|
||||
public int _slot_number
|
||||
# Array size of tis attribute.
|
||||
public SLOT_INDEX _slot_number
|
||||
|
||||
# Is this is a const attribute?
|
||||
public bool _is_const
|
||||
|
||||
# Is this is a list attribute?
|
||||
public bool _is_list
|
||||
|
|
|
@ -2,28 +2,60 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
|
||||
#distutils: language = c++
|
||||
#distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
|
||||
|
||||
import os
|
||||
|
||||
cimport cython
|
||||
|
||||
cimport numpy as np
|
||||
import numpy as np
|
||||
|
||||
from cpython cimport bool
|
||||
from typing import Union
|
||||
|
||||
from maro.backends.np_backend cimport NumpyBackend
|
||||
from maro.backends.raw_backend cimport RawBackend
|
||||
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, UINT, ULONG, USHORT, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
|
||||
|
||||
from maro.utils.exception.backends_exception import (
|
||||
BackendsGetItemInvalidException,
|
||||
BackendsSetItemInvalidException,
|
||||
BackendsArrayAttributeAccessException
|
||||
BackendsArrayAttributeAccessException,
|
||||
BackendsAppendToNonListAttributeException,
|
||||
BackendsResizeNonListAttributeException,
|
||||
BackendsClearNonListAttributeException,
|
||||
BackendsInsertNonListAttributeException,
|
||||
BackendsRemoveFromNonListAttributeException,
|
||||
BackendsAccessDeletedNodeException,
|
||||
BackendsInvalidNodeException,
|
||||
BackendsInvalidAttributeException
|
||||
)
|
||||
from maro.backends.backend cimport BackendAbc, SnapshotListAbc
|
||||
|
||||
# NOTE: here to support backend switching
|
||||
IF FRAME_BACKEND == "NUMPY":
|
||||
cimport numpy as np
|
||||
import numpy as np
|
||||
# Old type definition mapping.
|
||||
old_data_type_definitions = {
|
||||
"i": AttributeType.Int,
|
||||
"i4": AttributeType.Int,
|
||||
"i2": AttributeType.Short,
|
||||
"i8": AttributeType.Long,
|
||||
"f": AttributeType.Float,
|
||||
"d": AttributeType.Double
|
||||
}
|
||||
|
||||
from maro.backends.np_backend cimport NumpyBackend as backend
|
||||
# Supported backends.
|
||||
backend_dict = {
|
||||
"dynamic" : RawBackend,
|
||||
"static" : NumpyBackend
|
||||
}
|
||||
|
||||
ELSE:
|
||||
from maro.backends.raw_backend cimport RawBackend as backend
|
||||
# Default backend name.
|
||||
_default_backend_name = "static"
|
||||
|
||||
NP_SLOT_INDEX = np.uint32
|
||||
NP_NODE_INDEX = np.uint32
|
||||
|
||||
|
||||
def node(name: str):
|
||||
|
@ -43,40 +75,154 @@ def node(name: str):
|
|||
|
||||
|
||||
cdef class NodeAttribute:
|
||||
def __init__(self, dtype: str, slot_num: int = 1):
|
||||
self._dtype = dtype
|
||||
def __cinit__(self, object dtype = None, SLOT_INDEX slot_num = 1, is_const = False, is_list = False):
|
||||
# Check the type of dtype, used to compact with old version
|
||||
cdef bytes _type = AttributeType.Int
|
||||
|
||||
if dtype is not None:
|
||||
dtype_type = type(dtype)
|
||||
|
||||
if dtype_type == str:
|
||||
if dtype in old_data_type_definitions:
|
||||
_type = old_data_type_definitions[dtype]
|
||||
elif dtype_type == bytes:
|
||||
_type = dtype
|
||||
|
||||
self._dtype = _type
|
||||
self._slot_number = slot_num
|
||||
self._is_const = is_const
|
||||
self._is_list = is_list
|
||||
|
||||
|
||||
# TODO: A better way to support multiple value get/set for an attribute with more than one slot.
|
||||
#
|
||||
# Wrapper to provide easy way to access attribute value of specified node
|
||||
# with this wrapper, user can get/set attribute value more easily.
|
||||
cdef class _NodeAttributeAccessor:
|
||||
cdef:
|
||||
# target node
|
||||
str _node_name
|
||||
# attribute name
|
||||
str _attr_name
|
||||
BackendAbc _backend
|
||||
int _index
|
||||
# Target node index.
|
||||
NODE_INDEX _node_index
|
||||
|
||||
public NodeAttribute attr
|
||||
# Target attribute type.
|
||||
public ATTR_TYPE _attr_type
|
||||
|
||||
# Slot number of target attribute.
|
||||
public SLOT_INDEX _slot_number
|
||||
|
||||
# Is this is a list attribute?
|
||||
# True to enable append/remove/insert methods.
|
||||
public bool _is_list
|
||||
|
||||
# Target backend.
|
||||
BackendAbc _backend
|
||||
|
||||
# Index used to support for-loop
|
||||
SLOT_INDEX _cur_iter_slot_index
|
||||
|
||||
# Enable dynamic attributes.
|
||||
dict __dict__
|
||||
|
||||
# Slot list cache, used to avoid to much runtime list generation.
|
||||
# slot -> int[:]
|
||||
dict _slot_list_cache
|
||||
|
||||
def __init__(self, attr: NodeAttribute, node_name: str, attr_name: str, backend: BackendAbc, index: int):
|
||||
self.attr = attr
|
||||
self._node_name = node_name
|
||||
self._attr_name = attr_name
|
||||
def __cinit__(self, NodeAttribute attr, ATTR_TYPE attr_type, BackendAbc backend, NODE_INDEX node_index):
|
||||
self._attr_type = attr_type
|
||||
self._node_index = node_index
|
||||
self._slot_number = attr._slot_number
|
||||
self._is_list = attr._is_list
|
||||
self._backend = backend
|
||||
self._index = index
|
||||
self._slot_list_cache = {}
|
||||
|
||||
# Built-in index too support for-loop
|
||||
self._cur_iter_slot_index = 0
|
||||
|
||||
# Special for list attribute, we need to slot number to support __len__
|
||||
# We will count the slot number here, though we can get it from function call
|
||||
if self._is_list:
|
||||
self._slot_number = 0
|
||||
|
||||
def append(self, value):
|
||||
"""Append a value to current attribute.
|
||||
|
||||
NOTE:
|
||||
Current attribute must be a list.
|
||||
|
||||
Args:
|
||||
value(object): Value to append, the data type must fit the decleared one.
|
||||
"""
|
||||
if not self._is_list:
|
||||
raise BackendsAppendToNonListAttributeException()
|
||||
|
||||
self._backend.append_to_list(self._node_index, self._attr_type, value)
|
||||
|
||||
self._slot_number += 1
|
||||
|
||||
def resize(self, new_size: int):
|
||||
"""Resize current list attribute with specified new size.
|
||||
|
||||
NOTE:
|
||||
Current attribute must be a list.
|
||||
|
||||
Args:
|
||||
new_size(int): New size to resize, max number is 2^32.
|
||||
"""
|
||||
if not self._is_list:
|
||||
raise BackendsResizeNonListAttributeException()
|
||||
|
||||
self._backend.resize_list(self._node_index, self._attr_type, new_size)
|
||||
|
||||
self._slot_number = new_size
|
||||
|
||||
def clear(self):
|
||||
"""Clear all items in current list attribute.
|
||||
|
||||
NOTE:
|
||||
Current attribute must be a list.
|
||||
"""
|
||||
if not self._is_list:
|
||||
raise BackendsClearNonListAttributeException()
|
||||
|
||||
self._backend.clear_list(self._node_index, self._attr_type)
|
||||
|
||||
self._slot_number = 0
|
||||
|
||||
def insert(self, slot_index: int, value: object):
|
||||
"""Insert a value to specified slot.
|
||||
|
||||
Args:
|
||||
slot_index(int): Slot index to insert.
|
||||
value(object): Value to insert.
|
||||
"""
|
||||
if not self._is_list:
|
||||
raise BackendsInsertNonListAttributeException()
|
||||
|
||||
self._backend.insert_to_list(self._node_index, self._attr_type, slot_index, value)
|
||||
|
||||
self._slot_number += 1
|
||||
|
||||
def remove(self, slot_index: int):
|
||||
"""Remove specified slot.
|
||||
|
||||
Args:
|
||||
slot_index(int): Slot index to remove.
|
||||
"""
|
||||
if not self._is_list:
|
||||
raise BackendsRemoveFromNonListAttributeException()
|
||||
|
||||
self._backend.remove_from_list(self._node_index, self._attr_type, slot_index)
|
||||
|
||||
self._slot_number -= 1
|
||||
|
||||
def __iter__(self):
|
||||
"""Start for-loop."""
|
||||
self._cur_iter_slot_index = 0
|
||||
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Get next slot value."""
|
||||
if self._cur_iter_slot_index >= self._slot_number:
|
||||
raise StopIteration
|
||||
|
||||
value = self._backend.get_attr_value(self._node_index, self._attr_type, self._cur_iter_slot_index)
|
||||
|
||||
self._cur_iter_slot_index += 1
|
||||
|
||||
return value
|
||||
|
||||
def __getitem__(self, slot: Union[int, slice, list, tuple]):
|
||||
"""We use this function to support slice interface to access attribute, like:
|
||||
|
@ -85,15 +231,16 @@ cdef class _NodeAttributeAccessor:
|
|||
b = node.attr1[:]
|
||||
c = node.attr1[(1, 2, 3)]
|
||||
"""
|
||||
# NOTE: we do not support negative indexing now
|
||||
|
||||
cdef int start
|
||||
cdef int stop
|
||||
cdef SLOT_INDEX start
|
||||
cdef SLOT_INDEX stop
|
||||
cdef type slot_type = type(slot)
|
||||
cdef int[:] slot_list
|
||||
cdef SLOT_INDEX[:] slot_list
|
||||
|
||||
# node.attribute[0]
|
||||
# Get only one slot: node.attribute[0].
|
||||
if slot_type == int:
|
||||
return self._backend.get_attr_value(self._node_name, self._index, self._attr_name, slot)
|
||||
return self._backend.get_attr_value(self._node_index, self._attr_type, slot)
|
||||
|
||||
# Try to support following:
|
||||
# node.attribute[1:3]
|
||||
|
@ -101,62 +248,57 @@ cdef class _NodeAttributeAccessor:
|
|||
# node.attribute[(0, 1)]
|
||||
cdef tuple slot_key = tuple(slot) if slot_type != slice else (slot.start, slot.stop, slot.step)
|
||||
|
||||
slot_list = self._slot_list_cache.get(slot_key, None)
|
||||
slot_list = None
|
||||
|
||||
if slot_list is None:
|
||||
if slot_type == slice:
|
||||
start = 0 if slot.start is None else slot.start
|
||||
stop = self.attr._slot_number if slot.stop is None else slot.stop
|
||||
# Parse slice parameters: [:].
|
||||
if slot_type == slice:
|
||||
start = 0 if slot.start is None else slot.start
|
||||
stop = self._slot_number if slot.stop is None else slot.stop
|
||||
|
||||
slot_list = np.arange(start, stop, dtype='i')
|
||||
elif slot_type == list or slot_type == tuple:
|
||||
slot_list = np.array(slot, dtype='i')
|
||||
else:
|
||||
raise BackendsGetItemInvalidException()
|
||||
slot_list = np.arange(start, stop, dtype=NP_SLOT_INDEX)
|
||||
elif slot_type == list or slot_type == tuple:
|
||||
slot_list = np.array(slot, dtype=NP_SLOT_INDEX)
|
||||
else:
|
||||
raise BackendsGetItemInvalidException()
|
||||
|
||||
self._slot_list_cache[slot_key] = slot_list
|
||||
|
||||
return self._backend.get_attr_values(self._node_name, self._index, self._attr_name, slot_list)
|
||||
return self._backend.get_attr_values(self._node_index, self._attr_type, slot_list)
|
||||
|
||||
def __setitem__(self, slot: Union[int, slice, list, tuple], value: Union[object, list, tuple, np.ndarray]):
|
||||
# Check if type match.
|
||||
cdef int[:] slot_list
|
||||
cdef SLOT_INDEX[:] slot_list
|
||||
cdef list values
|
||||
|
||||
# TODO: Use large data type for index.
|
||||
cdef int start
|
||||
cdef int stop
|
||||
cdef SLOT_INDEX start
|
||||
cdef SLOT_INDEX stop
|
||||
|
||||
cdef type slot_type = type(slot)
|
||||
cdef type value_type = type(value)
|
||||
|
||||
cdef int values_length
|
||||
cdef int slot_length
|
||||
cdef SLOT_INDEX values_length
|
||||
cdef SLOT_INDEX slot_length
|
||||
cdef tuple slot_key
|
||||
|
||||
# node.attribute[0] = 1
|
||||
# Set value for one slot: node.attribute[0] = 1.
|
||||
if slot_type == int:
|
||||
self._backend.set_attr_value(self._node_name, self._index, self._attr_name, slot, value)
|
||||
self._backend.set_attr_value(self._node_index, self._attr_type, slot, value)
|
||||
elif slot_type == list or slot_type == tuple or slot_type == slice:
|
||||
# Try to support following:
|
||||
# node.attribute[0: 2] = 1/[1,2]/ (0, 2, 3)
|
||||
# node.attribute[0: 2] = 1 / [1,2] / (0, 2, 3)
|
||||
slot_key = tuple(slot) if slot_type != slice else (slot.start, slot.stop, slot.step)
|
||||
|
||||
slot_list = self._slot_list_cache.get(slot_key, None)
|
||||
slot_list = None
|
||||
|
||||
if slot_list is None:
|
||||
if slot_type == slice:
|
||||
start = 0 if slot.start is None else slot.start
|
||||
stop = self.attr._slot_number if slot.stop is None else slot.stop
|
||||
# Parse slot indices to set.
|
||||
if slot_type == slice:
|
||||
start = 0 if slot.start is None else slot.start
|
||||
stop = self._slot_number if slot.stop is None else slot.stop
|
||||
|
||||
slot_list = np.arange(start, stop, dtype='i')
|
||||
elif slot_type == list or slot_type == tuple:
|
||||
slot_list = np.array(slot, dtype='i')
|
||||
|
||||
self._slot_list_cache[slot_key] = slot_list
|
||||
slot_list = np.arange(start, stop, dtype=NP_SLOT_INDEX)
|
||||
elif slot_type == list or slot_type == tuple:
|
||||
slot_list = np.array(slot, dtype=NP_SLOT_INDEX)
|
||||
|
||||
slot_length = len(slot_list)
|
||||
|
||||
# Parse value, padding if needed.
|
||||
if value_type == list or value_type == tuple or value_type == np.ndarray:
|
||||
values = list(value)
|
||||
|
||||
|
@ -170,14 +312,17 @@ cdef class _NodeAttributeAccessor:
|
|||
else:
|
||||
values = [value] * slot_length
|
||||
|
||||
self._backend.set_attr_values(self._node_name, self._index, self._attr_name, slot_list, values)
|
||||
self._backend.set_attr_values(self._node_index, self._attr_type, slot_list, values)
|
||||
else:
|
||||
raise BackendsSetItemInvalidException()
|
||||
|
||||
# Check and invoke value changed callback.
|
||||
if "_cb" in self.__dict__:
|
||||
# Check and invoke value changed callback, except list attribute.
|
||||
if not self._is_list and "_cb" in self.__dict__:
|
||||
self._cb(value)
|
||||
|
||||
def __len__(self):
|
||||
return self._slot_number
|
||||
|
||||
def on_value_changed(self, cb):
|
||||
"""Set the value changed callback."""
|
||||
self._cb = cb
|
||||
|
@ -186,12 +331,21 @@ cdef class _NodeAttributeAccessor:
|
|||
cdef class NodeBase:
|
||||
@property
|
||||
def index(self):
|
||||
"""int: Index of current node instance."""
|
||||
return self._index
|
||||
|
||||
cdef void setup(self, BackendAbc backend, int index) except *:
|
||||
@property
|
||||
def is_deleted(self):
|
||||
"""bool:: Is this node instance already been deleted."""
|
||||
return self._is_deleted
|
||||
|
||||
cdef void setup(self, BackendAbc backend, NODE_INDEX index, NODE_TYPE node_type, dict attr_name_type_dict) except *:
|
||||
"""Setup frame node, and bind attributes."""
|
||||
self._index = index
|
||||
self._type = node_type
|
||||
self._backend = backend
|
||||
self._is_deleted = False
|
||||
self._attributes = attr_name_type_dict
|
||||
|
||||
self._bind_attributes()
|
||||
|
||||
|
@ -199,31 +353,35 @@ cdef class NodeBase:
|
|||
"""Bind attributes declared in class."""
|
||||
cdef dict __dict__ = object.__getattribute__(self, "__dict__")
|
||||
|
||||
cdef str name
|
||||
cdef str node_name
|
||||
cdef ATTR_TYPE attr_type
|
||||
cdef str cb_name
|
||||
cdef _NodeAttributeAccessor attr_acc
|
||||
|
||||
for name, attr in type(self).__dict__.items():
|
||||
# Append an attribute access wrapper to current instance.
|
||||
if isinstance(attr, NodeAttribute):
|
||||
# TODO: This will override exist attribute of sub-class instance, maybe a warning later.
|
||||
node_name = getattr(type(self), "__node_name__", None)
|
||||
# Register attribute.
|
||||
attr_type = self._attributes[name]
|
||||
attr_acc = _NodeAttributeAccessor(attr, attr_type, self._backend, self._index)
|
||||
|
||||
# NOTE: Here we have to use __dict__ to avoid infinite loop, as we override __getattribute__
|
||||
attr_acc = _NodeAttributeAccessor(attr, node_name, name, self._backend, self._index)
|
||||
|
||||
# NOTE: we use attribute name here to support get attribute value by name from python side.
|
||||
__dict__[name] = attr_acc
|
||||
|
||||
# Bind a value changed callback if available, named as _on_<attr name>_changed.
|
||||
cb_name = f"_on_{name}_changed"
|
||||
cb_func = getattr(self, cb_name, None)
|
||||
# Except list attribute.
|
||||
if not attr_acc._is_list:
|
||||
cb_name = f"_on_{name}_changed"
|
||||
cb_func = getattr(self, cb_name, None)
|
||||
|
||||
if cb_func is not None:
|
||||
attr_acc.on_value_changed(cb_func)
|
||||
if cb_func is not None:
|
||||
attr_acc.on_value_changed(cb_func)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Used to avoid attribute overriding, and an easy way to set for 1 slot attribute."""
|
||||
if self._is_deleted:
|
||||
raise BackendsAccessDeletedNodeException()
|
||||
|
||||
cdef dict __dict__ = self.__dict__
|
||||
cdef str attr_name = name
|
||||
|
||||
|
@ -231,10 +389,10 @@ cdef class NodeBase:
|
|||
attr_acc = __dict__[attr_name]
|
||||
|
||||
if isinstance(attr_acc, _NodeAttributeAccessor):
|
||||
if attr_acc.attr._slot_number > 1:
|
||||
if not attr_acc._is_list and attr_acc._slot_number > 1:
|
||||
raise BackendsArrayAttributeAccessException()
|
||||
else:
|
||||
# short-hand for attributes with 1 slot
|
||||
# Short-hand for attributes with 1 slot.
|
||||
attr_acc[0] = value
|
||||
else:
|
||||
__dict__[attr_name] = value
|
||||
|
@ -250,7 +408,11 @@ cdef class NodeBase:
|
|||
attr_acc = __dict__[attr_name]
|
||||
|
||||
if isinstance(attr_acc, _NodeAttributeAccessor):
|
||||
if attr_acc.attr._slot_number == 1:
|
||||
if self._is_deleted:
|
||||
raise BackendsAccessDeletedNodeException()
|
||||
|
||||
# For list attribute, we do not support ignore index.
|
||||
if not attr_acc._is_list and attr_acc._slot_number == 1:
|
||||
return attr_acc[0]
|
||||
|
||||
return attr_acc
|
||||
|
@ -259,17 +421,35 @@ cdef class NodeBase:
|
|||
|
||||
|
||||
cdef class FrameNode:
|
||||
def __cinit__(self, type node_cls, int number):
|
||||
def __cinit__(self, type node_cls, NODE_INDEX number):
|
||||
self._node_cls = node_cls
|
||||
self._number = number
|
||||
|
||||
|
||||
cdef class FrameBase:
|
||||
def __init__(self, enable_snapshot: bool = False, total_snapshot: int = 0, options: dict = {}):
|
||||
def __init__(self, enable_snapshot: bool = False, total_snapshot: int = 0, options: dict = {}, backend_name=None):
|
||||
# Backend name from parameter has highest priority.
|
||||
if backend_name is None:
|
||||
# Try to get default backend settings from environment settings, or use default.
|
||||
backend_name = os.environ.get("DEFAULT_BACKEND_NAME", _default_backend_name)
|
||||
|
||||
backend = backend_dict.get(backend_name, NumpyBackend)
|
||||
|
||||
self._backend_name = "static" if backend == NumpyBackend else "dynamic"
|
||||
|
||||
self._backend = backend()
|
||||
|
||||
self._node_cls_dict = {}
|
||||
self._node_origin_number_dict = {}
|
||||
self._node_name2attrname_dict = {}
|
||||
|
||||
self._setup_backend(enable_snapshot, total_snapshot, options)
|
||||
|
||||
@property
|
||||
def backend_type(self) -> str:
|
||||
"""str: Type of backend, static or dynamic."""
|
||||
return self._backend_name
|
||||
|
||||
@property
|
||||
def snapshots(self) -> SnapshotList:
|
||||
"""SnapshotList: Snapshots of this frame."""
|
||||
|
@ -291,10 +471,25 @@ cdef class FrameBase:
|
|||
"""
|
||||
self._backend.reset()
|
||||
|
||||
cpdef void take_snapshot(self, int tick) except *:
|
||||
cdef NodeBase node
|
||||
|
||||
if self._backend.is_support_dynamic_features():
|
||||
# We need to make sure node number same as origin after reset.
|
||||
for node_name, node_number in self._node_origin_number_dict.items():
|
||||
node_list = self.__dict__[self._node_name2attrname_dict[node_name]]
|
||||
|
||||
for i in range(len(node_list)-1, -1, -1):
|
||||
node = node_list[i]
|
||||
|
||||
if i >= node_number:
|
||||
del node_list[i]
|
||||
else:
|
||||
node._is_deleted = False
|
||||
|
||||
cpdef void take_snapshot(self, INT tick) except *:
|
||||
"""Take snapshot for specified point (tick) for current frame.
|
||||
|
||||
This method will copy current frame value into snapshot list for later using.
|
||||
This method will copy current frame value (except const attributes) into snapshot list for later using.
|
||||
|
||||
NOTE:
|
||||
Frame and SnapshotList do not know about snapshot_resolution from simulator,
|
||||
|
@ -322,62 +517,142 @@ cdef class FrameBase:
|
|||
if self._backend.snapshots is not None:
|
||||
self._backend.snapshots.enable_history(path)
|
||||
|
||||
cdef void _setup_backend(self, bool enable_snapshot, int total_snapshots, dict options) except *:
|
||||
cpdef void append_node(self, str node_name, NODE_INDEX number) except +:
|
||||
"""Append specified number of node instance to node type.
|
||||
|
||||
Args:
|
||||
node_name (str): Name of the node type to append.
|
||||
number (int): Number of node instance to append.
|
||||
"""
|
||||
cdef NODE_TYPE node_type
|
||||
cdef NodeBase node
|
||||
cdef NodeBase first_node
|
||||
cdef list node_list
|
||||
|
||||
if self._backend.is_support_dynamic_features() and number > 0:
|
||||
node_list = self.__dict__.get(self._node_name2attrname_dict[node_name], None)
|
||||
|
||||
if node_list is None:
|
||||
raise BackendsInvalidNodeException()
|
||||
|
||||
# Get the node type for furthur using.
|
||||
first_node = node_list[0]
|
||||
node_type = first_node._type
|
||||
|
||||
self._backend.append_node(node_type, number)
|
||||
|
||||
# Append instance to list.
|
||||
for i in range(number):
|
||||
node = self._node_cls_dict[node_name]()
|
||||
|
||||
node.setup(self._backend, len(node_list), node_type, first_node._attributes)
|
||||
|
||||
node_list.append(node)
|
||||
|
||||
cpdef void delete_node(self, NodeBase node) except +:
|
||||
"""Delete specified node instance, then any operation on this instance will cause error.
|
||||
|
||||
Args:
|
||||
node (NodeBase): Node instance to delete.
|
||||
"""
|
||||
if self._backend.is_support_dynamic_features():
|
||||
self._backend.delete_node(node._type, node._index)
|
||||
|
||||
node._is_deleted = True
|
||||
|
||||
cpdef void resume_node(self, NodeBase node) except +:
|
||||
"""Resume a deleted node instance, this will enable operations on this node instance.
|
||||
|
||||
Args:
|
||||
node (NodeBase): Node instance to resume.
|
||||
"""
|
||||
if self._backend.is_support_dynamic_features() and node._is_deleted:
|
||||
self._backend.resume_node(node._type, node._index)
|
||||
|
||||
node._is_deleted = False
|
||||
|
||||
def dump(self, folder: str):
|
||||
"""Dump data of current frame into specified folder.
|
||||
|
||||
Args:
|
||||
folder (str): Folder path to dump (without file name).
|
||||
"""
|
||||
if os.path.exists(folder):
|
||||
self._backend.dump(folder)
|
||||
|
||||
cdef void _setup_backend(self, bool enable_snapshot, USHORT total_snapshots, dict options) except *:
|
||||
"""Setup Frame for further using."""
|
||||
cdef str frame_attr_name
|
||||
cdef str node_attr_name
|
||||
cdef str node_name
|
||||
cdef NODE_TYPE node_type
|
||||
cdef ATTR_TYPE attr_type
|
||||
cdef type node_cls
|
||||
|
||||
cdef list node_instance_list
|
||||
# node name -> node number dict
|
||||
cdef dict node_name_num_dict = {}
|
||||
cdef int node_number
|
||||
|
||||
# Attr name -> type.
|
||||
cdef dict attr_name_type_dict = {}
|
||||
|
||||
# Node -> attr -> type.
|
||||
cdef dict node_attr_type_dict = {}
|
||||
cdef dict node_type_dict = {}
|
||||
cdef NODE_INDEX node_number
|
||||
cdef NodeBase node
|
||||
|
||||
# Internal loop indexer
|
||||
cdef int i
|
||||
|
||||
cdef list node_def_list = []
|
||||
# Internal loop indexer.
|
||||
cdef NODE_INDEX i
|
||||
|
||||
# Register node and attribute in backend.
|
||||
#for node_cls in self._node_def_list:
|
||||
for frame_attr_name, frame_attr in type(self).__dict__.items():
|
||||
# We only care about FrameNode instance.
|
||||
if isinstance(frame_attr, FrameNode):
|
||||
node_cls = frame_attr._node_cls
|
||||
node_number = frame_attr._number
|
||||
node_name = node_cls.__node_name__
|
||||
|
||||
# temp list to hold current node instances
|
||||
self._node_cls_dict[node_name] = node_cls
|
||||
self._node_origin_number_dict[node_name] = node_number
|
||||
|
||||
# Temp list to hold current node instances.
|
||||
node_instance_list = [None] * node_number
|
||||
|
||||
node_name_num_dict[node_name] = node_number
|
||||
# Register node.
|
||||
node_type = self._backend.add_node(node_name, node_number)
|
||||
|
||||
# register node
|
||||
self._backend.add_node(node_name, node_number)
|
||||
# Used to collect node type and its name, then initial snapshot list with different nodes.
|
||||
node_type_dict[node_name] = node_type
|
||||
|
||||
# register attribute
|
||||
# Used to collect attributes for current node, then initial node instance with it.
|
||||
attr_name_type_dict = {}
|
||||
|
||||
# Register attributes.
|
||||
for node_attr_name, node_attr in node_cls.__dict__.items():
|
||||
if isinstance(node_attr, NodeAttribute):
|
||||
self._backend.add_attr(node_name, node_attr_name, node_attr._dtype, node_attr._slot_number)
|
||||
attr_type = self._backend.add_attr(node_type, node_attr_name, node_attr._dtype, node_attr._slot_number, node_attr._is_const, node_attr._is_list)
|
||||
|
||||
# create instance
|
||||
attr_name_type_dict[node_attr_name] = attr_type
|
||||
|
||||
node_attr_type_dict[node_name] = attr_name_type_dict
|
||||
|
||||
# Create instance.
|
||||
for i in range(node_number):
|
||||
node = node_cls()
|
||||
|
||||
# pass the backend reference and index
|
||||
node.setup(self._backend, i)
|
||||
# Setup each node instance.
|
||||
node.setup(self._backend, i, node_type, attr_name_type_dict)
|
||||
|
||||
node_instance_list[i] = node
|
||||
|
||||
# add dynamic fields
|
||||
# Make it possible to get node instance list by their's name.
|
||||
self.__dict__[frame_attr_name] = node_instance_list
|
||||
self._node_name2attrname_dict[node_name] = frame_attr_name
|
||||
|
||||
# setup backend to allocate memory
|
||||
# Setup backend to allocate memory.
|
||||
self._backend.setup(enable_snapshot, total_snapshots, options)
|
||||
|
||||
if enable_snapshot:
|
||||
self._snapshot_list = SnapshotList(node_name_num_dict, self._backend.snapshots)
|
||||
self._snapshot_list = SnapshotList(node_type_dict, node_attr_type_dict, self._backend.snapshots)
|
||||
|
||||
def dump(self, filePath):
|
||||
self._backend.dump(filePath)
|
||||
|
@ -386,23 +661,23 @@ cdef class FrameBase:
|
|||
# All the slice interface will start from here to construct final parameters.
|
||||
cdef class SnapshotNode:
|
||||
cdef:
|
||||
# target node number, used for empty node list
|
||||
int _node_number
|
||||
# Target node id.
|
||||
NODE_TYPE _node_type
|
||||
|
||||
# target node name
|
||||
str _node_name
|
||||
# Attributes: name -> id.
|
||||
dict _attributes
|
||||
|
||||
# reference to snapshots for querying
|
||||
# Reference to snapshots for querying.
|
||||
SnapshotListAbc _snapshots
|
||||
|
||||
def __cinit__(self, str node_name, int node_number, SnapshotListAbc snapshots):
|
||||
self._node_name = node_name
|
||||
self._node_number = node_number
|
||||
def __cinit__(self, NODE_TYPE node_type, dict attributes, SnapshotListAbc snapshots):
|
||||
self._node_type = node_type
|
||||
self._snapshots = snapshots
|
||||
self._attributes = attributes
|
||||
|
||||
def __len__(self):
|
||||
"""Number of current node."""
|
||||
return self._node_number
|
||||
return self._snapshots.get_node_number(self._node_type)
|
||||
|
||||
def __getitem__(self, key: slice):
|
||||
"""Used to support states slice querying."""
|
||||
|
@ -415,7 +690,7 @@ cdef class SnapshotNode:
|
|||
cdef type stop_type = type(key.stop)
|
||||
cdef type step_type = type(key.step)
|
||||
|
||||
# ticks
|
||||
# Prepare ticks.
|
||||
if key.start is None:
|
||||
ticks = []
|
||||
elif start_type is tuple or start_type is list:
|
||||
|
@ -423,7 +698,7 @@ cdef class SnapshotNode:
|
|||
else:
|
||||
ticks.append(key.start)
|
||||
|
||||
# node id list
|
||||
# Prepare node index list.
|
||||
if key.stop is None:
|
||||
node_list = []
|
||||
elif stop_type is tuple or stop_type is list:
|
||||
|
@ -431,29 +706,41 @@ cdef class SnapshotNode:
|
|||
else:
|
||||
node_list.append(key.stop)
|
||||
|
||||
# Querying need at least one attribute.
|
||||
if key.step is None:
|
||||
return None
|
||||
|
||||
# attribute names
|
||||
# Prepare attribute names.
|
||||
if step_type is tuple or step_type is list:
|
||||
attr_list = list(key.step)
|
||||
else:
|
||||
attr_list = [key.step]
|
||||
|
||||
return self._snapshots.query(self._node_name, ticks, node_list, attr_list)
|
||||
cdef str attr_name
|
||||
cdef list attr_type_list = []
|
||||
|
||||
# Make sure all attributes exist.
|
||||
for attr_name in attr_list:
|
||||
if attr_name not in self._attributes:
|
||||
raise BackendsInvalidAttributeException()
|
||||
|
||||
attr_type_list.append(self._attributes[attr_name])
|
||||
|
||||
return self._snapshots.query(self._node_type, ticks, node_list, attr_type_list)
|
||||
|
||||
|
||||
cdef class SnapshotList:
|
||||
def __cinit__(self, dict node_name_num_dict, SnapshotListAbc snapshots):
|
||||
def __cinit__(self, dict node_type_dict, dict node_attr_type_dict, SnapshotListAbc snapshots):
|
||||
cdef str node_name
|
||||
cdef int node_number
|
||||
cdef NODE_TYPE node_type
|
||||
|
||||
self._snapshots = snapshots
|
||||
|
||||
self._nodes_dict = {}
|
||||
|
||||
for node_name, node_number in node_name_num_dict.items():
|
||||
self._nodes_dict[node_name] = SnapshotNode(node_name, node_number, snapshots)
|
||||
# Initial for each node type.
|
||||
for node_name, node_type in node_type_dict.items():
|
||||
self._nodes_dict[node_name] = SnapshotNode(node_type, node_attr_type_dict[node_name], snapshots)
|
||||
|
||||
def get_frame_index_list(self)->list:
|
||||
"""Get list of available frame index in snapshot list.
|
||||
|
@ -476,3 +763,12 @@ cdef class SnapshotList:
|
|||
def reset(self):
|
||||
"""Reset current states, this will cause all the values to be 0, make sure call it after states querying."""
|
||||
self._snapshots.reset()
|
||||
|
||||
def dump(self, folder: str):
|
||||
"""Dump data of current snapshots into specified folder.
|
||||
|
||||
Args:
|
||||
folder (str): Folder path to dump (without file name).
|
||||
"""
|
||||
if os.path.exists(folder):
|
||||
self._snapshots.dump(folder)
|
||||
|
|
|
@ -2,29 +2,29 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
#distutils: language = c++
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
|
||||
from cpython cimport bool
|
||||
from maro.backends.backend cimport BackendAbc, SnapshotListAbc
|
||||
from maro.backends.backend cimport BackendAbc, SnapshotListAbc, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX
|
||||
|
||||
|
||||
cdef class NumpyBackend(BackendAbc):
|
||||
"""Backend using numpy array to hold data, this backend only support fixed size array for now"""
|
||||
cdef:
|
||||
# used to store real data, key is node name, value is np.ndarray
|
||||
dict _node_data_dict
|
||||
|
||||
# node name -> node number in frame
|
||||
dict _node_num_dict
|
||||
# Used to store node information, index is the id (IDENTIFIER), value if NodeInfo
|
||||
list _nodes_list
|
||||
list _attrs_list
|
||||
|
||||
# used to cache attribute by node name
|
||||
# node name -> list of (name, type, slot), used to construct numpy structure array
|
||||
# node id -> list of attribute id, used to construct numpy structure array
|
||||
dict _node_attr_dict
|
||||
|
||||
# quick look up table to query with (node_name, attr_name) -> AttrInfo
|
||||
dict _node_attr_lut
|
||||
# Used to store real data, key is node id, value is np.ndarray
|
||||
dict _node_data_dict
|
||||
|
||||
bool _is_snapshot_enabled
|
||||
|
||||
|
@ -35,7 +35,6 @@ cdef class NumpyBackend(BackendAbc):
|
|||
size_t _data_size
|
||||
|
||||
|
||||
|
||||
cdef class NPBufferedMmap:
|
||||
"""Used to dump snapshot history using memory mapping with a fixed size in-memory buffer"""
|
||||
cdef:
|
||||
|
@ -57,7 +56,7 @@ cdef class NPBufferedMmap:
|
|||
# memory mapping np array
|
||||
np.ndarray _data_arr
|
||||
|
||||
cdef void reload(self) except *
|
||||
cdef void reload(self) except +
|
||||
|
||||
|
||||
cdef class NPSnapshotList(SnapshotListAbc):
|
||||
|
@ -65,10 +64,12 @@ cdef class NPSnapshotList(SnapshotListAbc):
|
|||
cdef:
|
||||
NumpyBackend _backend
|
||||
|
||||
# tick -> index mapping
|
||||
dict _node_name2type_dict
|
||||
|
||||
# frame_index -> index mapping
|
||||
dict _tick2index_dict
|
||||
|
||||
# index -> tick mapping
|
||||
# index -> old_frame_index mapping
|
||||
dict _index2tick_dict
|
||||
|
||||
# current index to insert snapshot, default should be 1, never be 0
|
||||
|
@ -82,4 +83,4 @@ cdef class NPSnapshotList(SnapshotListAbc):
|
|||
# key: node name, value: history buffer
|
||||
dict _history_dict
|
||||
|
||||
cdef void enable_history(self, str history_folder) except *
|
||||
cdef void enable_history(self, str history_folder) except +
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
|
||||
#distutils: language = c++
|
||||
#distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
|
||||
|
||||
import os
|
||||
|
||||
|
@ -11,7 +12,23 @@ cimport numpy as np
|
|||
cimport cython
|
||||
|
||||
from cpython cimport bool
|
||||
from maro.backends.backend cimport BackendAbc, SnapshotListAbc
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, UINT, ULONG, USHORT, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
|
||||
|
||||
|
||||
# Attribute data type mapping.
|
||||
attribute_type_mapping = {
|
||||
AttributeType.Byte: "b",
|
||||
AttributeType.UByte: "B",
|
||||
AttributeType.Short: "h",
|
||||
AttributeType.UShort: "H",
|
||||
AttributeType.Int: "i",
|
||||
AttributeType.UInt: "I",
|
||||
AttributeType.Long: "q",
|
||||
AttributeType.ULong: "Q",
|
||||
AttributeType.Float: "f",
|
||||
AttributeType.Double: "d"
|
||||
}
|
||||
|
||||
|
||||
IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
|
||||
|
@ -63,23 +80,43 @@ cdef class NPBufferedMmap:
|
|||
if self._current_record_number >= self._buffer_size:
|
||||
self.reload()
|
||||
|
||||
cdef void reload(self) except *:
|
||||
cdef void reload(self) except +:
|
||||
"""Reload the file with offset to avoid memmap size limitation"""
|
||||
self._data_arr = np.memmap(self._path, self._dtype, "w+", offset=self._offset, shape=(self._buffer_size, self._node_number))
|
||||
|
||||
self._offset += self._dtype.itemsize * self._buffer_size * self._node_number
|
||||
|
||||
|
||||
cdef class NodeInfo:
|
||||
"""Internal structure to hold node info."""
|
||||
cdef:
|
||||
public NODE_TYPE type
|
||||
public str name
|
||||
public NODE_INDEX number
|
||||
|
||||
def __cinit__(self, str name, NODE_TYPE type, NODE_INDEX number):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.number = number
|
||||
|
||||
def __repr__(self):
|
||||
return f"<NodeInfo name: {self.name}, type: {self.type}, number: {self.number}>"
|
||||
|
||||
|
||||
cdef class AttrInfo:
|
||||
"""Internal structure to hold attribute info"""
|
||||
cdef:
|
||||
public str name
|
||||
public str dtype
|
||||
public int slot_number
|
||||
public ATTR_TYPE type
|
||||
public NODE_TYPE node_type
|
||||
public SLOT_INDEX slot_number
|
||||
|
||||
def __cinit__(self, str name, str dtype, int slot_number):
|
||||
def __cinit__(self, str name, ATTR_TYPE type, NODE_TYPE node_type, str dtype, SLOT_INDEX slot_number):
|
||||
self.name = name
|
||||
self.dtype = dtype
|
||||
self.type = type
|
||||
self.node_type = node_type
|
||||
self.slot_number = slot_number
|
||||
|
||||
def gen_numpy_dtype(self):
|
||||
|
@ -89,119 +126,130 @@ cdef class AttrInfo:
|
|||
else:
|
||||
return (self.name, self.dtype, self.slot_number)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AttrInfo name: {self.name}, type: {self.type}, node_type: {self.node_type}, slot_number: {self.slot_number}>"
|
||||
|
||||
cdef class NumpyBackend(BackendAbc):
|
||||
def __cinit__(self):
|
||||
self._node_num_dict = {}
|
||||
self._nodes_list = []
|
||||
self._attrs_list = []
|
||||
self._node_attr_dict = {}
|
||||
self._node_data_dict = {}
|
||||
|
||||
# 2 dict for attribute for different scenario querying
|
||||
self._node_attr_dict = {} # node_name -> attribute list
|
||||
self._node_attr_lut = {} # (node_name, attr_name) -> attribute
|
||||
|
||||
def __dealloc__(self):
|
||||
"""Clear resources before deleted"""
|
||||
IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
|
||||
self._node_data_dict = None
|
||||
|
||||
PyMem_Free(self._data)
|
||||
ELSE:
|
||||
pass
|
||||
|
||||
cdef dict get_node_info(self):
|
||||
cdef str node_name
|
||||
cdef int node_number
|
||||
cdef dict node_info = {}
|
||||
cdef list attrs
|
||||
|
||||
for node_name, node_number in self._node_num_dict.items():
|
||||
attrs = self._node_attr_dict[node_name]
|
||||
|
||||
node_info[node_name]= {
|
||||
"number": node_number,
|
||||
"attributes": {attr.name:
|
||||
{
|
||||
"type": attr.dtype,
|
||||
"slots": attr.slot_number
|
||||
} for attr in attrs}
|
||||
}
|
||||
|
||||
return node_info
|
||||
|
||||
cdef void add_node(self, str name, int number) except *:
|
||||
cdef NODE_TYPE add_node(self, str name, NODE_INDEX number) except +:
|
||||
"""Add a new node type with name and number in backend"""
|
||||
# TODO: less than 1 checking
|
||||
self._node_num_dict[name] = number
|
||||
self._node_attr_dict[name] = []
|
||||
cdef NodeInfo new_node = NodeInfo(name, len(self._nodes_list), number)
|
||||
|
||||
cdef void add_attr(self, str node_name, str attr_name, str dtype, int slot_num) except *:
|
||||
self._nodes_list.append(new_node)
|
||||
self._node_attr_dict[new_node.type] = []
|
||||
|
||||
return new_node.type
|
||||
|
||||
cdef ATTR_TYPE add_attr(self, NODE_TYPE node_type, str attr_name, bytes dtype, SLOT_INDEX slot_num, bool is_const, bool is_list) except +:
|
||||
"""Add a new attribute for specified node with data type and slot number"""
|
||||
# TODO: type checking, slot_number checking
|
||||
cdef AttrInfo ai = AttrInfo(attr_name, dtype, slot_num)
|
||||
if node_type >= len(self._nodes_list):
|
||||
raise Exception("Invalid node type.")
|
||||
|
||||
self._node_attr_dict[node_name].append(ai)
|
||||
cdef str _dtype = attribute_type_mapping[dtype]
|
||||
|
||||
self._node_attr_lut[(node_name, attr_name)] = ai
|
||||
cdef NodeInfo node = self._nodes_list[node_type]
|
||||
cdef AttrInfo new_attr = AttrInfo(attr_name, len(self._attrs_list), node.type, dtype.decode(), slot_num)
|
||||
|
||||
cdef void set_attr_value(self, str node_name, int node_index, str attr_name, int slot_index, value) except *:
|
||||
self._attrs_list.append(new_attr)
|
||||
self._node_attr_dict[node_type].append(new_attr)
|
||||
|
||||
return new_attr.type
|
||||
|
||||
cdef void set_attr_value(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, object value) except +:
|
||||
"""Set specified attribute value"""
|
||||
cdef np.ndarray attr_array = self._node_data_dict[node_name][attr_name]
|
||||
cdef AttrInfo attr = self._node_attr_lut[(node_name, attr_name)]
|
||||
if attr_type >= len(self._attrs_list):
|
||||
raise Exception("Invalid attribute type.")
|
||||
|
||||
cdef AttrInfo attr = self._attrs_list[attr_type]
|
||||
|
||||
if attr.node_type >= len(self._nodes_list):
|
||||
raise Exception("Invalid node type.")
|
||||
|
||||
cdef NodeInfo node = self._nodes_list[attr.node_type]
|
||||
|
||||
if node_index >= node.number:
|
||||
raise Exception("Invalid node index.")
|
||||
|
||||
cdef np.ndarray attr_array = self._node_data_dict[attr.node_type][attr.name]
|
||||
|
||||
if attr.slot_number > 1:
|
||||
attr_array[0][node_index, slot_index] = value
|
||||
else:
|
||||
attr_array[0][node_index] = value
|
||||
|
||||
cdef object get_attr_value(self, str node_name, int node_index, str attr_name, int slot_index):
|
||||
cdef object get_attr_value(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index) except +:
|
||||
"""Get specified attribute value"""
|
||||
cdef np.ndarray attr_array = self._node_data_dict[node_name][attr_name]
|
||||
cdef AttrInfo attr = self._node_attr_lut[(node_name, attr_name)]
|
||||
if attr_type >= len(self._attrs_list):
|
||||
raise Exception("Invalid attribute type.")
|
||||
|
||||
cdef AttrInfo attr = self._attrs_list[attr_type]
|
||||
|
||||
if attr.node_type >= len(self._nodes_list):
|
||||
raise Exception("Invalid node type.")
|
||||
|
||||
cdef NodeInfo node = self._nodes_list[attr.node_type]
|
||||
|
||||
if node_index >= node.number:
|
||||
raise Exception("Invalid node index.")
|
||||
|
||||
cdef np.ndarray attr_array = self._node_data_dict[attr.node_type][attr.name]
|
||||
|
||||
if attr.slot_number > 1:
|
||||
return attr_array[0][node_index, slot_index]
|
||||
else:
|
||||
return attr_array[0][node_index]
|
||||
|
||||
cdef void set_attr_values(self, str node_name, int node_index, str attr_name, int[:] slot_index, list value) except *:
|
||||
cdef np.ndarray attr_array = self._node_data_dict[node_name][attr_name]
|
||||
cdef AttrInfo attr = self._node_attr_lut[(node_name, attr_name)]
|
||||
cdef void set_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX[:] slot_index, list value) except +:
|
||||
cdef AttrInfo attr = self._attrs_list[attr_type]
|
||||
cdef np.ndarray attr_array = self._node_data_dict[attr.node_type][attr.name]
|
||||
|
||||
if attr.slot_number == 1:
|
||||
attr_array[0][node_index, slot_index[0]] = value[0]
|
||||
else:
|
||||
attr_array[0][node_index, slot_index] = value
|
||||
|
||||
cdef object[object, ndim=1] get_attr_values(self, str node_name, int node_index, str attr_name, int[:] slot_indices):
|
||||
cdef np.ndarray attr_array = self._node_data_dict[node_name][attr_name]
|
||||
cdef AttrInfo attr = self._node_attr_lut[(node_name, attr_name)]
|
||||
cdef list get_attr_values(self, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX[:] slot_indices) except +:
|
||||
cdef AttrInfo attr = self._attrs_list[attr_type]
|
||||
cdef np.ndarray attr_array = self._node_data_dict[attr.node_type][attr.name]
|
||||
|
||||
if attr.slot_number == 1:
|
||||
return attr_array[0][node_index, slot_indices[0]]
|
||||
return attr_array[0][node_index, slot_indices[0]].tolist()
|
||||
else:
|
||||
return attr_array[0][node_index, slot_indices]
|
||||
return attr_array[0][node_index, slot_indices].tolist()
|
||||
|
||||
|
||||
cdef void setup(self, bool enable_snapshot, int total_snapshot, dict options) except *:
|
||||
cdef void setup(self, bool enable_snapshot, USHORT total_snapshot, dict options) except +:
|
||||
"""Set up the numpy backend"""
|
||||
self._is_snapshot_enabled = enable_snapshot
|
||||
|
||||
cdef int snapshot_number = 0
|
||||
cdef UINT snapshot_number = 0
|
||||
cdef str node_name
|
||||
cdef NODE_TYPE node_type
|
||||
cdef ATTR_TYPE attr_type
|
||||
cdef list node_attrs
|
||||
cdef np.dtype data_type
|
||||
cdef int node_number
|
||||
cdef UINT node_number
|
||||
cdef AttrInfo ai
|
||||
cdef NodeInfo ni
|
||||
cdef tuple shape
|
||||
cdef int max_tick = 0
|
||||
cdef UINT max_tick = 0
|
||||
|
||||
IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
|
||||
self._data_size = 0 # total memory size we need to hold nodes in both frame and snapshot list
|
||||
node_info = {} # temp node information, as we need several steps to build backend
|
||||
# Total memory size we need to hold nodes in both frame and snapshot list
|
||||
self._data_size = 0
|
||||
# Temp node information, as we need several steps to build backend
|
||||
node_info = {}
|
||||
|
||||
for node_name, node_attrs in self._node_attr_dict.items():
|
||||
node_number = self._node_num_dict[node_name]
|
||||
dtype = np.dtype([ai.gen_numpy_dtype() for ai in node_attrs])
|
||||
for node_type, node_attrs in self._node_attr_dict.items():
|
||||
ni = self._nodes_list[node_type]
|
||||
|
||||
node_number = ni.number
|
||||
|
||||
data_type = np.dtype([ai.gen_numpy_dtype() for ai in node_attrs])
|
||||
|
||||
# for each node, we keep frame and snapshot in one big numpy array
|
||||
# 1st slot is the node's frame data
|
||||
|
@ -220,12 +268,12 @@ cdef class NumpyBackend(BackendAbc):
|
|||
# NOTE: we have to keep data type here, or it will be collected by GC at sometime,
|
||||
# then will cause numpy array cannot get the reference
|
||||
# , we will increase he reference later
|
||||
node_info[node_name] = (shape, dtype, self._data_size)
|
||||
node_info[node_type] = (shape, data_type, self._data_size)
|
||||
|
||||
self._data_size += shape[0] * shape[1] * dtype.itemsize
|
||||
self._data_size += shape[0] * shape[1] * data_type.itemsize
|
||||
ELSE:
|
||||
# one memory block for each node
|
||||
self._node_data_dict[node_name] = np.zeros(shape, dtype)
|
||||
self._node_data_dict[node_type] = np.zeros(shape, data_type)
|
||||
|
||||
IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
|
||||
# allocate memory, and construct numpy array with numpy c api
|
||||
|
@ -239,39 +287,70 @@ cdef class NumpyBackend(BackendAbc):
|
|||
cdef int offset
|
||||
cdef np.npy_intp np_dims[2]
|
||||
|
||||
for node_name, info in node_info.items():
|
||||
for node_type, info in node_info.items():
|
||||
shape = info[0]
|
||||
dtype = info[1]
|
||||
data_type = info[1]
|
||||
offset = info[2]
|
||||
|
||||
np_dims[0] = shape[0]
|
||||
np_dims[1] = shape[1]
|
||||
|
||||
self._node_data_dict[node_name] = PyArray_NewFromDescr(&PyArray_Type, dtype, 2, np_dims, NULL, &self._data[offset], np.NPY_ARRAY_C_CONTIGUOUS | np.NPY_ARRAY_WRITEABLE, None)
|
||||
self._node_data_dict[node_type] = PyArray_NewFromDescr(&PyArray_Type, data_type, 2, np_dims, NULL, &self._data[offset], np.NPY_ARRAY_C_CONTIGUOUS | np.NPY_ARRAY_WRITEABLE, None)
|
||||
|
||||
# NOTE: we have to increate the reference count of related dtype,
|
||||
# or it will cause seg fault
|
||||
Py_INCREF(dtype)
|
||||
Py_INCREF(data_type)
|
||||
|
||||
if enable_snapshot:
|
||||
self.snapshots = NPSnapshotList(self, snapshot_number + 1)
|
||||
|
||||
cdef void reset(self) except *:
|
||||
def __dealloc__(self):
|
||||
"""Clear resources before deleted"""
|
||||
IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
|
||||
self._node_data_dict = None
|
||||
|
||||
PyMem_Free(self._data)
|
||||
ELSE:
|
||||
pass
|
||||
|
||||
cdef dict get_node_info(self) except +:
|
||||
cdef dict node_info = {}
|
||||
|
||||
cdef NODE_TYPE node_type
|
||||
cdef list node_attrs
|
||||
|
||||
for node_type, node_attrs in self._node_attr_dict.items():
|
||||
node = self._nodes_list[node_type]
|
||||
|
||||
node_info[node.name] = {
|
||||
"number": node.number,
|
||||
"attributes": {
|
||||
attr.name: {
|
||||
"type": attr.dtype,
|
||||
"slots": attr.slot_number
|
||||
} for attr in node_attrs
|
||||
}
|
||||
}
|
||||
|
||||
return node_info
|
||||
|
||||
|
||||
cdef void reset(self) except +:
|
||||
"""Reset all the attributes value"""
|
||||
cdef str node_name
|
||||
cdef NODE_TYPE node_type
|
||||
cdef AttrInfo attr_info
|
||||
cdef np.ndarray data_arr
|
||||
|
||||
for node_name, data_arr in self._node_data_dict.items():
|
||||
for node_type, data_arr in self._node_data_dict.items():
|
||||
# we have to reset by each attribute
|
||||
for attr_info in self._node_attr_dict[node_name]:
|
||||
for attr_info in self._node_attr_dict[node_type]:
|
||||
# we only reset frame here, without snapshot list
|
||||
data_arr[0][attr_info.name] = 0
|
||||
|
||||
cdef void dump(self, str filePath):
|
||||
cdef void dump(self, str folder) except +:
|
||||
for node_name, data_arr in self._node_data_dict.items():
|
||||
filename = os.path.join(filePath, node_name + ".npy")
|
||||
descFilename = os.path.join(filePath, node_name + ".meta")
|
||||
filename = os.path.join(folder, node_name + ".npy")
|
||||
descFilename = os.path.join(folder, node_name + ".meta")
|
||||
with open(filename, "wb+") as f:
|
||||
np.save(f, data_arr)
|
||||
with open(descFilename, "wt+") as f:
|
||||
|
@ -290,20 +369,30 @@ cdef class NPSnapshotList(SnapshotListAbc):
|
|||
|
||||
self._tick2index_dict = {}
|
||||
self._index2tick_dict = {}
|
||||
self._node_name2type_dict = {}
|
||||
self._cur_index = 0
|
||||
self._max_size = max_size
|
||||
self._is_history_enabled = False
|
||||
self._history_dict = {}
|
||||
|
||||
cdef list get_frame_index_list(self):
|
||||
return list(self._index2tick_dict.values())
|
||||
for node in backend._nodes_list:
|
||||
self._node_name2type_dict[node.name] = node.type
|
||||
|
||||
cdef void take_snapshot(self, int tick) except *:
|
||||
cdef NODE_INDEX get_node_number(self, NODE_TYPE node_type) except +:
|
||||
cdef NodeInfo node = self._backend._nodes_list[node_type]
|
||||
|
||||
return node.number
|
||||
|
||||
cdef list get_frame_index_list(self) except +:
|
||||
return list(self._tick2index_dict.keys())
|
||||
|
||||
cdef void take_snapshot(self, INT tick) except +:
|
||||
"""Take snapshot for current backend"""
|
||||
cdef str node_name
|
||||
cdef NODE_TYPE node_type
|
||||
cdef NodeInfo ni
|
||||
cdef np.ndarray data_arr
|
||||
cdef int target_index = 0
|
||||
cdef int old_tick # old tick to be removed
|
||||
cdef UINT target_index = 0
|
||||
cdef INT old_tick # old tick to be removed
|
||||
|
||||
# check if we are overriding exist snapshot, or not inserted yet
|
||||
if tick not in self._tick2index_dict:
|
||||
|
@ -325,24 +414,23 @@ cdef class NPSnapshotList(SnapshotListAbc):
|
|||
del self._tick2index_dict[old_tick]
|
||||
|
||||
# recording will copy data at 1st row into _cur_index row
|
||||
for node_name, data_arr in self._backend._node_data_dict.items():
|
||||
for node_type, data_arr in self._backend._node_data_dict.items():
|
||||
ni = self._backend._nodes_list[node_type]
|
||||
data_arr[target_index] = data_arr[0]
|
||||
|
||||
if self._is_history_enabled:
|
||||
self._history_dict[node_name].record(data_arr[0])
|
||||
self._history_dict[ni.name].record(data_arr[0])
|
||||
|
||||
self._index2tick_dict[target_index] = tick
|
||||
|
||||
self._tick2index_dict[tick] = target_index
|
||||
|
||||
cdef query(self, str node_name, list ticks, list node_index_list, list attr_name_list):
|
||||
cdef int tick
|
||||
cdef int data_index
|
||||
cdef int node_index
|
||||
cdef str attr_name
|
||||
cdef query(self, NODE_TYPE node_type, list ticks, list node_index_list, list attr_list) except +:
|
||||
cdef UINT tick
|
||||
cdef NODE_INDEX node_index
|
||||
cdef ATTR_TYPE attr_type
|
||||
cdef AttrInfo attr
|
||||
|
||||
cdef np.ndarray data_arr = self._backend._node_data_dict[node_name]
|
||||
cdef np.ndarray data_arr = self._backend._node_data_dict[node_type]
|
||||
|
||||
# TODO: how about use a pre-allocate np array instead concat?
|
||||
cdef list retq = []
|
||||
|
@ -351,56 +439,59 @@ cdef class NPSnapshotList(SnapshotListAbc):
|
|||
ticks = [t for t in self._tick2index_dict.keys()][-(self._max_size-1):]
|
||||
|
||||
if len(node_index_list) == 0:
|
||||
node_index_list = [i for i in range(self._backend._node_num_dict[node_name])]
|
||||
node_index_list = [i for i in range(self._backend._nodes_list[node_type].number)]
|
||||
|
||||
# querying by tick attribute
|
||||
for tick in ticks:
|
||||
for node_index in node_index_list:
|
||||
for attr_name in attr_name_list:
|
||||
for attr_type in attr_list:
|
||||
attr = self._backend._attrs_list[attr_type]
|
||||
|
||||
# since we have a clear tick to index mapping, do not need additional checking here
|
||||
if tick in self._tick2index_dict:
|
||||
retq.append(data_arr[attr_name][self._tick2index_dict[tick], node_index].astype("f").flatten())
|
||||
retq.append(data_arr[attr.name][self._tick2index_dict[tick], node_index].astype("f").flatten())
|
||||
else:
|
||||
# padding for tick which not exist
|
||||
attr = self._backend._node_attr_lut[(node_name, attr_name)]
|
||||
retq.append(np.zeros(attr.slot_number, dtype="f"))
|
||||
retq.append(np.zeros(attr.slot_number, dtype='f'))
|
||||
|
||||
return np.concatenate(retq)
|
||||
|
||||
cdef void enable_history(self, str history_folder) except *:
|
||||
cdef void enable_history(self, str history_folder) except +:
|
||||
"""Enable history recording, used to save all the snapshots into file"""
|
||||
if self._is_history_enabled:
|
||||
return
|
||||
|
||||
self._is_history_enabled = True
|
||||
|
||||
cdef str node_name
|
||||
cdef NODE_TYPE node_type
|
||||
cdef NodeInfo ni
|
||||
cdef str dump_path
|
||||
cdef np.ndarray data_arr
|
||||
|
||||
for node_name, data_arr in self._backend._node_data_dict.items():
|
||||
dump_path = os.path.join(history_folder, f"{node_name}.bin")
|
||||
for node_type, data_arr in self._backend._node_data_dict.items():
|
||||
ni = self._backend._nodes_list[node_type]
|
||||
dump_path = os.path.join(history_folder, f"{ni.name}.bin")
|
||||
|
||||
self._history_dict[node_name] = NPBufferedMmap(dump_path, data_arr.dtype, self._backend._node_num_dict[node_name])
|
||||
self._history_dict[ni.name] = NPBufferedMmap(dump_path, data_arr.dtype, ni.number)
|
||||
|
||||
cdef void reset(self) except *:
|
||||
cdef void reset(self) except +:
|
||||
"""Reset snapshot list"""
|
||||
self._cur_index = 0
|
||||
self._tick2index_dict.clear()
|
||||
self._index2tick_dict.clear()
|
||||
self._history_dict.clear()
|
||||
|
||||
cdef str node_name
|
||||
cdef NODE_TYPE node_type
|
||||
cdef AttrInfo attr_info
|
||||
cdef np.ndarray data_arr
|
||||
|
||||
for node_name, data_arr in self._backend._node_data_dict.items():
|
||||
for node_type, data_arr in self._backend._node_data_dict.items():
|
||||
# we have to reset by each attribute
|
||||
for attr_info in self._backend._node_attr_dict[node_name]:
|
||||
for attr_info in self._backend._node_attr_dict[node_type]:
|
||||
# we only reset frame here, without snapshot list
|
||||
data_arr[1:][attr_info.name] = 0
|
||||
|
||||
# NOTE: we do not reset the history file here, so the file will keep increasing
|
||||
|
||||
def __len__(self):
|
||||
return self._max_size - 1
|
||||
return len(self._index2tick_dict)
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "attribute.h"
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
Attribute::Attribute() noexcept
|
||||
{
|
||||
memset(_data, 0, ATTRIBUTE_DATA_LENGTH);
|
||||
}
|
||||
|
||||
// Macro for all type of constructors.
|
||||
#define CONSTRUCTOR(data_type, type_name) \
|
||||
Attribute::Attribute(data_type value) noexcept \
|
||||
{ \
|
||||
memcpy(_data, &value, sizeof(data_type)); \
|
||||
_type = type_name; \
|
||||
}
|
||||
|
||||
CONSTRUCTOR(ATTR_CHAR, AttrDataType::ACHAR)
|
||||
CONSTRUCTOR(ATTR_UCHAR, AttrDataType::AUCHAR)
|
||||
CONSTRUCTOR(ATTR_SHORT, AttrDataType::ASHORT)
|
||||
CONSTRUCTOR(ATTR_USHORT, AttrDataType::AUSHORT)
|
||||
CONSTRUCTOR(ATTR_INT, AttrDataType::AINT)
|
||||
CONSTRUCTOR(ATTR_UINT, AttrDataType::AUINT)
|
||||
CONSTRUCTOR(ATTR_LONG, AttrDataType::ALONG)
|
||||
CONSTRUCTOR(ATTR_ULONG, AttrDataType::AULONG)
|
||||
CONSTRUCTOR(ATTR_FLOAT, AttrDataType::AFLOAT)
|
||||
CONSTRUCTOR(ATTR_DOUBLE, AttrDataType::ADOUBLE)
|
||||
|
||||
AttrDataType Attribute::get_type() const noexcept
|
||||
{
|
||||
return _type;
|
||||
}
|
||||
|
||||
Attribute::operator QUERY_FLOAT() const
|
||||
{
|
||||
switch (_type)
|
||||
{
|
||||
case AttrDataType::AUCHAR: { return QUERY_FLOAT(get_value<ATTR_UCHAR>()); }
|
||||
case AttrDataType::AUSHORT: { return QUERY_FLOAT(get_value<ATTR_USHORT>()); }
|
||||
case AttrDataType::AUINT: { return QUERY_FLOAT(get_value<ATTR_UINT>()); }
|
||||
case AttrDataType::AULONG: { return QUERY_FLOAT(get_value<ATTR_ULONG>()); }
|
||||
case AttrDataType::ACHAR: { return QUERY_FLOAT(get_value<ATTR_CHAR>()); }
|
||||
case AttrDataType::ASHORT: { return QUERY_FLOAT(get_value<ATTR_SHORT>()); }
|
||||
case AttrDataType::AINT: { return QUERY_FLOAT(get_value<ATTR_INT>()); }
|
||||
case AttrDataType::ALONG: { return QUERY_FLOAT(get_value<ATTR_LONG>()); }
|
||||
case AttrDataType::AFLOAT: { return QUERY_FLOAT(get_value<ATTR_FLOAT>()); }
|
||||
case AttrDataType::ADOUBLE: { return QUERY_FLOAT(get_value<ATTR_DOUBLE>()); }
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
throw AttributeInvalidDataTypeError();
|
||||
}
|
||||
|
||||
bool Attribute::is_nan() const noexcept
|
||||
{
|
||||
return _type == AttrDataType::AFLOAT && isnan(get_value<ATTR_FLOAT>());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
typename Attribute_Trait<T>::type Attribute::get_value() const noexcept
|
||||
{
|
||||
T value = T();
|
||||
|
||||
// NOTE: we do not check type here, if the type not match, will get invalid value.
|
||||
memcpy(&value, _data, sizeof(T));
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
// Macro for attribute getter template.
|
||||
#define GETTER(type) template type Attribute::get_value<type>() const noexcept;
|
||||
|
||||
GETTER(ATTR_CHAR)
|
||||
GETTER(ATTR_UCHAR)
|
||||
GETTER(ATTR_SHORT)
|
||||
GETTER(ATTR_USHORT)
|
||||
GETTER(ATTR_INT)
|
||||
GETTER(ATTR_UINT)
|
||||
GETTER(ATTR_LONG)
|
||||
GETTER(ATTR_ULONG)
|
||||
GETTER(ATTR_FLOAT)
|
||||
GETTER(ATTR_DOUBLE)
|
||||
|
||||
Attribute& Attribute::operator=(const Attribute& attr) noexcept
|
||||
{
|
||||
if (this != &attr)
|
||||
{
|
||||
_type = attr._type;
|
||||
|
||||
memcpy(_data, attr._data, ATTRIBUTE_DATA_LENGTH);
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Macro for setters.
|
||||
#define SETTER(data_type, value_type) \
|
||||
Attribute& Attribute::operator=(data_type value) noexcept \
|
||||
{ \
|
||||
memcpy(_data, &value, sizeof(data_type)); \
|
||||
_type = value_type; \
|
||||
return *this; \
|
||||
}
|
||||
|
||||
SETTER(ATTR_CHAR, AttrDataType::ACHAR)
|
||||
SETTER(ATTR_UCHAR, AttrDataType::AUCHAR)
|
||||
SETTER(ATTR_SHORT, AttrDataType::ASHORT)
|
||||
SETTER(ATTR_USHORT, AttrDataType::AUSHORT)
|
||||
SETTER(ATTR_INT, AttrDataType::AINT)
|
||||
SETTER(ATTR_UINT, AttrDataType::AUINT)
|
||||
SETTER(ATTR_LONG, AttrDataType::ALONG)
|
||||
SETTER(ATTR_ULONG, AttrDataType::AULONG)
|
||||
SETTER(ATTR_FLOAT, AttrDataType::AFLOAT)
|
||||
SETTER(ATTR_DOUBLE, AttrDataType::ADOUBLE)
|
||||
|
||||
|
||||
const char* AttributeInvalidDataTypeError::what() const noexcept
|
||||
{
|
||||
return "Invalid attribute data type.";
|
||||
}
|
||||
} // namespace raw
|
||||
} // namespace backends
|
||||
} // namespace maro
|
|
@ -0,0 +1,111 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef _MARO_BACKEND_RAW_ATTRIBUTE_
|
||||
#define _MARO_BACKEND_RAW_ATTRIBUTE_
|
||||
|
||||
#include <string>
|
||||
#include <math.h>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
// Length of the attribute data.
|
||||
const int ATTRIBUTE_DATA_LENGTH = 8;
|
||||
|
||||
/// <summary>
|
||||
/// Trait struct to support getter template.
|
||||
/// </summary>
|
||||
template<typename T>
|
||||
struct Attribute_Trait
|
||||
{
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Attribute for a node, used to hold all supported data type.
|
||||
/// </summary>
|
||||
class Attribute
|
||||
{
|
||||
// Chars to hold all data we supported.
|
||||
char _data[ATTRIBUTE_DATA_LENGTH];
|
||||
|
||||
// Type of current attribute, defalut is char
|
||||
AttrDataType _type = AttrDataType::ACHAR;
|
||||
|
||||
public:
|
||||
// Slot number of list attribute, it will alway be 0 for fixed size attribute.
|
||||
SLOT_INDEX slot_number = 0;
|
||||
|
||||
// Constructors
|
||||
Attribute() noexcept;
|
||||
Attribute(ATTR_CHAR value) noexcept;
|
||||
Attribute(ATTR_UCHAR value) noexcept;
|
||||
Attribute(ATTR_SHORT value) noexcept;
|
||||
Attribute(ATTR_USHORT value) noexcept;
|
||||
Attribute(ATTR_INT value) noexcept;
|
||||
Attribute(ATTR_UINT value) noexcept;
|
||||
Attribute(ATTR_LONG value) noexcept;
|
||||
Attribute(ATTR_ULONG value) noexcept;
|
||||
Attribute(ATTR_FLOAT value) noexcept;
|
||||
Attribute(ATTR_DOUBLE value) noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get type of current attribute.
|
||||
/// </summary>
|
||||
AttrDataType get_type() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get value of current attribute.
|
||||
/// </summary>
|
||||
template<typename T>
|
||||
typename Attribute_Trait<T>::type get_value() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Cast current value into float, for snapshot querying.
|
||||
/// </summary>
|
||||
operator QUERY_FLOAT() const;
|
||||
|
||||
// Assignment, copy from another attribute (deep copy).
|
||||
Attribute& operator=(const Attribute& attr) noexcept;
|
||||
|
||||
// Setters.
|
||||
// NOTE: setters will change its type.
|
||||
Attribute& operator=(ATTR_CHAR value) noexcept;
|
||||
Attribute& operator=(ATTR_UCHAR value) noexcept;
|
||||
Attribute& operator=(ATTR_SHORT value) noexcept;
|
||||
Attribute& operator=(ATTR_USHORT value) noexcept;
|
||||
Attribute& operator=(ATTR_INT value) noexcept;
|
||||
Attribute& operator=(ATTR_UINT value) noexcept;
|
||||
Attribute& operator=(ATTR_LONG value) noexcept;
|
||||
Attribute& operator=(ATTR_ULONG value) noexcept;
|
||||
Attribute& operator=(ATTR_FLOAT value) noexcept;
|
||||
Attribute& operator=(ATTR_DOUBLE value) noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Is current value is nan, for float type only.
|
||||
/// </summary>
|
||||
/// <returns>True if value is nan, or false.</returns>
|
||||
bool is_nan() const noexcept;
|
||||
};
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Invalid casting
|
||||
/// </summary>
|
||||
struct AttributeInvalidDataTypeError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
} // namespace raw
|
||||
} // namespace backends
|
||||
} // namespace maro
|
||||
|
||||
#endif
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "bitset.h"
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
inline size_t ceil_to_times(UINT number)
|
||||
{
|
||||
auto bits = sizeof(ULONG) * BITS_PER_BYTE;
|
||||
|
||||
return number % bits == 0 ? number / bits : (floorl(number / bits) + 1);
|
||||
}
|
||||
|
||||
Bitset::Bitset()
|
||||
{
|
||||
}
|
||||
|
||||
Bitset::Bitset(UINT size)
|
||||
{
|
||||
auto vector_size = ceil_to_times(size);
|
||||
|
||||
_masks.resize(vector_size);
|
||||
|
||||
_bit_size = ULONG(vector_size) * BITS_PER_MASK;
|
||||
}
|
||||
|
||||
Bitset& Bitset::operator=(const Bitset& set) noexcept
|
||||
{
|
||||
if (this != &set)
|
||||
{
|
||||
_masks.resize(set._masks.size());
|
||||
|
||||
memcpy(&_masks[0], &set._masks[0], _masks.size() * sizeof(ULONG));
|
||||
|
||||
_bit_size = set._bit_size;
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Bitset::resize(UINT size) noexcept
|
||||
{
|
||||
auto new_size = ceil_to_times(size);
|
||||
|
||||
_masks.resize(new_size);
|
||||
|
||||
_bit_size = ULONG(new_size) * BITS_PER_MASK;
|
||||
}
|
||||
|
||||
void Bitset::reset(bool value) noexcept
|
||||
{
|
||||
auto v = value ? ULONG_MAX : 0ULL;
|
||||
|
||||
memset(&_masks[0], v, _masks.size() * sizeof(ULONG));
|
||||
}
|
||||
|
||||
ULONG Bitset::size() const noexcept
|
||||
{
|
||||
return _bit_size;
|
||||
}
|
||||
|
||||
UINT Bitset::mask_size() const noexcept
|
||||
{
|
||||
return _masks.size();
|
||||
}
|
||||
|
||||
bool Bitset::get(ULONG index) const noexcept
|
||||
{
|
||||
if (index >= _bit_size)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
ULONG i = floorl(index / BITS_PER_MASK);
|
||||
|
||||
auto offset = index % BITS_PER_MASK;
|
||||
|
||||
auto mask = _masks[i];
|
||||
|
||||
auto target = mask >> offset & 0x1ULL;
|
||||
|
||||
return target == 1;
|
||||
}
|
||||
|
||||
void Bitset::set(ULONG index, bool value)
|
||||
{
|
||||
if (index >= _bit_size)
|
||||
{
|
||||
throw BitsetIndexOutRangeError();
|
||||
}
|
||||
|
||||
ULONG i = floorl(index / BITS_PER_MASK);
|
||||
auto offset = index % BITS_PER_MASK;
|
||||
|
||||
if (value)
|
||||
{
|
||||
// Set to 1.
|
||||
_masks[i] |= 0x1ULL << offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
_masks[i] &= ~(0x1ULL << offset);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const char* BitsetIndexOutRangeError::what() const noexcept
|
||||
{
|
||||
return "Index of bit flag out of range.";
|
||||
}
|
||||
} // namespace raw
|
||||
} // namespace backends
|
||||
} // namespace maro
|
|
@ -0,0 +1,92 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef _MARO_BACKENDS_RAW_BITSET_
|
||||
#define _MARO_BACKENDS_RAW_BITSET_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
const USHORT BITS_PER_BYTE = 8;
|
||||
const USHORT BITS_PER_MASK = sizeof(ULONG) * BITS_PER_BYTE;
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// A simple bitset implementation.
|
||||
/// </summary>
|
||||
class Bitset
|
||||
{
|
||||
// Masks of current bitset, we use ULL for each item.
|
||||
vector<ULONG> _masks;
|
||||
|
||||
// Size of bits.
|
||||
ULONG _bit_size = 0;
|
||||
public:
|
||||
Bitset();
|
||||
Bitset(UINT size);
|
||||
|
||||
// Copy all from input set.
|
||||
Bitset& operator=(const Bitset& set) noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Resize bitset with spcified size.
|
||||
/// </summary>
|
||||
/// <param name="size">Size to extend, it should be 64 times.</param>
|
||||
void resize(UINT size) noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Reset all bit to specified value.
|
||||
/// </summary>
|
||||
/// <param name="">Value to reset.</param>
|
||||
void reset(bool value = false) noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get value at specified index.
|
||||
/// </summary>
|
||||
/// <param name="index">Index of bit.</param>
|
||||
/// <returns>True if the bit is 1, or false for 0 (not exist).</returns>
|
||||
bool get(ULONG index) const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Set value for specified position.
|
||||
/// </summary>
|
||||
/// <param name="index">Index of item.</param>
|
||||
/// <param name="value">Value to set.</param>
|
||||
void set(ULONG index, bool value);
|
||||
|
||||
/// <summary>
|
||||
/// Current size of items (in bit).
|
||||
/// </summary>
|
||||
/// <returns>Number of bits.</returns>
|
||||
ULONG size() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get size of mask items (in ULL).
|
||||
/// </summary>
|
||||
/// <returns>Number of mask items.</returns>
|
||||
UINT mask_size() const noexcept;
|
||||
};
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Query index out of range.
|
||||
/// </summary>
|
||||
struct BitsetIndexOutRangeError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
} // namespace raw
|
||||
} // namespace backends
|
||||
} // namespace maro
|
||||
|
||||
#endif // !_MARO_BACKENDS_RAW_BITSET_
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef _MARO_BACKEND_RAW_COMMON_
|
||||
#define _MARO_BACKEND_RAW_COMMON_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
using UCHAR = unsigned char;
|
||||
using USHORT = unsigned short;
|
||||
using UINT = uint32_t;
|
||||
using LONG = long long;
|
||||
using ULONG = unsigned long long;
|
||||
|
||||
using NODE_TYPE = unsigned short;
|
||||
using ATTR_TYPE = uint32_t;
|
||||
|
||||
const size_t MAX_NODE_TYPE = USHRT_MAX;
|
||||
const size_t MAX_ATTR_TYPE = USHRT_MAX;
|
||||
const size_t MAX_SLOT_NUMBER = UINT32_MAX;
|
||||
|
||||
using NODE_INDEX = uint32_t;
|
||||
using SLOT_INDEX = uint32_t;
|
||||
using QUERY_FLOAT = float;
|
||||
|
||||
using ATTR_CHAR = char;
|
||||
using ATTR_UCHAR = unsigned char;
|
||||
using ATTR_SHORT = short;
|
||||
using ATTR_USHORT = unsigned short;
|
||||
using ATTR_INT = int32_t;
|
||||
using ATTR_UINT = uint32_t;
|
||||
using ATTR_LONG = int64_t;
|
||||
using ATTR_ULONG = uint64_t;
|
||||
using ATTR_FLOAT = float;
|
||||
using ATTR_DOUBLE = double;
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Attribute data type.
|
||||
/// </summary>
|
||||
enum class AttrDataType : char
|
||||
{
|
||||
ACHAR,
|
||||
AUCHAR,
|
||||
ASHORT,
|
||||
AUSHORT,
|
||||
AINT,
|
||||
AUINT,
|
||||
ALONG,
|
||||
AULONG,
|
||||
AFLOAT,
|
||||
ADOUBLE,
|
||||
APOINTER,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ! _MARO_BACKEND_RAW_COMMON_
|
|
@ -0,0 +1,391 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "frame.h"
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
inline NODE_TYPE extract_node_type(ATTR_TYPE attr_type)
|
||||
{
|
||||
// Our ATTR_TYPE is composed with 2 parts:
|
||||
// 2 bytes: NODE_TYPE
|
||||
// 2 bytes: Attribute index in current node type
|
||||
return NODE_TYPE(attr_type >> 16);
|
||||
}
|
||||
|
||||
inline void Frame::copy_from(const Frame& frame)
|
||||
{
|
||||
_nodes = frame._nodes;
|
||||
|
||||
_is_setup = frame._is_setup;
|
||||
}
|
||||
|
||||
inline void Frame::ensure_setup()
|
||||
{
|
||||
if (!_is_setup)
|
||||
{
|
||||
throw FrameNotSetupError();
|
||||
}
|
||||
}
|
||||
|
||||
inline void Frame::ensure_node_type(NODE_TYPE node_type)
|
||||
{
|
||||
if (node_type >= _nodes.size())
|
||||
{
|
||||
throw FrameBadNodeTypeError();
|
||||
}
|
||||
}
|
||||
|
||||
Frame::Frame()
|
||||
{
|
||||
}
|
||||
|
||||
Frame::Frame(const Frame& frame)
|
||||
{
|
||||
copy_from(frame);
|
||||
}
|
||||
|
||||
Frame& Frame::operator=(const Frame& frame)
|
||||
{
|
||||
if (this != &frame)
|
||||
{
|
||||
copy_from(frame);
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
NODE_TYPE Frame::add_node(string node_name, NODE_INDEX node_number)
|
||||
{
|
||||
if (_is_setup)
|
||||
{
|
||||
throw FrameAlreadySetupError();
|
||||
}
|
||||
|
||||
if (node_number == 0)
|
||||
{
|
||||
throw FrameInvalidNodeNumerError();
|
||||
}
|
||||
|
||||
_nodes.emplace_back();
|
||||
|
||||
// We use index as node type for easily querying.
|
||||
NODE_TYPE node_type = _nodes.size() - 1;
|
||||
|
||||
auto& node = _nodes[node_type];
|
||||
|
||||
node.set_name(node_name);
|
||||
node.set_type(node_type);
|
||||
node.set_defined_number(node_number);
|
||||
|
||||
return node_type;
|
||||
}
|
||||
|
||||
ATTR_TYPE Frame::add_attr(NODE_TYPE node_type, string attr_name, AttrDataType data_type,
|
||||
SLOT_INDEX slot_number, bool is_const, bool is_list)
|
||||
{
|
||||
if (_is_setup)
|
||||
{
|
||||
throw FrameAlreadySetupError();
|
||||
}
|
||||
|
||||
ensure_node_type(node_type);
|
||||
|
||||
auto& node = _nodes[node_type];
|
||||
|
||||
return node.add_attr(attr_name, data_type, slot_number, is_const, is_list);
|
||||
}
|
||||
|
||||
Node& Frame::get_node(NODE_TYPE node_type)
|
||||
{
|
||||
ensure_setup();
|
||||
ensure_node_type(node_type);
|
||||
|
||||
return _nodes[node_type];
|
||||
}
|
||||
|
||||
void Frame::append_node(NODE_TYPE node_type, NODE_INDEX node_number)
|
||||
{
|
||||
auto& node = get_node(node_type);
|
||||
|
||||
node.append_nodes(node_number);
|
||||
}
|
||||
|
||||
void Frame::remove_node(NODE_TYPE node_type, NODE_INDEX node_index)
|
||||
{
|
||||
auto& node = get_node(node_type);
|
||||
node.remove_node(node_index);
|
||||
}
|
||||
|
||||
void Frame::resume_node(NODE_TYPE node_type, NODE_INDEX node_index)
|
||||
{
|
||||
auto& node = get_node(node_type);
|
||||
node.resume_node(node_index);
|
||||
}
|
||||
|
||||
void Frame::clear_list(NODE_INDEX node_index, ATTR_TYPE attr_type)
|
||||
{
|
||||
NODE_TYPE node_type = extract_node_type(attr_type);
|
||||
|
||||
auto& node = get_node(node_type);
|
||||
|
||||
node.clear_list(node_index, attr_type);
|
||||
}
|
||||
|
||||
void Frame::resize_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX new_size)
|
||||
{
|
||||
NODE_TYPE node_type = extract_node_type(attr_type);
|
||||
|
||||
auto& node = get_node(node_type);
|
||||
|
||||
node.resize_list(node_index, attr_type, new_size);
|
||||
}
|
||||
|
||||
void Frame::setup()
|
||||
{
|
||||
if (_is_setup)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& node : _nodes)
|
||||
{
|
||||
node.setup();
|
||||
}
|
||||
|
||||
_is_setup = true;
|
||||
}
|
||||
|
||||
void Frame::reset()
|
||||
{
|
||||
ensure_setup();
|
||||
|
||||
for (auto& node : _nodes)
|
||||
{
|
||||
node.reset();
|
||||
}
|
||||
}
|
||||
|
||||
bool Frame::is_node_exist(NODE_TYPE node_type) const noexcept
|
||||
{
|
||||
return node_type < _nodes.size();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
typename Attribute_Trait<T>::type Frame::get_value(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index)
|
||||
{
|
||||
NODE_TYPE node_type = extract_node_type(attr_type);
|
||||
|
||||
auto& node = get_node(node_type);
|
||||
|
||||
auto& target_attr = node.get_attr(node_index, attr_type, slot_index);
|
||||
|
||||
return target_attr.get_value<T>();
|
||||
}
|
||||
|
||||
#define ATTRIBUTE_GETTER(type) \
|
||||
template type Frame::get_value<type>(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index);
|
||||
|
||||
ATTRIBUTE_GETTER(ATTR_CHAR)
|
||||
ATTRIBUTE_GETTER(ATTR_UCHAR)
|
||||
ATTRIBUTE_GETTER(ATTR_SHORT)
|
||||
ATTRIBUTE_GETTER(ATTR_USHORT)
|
||||
ATTRIBUTE_GETTER(ATTR_INT)
|
||||
ATTRIBUTE_GETTER(ATTR_UINT)
|
||||
ATTRIBUTE_GETTER(ATTR_LONG)
|
||||
ATTRIBUTE_GETTER(ATTR_ULONG)
|
||||
ATTRIBUTE_GETTER(ATTR_FLOAT)
|
||||
ATTRIBUTE_GETTER(ATTR_DOUBLE)
|
||||
|
||||
template<typename T>
|
||||
void Frame::set_value(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, T value)
|
||||
{
|
||||
NODE_TYPE node_type = extract_node_type(attr_type);
|
||||
|
||||
auto& node = get_node(node_type);
|
||||
|
||||
auto& target_attr = node.get_attr(node_index, attr_type, slot_index);
|
||||
|
||||
target_attr = T(value);
|
||||
}
|
||||
|
||||
#define ATTRIBUTE_SETTER(type) \
|
||||
template void Frame::set_value(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, type value);
|
||||
|
||||
ATTRIBUTE_SETTER(ATTR_CHAR)
|
||||
ATTRIBUTE_SETTER(ATTR_UCHAR)
|
||||
ATTRIBUTE_SETTER(ATTR_SHORT)
|
||||
ATTRIBUTE_SETTER(ATTR_USHORT)
|
||||
ATTRIBUTE_SETTER(ATTR_INT)
|
||||
ATTRIBUTE_SETTER(ATTR_UINT)
|
||||
ATTRIBUTE_SETTER(ATTR_LONG)
|
||||
ATTRIBUTE_SETTER(ATTR_ULONG)
|
||||
ATTRIBUTE_SETTER(ATTR_FLOAT)
|
||||
ATTRIBUTE_SETTER(ATTR_DOUBLE)
|
||||
|
||||
template<typename T>
|
||||
void Frame::append_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, T value)
|
||||
{
|
||||
NODE_TYPE node_type = extract_node_type(attr_type);
|
||||
|
||||
auto& node = get_node(node_type);
|
||||
|
||||
node.append_to_list<T>(node_index, attr_type, value);
|
||||
}
|
||||
|
||||
#define ATTRIBUTE_APPENDER(type) \
|
||||
template void Frame::append_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, type value);
|
||||
|
||||
ATTRIBUTE_APPENDER(ATTR_CHAR)
|
||||
ATTRIBUTE_APPENDER(ATTR_UCHAR)
|
||||
ATTRIBUTE_APPENDER(ATTR_SHORT)
|
||||
ATTRIBUTE_APPENDER(ATTR_USHORT)
|
||||
ATTRIBUTE_APPENDER(ATTR_INT)
|
||||
ATTRIBUTE_APPENDER(ATTR_UINT)
|
||||
ATTRIBUTE_APPENDER(ATTR_LONG)
|
||||
ATTRIBUTE_APPENDER(ATTR_ULONG)
|
||||
ATTRIBUTE_APPENDER(ATTR_FLOAT)
|
||||
ATTRIBUTE_APPENDER(ATTR_DOUBLE)
|
||||
|
||||
void Frame::remove_from_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index)
|
||||
{
|
||||
NODE_TYPE node_type = extract_node_type(attr_type);
|
||||
|
||||
auto& node = get_node(node_type);
|
||||
|
||||
node.remove_from_list(node_index, attr_type, slot_index);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Frame::insert_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, T value)
|
||||
{
|
||||
NODE_TYPE node_type = extract_node_type(attr_type);
|
||||
|
||||
auto& node = get_node(node_type);
|
||||
|
||||
node.insert_to_list(node_index, attr_type, slot_index, value);
|
||||
}
|
||||
|
||||
#define ATTRIBUTE_INSERTER(type) \
|
||||
template void Frame::insert_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, type value);
|
||||
|
||||
ATTRIBUTE_INSERTER(ATTR_CHAR)
|
||||
ATTRIBUTE_INSERTER(ATTR_UCHAR)
|
||||
ATTRIBUTE_INSERTER(ATTR_SHORT)
|
||||
ATTRIBUTE_INSERTER(ATTR_USHORT)
|
||||
ATTRIBUTE_INSERTER(ATTR_INT)
|
||||
ATTRIBUTE_INSERTER(ATTR_UINT)
|
||||
ATTRIBUTE_INSERTER(ATTR_LONG)
|
||||
ATTRIBUTE_INSERTER(ATTR_ULONG)
|
||||
ATTRIBUTE_INSERTER(ATTR_FLOAT)
|
||||
ATTRIBUTE_INSERTER(ATTR_DOUBLE)
|
||||
|
||||
void Frame::write_attribute(ofstream &file, NODE_INDEX node_index, ATTR_TYPE attr_id, SLOT_INDEX slot_index)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void Frame::dump(string folder)
|
||||
{
|
||||
// for dump, we will save for each node, named as "node_<node_name>.csv"
|
||||
// content of the csv will follow padans' output that list will be wrapped into a "[]",
|
||||
for (auto& node : _nodes)
|
||||
{
|
||||
auto output_path = folder + "/" + "node_" + node._name + ".csv";
|
||||
|
||||
ofstream file(output_path);
|
||||
|
||||
// Write header - first column.
|
||||
file << "node_index";
|
||||
|
||||
// Futhure columns (attribute name).
|
||||
for(auto& attr_def : node._attribute_definitions)
|
||||
{
|
||||
file << "," << attr_def.name;
|
||||
}
|
||||
|
||||
// End of header.
|
||||
file << "\n";
|
||||
|
||||
// Write for each node instance.
|
||||
for (NODE_INDEX node_index = 0; node_index < node._max_node_number; node_index++)
|
||||
{
|
||||
// Ignore deleted node instance.
|
||||
if(!node.is_node_alive(node_index))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Row - node index.
|
||||
file << node_index;
|
||||
|
||||
for (auto& attr_def : node._attribute_definitions)
|
||||
{
|
||||
if (!attr_def.is_list && attr_def.slot_number == 1)
|
||||
{
|
||||
file << ",";
|
||||
|
||||
auto& attr = node.get_attr(node_index, attr_def.attr_type, 0);
|
||||
|
||||
file << QUERY_FLOAT(attr);
|
||||
}
|
||||
else
|
||||
{
|
||||
// List start.
|
||||
file << ",\"[";
|
||||
|
||||
auto slot_number = node.get_slot_number(node_index, attr_def.attr_type);
|
||||
|
||||
for (SLOT_INDEX slot_index = 0; slot_index < slot_number; slot_index++)
|
||||
{
|
||||
auto& attr = node.get_attr(node_index, attr_def.attr_type, 0);
|
||||
|
||||
file << QUERY_FLOAT(attr);
|
||||
|
||||
file << ",";
|
||||
}
|
||||
|
||||
// List end.
|
||||
file << "]\"";
|
||||
}
|
||||
}
|
||||
|
||||
// end of row
|
||||
file << "\n";
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
}
|
||||
|
||||
const char* FrameNotSetupError::what() const noexcept
|
||||
{
|
||||
return "Frame has not been setup.";
|
||||
}
|
||||
|
||||
const char* FrameAlreadySetupError::what() const noexcept
|
||||
{
|
||||
return "Cannot add new node or attribute type after setting up.";
|
||||
}
|
||||
|
||||
const char* FrameBadNodeTypeError::what() const noexcept
|
||||
{
|
||||
return "Not exist node type.";
|
||||
}
|
||||
|
||||
const char* FrameBadAttributeTypeError::what() const noexcept
|
||||
{
|
||||
return "Not exist attribute type.";
|
||||
}
|
||||
|
||||
const char* FrameInvalidNodeNumerError::what() const noexcept
|
||||
{
|
||||
return "Node number must be greater than 0.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,257 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef _MARO_BACKENDS_RAW_FRAME_
|
||||
#define _MARO_BACKENDS_RAW_FRAME_
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "attribute.h"
|
||||
#include "node.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
/// <summary>
|
||||
/// Extract node type from attribute type.
|
||||
/// </summary>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <returns>Type of node.</returns>
|
||||
inline NODE_TYPE extract_node_type(ATTR_TYPE attr_type);
|
||||
|
||||
/// <summary>
|
||||
/// A frame used to hold nodes and their attribute, it can be a current frame or a snapshot in snapshot list.
|
||||
/// </summary>
|
||||
class Frame
|
||||
{
|
||||
friend class SnapshotList;
|
||||
|
||||
private:
|
||||
// All node types, index is the NODE_TYPE.
|
||||
vector<Node> _nodes;
|
||||
|
||||
// Is current frame instance already being set up.
|
||||
bool _is_setup = false;
|
||||
|
||||
// Copy from another frame, used for taking snapshot.
|
||||
inline void copy_from(const Frame& frame);
|
||||
|
||||
// Make sure frame already setup.
|
||||
inline void ensure_setup();
|
||||
|
||||
// Make sure node type correct.
|
||||
inline void ensure_node_type(NODE_TYPE node_type);
|
||||
|
||||
// Helper function for dump one attribute slot
|
||||
void write_attribute(ofstream &file, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index);
|
||||
public:
|
||||
Frame();
|
||||
|
||||
/// <summary>
|
||||
/// Copy contents from another frame, deep copy.
|
||||
/// </summary>
|
||||
/// <param name="frame">Source frame to copy.</param>
|
||||
Frame(const Frame& frame);
|
||||
|
||||
/// <summary>
|
||||
/// Copy contents from another frame, for taking snapshot,
|
||||
/// copy without name, const block and attribute definitions.
|
||||
/// </summary>
|
||||
/// <param name="frame">Source frame to copy.</param>
|
||||
/// <returns>Current frame instance.</returns>
|
||||
Frame& operator=(const Frame& frame);
|
||||
|
||||
/// <summary>
|
||||
/// Add a node type in frame.
|
||||
/// </summary>
|
||||
/// <param name="node_name">Name of the new node type.</param>
|
||||
/// <param name="node_number">Number of initial instance for this node type.</param>
|
||||
/// <returns>Node type used to identify this kind of node.</returns>
|
||||
NODE_TYPE add_node(string node_name, NODE_INDEX node_number);
|
||||
|
||||
/// <summary>
|
||||
/// Add an attribute for specified node type.
|
||||
/// </summary>
|
||||
/// <param name="node_type">Type of node.</param>
|
||||
/// <param name="attr_name">Name of new attribute.</param>
|
||||
/// <param name="data_type">Data type of new attribute, default is int.</param>
|
||||
/// <param name="slot_number">How many slot of this attribute, default is 1.</param>
|
||||
/// <param name="is_const">Is this is a const attribute?</param>
|
||||
/// <param name="is_list">Is this a list attribute that without fixed slot number.</param>
|
||||
/// <returns>Type of this attribute.</returns>
|
||||
ATTR_TYPE add_attr(NODE_TYPE node_type, string attr_name,
|
||||
AttrDataType data_type = AttrDataType::AINT, SLOT_INDEX slot_number = 1,
|
||||
bool is_const = false, bool is_list = false);
|
||||
|
||||
/// <summary>
|
||||
/// Get specified node.
|
||||
/// </summary>
|
||||
/// <param name="node_type">Type of node.</param>
|
||||
/// <returns>Target node reference.</returns>
|
||||
Node& get_node(NODE_TYPE node_type);
|
||||
|
||||
/// <summary>
|
||||
/// Add node instance for specified node type.
|
||||
/// </summary>
|
||||
/// <param name="node_type">Type of node.</param>
|
||||
/// <param name="node_number">Number to append.</param>
|
||||
void append_node(NODE_TYPE node_type, NODE_INDEX node_number);
|
||||
|
||||
/// <summary>
|
||||
/// Remove specified node instace from node type.
|
||||
/// </summary>
|
||||
/// <param name="node_type">Type of node.</param>
|
||||
/// <param name="node_index">Index of node instance to remove.</param>
|
||||
void remove_node(NODE_TYPE node_type, NODE_INDEX node_index);
|
||||
|
||||
/// <summary>
|
||||
/// Resume a node instance.
|
||||
/// </summary>
|
||||
/// <param name="node_type">Type of node.</param>
|
||||
/// <param name="node_index">Index of node instance to resume.</param>
|
||||
void resume_node(NODE_TYPE node_type, NODE_INDEX node_index);
|
||||
|
||||
/// <summary>
|
||||
/// Get value from specified attribute.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Type of attribute value.</typeparam>
|
||||
/// <param name="node_index">Index of the node instance.</param>
|
||||
/// <param name="attr_type">Type of the attribute.</param>
|
||||
/// <param name="slot_index">Which slot to query.</param>
|
||||
/// <returns>Value of attribute.</returns>
|
||||
template<typename T>
|
||||
typename Attribute_Trait<T>::type get_value(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index);
|
||||
|
||||
/// <summary>
|
||||
/// Set value for specified attribute.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Type of attribute.</typeparam>
|
||||
/// <param name="node_index">Index of node instance to set.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="slot_index">Which slot to set.</param>
|
||||
/// <param name="value">Value to set.</param>
|
||||
template<typename T>
|
||||
void set_value(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, T value);
|
||||
|
||||
/// <summary>
|
||||
/// Append a value to a list attribute.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Type of the value.</typeparam>
|
||||
/// <param name="node_index">Index of node instance to set.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="value">Value to append.</param>
|
||||
template<typename T>
|
||||
void append_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, T value);
|
||||
|
||||
/// <summary>
|
||||
/// Clear a list attribute.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to clear.</param>
|
||||
/// <param name="attr_type">Type of attribute to clear</param>
|
||||
void clear_list(NODE_INDEX node_index, ATTR_TYPE attr_type);
|
||||
|
||||
/// <summary>
|
||||
/// Resize a list attribute with specified size.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to resize.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="new_size">New size to resize.</param>
|
||||
void resize_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX new_size);
|
||||
|
||||
/// <summary>
|
||||
/// Remove specified slot from list attribute.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to resize.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="slot_index">Slot to remove.</param>
|
||||
void remove_from_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index);
|
||||
|
||||
/// <summary>
|
||||
/// Insert a value to specified slot for list attribute.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to resize.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="slot_index">Slot to insert.</param>
|
||||
/// <param name="value">Value to insert. </param>
|
||||
template<typename T>
|
||||
void insert_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, T value);
|
||||
|
||||
/// <summary>
|
||||
/// Initial current frame.
|
||||
/// </summary>
|
||||
void setup();
|
||||
|
||||
/// <summary>
|
||||
/// Reset current frame, it will recover the node instance number to pre-defined one.
|
||||
/// </summary>
|
||||
void reset();
|
||||
|
||||
/// <summary>
|
||||
/// Dump current frame content into specified folder, nodes will be dump into
|
||||
/// different files.
|
||||
/// </summary>
|
||||
/// <param name="folder">Folder to dump file.</param>
|
||||
void dump(string folder);
|
||||
|
||||
/// <summary>
|
||||
/// Check if specified node type exist or not.
|
||||
/// </summary>
|
||||
/// <param name="node_type">Type of node</param>
|
||||
/// <returns>True if exist, or false.</returns>
|
||||
bool is_node_exist(NODE_TYPE node_type) const noexcept;
|
||||
};
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Operations before frame being setup.
|
||||
/// </summary>
|
||||
struct FrameNotSetupError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Try to add new node/attribute type after seting up.
|
||||
/// </summary>
|
||||
struct FrameAlreadySetupError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Invalid node type.
|
||||
/// </summary>
|
||||
struct FrameBadNodeTypeError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Invalid attribute type.
|
||||
/// </summary>
|
||||
struct FrameBadAttributeTypeError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct FrameInvalidNodeNumerError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ! _MARO_BACKENDS_RAW_FRAME_
|
|
@ -0,0 +1,618 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "node.h"
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
inline USHORT extract_attr_index(ATTR_TYPE attr_type)
|
||||
{
|
||||
return USHORT(attr_type & 0x0000ffff);
|
||||
}
|
||||
|
||||
inline size_t compose_attr_offset_in_node(NODE_INDEX node_index, size_t node_size,
|
||||
size_t attr_offset, SLOT_INDEX slot)
|
||||
{
|
||||
return node_index * node_size + attr_offset + slot;
|
||||
}
|
||||
|
||||
AttributeDef::AttributeDef(string name, AttrDataType data_type, SLOT_INDEX slot_number,
|
||||
size_t offset, bool is_list, bool is_const, ATTR_TYPE attr_type) :
|
||||
name(name),
|
||||
slot_number(slot_number),
|
||||
offset(offset),
|
||||
is_list(is_list),
|
||||
is_const(is_const),
|
||||
data_type(data_type),
|
||||
attr_type(attr_type)
|
||||
{
|
||||
}
|
||||
|
||||
void Node::copy_from(const Node& node, bool is_deep_copy)
|
||||
{
|
||||
// Copy normal fields.
|
||||
_dynamic_size_per_node = node._dynamic_size_per_node;
|
||||
_const_size_per_node = node._const_size_per_node;
|
||||
_max_node_number = node._max_node_number;
|
||||
_alive_node_number = node._alive_node_number;
|
||||
_defined_node_number = node._defined_node_number;
|
||||
_type = node._type;
|
||||
_is_setup = node._is_setup;
|
||||
|
||||
// Ignore name.
|
||||
_name = "";
|
||||
|
||||
// Copy dynamic block.
|
||||
if (node._dynamic_block.size() > 0)
|
||||
{
|
||||
// Copy according to max_node number, as memory block may larger than it (after reset).
|
||||
auto valid_dynamic_size = node._dynamic_size_per_node * node._max_node_number;
|
||||
|
||||
_dynamic_block.resize(valid_dynamic_size);
|
||||
|
||||
memcpy(&_dynamic_block[0], &node._dynamic_block[0], valid_dynamic_size * sizeof(Attribute));
|
||||
}
|
||||
|
||||
// Copy list attributes store.
|
||||
if (node._list_store.size() > 0)
|
||||
{
|
||||
_list_store.resize(node._list_store.size());
|
||||
|
||||
for (size_t i = 0; i < _list_store.size(); i++)
|
||||
{
|
||||
auto& source_list = node._list_store[i];
|
||||
|
||||
if (source_list.size() > 0)
|
||||
{
|
||||
auto& target_list = _list_store[i];
|
||||
|
||||
target_list.resize(source_list.size());
|
||||
|
||||
memcpy(&target_list[0], &source_list[0], source_list.size() * sizeof(Attribute));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy masks.
|
||||
_node_instance_masks = node._node_instance_masks;
|
||||
|
||||
// Copy others for deep-copy.
|
||||
if (is_deep_copy)
|
||||
{
|
||||
_name = node.get_name();
|
||||
|
||||
_attribute_definitions = node._attribute_definitions;
|
||||
|
||||
// NOTE: we do not copy const block here, as this operation occur before setting up
|
||||
// and there is nothing in const block
|
||||
}
|
||||
}
|
||||
|
||||
inline void Node::ensure_setup() const
|
||||
{
|
||||
if (!_is_setup)
|
||||
{
|
||||
throw OperationsBeforeSetupError();
|
||||
}
|
||||
}
|
||||
|
||||
inline void Node::ensure_attr_index(USHORT attr_index) const
|
||||
{
|
||||
if (attr_index >= _attribute_definitions.size())
|
||||
{
|
||||
throw InvalidAttributeTypeError();
|
||||
}
|
||||
}
|
||||
|
||||
inline void Node::ensure_node_index(NODE_INDEX node_index) const
|
||||
{
|
||||
// check is node alive
|
||||
if (!_node_instance_masks.get(node_index))
|
||||
{
|
||||
throw InvalidNodeIndexError();
|
||||
}
|
||||
}
|
||||
|
||||
Node::Node()
|
||||
{
|
||||
}
|
||||
|
||||
Node::Node(const Node& node)
|
||||
{
|
||||
// This function invoked when the node list is increasing its size,
|
||||
// then it need to copy nodes to new memory block.
|
||||
copy_from(node, true);
|
||||
}
|
||||
|
||||
Node& Node::operator=(const Node& node)
|
||||
{
|
||||
if (this != &node)
|
||||
{
|
||||
copy_from(node);
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Node::set_type(NODE_TYPE type) noexcept
|
||||
{
|
||||
_type = type;
|
||||
}
|
||||
|
||||
NODE_TYPE Node::get_type() const noexcept
|
||||
{
|
||||
return _type;
|
||||
}
|
||||
|
||||
void Node::set_name(string name) noexcept
|
||||
{
|
||||
_name = name;
|
||||
}
|
||||
|
||||
string Node::get_name() const noexcept
|
||||
{
|
||||
return _name;
|
||||
}
|
||||
|
||||
void Node::set_defined_number(NODE_INDEX number)
|
||||
{
|
||||
if (number == 0)
|
||||
{
|
||||
throw InvalidNodeNumberError();
|
||||
}
|
||||
|
||||
_defined_node_number = number;
|
||||
_max_node_number = number;
|
||||
_alive_node_number = number;
|
||||
}
|
||||
|
||||
NODE_INDEX Node::get_defined_number() const noexcept
|
||||
{
|
||||
return _defined_node_number;
|
||||
}
|
||||
|
||||
NODE_INDEX Node::get_max_number() const noexcept
|
||||
{
|
||||
return _max_node_number;
|
||||
}
|
||||
|
||||
const AttributeDef& Node::get_attr_definition(ATTR_TYPE attr_type) const
|
||||
{
|
||||
USHORT attr_index = extract_attr_index(attr_type);
|
||||
|
||||
ensure_attr_index(attr_index);
|
||||
|
||||
return _attribute_definitions[attr_index];
|
||||
}
|
||||
|
||||
bool Node::is_node_alive(NODE_INDEX node_index) const noexcept
|
||||
{
|
||||
ensure_setup();
|
||||
|
||||
return _node_instance_masks.get(node_index);
|
||||
}
|
||||
|
||||
SLOT_INDEX Node::get_slot_number(NODE_INDEX node_index, ATTR_TYPE attr_type) const
|
||||
{
|
||||
ensure_setup();
|
||||
ensure_node_index(node_index);
|
||||
|
||||
auto& attr_def = get_attr_definition(attr_type);
|
||||
|
||||
// If it is a list attribute, we will return actual list size.
|
||||
if (attr_def.is_list)
|
||||
{
|
||||
auto attr_offset = compose_attr_offset_in_node(node_index, _dynamic_size_per_node, attr_def.offset);
|
||||
auto& target_attr = _dynamic_block[attr_offset];
|
||||
|
||||
return target_attr.slot_number;
|
||||
}
|
||||
|
||||
// Or used pre-defined number.
|
||||
return attr_def.slot_number;
|
||||
}
|
||||
|
||||
void Node::setup()
|
||||
{
|
||||
// Ignore is already been setup.
|
||||
if (_is_setup)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize dynamic and const block.
|
||||
_const_block.resize(_defined_node_number * _const_size_per_node);
|
||||
_dynamic_block.resize(_defined_node_number * _dynamic_size_per_node);
|
||||
|
||||
// Prepare bitset for masks.
|
||||
_node_instance_masks.resize(_defined_node_number);
|
||||
_node_instance_masks.reset(true);
|
||||
|
||||
// Prepare memory for list attributes.
|
||||
for (auto& attr_def : _attribute_definitions)
|
||||
{
|
||||
if (attr_def.is_list)
|
||||
{
|
||||
// Assign each attribute with the index of actual list.
|
||||
for (NODE_INDEX i = 0; i < _defined_node_number; i++)
|
||||
{
|
||||
auto& target_attr = _dynamic_block[_dynamic_size_per_node * i + attr_def.offset];
|
||||
|
||||
// Save the index of list in list store.
|
||||
target_attr = UINT(_list_store.size());
|
||||
|
||||
// Append a new vector for this attribute.
|
||||
_list_store.emplace_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_is_setup = true;
|
||||
}
|
||||
|
||||
void Node::reset()
|
||||
{
|
||||
ensure_setup();
|
||||
|
||||
// Reset all node number to pre-defined.
|
||||
_max_node_number = _defined_node_number;
|
||||
_alive_node_number = _defined_node_number;
|
||||
|
||||
// Clear all attribute to 0.
|
||||
memset(&_dynamic_block[0], 0, _dynamic_block.size() * sizeof(Attribute));
|
||||
|
||||
// Clear all list attribute.
|
||||
for (auto& list : _list_store)
|
||||
{
|
||||
list.clear();
|
||||
}
|
||||
|
||||
// Reset bitset masks.
|
||||
_node_instance_masks.resize(_defined_node_number);
|
||||
_node_instance_masks.reset(true);
|
||||
}
|
||||
|
||||
void Node::append_nodes(NODE_INDEX node_number)
|
||||
{
|
||||
ensure_setup();
|
||||
|
||||
if (node_number == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_max_node_number += node_number;
|
||||
_alive_node_number += node_number;
|
||||
|
||||
// Extend const memory block.
|
||||
auto extend_size = _max_node_number * _const_size_per_node;
|
||||
|
||||
if (extend_size > _const_block.size())
|
||||
{
|
||||
_const_block.resize(extend_size);
|
||||
}
|
||||
|
||||
// Extend dynamic memory block.
|
||||
extend_size = _max_node_number * _dynamic_size_per_node;
|
||||
|
||||
if (extend_size > _dynamic_block.size())
|
||||
{
|
||||
_dynamic_block.resize(extend_size);
|
||||
}
|
||||
|
||||
// Prepare memory for new list attributes.
|
||||
for (auto& attr_def : _attribute_definitions)
|
||||
{
|
||||
if (attr_def.is_list)
|
||||
{
|
||||
// Again allocate list.
|
||||
for (NODE_INDEX i = 0; i < node_number; i++)
|
||||
{
|
||||
auto node_index = _max_node_number - node_number + i;
|
||||
auto attr_offset = compose_attr_offset_in_node(node_index, _dynamic_size_per_node, attr_def.offset);
|
||||
auto& target_attr = _dynamic_block[attr_offset];
|
||||
|
||||
target_attr = UINT(_list_store.size());
|
||||
|
||||
_list_store.emplace_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extern masks.
|
||||
_node_instance_masks.resize(_max_node_number);
|
||||
|
||||
// Set new node instance as alive.
|
||||
for (NODE_INDEX i = 0; i < node_number; i++)
|
||||
{
|
||||
_node_instance_masks.set(_max_node_number - node_number + i, true);
|
||||
}
|
||||
}
|
||||
|
||||
void Node::remove_node(NODE_INDEX node_index)
|
||||
{
|
||||
ensure_setup();
|
||||
ensure_node_index(node_index);
|
||||
|
||||
_node_instance_masks.set(node_index, false);
|
||||
}
|
||||
|
||||
void Node::resume_node(NODE_INDEX node_index)
|
||||
{
|
||||
ensure_setup();
|
||||
|
||||
if(node_index < _max_node_number)
|
||||
{
|
||||
_node_instance_masks.set(node_index, true);
|
||||
}
|
||||
}
|
||||
|
||||
ATTR_TYPE Node::add_attr(string attr_name, AttrDataType data_type,
|
||||
SLOT_INDEX slot_number, bool is_const, bool is_list)
|
||||
{
|
||||
if (_is_setup)
|
||||
{
|
||||
throw OperationsAfterSetupError();
|
||||
}
|
||||
|
||||
USHORT attr_index = USHORT(_attribute_definitions.size());
|
||||
ATTR_TYPE attr_type = UINT(_type) << 16 | attr_index;
|
||||
|
||||
size_t offset = 0;
|
||||
|
||||
// We do not support const list attribute.
|
||||
if (is_const && is_list)
|
||||
{
|
||||
throw InvalidAttributeDescError();
|
||||
}
|
||||
|
||||
// List attribute take 1 attribute to hold its list index in list store.
|
||||
slot_number = is_list ? 1 : slot_number;
|
||||
|
||||
// Calculate size of each node instance in different memory block.
|
||||
if (is_const)
|
||||
{
|
||||
offset = _const_size_per_node;
|
||||
|
||||
_const_size_per_node += slot_number;
|
||||
}
|
||||
else
|
||||
{
|
||||
offset = _dynamic_size_per_node;
|
||||
|
||||
_dynamic_size_per_node += slot_number;
|
||||
}
|
||||
|
||||
_attribute_definitions.emplace_back(attr_name, data_type, slot_number, offset, is_list, is_const, attr_type);
|
||||
|
||||
return attr_type;
|
||||
}
|
||||
|
||||
Attribute& Node::get_attr(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index)
|
||||
{
|
||||
ensure_setup();
|
||||
ensure_node_index(node_index);
|
||||
|
||||
auto& attr_def = get_attr_definition(attr_type);
|
||||
|
||||
if (attr_def.is_list)
|
||||
{
|
||||
// For list attribute, we need to get its index in list store.
|
||||
auto attr_offset = compose_attr_offset_in_node(node_index, _dynamic_size_per_node, attr_def.offset);
|
||||
auto& target_attr = _dynamic_block[attr_offset];
|
||||
|
||||
// Slot number of list attribute save in itr attribute.
|
||||
if (slot_index >= target_attr.slot_number)
|
||||
{
|
||||
throw InvalidSlotIndexError();
|
||||
}
|
||||
|
||||
const auto list_index = target_attr.get_value<ATTR_UINT>();
|
||||
|
||||
// Then get the actual list reference for furthure operation.
|
||||
auto& target_list = _list_store[list_index];
|
||||
|
||||
return target_list[slot_index];
|
||||
}
|
||||
|
||||
// Check slot number for normal attributes.
|
||||
if (slot_index >= attr_def.slot_number)
|
||||
{
|
||||
throw InvalidSlotIndexError();
|
||||
}
|
||||
|
||||
// Get attribute for const and dynamic attribute.
|
||||
vector<Attribute>* target_block = nullptr;
|
||||
size_t node_size = 0;
|
||||
|
||||
if (attr_def.is_const)
|
||||
{
|
||||
target_block = &_const_block;
|
||||
node_size = _const_size_per_node;
|
||||
}
|
||||
else
|
||||
{
|
||||
target_block = &_dynamic_block;
|
||||
node_size = _dynamic_size_per_node;
|
||||
}
|
||||
|
||||
auto attr_offset = compose_attr_offset_in_node(node_index, node_size, attr_def.offset, slot_index);
|
||||
|
||||
return (*target_block)[attr_offset];
|
||||
}
|
||||
|
||||
inline Attribute& Node::get_list_attribute(NODE_INDEX node_index, ATTR_TYPE attr_type)
|
||||
{
|
||||
ensure_setup();
|
||||
ensure_node_index(node_index);
|
||||
|
||||
auto& attr_def = get_attr_definition(attr_type);
|
||||
|
||||
if (!attr_def.is_list)
|
||||
{
|
||||
throw OperationsOnNonListAttributeError();
|
||||
}
|
||||
|
||||
auto attr_offset = compose_attr_offset_in_node(node_index, _dynamic_size_per_node, attr_def.offset);
|
||||
auto& target_attr = _dynamic_block[attr_offset];
|
||||
|
||||
return target_attr;
|
||||
}
|
||||
|
||||
inline vector<Attribute>& Node::get_attribute_list(Attribute& attribute)
|
||||
{
|
||||
const auto& list_index = attribute.get_value<ATTR_UINT>();
|
||||
|
||||
auto& target_list = _list_store[list_index];
|
||||
|
||||
return target_list;
|
||||
}
|
||||
|
||||
void Node::clear_list(NODE_INDEX node_index, ATTR_TYPE attr_type)
|
||||
{
|
||||
auto& target_attr = get_list_attribute(node_index, attr_type);
|
||||
auto& target_list = get_attribute_list(target_attr);
|
||||
|
||||
target_list.clear();
|
||||
|
||||
target_attr.slot_number = 0;
|
||||
}
|
||||
|
||||
void Node::resize_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX new_size)
|
||||
{
|
||||
auto& target_attr = get_list_attribute(node_index, attr_type);
|
||||
auto& target_list = get_attribute_list(target_attr);
|
||||
|
||||
target_list.resize(new_size);
|
||||
|
||||
target_attr.slot_number = new_size;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Node::append_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, T value)
|
||||
{
|
||||
auto& target_attr = get_list_attribute(node_index, attr_type);
|
||||
auto& target_list = get_attribute_list(target_attr);
|
||||
|
||||
target_list.push_back(value);
|
||||
|
||||
target_attr.slot_number++;
|
||||
}
|
||||
|
||||
#define APPEND_TO_LIST(type) \
|
||||
template void Node::append_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, type value);
|
||||
|
||||
APPEND_TO_LIST(ATTR_CHAR)
|
||||
APPEND_TO_LIST(ATTR_UCHAR)
|
||||
APPEND_TO_LIST(ATTR_SHORT)
|
||||
APPEND_TO_LIST(ATTR_USHORT)
|
||||
APPEND_TO_LIST(ATTR_INT)
|
||||
APPEND_TO_LIST(ATTR_UINT)
|
||||
APPEND_TO_LIST(ATTR_LONG)
|
||||
APPEND_TO_LIST(ATTR_ULONG)
|
||||
APPEND_TO_LIST(ATTR_FLOAT)
|
||||
APPEND_TO_LIST(ATTR_DOUBLE)
|
||||
|
||||
void Node::remove_from_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index)
|
||||
{
|
||||
auto& target_attr = get_list_attribute(node_index, attr_type);
|
||||
auto& target_list = get_attribute_list(target_attr);
|
||||
|
||||
if(slot_index >= target_list.size())
|
||||
{
|
||||
throw InvalidSlotIndexError();
|
||||
}
|
||||
|
||||
target_list.erase(target_list.begin() + slot_index);
|
||||
|
||||
target_attr.slot_number--;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Node::insert_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, T value)
|
||||
{
|
||||
auto& target_attr = get_list_attribute(node_index, attr_type);
|
||||
auto& target_list = get_attribute_list(target_attr);
|
||||
|
||||
// NOTE: the insert index can same as size, then it is the last one.
|
||||
if(slot_index > target_list.size())
|
||||
{
|
||||
throw InvalidSlotIndexError();
|
||||
}
|
||||
|
||||
// Check if reach the max slot number
|
||||
if(target_list.size() >= MAX_SLOT_NUMBER)
|
||||
{
|
||||
throw MaxSlotNumberError();
|
||||
}
|
||||
|
||||
target_list.insert(target_list.begin() + slot_index, Attribute(value));
|
||||
|
||||
target_attr.slot_number++;
|
||||
}
|
||||
|
||||
#define INSERT_TO_LIST(type) \
|
||||
template void Node::insert_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, type value);
|
||||
|
||||
INSERT_TO_LIST(ATTR_CHAR)
|
||||
INSERT_TO_LIST(ATTR_UCHAR)
|
||||
INSERT_TO_LIST(ATTR_SHORT)
|
||||
INSERT_TO_LIST(ATTR_USHORT)
|
||||
INSERT_TO_LIST(ATTR_INT)
|
||||
INSERT_TO_LIST(ATTR_UINT)
|
||||
INSERT_TO_LIST(ATTR_LONG)
|
||||
INSERT_TO_LIST(ATTR_ULONG)
|
||||
INSERT_TO_LIST(ATTR_FLOAT)
|
||||
INSERT_TO_LIST(ATTR_DOUBLE)
|
||||
|
||||
const char* OperationsBeforeSetupError::what() const noexcept
|
||||
{
|
||||
return "Node has not been setup.";
|
||||
}
|
||||
|
||||
const char* InvalidAttributeDescError::what() const noexcept
|
||||
{
|
||||
return "Const attribute cannot be a list.";
|
||||
}
|
||||
|
||||
const char* InvalidNodeIndexError::what() const noexcept
|
||||
{
|
||||
return "Node index not exist.";
|
||||
}
|
||||
|
||||
const char* InvalidSlotIndexError::what() const noexcept
|
||||
{
|
||||
return "Slot index not exist.";
|
||||
}
|
||||
|
||||
const char* InvalidNodeNumberError::what() const noexcept
|
||||
{
|
||||
return "Node number must be greater than 0.";
|
||||
}
|
||||
|
||||
const char* InvalidAttributeTypeError::what() const noexcept
|
||||
{
|
||||
return "Attriute type note exist.";
|
||||
}
|
||||
|
||||
const char* OperationsAfterSetupError::what() const noexcept
|
||||
{
|
||||
return "Cannot add attribute after setup.";
|
||||
}
|
||||
|
||||
const char* OperationsOnNonListAttributeError::what() const noexcept
|
||||
{
|
||||
return "Append, clear and resize function only support for list attribute.";
|
||||
}
|
||||
|
||||
const char* MaxSlotNumberError::what() const noexcept
|
||||
{
|
||||
return "Reach the max number of slot.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,346 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef _MARO_BACKENDS_RAW_NODE_
|
||||
#define _MARO_BACKENDS_RAW_NODE_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "common.h"
|
||||
#include "attribute.h"
|
||||
#include "bitset.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
/// <summary>
|
||||
/// Extract attribute index from attribute type.
|
||||
/// </summary>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <returns>Index of this attribute in its node.</returns>
|
||||
inline USHORT extract_attr_index(ATTR_TYPE attr_type);
|
||||
|
||||
/// <summary>
|
||||
/// Compose the attribute offset in memory block.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance.</param>
|
||||
/// <param name="node_size">Per node size in related memory block.</param>
|
||||
/// <param name="attr_offset">Attribute offset in node instance.</param>
|
||||
/// <param name="slot">Slot index of attribute.</param>
|
||||
/// <returns>Attribute offset in memory block.</returns>
|
||||
inline size_t compose_attr_offset_in_node(NODE_INDEX node_index, size_t node_size, size_t attr_offset, SLOT_INDEX slot = 0);
|
||||
|
||||
/// <summary>
|
||||
/// Definition of attribute.
|
||||
/// </summary>
|
||||
struct AttributeDef
|
||||
{
|
||||
// Is this a list attribute.
|
||||
bool is_list;
|
||||
|
||||
// Is this a const attribute.
|
||||
bool is_const;
|
||||
|
||||
// Data type of this attribute.
|
||||
AttrDataType data_type;
|
||||
|
||||
// Number of slot, for fixed size only, list attribute'slot number samed in attribute class.
|
||||
SLOT_INDEX slot_number;
|
||||
|
||||
// Type of this attribute.
|
||||
ATTR_TYPE attr_type;
|
||||
|
||||
// Offset in each node instance, used to retrieve attribute from node instance.
|
||||
size_t offset;
|
||||
|
||||
// Name of attribute.
|
||||
string name;
|
||||
|
||||
AttributeDef(string name, AttrDataType data_type, SLOT_INDEX slot_number, size_t offset, bool is_list, bool is_const, ATTR_TYPE attr_type);
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Node type in memory, there is not node instance in physical, just a concept.
|
||||
/// </summary>
|
||||
class Node
|
||||
{
|
||||
friend class SnapshotList;
|
||||
friend class Frame;
|
||||
|
||||
private:
|
||||
// Memory block to hold dyanmic attributes, these attributes will be copied into snapshot list.
|
||||
vector<Attribute> _dynamic_block;
|
||||
|
||||
// Memory block to hold const attribute, these attributes will not be copied into snapshot list,
|
||||
// and its value can be set only one time.
|
||||
vector<Attribute> _const_block;
|
||||
|
||||
// Attribute defintions of this node type.
|
||||
vector<AttributeDef> _attribute_definitions;
|
||||
|
||||
// Used to store all the list of list attribute in this node.
|
||||
vector<vector<Attribute>> _list_store;
|
||||
|
||||
// Used to mark which node instance is alive.
|
||||
Bitset _node_instance_masks;
|
||||
|
||||
// Number of node instance that alive.
|
||||
NODE_INDEX _alive_node_number = 0;
|
||||
|
||||
// Max number of node instance we have (include deleted nodes), used to for padding in snapshot list.
|
||||
NODE_INDEX _max_node_number = 0;
|
||||
|
||||
// Node number from definition time, used for reset.
|
||||
NODE_INDEX _defined_node_number = 0;
|
||||
|
||||
// Size of each node instance (by attribute number).
|
||||
size_t _const_size_per_node = 0;
|
||||
size_t _dynamic_size_per_node = 0;
|
||||
|
||||
// Type of this node.
|
||||
NODE_TYPE _type = 0;
|
||||
|
||||
// Name of this node type.
|
||||
string _name;
|
||||
|
||||
// Is this node been setup.
|
||||
bool _is_setup = false;
|
||||
|
||||
// Copy content from source node, for taking snapshot.
|
||||
void copy_from(const Node& node, bool is_deep_copy = false);
|
||||
|
||||
// Make sure setup called.
|
||||
inline void ensure_setup() const;
|
||||
|
||||
// Make sure attribute index correct.
|
||||
inline void ensure_attr_index(USHORT attr_index) const;
|
||||
|
||||
// Make sure node index correct.
|
||||
inline void ensure_node_index(NODE_INDEX node_index) const;
|
||||
|
||||
// Get list attribute reference.
|
||||
inline Attribute& get_list_attribute(NODE_INDEX node_index, ATTR_TYPE attr_type);
|
||||
|
||||
// Get actual list of a list attribute
|
||||
inline vector<Attribute>& get_attribute_list(Attribute& attribute);
|
||||
public:
|
||||
Node();
|
||||
|
||||
Node(const Node& node);
|
||||
|
||||
Node& operator=(const Node& node);
|
||||
|
||||
/// <summary>
|
||||
/// Set type of this node.
|
||||
/// </summary>
|
||||
/// <param name="type">Type of this node.</param>
|
||||
void set_type(NODE_TYPE type) noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get type of this node.
|
||||
/// </summary>
|
||||
/// <returns>Type of this node.</returns>
|
||||
NODE_TYPE get_type() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Set name of this node.
|
||||
/// </summary>
|
||||
/// <param name="name">Name to set.</param>
|
||||
void set_name(string name) noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get name of this node.
|
||||
/// </summary>
|
||||
/// <returns>Name of this node.</returns>
|
||||
string get_name() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Set defined node number, this is the orign node number, used to reset.
|
||||
/// </summary>
|
||||
/// <param name="number">Number of node instance.</param>
|
||||
void set_defined_number(NODE_INDEX number);
|
||||
|
||||
/// <summary>
|
||||
/// Get predefined node instance number.
|
||||
/// </summary>
|
||||
/// <returns>Number of node instance.</returns>
|
||||
NODE_INDEX get_defined_number() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get current max node instance number.
|
||||
/// </summary>
|
||||
/// <returns>Number of max node instance for this node type.</returns>
|
||||
NODE_INDEX get_max_number() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get attribute definition.
|
||||
/// </summary>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <returns>Definition of specified attribute.</returns>
|
||||
const AttributeDef& get_attr_definition(ATTR_TYPE attr_type) const;
|
||||
|
||||
/// <summary>
|
||||
/// Check if specified node instance is alive.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of specified node instance.</param>
|
||||
/// <returns>True if node instance is alive, or false.</returns>
|
||||
bool is_node_alive(NODE_INDEX node_index) const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get slot number of specified attribute.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to query.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <returns>Slot number of specified attribute, this is the predefined one for normal attributes,
|
||||
/// and current list size for list attributes,
|
||||
/// </returns>
|
||||
SLOT_INDEX get_slot_number(NODE_INDEX node_index, ATTR_TYPE attr_type) const;
|
||||
|
||||
/// <summary>
|
||||
/// Initial this node.
|
||||
/// </summary>
|
||||
void setup();
|
||||
|
||||
/// <summary>
|
||||
/// Reset this node to intial state.
|
||||
/// </summary>
|
||||
void reset();
|
||||
|
||||
/// <summary>
|
||||
/// Append node instance for this node type.
|
||||
/// </summary>
|
||||
/// <param name="node_number">Number of new instance.</param>
|
||||
void append_nodes(NODE_INDEX node_number);
|
||||
|
||||
/// <summary>
|
||||
/// Remove a node instance.
|
||||
/// NOTE: this will not delete the attributes from memory, just mark them as deleted.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to remove.</param>
|
||||
void remove_node(NODE_INDEX node_index);
|
||||
|
||||
/// <summary>
|
||||
/// Resume a node instance.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to resume.</param>
|
||||
void resume_node(NODE_INDEX node_index);
|
||||
|
||||
/// <summary>
|
||||
/// Add an attribute to this node type.
|
||||
/// </summary>
|
||||
/// <param name="attr_name">Name of new attribute.</param>
|
||||
/// <param name="data_type">Data type of new attribute.</param>
|
||||
/// <param name="slot_number">Number of slot for new attribute.</param>
|
||||
/// <param name="is_const">Is a const attribute?</param>
|
||||
/// <param name="is_list">Is a list attribute?</param>
|
||||
/// <returns>Type of new attribute.</returns>
|
||||
ATTR_TYPE add_attr(string attr_name, AttrDataType data_type, SLOT_INDEX slot_number = 1, bool is_const = false, bool is_list = false);
|
||||
|
||||
/// <summary>
|
||||
/// Get specified attribute from an node instance.
|
||||
/// NOTE: this function only used for current frame, not for snapshot, as nodes in snapshot list do not contains
|
||||
/// attribute definition.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="slot_index">Slot index to query.</param>
|
||||
/// <returns>Specified attribute instance.</returns>
|
||||
Attribute& get_attr(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index);
|
||||
|
||||
/// <summary>
|
||||
/// Append a value to list attribute.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Data type</typeparam>
|
||||
/// <param name="node_index">Index of node instance to append.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="value">Value to append</param>
|
||||
template<typename T>
|
||||
void append_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, T value);
|
||||
|
||||
/// <summary>
|
||||
/// Clear values in a list attribute.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
void clear_list(NODE_INDEX node_index, ATTR_TYPE attr_type);
|
||||
|
||||
/// <summary>
|
||||
/// Resize size of a list attribute.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to resize.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="new_size">New size to resize.</param>
|
||||
void resize_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX new_size);
|
||||
|
||||
/// <summary>
|
||||
/// Remove an index from list attribute.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to resize.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
/// <param name="slot_index">Slot index to remove.</param>
|
||||
void remove_from_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index);
|
||||
|
||||
/// <summary>
|
||||
/// Insert a value to specified slot.
|
||||
/// </summary>
|
||||
/// <param name="node_index">Index of node instance to resize.</param>
|
||||
/// <param name="attr_type">Type of attribute.</param>
|
||||
template<typename T>
|
||||
void insert_to_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, T value);
|
||||
};
|
||||
|
||||
struct OperationsBeforeSetupError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct InvalidAttributeDescError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct InvalidNodeIndexError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct InvalidSlotIndexError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct InvalidNodeNumberError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct InvalidAttributeTypeError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct OperationsAfterSetupError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct OperationsOnNonListAttributeError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct MaxSlotNumberError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ! _MARO_BACKENDS_RAW_NODE_
|
|
@ -0,0 +1,532 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "snapshotlist.h"
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
inline void SnapshotList::ensure_cur_frame()
|
||||
{
|
||||
if (_cur_frame == nullptr)
|
||||
{
|
||||
throw SnapshotInvalidFrameStateError();
|
||||
}
|
||||
}
|
||||
|
||||
inline void SnapshotList::ensure_max_size()
|
||||
{
|
||||
if (_max_size == 0)
|
||||
{
|
||||
throw SnapshotSizeError();
|
||||
}
|
||||
}
|
||||
|
||||
void SnapshotList::set_max_size(USHORT max_size)
|
||||
{
|
||||
_max_size = max_size;
|
||||
|
||||
ensure_max_size();
|
||||
}
|
||||
|
||||
void SnapshotList::setup(Frame* frame)
|
||||
{
|
||||
_cur_frame = frame;
|
||||
|
||||
ensure_cur_frame();
|
||||
}
|
||||
|
||||
void SnapshotList::take_snapshot(int tick)
|
||||
{
|
||||
ensure_max_size();
|
||||
ensure_cur_frame();
|
||||
|
||||
// Try to remove exist tick.
|
||||
_snapshots.erase(tick);
|
||||
|
||||
// Remove oldest one if we reach the max size limitation.
|
||||
if (_snapshots.size() > 0 && _snapshots.size() >= _max_size)
|
||||
{
|
||||
_snapshots.erase(_snapshots.begin());
|
||||
}
|
||||
|
||||
// Copy current frame.
|
||||
_snapshots[tick] = *_cur_frame;
|
||||
}
|
||||
|
||||
UINT SnapshotList::size() const noexcept
|
||||
{
|
||||
return _snapshots.size();
|
||||
}
|
||||
|
||||
UINT SnapshotList::max_size() const noexcept
|
||||
{
|
||||
return _max_size;
|
||||
}
|
||||
|
||||
NODE_INDEX SnapshotList::get_max_node_number(NODE_TYPE node_type) const
|
||||
{
|
||||
auto& cur_node = _cur_frame->get_node(node_type);
|
||||
|
||||
return cur_node.get_max_number();
|
||||
}
|
||||
|
||||
void SnapshotList::reset()
|
||||
{
|
||||
_snapshots.clear();
|
||||
}
|
||||
|
||||
void SnapshotList::get_ticks(int* result) const
|
||||
{
|
||||
if (result == nullptr)
|
||||
{
|
||||
throw SnapshotQueryResultPtrNullError();
|
||||
}
|
||||
|
||||
auto i = 0;
|
||||
for (auto& iter : _snapshots)
|
||||
{
|
||||
result[i] = iter.first;
|
||||
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
SnapshotQueryResultShape SnapshotList::prepare(NODE_TYPE node_type, int ticks[], UINT tick_length, NODE_INDEX node_indices[], UINT node_length, ATTR_TYPE attributes[], UINT attr_length)
|
||||
{
|
||||
SnapshotQueryResultShape shape;
|
||||
|
||||
if (attributes == nullptr)
|
||||
{
|
||||
throw SnapshotQueryNoAttributesError();
|
||||
}
|
||||
|
||||
// Node in current frame, used to get attribute definition.
|
||||
auto& cur_node = _cur_frame->get_node(node_type);
|
||||
auto first_attr_type = attributes[0];
|
||||
auto& attr_definition = cur_node.get_attr_definition(first_attr_type);
|
||||
|
||||
// We use first attribute determine the type of current querying.
|
||||
_query_parameters.is_list = attr_definition.is_list;
|
||||
|
||||
shape.max_node_number = node_indices == nullptr ? cur_node.get_max_number() : node_length;
|
||||
shape.tick_number = ticks == nullptr ? _snapshots.size() : tick_length;
|
||||
|
||||
if (!_query_parameters.is_list)
|
||||
{
|
||||
// If it is not a list attriubte, then accept all attribute except list .
|
||||
for (UINT attr_index = 0; attr_index < attr_length; attr_index++)
|
||||
{
|
||||
auto attr_type = attributes[attr_index];
|
||||
auto& attr_def = cur_node.get_attr_definition(attr_type);
|
||||
|
||||
if (attr_def.is_list)
|
||||
{
|
||||
// warning and ignore it
|
||||
cerr << "Ignore list attribute: " << attr_def.name << " for fixed size attribute querying." << endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
shape.attr_number++;
|
||||
shape.max_slot_number = max(attr_def.slot_number, shape.max_slot_number);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// If it is a list attribute, then just use first one as querying attribute,
|
||||
// we only support query 1 list attribute (1st one) for 1 node at 1 tick each time to reduce too much padding.
|
||||
|
||||
// Make sure we have at least one tick.
|
||||
if (_snapshots.size() == 0)
|
||||
{
|
||||
throw SnapshotQueryNoSnapshotsError();
|
||||
}
|
||||
|
||||
// There must be 1 node index for list attribute querying.
|
||||
if (node_indices == nullptr)
|
||||
{
|
||||
throw SnapshotListQueryNoNodeIndexError();
|
||||
}
|
||||
|
||||
// 1 tick, 1 node and 1 attribute for list attribute querying.
|
||||
shape.attr_number = 1;
|
||||
shape.tick_number = 1;
|
||||
shape.max_node_number = 1;
|
||||
|
||||
// Use first tick in parameter, or latest tick in snapshot.
|
||||
int tick = ticks == nullptr ? _snapshots.rbegin()->first : ticks[0];
|
||||
auto target_node_index = node_indices[0];
|
||||
|
||||
// Check if tick exist.
|
||||
auto target_tick_pair = _snapshots.find(tick);
|
||||
|
||||
if (target_tick_pair == _snapshots.end())
|
||||
{
|
||||
throw SnapshotQueryNoSnapshotsError();
|
||||
}
|
||||
|
||||
auto& snapshot = target_tick_pair->second;
|
||||
auto& history_node = snapshot.get_node(node_type);
|
||||
|
||||
// Check if the node index exist.
|
||||
if (!history_node.is_node_alive(target_node_index))
|
||||
{
|
||||
throw SnapshotListQueryNoNodeIndexError();
|
||||
}
|
||||
|
||||
shape.max_slot_number = history_node.get_slot_number(target_node_index, first_attr_type);
|
||||
}
|
||||
|
||||
_query_parameters.ticks = ticks;
|
||||
_query_parameters.node_indices = node_indices;
|
||||
_query_parameters.attributes = attributes;
|
||||
|
||||
_query_parameters.node_type = node_type;
|
||||
|
||||
_query_parameters.max_slot_number = shape.max_slot_number;
|
||||
_query_parameters.attr_length = shape.attr_number;
|
||||
_query_parameters.tick_length = shape.tick_number;
|
||||
_query_parameters.node_length = shape.max_node_number;
|
||||
|
||||
_is_prepared = true;
|
||||
|
||||
return shape;
|
||||
}
|
||||
|
||||
void SnapshotList::query(QUERY_FLOAT* result)
|
||||
{
|
||||
if (!_is_prepared)
|
||||
{
|
||||
throw SnapshotQueryNotPreparedError();
|
||||
}
|
||||
|
||||
_is_prepared = false;
|
||||
|
||||
if (!_query_parameters.is_list)
|
||||
{
|
||||
query_for_normal(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
query_for_list(result);
|
||||
}
|
||||
|
||||
_query_parameters.reset();
|
||||
}
|
||||
|
||||
void SnapshotList::query_for_list(QUERY_FLOAT* result)
|
||||
{
|
||||
auto* ticks = _query_parameters.ticks;
|
||||
auto max_slot_number = _query_parameters.max_slot_number;
|
||||
auto tick = ticks == nullptr ? _snapshots.rbegin()->first : ticks[0];
|
||||
auto node_index = _query_parameters.node_indices[0];
|
||||
auto attr_type = _query_parameters.attributes[0];
|
||||
|
||||
// Go through all slots.
|
||||
for (UINT i = 0; i < max_slot_number; i++)
|
||||
{
|
||||
auto& attr = get_attr(tick, node_index, attr_type, i);
|
||||
|
||||
// Ignore nan for now, use default value from outside.
|
||||
if (!attr.is_nan())
|
||||
{
|
||||
result[i] = QUERY_FLOAT(attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SnapshotList::query_for_normal(QUERY_FLOAT* result)
|
||||
{
|
||||
auto node_type = _query_parameters.node_type;
|
||||
|
||||
// Node in current frame, used to get attribute defition and const value.
|
||||
auto& node = _cur_frame->get_node(node_type);
|
||||
|
||||
auto* ticks = _query_parameters.ticks;
|
||||
auto* node_indices = _query_parameters.node_indices;
|
||||
auto* attrs = _query_parameters.attributes;
|
||||
auto tick_length = _query_parameters.tick_length;
|
||||
auto node_length = _query_parameters.node_length;
|
||||
auto attr_length = _query_parameters.attr_length;
|
||||
auto max_slot_number = _query_parameters.max_slot_number;
|
||||
|
||||
vector<int> _ticks;
|
||||
|
||||
// Prepare ticks if no one provided.
|
||||
if (_query_parameters.ticks == nullptr)
|
||||
{
|
||||
tick_length = _snapshots.size();
|
||||
|
||||
for (auto& iter : _snapshots)
|
||||
{
|
||||
_ticks.push_back(iter.first);
|
||||
}
|
||||
}
|
||||
|
||||
vector<NODE_INDEX> _node_indices;
|
||||
|
||||
// Prepare node indices if no one provided.
|
||||
if (node_indices == nullptr)
|
||||
{
|
||||
node_length = node.get_max_number();
|
||||
|
||||
for (UINT i = 0; i < node_length; i++)
|
||||
{
|
||||
_node_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
const int* __ticks = ticks == nullptr ? &_ticks[0] : ticks;
|
||||
const NODE_INDEX* __node_indices = node_indices == nullptr ? &_node_indices[0] : node_indices;
|
||||
|
||||
// Index in result list.
|
||||
auto result_index = 0;
|
||||
|
||||
// Go through by tick -> node -> attribute -> slot.
|
||||
for (UINT i = 0; i < tick_length; i++)
|
||||
{
|
||||
auto tick = __ticks[i];
|
||||
|
||||
for (UINT j = 0; j < node_length; j++)
|
||||
{
|
||||
auto node_index = __node_indices[j];
|
||||
|
||||
for (UINT k = 0; k < attr_length; k++)
|
||||
{
|
||||
auto attr_type = attrs[k];
|
||||
|
||||
for (SLOT_INDEX slot_index = 0; slot_index < max_slot_number; slot_index++)
|
||||
{
|
||||
auto& attr = get_attr(tick, node_index, attr_type, slot_index);
|
||||
|
||||
if (!attr.is_nan())
|
||||
{
|
||||
result[result_index] = ATTR_FLOAT(attr);
|
||||
}
|
||||
|
||||
result_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SnapshotList::cancel_query() noexcept
|
||||
{
|
||||
_is_prepared = false;
|
||||
_query_parameters.reset();
|
||||
}
|
||||
|
||||
Attribute& SnapshotList::get_attr(int tick, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index) noexcept
|
||||
{
|
||||
NODE_TYPE node_type = extract_node_type(attr_type);
|
||||
|
||||
// check if node exist
|
||||
if(!_cur_frame->is_node_exist(node_type))
|
||||
{
|
||||
return _nan_attr;
|
||||
}
|
||||
|
||||
auto& cur_node = _cur_frame->get_node(node_type);
|
||||
const auto& attr_def = cur_node.get_attr_definition(attr_type);
|
||||
|
||||
// Check slot index for non-list attribute.
|
||||
if (!attr_def.is_list && slot_index >= attr_def.slot_number)
|
||||
{
|
||||
return _nan_attr;
|
||||
}
|
||||
|
||||
// If it is a const attribute, retrieve from const block,
|
||||
// we do not care if tick exist for const attribute.
|
||||
if (attr_def.is_const)
|
||||
{
|
||||
return cur_node.get_attr(node_index, attr_type, slot_index);
|
||||
}
|
||||
|
||||
auto target_tick_pair = _snapshots.find(tick);
|
||||
|
||||
// Check if tick valid.
|
||||
if (target_tick_pair == _snapshots.end())
|
||||
{
|
||||
return _nan_attr;
|
||||
}
|
||||
|
||||
auto& snapshot = target_tick_pair->second;
|
||||
auto& history_node = snapshot.get_node(node_type);
|
||||
|
||||
// Check if node index valid.
|
||||
if (node_index >= history_node._max_node_number || !history_node.is_node_alive(node_index))
|
||||
{
|
||||
return _nan_attr;
|
||||
}
|
||||
|
||||
if (attr_def.is_list)
|
||||
{
|
||||
auto attr_offset = compose_attr_offset_in_node(node_index, history_node._dynamic_size_per_node, attr_def.offset);
|
||||
auto& target_attr = history_node._dynamic_block[attr_offset];
|
||||
|
||||
const auto list_index = target_attr.get_value<ATTR_UINT>();
|
||||
|
||||
auto& target_list = history_node._list_store[list_index];
|
||||
|
||||
// Check slot for list attribute.
|
||||
if (slot_index >= target_list.size())
|
||||
{
|
||||
return _nan_attr;
|
||||
}
|
||||
|
||||
return target_list[slot_index];
|
||||
}
|
||||
|
||||
auto attr_offset = compose_attr_offset_in_node(node_index, history_node._dynamic_size_per_node, attr_def.offset, slot_index);
|
||||
|
||||
return history_node._dynamic_block[attr_offset];
|
||||
}
|
||||
|
||||
void SnapshotList::SnapshotQueryParameters::reset()
|
||||
{
|
||||
ticks = nullptr;
|
||||
attributes = nullptr;
|
||||
node_indices = nullptr;
|
||||
|
||||
tick_length = 0;
|
||||
node_length = 0;
|
||||
attr_length = 0;
|
||||
max_slot_number = 0;
|
||||
|
||||
is_list = false;
|
||||
}
|
||||
|
||||
inline void SnapshotList::write_attribute(ofstream &file, int tick, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index)
|
||||
{
|
||||
auto &attr = get_attr(tick, node_index, attr_type, slot_index);
|
||||
|
||||
if (attr.is_nan())
|
||||
{
|
||||
file << "nan";
|
||||
}
|
||||
else
|
||||
{
|
||||
file << ATTR_FLOAT(attr);
|
||||
}
|
||||
}
|
||||
|
||||
void SnapshotList::dump(string path)
|
||||
{
|
||||
for (auto& node : _cur_frame->_nodes)
|
||||
{
|
||||
auto full_path = path + "/" + "snapshots_" + node._name + ".csv";
|
||||
|
||||
ofstream file(full_path);
|
||||
|
||||
// Headers.
|
||||
file << "tick,node_index";
|
||||
|
||||
for(auto& attr_def : node._attribute_definitions)
|
||||
{
|
||||
file << "," << attr_def.name;
|
||||
}
|
||||
|
||||
// End of Headers.
|
||||
file << "\n";
|
||||
|
||||
// Rows.
|
||||
for(auto& snapshot_iter : _snapshots)
|
||||
{
|
||||
auto tick = snapshot_iter.first;
|
||||
auto& snapshot = snapshot_iter.second;
|
||||
auto& history_node = snapshot.get_node(node._type);
|
||||
|
||||
for (NODE_INDEX node_index = 0; node_index < node._max_node_number; node_index++)
|
||||
{
|
||||
// ignore deleted node
|
||||
if(!history_node.is_node_alive(node_index))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
file << tick << "," << node_index;
|
||||
|
||||
for(auto& attr_def : node._attribute_definitions)
|
||||
{
|
||||
if(!attr_def.is_list && attr_def.slot_number == 1)
|
||||
{
|
||||
file << ",";
|
||||
|
||||
write_attribute(file, tick, node_index, attr_def.attr_type, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
file << ",\"[";
|
||||
|
||||
auto slot_number = history_node.get_slot_number(node_index, attr_def.attr_type);
|
||||
|
||||
for(SLOT_INDEX slot_index = 0; slot_index < slot_number; slot_index++)
|
||||
{
|
||||
write_attribute(file, tick, node._type, attr_def.attr_type, slot_index);
|
||||
|
||||
file << ",";
|
||||
}
|
||||
|
||||
file << "]\"";
|
||||
}
|
||||
}
|
||||
|
||||
file << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
}
|
||||
|
||||
const char* SnapshotTickError::what() const noexcept
|
||||
{
|
||||
return "Invalid tick to take snapshot, same tick must be used sequentially.";
|
||||
}
|
||||
|
||||
const char* SnapshotSizeError::what() const noexcept
|
||||
{
|
||||
return "Invalid snapshot list max size, it must be larger than 0.";
|
||||
}
|
||||
|
||||
const char* SnapshotQueryNotPreparedError::what() const noexcept
|
||||
{
|
||||
return "Query must be after prepare function.";
|
||||
}
|
||||
|
||||
const char* SnapshotQueryNoAttributesError::what() const noexcept
|
||||
{
|
||||
return "Attribute list for query should contain at least 1.";
|
||||
}
|
||||
|
||||
const char* SnapshotInvalidFrameStateError::what() const noexcept
|
||||
{
|
||||
return "Not set frame before operations.";
|
||||
}
|
||||
|
||||
const char* SnapshotQueryResultPtrNullError::what() const noexcept
|
||||
{
|
||||
return "Result pointer is NULL.";
|
||||
}
|
||||
|
||||
const char* SnapshotQueryInvalidTickError::what() const noexcept
|
||||
{
|
||||
return "Only support one tick to query for list attribute, and the tick must exist.";
|
||||
}
|
||||
|
||||
const char* SnapshotQueryNoSnapshotsError::what() const noexcept
|
||||
{
|
||||
return "List attribute querying need at lease one snapshot, it does not support invalid tick padding.";
|
||||
}
|
||||
|
||||
const char* SnapshotListQueryNoNodeIndexError::what() const noexcept
|
||||
{
|
||||
return "List attribute querying need one alive node index.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,277 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef _MARO_BACKENDS_RAW_SNAPSHOTLIST_
|
||||
#define _MARO_BACKENDS_RAW_SNAPSHOTLIST_
|
||||
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
#include "common.h"
|
||||
#include "attribute.h"
|
||||
#include "node.h"
|
||||
#include "frame.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace maro
|
||||
{
|
||||
namespace backends
|
||||
{
|
||||
namespace raw
|
||||
{
|
||||
/// <summary>
|
||||
/// Shape of current querying.
|
||||
/// </summary>
|
||||
struct SnapshotQueryResultShape
|
||||
{
|
||||
// Number of attribute in result.
|
||||
USHORT attr_number = 0;
|
||||
|
||||
// Number of ticks in result.
|
||||
int tick_number = 0;
|
||||
|
||||
// Number of slot in result, include padding slot.
|
||||
SLOT_INDEX max_slot_number = 0;
|
||||
|
||||
// Number of node in result, include padding nodes.
|
||||
NODE_INDEX max_node_number = 0;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Snapshot list used to hold snapshot of current frame at specified tick.
|
||||
/// </summary>
|
||||
class SnapshotList
|
||||
{
|
||||
/// <summary>
|
||||
/// Querying parameter from prepare step.
|
||||
/// </summary>
|
||||
struct SnapshotQueryParameters
|
||||
{
|
||||
// Is this query for list?
|
||||
bool is_list = false;
|
||||
|
||||
// For furthur querying, these fields would be changed by prepare function.
|
||||
NODE_TYPE node_type = 0;
|
||||
|
||||
// List of ticks to query.
|
||||
int* ticks = nullptr;
|
||||
|
||||
// Number of ticks in tick list.
|
||||
UINT tick_length = 0;
|
||||
|
||||
// List of node instance index to query.
|
||||
NODE_INDEX* node_indices = nullptr;
|
||||
|
||||
// Node number
|
||||
UINT node_length = 0;
|
||||
|
||||
// Attributes to query.
|
||||
ATTR_TYPE* attributes = nullptr;
|
||||
|
||||
// Number of attribute to query.
|
||||
UINT attr_length = 0;
|
||||
|
||||
// Max slot number in result, for padding.
|
||||
SLOT_INDEX max_slot_number = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Reset current parameter after querying.
|
||||
/// </summary>
|
||||
void reset();
|
||||
};
|
||||
|
||||
|
||||
private:
|
||||
// Tick and its snapshot frame, we will keep a copy of frame.
|
||||
map<int, Frame> _snapshots;
|
||||
|
||||
// Max size of snapshot is memory.
|
||||
USHORT _max_size = 0;
|
||||
|
||||
// Current frame that used to copy.
|
||||
Frame* _cur_frame;
|
||||
|
||||
// Used to hold parameters from prepare function.
|
||||
SnapshotQueryParameters _query_parameters;
|
||||
|
||||
// Is prepare function called?
|
||||
bool _is_prepared = false;
|
||||
|
||||
// Default attribute for invalid attribute, for padding.
|
||||
Attribute _nan_attr = NAN;
|
||||
|
||||
// Query state for list attribute.
|
||||
// NOTE: for list attribute, we only support 1 tick, 1 attribute, 1 node.
|
||||
// and node cannot be null. If ticks not provided, then use latest tick.
|
||||
void query_for_list(QUERY_FLOAT* result);
|
||||
|
||||
// Query for normal attributes.
|
||||
void query_for_normal(QUERY_FLOAT* result);
|
||||
|
||||
// Get attribute from specified tick, this function will not throw exception, it will return a NAN attribute
|
||||
// if invalid.
|
||||
Attribute& get_attr(int tick, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index) noexcept;
|
||||
|
||||
// Make sure currect frame not null.
|
||||
inline void ensure_cur_frame();
|
||||
|
||||
// Make sure max size greater than 0.
|
||||
inline void ensure_max_size();
|
||||
|
||||
inline void write_attribute(ofstream &file, int tick, NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index);
|
||||
public:
|
||||
/// <summary>
|
||||
/// Set max size of snapshot in memory.
|
||||
/// </summary>
|
||||
/// <param name="max_size">Max size to set.</param>
|
||||
void set_max_size(USHORT max_size);
|
||||
|
||||
/// <summary>
|
||||
/// Setup snapshot list with current frame.
|
||||
/// </summary>
|
||||
/// <param name="frame">Current frame that used for snapshots.</param>
|
||||
void setup(Frame* frame);
|
||||
|
||||
/// <summary>
|
||||
/// Take snapshot for specified tick.
|
||||
/// </summary>
|
||||
/// <param name="ticks">Tick to take snapshot.</param>
|
||||
void take_snapshot(int ticks);
|
||||
|
||||
/// <summary>
|
||||
/// Current size of snapshots.
|
||||
/// </summary>
|
||||
/// <returns>Number of current snapshots.</returns>
|
||||
UINT size() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Get max size of current snapshot list.
|
||||
/// </summary>
|
||||
/// <returns>Max number of snapshot list.</returns>
|
||||
UINT max_size() const noexcept;
|
||||
|
||||
/// <summary>
|
||||
/// Reset snapshot list states.
|
||||
/// </summary>
|
||||
void reset();
|
||||
|
||||
/// <summary>
|
||||
/// Dump current snapshots into folder, node will be split into different files.
|
||||
/// </summary>
|
||||
void dump(string path);
|
||||
|
||||
/// <summary>
|
||||
/// Get avaiable ticks from snapshot list.
|
||||
/// </summary>
|
||||
/// <param name="result">List pointer to hold ticks.</param>
|
||||
void get_ticks(int* result) const;
|
||||
|
||||
/// <summary>
|
||||
/// Get current max node number for specified node type.
|
||||
/// </summary>
|
||||
/// <param name="node_type">Target node type.</param>
|
||||
/// <returns>Max node number.</returns>
|
||||
NODE_INDEX get_max_node_number(NODE_TYPE node_type) const;
|
||||
|
||||
/// <summary>
|
||||
/// Prepare for querying.
|
||||
/// </summary>
|
||||
/// <param name="node_type">Target node type.</param>
|
||||
/// <param name="ticks">Ticks to query, leave it as null to retrieve all avaible ticks from snapshots.
|
||||
/// NOTE: if it is null, then use latest tick for list attribute querying.</param>
|
||||
/// <param name="tick_length">Number of ticks to query.</param>
|
||||
/// <param name="node_indices">Indices of node instance to query, leave it as null to retrieve all node instance from snapshots.
|
||||
/// NOTE: it cannot be null if qury for list attribute</param>
|
||||
/// <param name="node_length">Number of node instance to query.</param>
|
||||
/// <param name="attributes">Attribute type list to query, cannot be null.
|
||||
/// NOTE: if first attribute if a list attribute, then there will be a list querying, means only support 1 tick, 1 node, 1 attribute querying.
|
||||
/// </param>
|
||||
/// <param name="attr_length">Target node type.</param>
|
||||
/// <returns>Result shape for input query parameters.</returns>
|
||||
SnapshotQueryResultShape prepare(NODE_TYPE node_type, int ticks[], UINT tick_length,
|
||||
NODE_INDEX node_indices[], UINT node_length, ATTR_TYPE attributes[], UINT attr_length);
|
||||
|
||||
/// <summary>
|
||||
/// Qeury with parameters from prepare function.
|
||||
/// </summary>
|
||||
/// <param name="result">Pointer to list to hold result value. NOTE: query function will leave the default value for padding.</param>
|
||||
void query(QUERY_FLOAT* result);
|
||||
|
||||
/// <summary>
|
||||
/// Cancel current querying, this will clear the parameters from last prepare calling.
|
||||
/// </summary>
|
||||
void cancel_query() noexcept;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Tick not supported, like negative tick
|
||||
/// </summary>
|
||||
struct SnapshotTickError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Snapshot list max size is 0
|
||||
/// </summary>
|
||||
struct SnapshotSizeError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Query without call prepare function
|
||||
/// </summary>
|
||||
struct SnapshotQueryNotPreparedError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Attribute not exist when querying
|
||||
/// </summary>
|
||||
struct SnapshotQueryNoAttributesError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Frame not set before operations
|
||||
/// </summary>
|
||||
struct SnapshotInvalidFrameStateError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Array pointer is nullptr
|
||||
/// </summary>
|
||||
struct SnapshotQueryResultPtrNullError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct SnapshotQueryInvalidTickError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct SnapshotQueryNoSnapshotsError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
|
||||
struct SnapshotListQueryNoNodeIndexError : public exception
|
||||
{
|
||||
const char* what() const noexcept override;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif // !_MARO_BACKENDS_RAW_SNAPSHOTLIST_
|
|
@ -2,11 +2,140 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
#distutils: language = c++
|
||||
|
||||
cimport cython
|
||||
|
||||
from cpython cimport bool
|
||||
from libcpp cimport bool as cppbool
|
||||
from libcpp.string cimport string
|
||||
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, INT, UINT, ULONG, USHORT, NODE_INDEX, SLOT_INDEX,
|
||||
ATTR_CHAR, ATTR_SHORT, ATTR_INT, ATTR_LONG, ATTR_FLOAT, ATTR_DOUBLE, QUERY_FLOAT, ATTR_TYPE, NODE_TYPE)
|
||||
|
||||
|
||||
# TODO: another implementation with c/c++ to support more features
|
||||
|
||||
from maro.backends.backend cimport BackendAbc
|
||||
cdef extern from "raw/common.h" namespace "maro::backends::raw":
|
||||
cdef cppclass AttrDataType:
|
||||
pass
|
||||
|
||||
|
||||
cdef extern from "raw/common.h" namespace "maro::backends::raw::AttrDataType":
|
||||
cdef AttrDataType ACHAR
|
||||
cdef AttrDataType AUCHAR
|
||||
cdef AttrDataType ASHORT
|
||||
cdef AttrDataType AUSHORT
|
||||
cdef AttrDataType AINT
|
||||
cdef AttrDataType AUINT
|
||||
cdef AttrDataType ALONG
|
||||
cdef AttrDataType AULONG
|
||||
cdef AttrDataType AFLOAT
|
||||
cdef AttrDataType ADOUBLE
|
||||
|
||||
|
||||
cdef extern from "raw/attribute.cpp":
|
||||
pass
|
||||
|
||||
|
||||
cdef extern from "raw/attribute.h" namespace "maro::backends::raw":
|
||||
cdef cppclass Attribute:
|
||||
pass
|
||||
|
||||
|
||||
cdef extern from "raw/bitset.h" namespace "maro::backends::raw":
|
||||
cdef cppclass Bitset:
|
||||
pass
|
||||
|
||||
|
||||
cdef extern from "raw/bitset.cpp" namespace "maro::backends::raw":
|
||||
pass
|
||||
|
||||
|
||||
cdef extern from "raw/node.h" namespace "maro::backends::raw":
|
||||
cdef cppclass Node:
|
||||
pass
|
||||
|
||||
|
||||
cdef extern from "raw/node.cpp" namespace "maro::backends::raw":
|
||||
pass
|
||||
|
||||
|
||||
cdef extern from "raw/frame.h" namespace "maro::backends::raw":
|
||||
cdef cppclass Frame:
|
||||
Frame()
|
||||
Frame(const Frame& frame)
|
||||
Frame& operator=(const Frame& frame)
|
||||
|
||||
NODE_TYPE add_node(string node_name, NODE_INDEX node_number)
|
||||
ATTR_TYPE add_attr(NODE_TYPE node_type, string attr_name, AttrDataType data_type, SLOT_INDEX slot_number, cppbool is_const, cppbool is_list)
|
||||
|
||||
void append_node(NODE_TYPE node_type, NODE_INDEX node_number)
|
||||
void resume_node(NODE_TYPE node_type, NODE_INDEX node_number)
|
||||
void remove_node(NODE_TYPE node_type, NODE_INDEX node_index)
|
||||
|
||||
T get_value[T](NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index)
|
||||
void set_value[T](NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, T value)
|
||||
|
||||
void append_to_list[T](NODE_INDEX node_index, ATTR_TYPE attr_type, T value)
|
||||
void clear_list(NODE_INDEX node_index, ATTR_TYPE attr_type)
|
||||
void resize_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX new_size)
|
||||
void remove_from_list(NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index)
|
||||
void insert_to_list[T](NODE_INDEX node_index, ATTR_TYPE attr_type, SLOT_INDEX slot_index, T value)
|
||||
|
||||
void setup()
|
||||
void reset()
|
||||
void dump(string path)
|
||||
|
||||
|
||||
cdef extern from "raw/frame.cpp" namespace "maro::backends::raw":
|
||||
pass
|
||||
|
||||
|
||||
cdef extern from "raw/snapshotlist.h" namespace "maro::backends::raw":
|
||||
cdef cppclass SnapshotList:
|
||||
void set_max_size(USHORT max_size)
|
||||
void setup(Frame* frame)
|
||||
|
||||
void take_snapshot(int ticks)
|
||||
|
||||
UINT size() const
|
||||
UINT max_size() const
|
||||
NODE_INDEX get_max_node_number(NODE_TYPE node_type) const
|
||||
|
||||
void reset()
|
||||
|
||||
void dump(string path)
|
||||
|
||||
void get_ticks(int* result) const
|
||||
|
||||
SnapshotQueryResultShape prepare(NODE_TYPE node_type, int ticks[], UINT tick_length, NODE_INDEX node_indices[], UINT node_length, ATTR_TYPE attributes[], UINT attr_length)
|
||||
void query(QUERY_FLOAT* result)
|
||||
void cancel_query()
|
||||
|
||||
cdef struct SnapshotQueryResultShape:
|
||||
USHORT attr_number
|
||||
int tick_number
|
||||
SLOT_INDEX max_slot_number
|
||||
NODE_INDEX max_node_number
|
||||
|
||||
|
||||
cdef extern from "raw/snapshotlist.cpp" namespace "maro::backends::raw":
|
||||
pass
|
||||
|
||||
|
||||
cdef class RawBackend(BackendAbc):
|
||||
pass
|
||||
cdef:
|
||||
Frame _frame
|
||||
|
||||
# node name -> ATTR_TYPE
|
||||
dict _node2type_dict
|
||||
|
||||
# attr_type -> dtype
|
||||
dict _attr_type_dict
|
||||
|
||||
dict _node_info
|
||||
|
||||
|
||||
cdef class RawSnapshotList(SnapshotListAbc):
|
||||
cdef:
|
||||
SnapshotList _snapshots
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
#cython: language_level=3
|
||||
|
||||
|
||||
from maro.backends.backend cimport BackendAbc
|
||||
|
||||
|
||||
cdef class RawBackend(BackendAbc):
|
||||
pass
|
|
@ -32,6 +32,8 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
_build_file_name = "vmtable.bin"
|
||||
|
||||
_meta_file_name = "vmtable.yml"
|
||||
# VM category includes three types, converting to 0, 1, 2.
|
||||
_category_map = {'Delay-insensitive': 0, 'Interactive': 1, 'Unknown': 2}
|
||||
|
||||
def __init__(self, topology: str, source: str, sample: int, seed: int, is_temp: bool = False):
|
||||
super().__init__(scenario="vm_scheduling", topology=topology, source=source, is_temp=is_temp)
|
||||
|
@ -90,7 +92,7 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
"""
|
||||
logger.info_green("Downloading vmtable and cpu readings.")
|
||||
# Download parts of cpu reading files.
|
||||
num_files = 10
|
||||
num_files = 195
|
||||
# Open the txt file which contains all the required urls.
|
||||
with open(self._download_file, mode="r", encoding="utf-8") as urls:
|
||||
for remote_url in urls.read().splitlines():
|
||||
|
@ -171,23 +173,39 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
# Preprocess.
|
||||
self._preprocess()
|
||||
|
||||
def _process_vm_table(self, raw_vm_table_file: str) -> pd.DataFrame:
|
||||
def _generate_id_map(self, old_id):
|
||||
num = len(old_id)
|
||||
new_id_list = [i for i in range(1, num + 1)]
|
||||
id_map = dict(zip(old_id, new_id_list))
|
||||
|
||||
return id_map
|
||||
|
||||
def _process_vm_table(self, raw_vm_table_file: str):
|
||||
"""Process vmtable file."""
|
||||
|
||||
headers = [
|
||||
'vmid', 'subscriptionid', 'deploymentid', 'vmcreated', 'vmdeleted', 'maxcpu', 'avgcpu', 'p95maxcpu',
|
||||
'vmcategory', 'vmcorecountbucket', 'vmmemorybucket'
|
||||
]
|
||||
required_headers = ['vmid', 'vmcreated', 'vmdeleted', 'vmcorecountbucket', 'vmmemorybucket']
|
||||
|
||||
required_headers = [
|
||||
'vmid', 'subscriptionid', 'deploymentid', 'vmcreated', 'vmdeleted', 'vmcategory',
|
||||
'vmcorecountbucket', 'vmmemorybucket'
|
||||
]
|
||||
|
||||
vm_table = pd.read_csv(raw_vm_table_file, header=None, index_col=False, names=headers)
|
||||
vm_table = vm_table.loc[:, required_headers]
|
||||
|
||||
# Convert to tick by dividing by 300 (5 minutes).
|
||||
vm_table['vmcreated'] = pd.to_numeric(vm_table['vmcreated'], errors="coerce", downcast="integer") // 300
|
||||
vm_table['vmdeleted'] = pd.to_numeric(vm_table['vmdeleted'], errors="coerce", downcast="integer") // 300
|
||||
# Transform vmcorecount '>24' bucket to 30 and vmmemory '>64' to 70.
|
||||
vm_table = vm_table.replace({'vmcorecountbucket': '>24'}, 30)
|
||||
vm_table = vm_table.replace({'vmmemorybucket': '>64'}, 70)
|
||||
# The lifetime of the VM is deleted time - created time + 1 (tick).
|
||||
vm_table['lifetime'] = vm_table['vmdeleted'] - vm_table['vmcreated'] + 1
|
||||
|
||||
vm_table['vmcategory'] = vm_table['vmcategory'].map(self._category_map)
|
||||
|
||||
# Transform vmcorecount '>24' bucket to 32 and vmmemory '>64' to 128.
|
||||
vm_table = vm_table.replace({'vmcorecountbucket': '>24'}, 32)
|
||||
vm_table = vm_table.replace({'vmmemorybucket': '>64'}, 128)
|
||||
vm_table['vmcorecountbucket'] = pd.to_numeric(
|
||||
vm_table['vmcorecountbucket'], errors="coerce", downcast="integer"
|
||||
)
|
||||
|
@ -195,25 +213,28 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
vm_table.dropna(inplace=True)
|
||||
|
||||
vm_table = vm_table.sort_values(by='vmcreated', ascending=True)
|
||||
# Generate new id column.
|
||||
vm_table = vm_table.reset_index(drop=True)
|
||||
vm_table['new_id'] = vm_table.index + 1
|
||||
vm_id_map = vm_table.set_index('vmid')['new_id']
|
||||
# Drop the original id column.
|
||||
vm_table = vm_table.drop(['vmid'], axis=1)
|
||||
# Reorder columns.
|
||||
vm_table = vm_table[['new_id', 'vmcreated', 'vmdeleted', 'vmcorecountbucket', 'vmmemorybucket']]
|
||||
# Rename column name.
|
||||
vm_table.rename(columns={'new_id': 'vmid'}, inplace=True)
|
||||
|
||||
# Generate ID map.
|
||||
vm_id_map = self._generate_id_map(vm_table['vmid'].unique())
|
||||
sub_id_map = self._generate_id_map(vm_table['subscriptionid'].unique())
|
||||
deployment_id_map = self._generate_id_map(vm_table['deploymentid'].unique())
|
||||
|
||||
id_maps = (vm_id_map, sub_id_map, deployment_id_map)
|
||||
|
||||
# Mapping IDs.
|
||||
vm_table['vmid'] = vm_table['vmid'].map(vm_id_map)
|
||||
vm_table['subscriptionid'] = vm_table['subscriptionid'].map(sub_id_map)
|
||||
vm_table['deploymentid'] = vm_table['deploymentid'].map(deployment_id_map)
|
||||
|
||||
# Sampling the VM table.
|
||||
# 2695548 is the total number of vms in the original Azure public dataset.
|
||||
if self._sample < 2695548:
|
||||
vm_table = vm_table.sample(n=self._sample, random_state=self._seed)
|
||||
vm_table = vm_table.sort_values(by='vmcreated', ascending=True)
|
||||
vm_id_map = vm_id_map[vm_id_map.isin(vm_table['vmid'])]
|
||||
|
||||
return vm_id_map, vm_table
|
||||
return id_maps, vm_table
|
||||
|
||||
def _convert_cpu_readings_id(self, old_data_path: str, new_data_path: str, vm_id_map: pd.DataFrame):
|
||||
def _convert_cpu_readings_id(self, old_data_path: str, new_data_path: str, vm_id_map: dict):
|
||||
"""Convert vmid in each cpu readings file."""
|
||||
with open(old_data_path, 'r') as f_in:
|
||||
csv_reader = reader(f_in)
|
||||
|
@ -223,16 +244,39 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
for row in csv_reader:
|
||||
# [timestamp, vmid, mincpu, maxcpu, avgcpu]
|
||||
if row[1] in vm_id_map:
|
||||
new_row = [int(row[0]) // 300, vm_id_map.loc[row[1]], round(float(row[3]), 2)]
|
||||
new_row = [int(row[0]) // 300, vm_id_map[row[1]], row[3]]
|
||||
csv_writer.writerow(new_row)
|
||||
|
||||
def _write_id_map_to_csv(self, id_maps):
|
||||
file_name = ['vm_id_map', 'sub_id_map', 'deployment_id_map']
|
||||
for index in range(len(id_maps)):
|
||||
id_map = id_maps[index]
|
||||
with open(os.path.join(self._raw_folder, file_name[index]) + '.csv', 'w') as f:
|
||||
csv_writer = writer(f)
|
||||
csv_writer.writerow(['original_id', 'new_id'])
|
||||
for key, value in id_map.items():
|
||||
csv_writer.writerow([key, value])
|
||||
|
||||
def _filter_out_vmid(self, vm_table: pd.DataFrame, vm_id_map: dict) -> dict:
|
||||
new_id_map = {}
|
||||
for key, value in vm_id_map.items():
|
||||
if value in vm_table.vmid.values:
|
||||
new_id_map[key] = value
|
||||
|
||||
return new_id_map
|
||||
|
||||
def _preprocess(self):
|
||||
logger.info_green("Process vmtable data.")
|
||||
# Process vmtable file.
|
||||
vm_id_map, vm_table = self._process_vm_table(raw_vm_table_file=self._raw_vm_table_file)
|
||||
id_maps, vm_table = self._process_vm_table(raw_vm_table_file=self._raw_vm_table_file)
|
||||
filtered_vm_id_map = self._filter_out_vmid(vm_table=vm_table, vm_id_map=id_maps[0])
|
||||
|
||||
with open(self._clean_file, mode="w", encoding="utf-8", newline="") as f:
|
||||
vm_table.to_csv(f, index=False, header=True)
|
||||
|
||||
logger.info_green("Writing id maps file.")
|
||||
self._write_id_map_to_csv(id_maps=id_maps)
|
||||
|
||||
logger.info_green("Reading cpu data.")
|
||||
# Process every cpu readings file.
|
||||
for clean_cpu_readings_file_name in self._clean_cpu_readings_file_name_list:
|
||||
|
@ -244,7 +288,7 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
self._convert_cpu_readings_id(
|
||||
old_data_path=raw_cpu_readings_file,
|
||||
new_data_path=clean_cpu_readings_file,
|
||||
vm_id_map=vm_id_map
|
||||
vm_id_map=filtered_vm_id_map
|
||||
)
|
||||
|
||||
def build(self):
|
||||
|
|
|
@ -3,7 +3,10 @@
|
|||
|
||||
from maro.rl.actor import AbsActor, SimpleActor
|
||||
from maro.rl.agent import AbsAgent, AbsAgentManager, AgentManagerMode, SimpleAgentManager
|
||||
from maro.rl.algorithms import DQN, AbsAlgorithm, DQNConfig
|
||||
from maro.rl.algorithms import (
|
||||
DQN, AbsAlgorithm, ActionInfo, ActorCritic, ActorCriticConfig, DQNConfig, PolicyGradient, PolicyOptimization,
|
||||
PolicyOptimizationConfig
|
||||
)
|
||||
from maro.rl.dist_topologies import (
|
||||
ActorProxy, ActorWorker, concat_experiences_by_agent, merge_experiences_with_trajectory_boundaries
|
||||
)
|
||||
|
@ -19,7 +22,8 @@ from maro.rl.storage import AbsStore, ColumnBasedStore, OverwriteType
|
|||
__all__ = [
|
||||
"AbsActor", "SimpleActor",
|
||||
"AbsAgent", "AbsAgentManager", "AgentManagerMode", "SimpleAgentManager",
|
||||
"AbsAlgorithm", "DQN", "DQNConfig",
|
||||
"AbsAlgorithm", "ActionInfo", "ActorCritic", "ActorCriticConfig", "DQN", "DQNConfig", "PolicyGradient",
|
||||
"PolicyOptimization", "PolicyOptimizationConfig",
|
||||
"ActorProxy", "ActorWorker", "concat_experiences_by_agent", "merge_experiences_with_trajectory_boundaries",
|
||||
"AbsExplorer", "EpsilonGreedyExplorer", "GaussianNoiseExplorer", "NoiseExplorer", "UniformNoiseExplorer",
|
||||
"AbsLearner", "SimpleLearner",
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
from maro.rl.algorithms.policy_optimization import ActionInfo
|
||||
from maro.rl.shaping.action_shaper import ActionShaper
|
||||
from maro.rl.shaping.experience_shaper import ExperienceShaper
|
||||
from maro.rl.shaping.state_shaper import StateShaper
|
||||
|
@ -45,15 +46,20 @@ class SimpleAgentManager(AbsAgentManager):
|
|||
def choose_action(self, decision_event, snapshot_list):
|
||||
self._assert_inference_mode()
|
||||
agent_id, model_state = self._state_shaper(decision_event, snapshot_list)
|
||||
model_action = self.agent_dict[agent_id].choose_action(model_state)
|
||||
action_info = self.agent_dict[agent_id].choose_action(model_state)
|
||||
self._transition_cache = {
|
||||
"state": model_state,
|
||||
"action": model_action,
|
||||
"reward": None,
|
||||
"agent_id": agent_id,
|
||||
"event": decision_event
|
||||
}
|
||||
return self._action_shaper(model_action, decision_event, snapshot_list)
|
||||
if isinstance(action_info, ActionInfo):
|
||||
self._transition_cache["action"] = action_info.action
|
||||
self._transition_cache["log_action_probability"] = action_info.log_probability
|
||||
else:
|
||||
self._transition_cache["action"] = action_info
|
||||
|
||||
return self._action_shaper(self._transition_cache["action"], decision_event, snapshot_list)
|
||||
|
||||
def on_env_feedback(self, metrics):
|
||||
"""This method records the environment-generated metrics as part of the latest transition in the trajectory.
|
||||
|
|
|
@ -3,5 +3,13 @@
|
|||
|
||||
from .abs_algorithm import AbsAlgorithm
|
||||
from .dqn import DQN, DQNConfig
|
||||
from .policy_optimization import (
|
||||
ActionInfo, ActorCritic, ActorCriticConfig, PolicyGradient, PolicyOptimization, PolicyOptimizationConfig
|
||||
)
|
||||
|
||||
__all__ = ["AbsAlgorithm", "DQN", "DQNConfig"]
|
||||
__all__ = [
|
||||
"AbsAlgorithm",
|
||||
"DQN", "DQNConfig",
|
||||
"ActionInfo", "ActorCritic", "ActorCriticConfig", "PolicyGradient", "PolicyOptimization",
|
||||
"PolicyOptimizationConfig"
|
||||
]
|
||||
|
|
|
@ -87,6 +87,7 @@ class DQN(AbsAlgorithm):
|
|||
if is_single:
|
||||
return greedy_action if np.random.random() > self._config.epsilon else np.random.choice(self._num_actions)
|
||||
|
||||
# batch inference
|
||||
return np.array([
|
||||
act if np.random.random() > self._config.epsilon else np.random.choice(self._num_actions)
|
||||
for act in greedy_action
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from maro.rl.algorithms.abs_algorithm import AbsAlgorithm
|
||||
from maro.rl.models.learning_model import LearningModel
|
||||
from maro.rl.utils.trajectory_utils import get_lambda_returns, get_truncated_cumulative_reward
|
||||
|
||||
ActionInfo = namedtuple("ActionInfo", ["action", "log_probability"])
|
||||
|
||||
|
||||
class PolicyOptimizationConfig:
|
||||
"""Configuration for the policy optimization algorithm family."""
|
||||
__slots__ = ["reward_discount"]
|
||||
|
||||
def __init__(self, reward_discount):
|
||||
self.reward_discount = reward_discount
|
||||
|
||||
|
||||
class PolicyOptimization(AbsAlgorithm):
|
||||
"""Policy optimization algorithm family.
|
||||
|
||||
The algorithm family includes policy gradient (e.g. REINFORCE), actor-critic, PPO, etc.
|
||||
"""
|
||||
def choose_action(self, state: np.ndarray) -> Union[ActionInfo, List[ActionInfo]]:
|
||||
"""Use the actor (policy) model to generate stochastic actions.
|
||||
|
||||
Args:
|
||||
state: Input to the actor model.
|
||||
|
||||
Returns:
|
||||
A single ActionInfo namedtuple or a list of ActionInfo namedtuples.
|
||||
"""
|
||||
state = torch.from_numpy(state).to(self._device)
|
||||
is_single = len(state.shape) == 1
|
||||
if is_single:
|
||||
state = state.unsqueeze(dim=0)
|
||||
|
||||
action_distribution = self._model(state, task_name="actor", is_training=False).squeeze().numpy()
|
||||
if is_single:
|
||||
action = np.random.choice(len(action_distribution), p=action_distribution)
|
||||
return ActionInfo(action=action, log_probability=np.log(action_distribution[action]))
|
||||
|
||||
# batch inference
|
||||
batch_results = []
|
||||
for distribution in action_distribution:
|
||||
action = np.random.choice(len(distribution), p=distribution)
|
||||
batch_results.append(ActionInfo(action=action, log_probability=np.log(distribution[action])))
|
||||
|
||||
return batch_results
|
||||
|
||||
def train(
|
||||
self, states: np.ndarray, actions: np.ndarray, log_action_prob: np.ndarray, rewards: np.ndarray
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PolicyGradient(PolicyOptimization):
|
||||
"""The vanilla Policy Gradient (VPG) algorithm, a.k.a., REINFORCE.
|
||||
|
||||
Reference: https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch.
|
||||
"""
|
||||
def train(
|
||||
self, states: np.ndarray, actions: np.ndarray, log_action_prob: np.ndarray, rewards: np.ndarray
|
||||
):
|
||||
states = torch.from_numpy(states).to(self._device)
|
||||
actions = torch.from_numpy(actions).to(self._device)
|
||||
returns = get_truncated_cumulative_reward(rewards, self._config.reward_discount)
|
||||
returns = torch.from_numpy(returns).to(self._device)
|
||||
action_distributions = self._model(states)
|
||||
action_prob = action_distributions.gather(1, actions.unsqueeze(1)).squeeze() # (N, 1)
|
||||
loss = -(torch.log(action_prob) * returns).mean()
|
||||
self._model.learn(loss)
|
||||
|
||||
|
||||
class ActorCriticConfig(PolicyOptimizationConfig):
|
||||
"""Configuration for the Actor-Critic algorithm.
|
||||
|
||||
Args:
|
||||
reward_discount (float): Reward decay as defined in standard RL terminology.
|
||||
critic_loss_func (Callable): Loss function for the critic model.
|
||||
train_iters (int): Number of gradient descent steps per call to ``train``.
|
||||
actor_loss_coefficient (float): The coefficient for actor loss in the total loss function, e.g.,
|
||||
loss = critic_loss + ``actor_loss_coefficient`` * actor_loss. Defaults to 1.0.
|
||||
k (int): Number of time steps used in computing returns or return estimates. Defaults to -1, in which case
|
||||
rewards are accumulated until the end of the trajectory.
|
||||
lam (float): Lambda coefficient used in computing lambda returns. Defaults to 1.0, in which case the usual
|
||||
k-step return is computed.
|
||||
clip_ratio (float): Clip ratio in the PPO algorithm (https://arxiv.org/pdf/1707.06347.pdf). Defaults to None,
|
||||
in which case the actor loss is calculated using the usual policy gradient theorem.
|
||||
"""
|
||||
__slots__ = [
|
||||
"reward_discount", "critic_loss_func", "train_iters", "actor_loss_coefficient", "k", "lam", "clip_ratio"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reward_discount: float,
|
||||
critic_loss_func: Callable,
|
||||
train_iters: int,
|
||||
actor_loss_coefficient: float = 1.0,
|
||||
k: int = -1,
|
||||
lam: float = 1.0,
|
||||
clip_ratio: float = None
|
||||
):
|
||||
super().__init__(reward_discount)
|
||||
self.critic_loss_func = critic_loss_func
|
||||
self.train_iters = train_iters
|
||||
self.actor_loss_coefficient = actor_loss_coefficient
|
||||
self.k = k
|
||||
self.lam = lam
|
||||
self.clip_ratio = clip_ratio
|
||||
|
||||
|
||||
class ActorCritic(PolicyOptimization):
|
||||
"""Actor Critic algorithm with separate policy and value models.
|
||||
|
||||
References:
|
||||
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch.
|
||||
https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f
|
||||
|
||||
Args:
|
||||
model (LearningModel): Multi-task model that computes action distributions and state values.
|
||||
It may or may not have a shared bottom stack.
|
||||
config: Configuration for the AC algorithm.
|
||||
"""
|
||||
def __init__(self, model: LearningModel, config: ActorCriticConfig):
|
||||
self.validate_task_names(model.task_names, {"actor", "critic"})
|
||||
super().__init__(model, config)
|
||||
|
||||
def _get_values_and_bootstrapped_returns(self, state_sequence, reward_sequence):
|
||||
state_values = self._model(state_sequence, task_name="critic").detach().squeeze()
|
||||
return_est = get_lambda_returns(
|
||||
reward_sequence, state_values, self._config.reward_discount, self._config.lam, k=self._config.k
|
||||
)
|
||||
return state_values, return_est
|
||||
|
||||
def train(
|
||||
self, states: np.ndarray, actions: np.ndarray, log_action_prob: np.ndarray, rewards: np.ndarray
|
||||
):
|
||||
states = torch.from_numpy(states).to(self._device)
|
||||
actions = torch.from_numpy(actions).to(self._device)
|
||||
log_action_prob = torch.from_numpy(log_action_prob).to(self._device)
|
||||
rewards = torch.from_numpy(rewards).to(self._device)
|
||||
state_values, return_est = self._get_values_and_bootstrapped_returns(states, rewards)
|
||||
advantages = return_est - state_values
|
||||
for _ in range(self._config.train_iters):
|
||||
critic_loss = self._config.critic_loss_func(
|
||||
self._model(states, task_name="critic").squeeze(), return_est
|
||||
)
|
||||
action_prob = self._model(states, task_name="actor").gather(1, actions.unsqueeze(1)).squeeze() # (N,)
|
||||
log_action_prob_new = torch.log(action_prob)
|
||||
actor_loss = self._actor_loss(log_action_prob_new, log_action_prob, advantages)
|
||||
loss = critic_loss + self._config.actor_loss_coefficient * actor_loss
|
||||
self._model.learn(loss)
|
||||
|
||||
def _actor_loss(self, log_action_prob_new, log_action_prob_old, advantages):
|
||||
if self._config.clip_ratio is not None:
|
||||
ratio = torch.exp(log_action_prob_new - log_action_prob_old)
|
||||
clip_ratio = torch.clamp(ratio, 1 - self._config.clip_ratio, 1 + self._config.clip_ratio)
|
||||
actor_loss = -(torch.min(ratio * advantages, clip_ratio * advantages)).mean()
|
||||
else:
|
||||
actor_loss = -(log_action_prob_new * advantages).mean()
|
||||
|
||||
return actor_loss
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from functools import reduce
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_truncated_cumulative_reward(
|
||||
rewards: Union[list, np.ndarray, torch.tensor],
|
||||
discount: float,
|
||||
k: int = -1
|
||||
):
|
||||
"""Compute K-step cumulative rewards from a reward sequence.
|
||||
Args:
|
||||
rewards (Union[list, np.ndarray, torch.tensor]): Reward sequence from a trajectory.
|
||||
discount (float): Reward discount as in standard RL.
|
||||
k (int): Number of steps in computing cumulative rewards. If it is -1, returns are computed using the
|
||||
largest possible number of steps. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
An ndarray or torch.tensor instance containing the k-step cumulative rewards for each time step.
|
||||
"""
|
||||
if k < 0:
|
||||
k = len(rewards) - 1
|
||||
pad = np.pad if isinstance(rewards, list) or isinstance(rewards, np.ndarray) else F.pad
|
||||
return reduce(
|
||||
lambda x, y: x * discount + y,
|
||||
[pad(rewards[i:], (0, i)) for i in range(min(k, len(rewards)) - 1, -1, -1)]
|
||||
)
|
||||
|
||||
|
||||
def get_k_step_returns(
|
||||
rewards: Union[list, np.ndarray, torch.tensor],
|
||||
values: Union[list, np.ndarray, torch.tensor],
|
||||
discount: float,
|
||||
k: int = -1
|
||||
):
|
||||
"""Compute K-step returns given reward and value sequences.
|
||||
Args:
|
||||
rewards (Union[list, np.ndarray, torch.tensor]): Reward sequence from a trajectory.
|
||||
values (Union[list, np.ndarray, torch.tensor]): Sequence of values for the traversed states in a trajectory.
|
||||
discount (float): Reward discount as in standard RL.
|
||||
k (int): Number of steps in computing returns. If it is -1, returns are computed using the largest possible
|
||||
number of steps. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
An ndarray or torch.tensor instance containing the k-step returns for each time step.
|
||||
"""
|
||||
assert len(rewards) == len(values), "rewards and values should have the same length"
|
||||
assert len(values.shape) == 1, "values should be a one-dimensional array"
|
||||
rewards[-1] = values[-1]
|
||||
if k < 0:
|
||||
k = len(rewards) - 1
|
||||
pad = np.pad if isinstance(rewards, list) or isinstance(rewards, np.ndarray) else F.pad
|
||||
return reduce(
|
||||
lambda x, y: x * discount + y,
|
||||
[pad(rewards[i:], (0, i)) for i in range(min(k, len(rewards)) - 1, -1, -1)],
|
||||
pad(values[k:], (0, k))
|
||||
)
|
||||
|
||||
|
||||
def get_lambda_returns(
|
||||
rewards: Union[list, np.ndarray, torch.tensor],
|
||||
values: Union[list, np.ndarray, torch.tensor],
|
||||
discount: float,
|
||||
lam: float,
|
||||
k: int = -1
|
||||
):
|
||||
"""Compute lambda returns given reward and value sequences and a k.
|
||||
Args:
|
||||
rewards (Union[list, np.ndarray, torch.tensor]): Reward sequence from a trajectory.
|
||||
values (Union[list, np.ndarray, torch.tensor]): Sequence of values for the traversed states in a trajectory.
|
||||
discount (float): Reward discount as in standard RL.
|
||||
lam (float): Lambda coefficient involved in computing lambda returns.
|
||||
k (int): Number of steps where the lambda return series is truncated. If it is -1, no truncating is done and
|
||||
the lambda return is carried out to the end of the sequence. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
An ndarray or torch.tensor instance containing the lambda returns for each time step.
|
||||
"""
|
||||
if k < 0:
|
||||
k = len(rewards) - 1
|
||||
|
||||
# If lambda is zero, lambda return reduces to one-step return
|
||||
if lam == .0:
|
||||
return get_k_step_returns(rewards, values, discount, k=1)
|
||||
|
||||
# If lambda is one, lambda return reduces to k-step return
|
||||
if lam == 1.0:
|
||||
return get_k_step_returns(rewards, values, discount, k=k)
|
||||
|
||||
k = min(k, len(rewards) - 1)
|
||||
pre_truncate = reduce(
|
||||
lambda x, y: x * lam + y,
|
||||
[get_k_step_returns(rewards, values, discount, k=k) for k in range(k - 1, 0, -1)]
|
||||
)
|
||||
|
||||
post_truncate = get_k_step_returns(rewards, values, discount, k=k) * lam**(k - 1)
|
||||
return (1 - lam) * pre_truncate + post_truncate
|
|
@ -1,27 +1,42 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import csv
|
||||
|
||||
|
||||
class PortOrderExporter:
|
||||
def __init__(self, enabled: bool = False):
|
||||
self._enabled = enabled
|
||||
self._orders = []
|
||||
|
||||
def add(self, order):
|
||||
if self._enabled:
|
||||
self._orders.append(order)
|
||||
|
||||
def dump(self, folder: str):
|
||||
if self._enabled:
|
||||
with open(f"{folder}/orders.csv", "w+", newline="") as fp:
|
||||
writer = csv.writer(fp)
|
||||
|
||||
writer.writerow(["tick", "src_port_idx", "dest_port_idx", "quantity"])
|
||||
|
||||
for order in self._orders:
|
||||
writer.writerow([order.tick, order.src_port_idx, order.dest_port_idx, order.quantity])
|
||||
|
||||
self._orders.clear()
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import csv
|
||||
|
||||
|
||||
class PortOrderExporter:
|
||||
"""Utils used to export full's source and target."""
|
||||
|
||||
def __init__(self, enabled: bool = False):
|
||||
self._enabled = enabled
|
||||
self._orders = []
|
||||
|
||||
def add(self, order):
|
||||
"""Add an order to export, it will be ignored if export is disabled.
|
||||
|
||||
Args:
|
||||
order (object): Order to export.
|
||||
"""
|
||||
if self._enabled:
|
||||
self._orders.append(order)
|
||||
|
||||
def dump(self, folder: str):
|
||||
"""Dump current orders to csv.
|
||||
|
||||
Args:
|
||||
folder (str): Folder to hold dump file.
|
||||
"""
|
||||
if self._enabled:
|
||||
with open(f"{folder}/orders.csv", "w+", newline="") as fp:
|
||||
writer = csv.writer(fp)
|
||||
|
||||
writer.writerow(
|
||||
["tick", "src_port_idx", "dest_port_idx", "quantity"]
|
||||
)
|
||||
|
||||
for order in self._orders:
|
||||
writer.writerow(
|
||||
[order.tick, order.src_port_idx, order.dest_port_idx, order.quantity])
|
||||
|
||||
self._orders.clear()
|
||||
|
|
|
@ -2,21 +2,17 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from .business_engine import VmSchedulingBusinessEngine
|
||||
from .common import AllocateAction, DecisionPayload, Latency, PostponeAction, PostponeType, VmRequestPayload
|
||||
from .common import AllocateAction, DecisionPayload, Latency, PostponeAction, VmRequestPayload
|
||||
from .cpu_reader import CpuReader
|
||||
from .events import Events
|
||||
from .enums import Events, PmState, PostponeType, VmCategory
|
||||
from .physical_machine import PhysicalMachine
|
||||
from .virtual_machine import VirtualMachine
|
||||
|
||||
__all__ = [
|
||||
"VmSchedulingBusinessEngine",
|
||||
"AllocateAction", "PostponeAction",
|
||||
"DecisionPayload",
|
||||
"Latency",
|
||||
"PostponeType",
|
||||
"VmRequestPayload",
|
||||
"AllocateAction", "PostponeAction", "DecisionPayload", "Latency", "VmRequestPayload",
|
||||
"CpuReader",
|
||||
"Events",
|
||||
"Events", "PmState", "PostponeType", "VmCategory",
|
||||
"PhysicalMachine",
|
||||
"VirtualMachine"
|
||||
]
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
|||
from yaml import safe_load
|
||||
|
||||
from maro.backends.frame import FrameBase, SnapshotList
|
||||
from maro.cli.data_pipeline.utils import download_file, StaticParameter
|
||||
from maro.cli.data_pipeline.utils import StaticParameter, download_file
|
||||
from maro.data_lib import BinaryReader
|
||||
from maro.event_buffer import CascadeEvent, EventBuffer, MaroEvents
|
||||
from maro.simulator.scenarios.abs_business_engine import AbsBusinessEngine
|
||||
|
@ -17,9 +17,9 @@ from maro.simulator.scenarios.helpers import DocableDict
|
|||
from maro.utils.logger import CliLogger
|
||||
from maro.utils.utils import convert_dottable
|
||||
|
||||
from .common import AllocateAction, DecisionPayload, Latency, PostponeAction, PostponeType, VmRequestPayload
|
||||
from .common import AllocateAction, DecisionPayload, Latency, PostponeAction, VmRequestPayload
|
||||
from .cpu_reader import CpuReader
|
||||
from .events import Events
|
||||
from .enums import Events, PmState, PostponeType, VmCategory
|
||||
from .frame_builder import build_frame
|
||||
from .physical_machine import PhysicalMachine
|
||||
from .virtual_machine import VirtualMachine
|
||||
|
@ -33,8 +33,11 @@ total_energy_consumption (float): Accumulative total PM energy consumption.
|
|||
successful_allocation (int): Accumulative successful VM allocation until now.
|
||||
successful_completion (int): Accumulative successful completion of tasks.
|
||||
failed_allocation (int): Accumulative failed VM allocation until now.
|
||||
failed_completion (int): Accumulative failed VM completion due to PM overloading.
|
||||
total_latency (Latency): Accumulative used buffer time until now.
|
||||
total_oversubscriptions (int): Accumulative over-subscriptions.
|
||||
total_oversubscriptions (int): Accumulative over-subscriptions. The unit is PM amount * tick.
|
||||
total_overload_pms (int): Accumulative overload pms. The unit is PM amount * tick.
|
||||
total_overload_vms (int): Accumulative VMs on overload pms. The unit is VM amount * tick.
|
||||
"""
|
||||
|
||||
logger = CliLogger(name=__name__)
|
||||
|
@ -57,23 +60,15 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
additional_options=additional_options
|
||||
)
|
||||
|
||||
# Env metrics.
|
||||
self._total_vm_requests: int = 0
|
||||
self._total_energy_consumption: float = 0
|
||||
self._successful_allocation: int = 0
|
||||
self._successful_completion: int = 0
|
||||
self._failed_allocation: int = 0
|
||||
self._total_latency: Latency = Latency()
|
||||
self._total_oversubscriptions: int = 0
|
||||
|
||||
# Initialize environment metrics.
|
||||
self._init_metrics()
|
||||
# Load configurations.
|
||||
self._load_configs()
|
||||
self._register_events()
|
||||
|
||||
self._init_frame()
|
||||
|
||||
# Initialize simulation data.
|
||||
self._init_data()
|
||||
|
||||
# PMs list used for quick accessing.
|
||||
self._init_pms()
|
||||
# All living VMs.
|
||||
|
@ -114,7 +109,26 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
|
||||
self._delay_duration: int = self._config.DELAY_DURATION
|
||||
self._buffer_time_budget: int = self._config.BUFFER_TIME_BUDGET
|
||||
self._pm_amount: int = self._config.PM.AMOUNT
|
||||
# Oversubscription rate.
|
||||
self._max_cpu_oversubscription_rate: float = self._config.MAX_CPU_OVERSUBSCRIPTION_RATE
|
||||
self._max_memory_oversubscription_rate: float = self._config.MAX_MEM_OVERSUBSCRIPTION_RATE
|
||||
self._max_utilization_rate: float = self._config.MAX_UTILIZATION_RATE
|
||||
# Load PM related configs.
|
||||
self._pm_amount: int = self._cal_pm_amount()
|
||||
self._kill_all_vms_if_overload = self._config.KILL_ALL_VMS_IF_OVERLOAD
|
||||
|
||||
def _init_metrics(self):
|
||||
# Env metrics.
|
||||
self._total_vm_requests: int = 0
|
||||
self._total_energy_consumption: float = 0.0
|
||||
self._successful_allocation: int = 0
|
||||
self._successful_completion: int = 0
|
||||
self._failed_allocation: int = 0
|
||||
self._failed_completion: int = 0
|
||||
self._total_latency: Latency = Latency()
|
||||
self._total_oversubscriptions: int = 0
|
||||
self._total_overload_pms: int = 0
|
||||
self._total_overload_vms: int = 0
|
||||
|
||||
def _init_data(self):
|
||||
"""If the file does not exist, then trigger the short data pipeline to download the processed data."""
|
||||
|
@ -127,31 +141,51 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
cpu_readings_data_path = os.path.expanduser(cpu_readings_data_path)
|
||||
|
||||
if (not os.path.exists(vm_table_data_path)) or (not os.path.exists(cpu_readings_data_path)):
|
||||
logger.info_green("Lack data. Start preparing data.")
|
||||
self._download_processed_data()
|
||||
logger.info_green("Data preparation is finished.")
|
||||
|
||||
def _cal_pm_amount(self) -> int:
|
||||
amount: int = 0
|
||||
for pm_type in self._config.PM:
|
||||
amount += pm_type["amount"]
|
||||
|
||||
return amount
|
||||
|
||||
def _init_pms(self):
|
||||
"""Initialize the physical machines based on the config setting. The PM id starts from 0."""
|
||||
self._pm_cpu_cores_capacity: int = self._config.PM.CPU
|
||||
self._pm_memory_capacity: int = self._config.PM.MEMORY
|
||||
|
||||
# TODO: Improve the scalability. Like the use of multiple PM sets.
|
||||
self._machines = self._frame.pms
|
||||
for pm_id in range(self._pm_amount):
|
||||
pm = self._machines[pm_id]
|
||||
pm.set_init_state(
|
||||
id=pm_id,
|
||||
cpu_cores_capacity=self._pm_cpu_cores_capacity,
|
||||
memory_capacity=self._pm_memory_capacity
|
||||
)
|
||||
# PM type dictionary.
|
||||
self._pm_type_dict: dict = {}
|
||||
pm_id = 0
|
||||
for pm_type in self._config.PM:
|
||||
amount = pm_type["amount"]
|
||||
self._pm_type_dict[pm_type["PM_type"]] = pm_type
|
||||
while amount > 0:
|
||||
pm = self._machines[pm_id]
|
||||
pm.set_init_state(
|
||||
id=pm_id,
|
||||
cpu_cores_capacity=pm_type["CPU"],
|
||||
memory_capacity=pm_type["memory"],
|
||||
pm_type=pm_type["PM_type"],
|
||||
oversubscribable=PmState.EMPTY
|
||||
)
|
||||
amount -= 1
|
||||
pm_id += 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset internal states for episode."""
|
||||
self._total_vm_requests: int = 0
|
||||
self._total_energy_consumption: float = 0.0
|
||||
self._successful_allocation: int = 0
|
||||
self._successful_completion: int = 0
|
||||
self._failed_allocation: int = 0
|
||||
self._failed_completion: int = 0
|
||||
self._total_latency: Latency = Latency()
|
||||
self._total_oversubscriptions: int = 0
|
||||
self._total_overload_pms: int = 0
|
||||
self._total_overload_vms: int = 0
|
||||
|
||||
self._frame.reset()
|
||||
self._snapshots.reset()
|
||||
|
@ -189,12 +223,15 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
self._update_pm_workload()
|
||||
|
||||
for vm in self._vm_item_picker.items(tick):
|
||||
# TODO: Calculate
|
||||
# TODO: Batch request support.
|
||||
vm_info = VirtualMachine(
|
||||
id=vm.vm_id,
|
||||
cpu_cores_requirement=vm.vm_cpu_cores,
|
||||
memory_requirement=vm.vm_memory,
|
||||
lifetime=vm.vm_deleted - vm.timestamp + 1
|
||||
lifetime=vm.vm_lifetime,
|
||||
sub_id=vm.sub_id,
|
||||
deployment_id=vm.deploy_id,
|
||||
category=VmCategory(vm.vm_category)
|
||||
)
|
||||
|
||||
if vm.vm_id not in cur_tick_cpu_utilization:
|
||||
|
@ -217,7 +254,12 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
# Update energy to the environment metrices.
|
||||
total_energy: float = 0.0
|
||||
for pm in self._machines:
|
||||
if pm.oversubscribable and pm.cpu_cores_allocated > pm.cpu_cores_capacity:
|
||||
self._total_oversubscriptions += 1
|
||||
total_energy += pm.energy_consumption
|
||||
# Overload PMs.
|
||||
if pm.cpu_utilization > 100:
|
||||
self._overload(pm.id)
|
||||
self._total_energy_consumption += total_energy
|
||||
|
||||
if (tick + 1) % self._snapshot_resolution == 0:
|
||||
|
@ -265,8 +307,11 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
successful_allocation=self._successful_allocation,
|
||||
successful_completion=self._successful_completion,
|
||||
failed_allocation=self._failed_allocation,
|
||||
failed_completion=self._failed_completion,
|
||||
total_latency=self._total_latency,
|
||||
total_oversubscriptions=self._total_oversubscriptions
|
||||
total_oversubscriptions=self._total_oversubscriptions,
|
||||
total_overload_pms=self._total_overload_pms,
|
||||
total_overload_vms=self._total_overload_vms
|
||||
)
|
||||
|
||||
def _register_events(self):
|
||||
|
@ -304,18 +349,42 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
vm = self._live_vms[vm_id]
|
||||
total_pm_cpu_cores_used += vm.cpu_utilization * vm.cpu_cores_requirement
|
||||
pm.update_cpu_utilization(vm=None, cpu_utilization=total_pm_cpu_cores_used / pm.cpu_cores_capacity)
|
||||
pm.energy_consumption = self._cpu_utilization_to_energy_consumption(cpu_utilization=pm.cpu_utilization)
|
||||
pm.energy_consumption = self._cpu_utilization_to_energy_consumption(
|
||||
pm_type=self._pm_type_dict[pm.pm_type],
|
||||
cpu_utilization=pm.cpu_utilization
|
||||
)
|
||||
|
||||
def _cpu_utilization_to_energy_consumption(self, cpu_utilization: float) -> float:
|
||||
def _overload(self, pm_id: int):
|
||||
"""Overload logic.
|
||||
|
||||
Currently only support killing all VMs on the overload PM and note them as failed allocations.
|
||||
"""
|
||||
# TODO: Future features of overload modeling.
|
||||
# 1. Performance degradation
|
||||
# 2. Quiesce specific VMs.
|
||||
pm: PhysicalMachine = self._machines[pm_id]
|
||||
vm_ids: List[int] = [vm_id for vm_id in pm.live_vms]
|
||||
|
||||
if self._kill_all_vms_if_overload:
|
||||
for vm_id in vm_ids:
|
||||
self._live_vms.pop(vm_id)
|
||||
|
||||
pm.deallocate_vms(vm_ids=vm_ids)
|
||||
self._failed_completion += len(vm_ids)
|
||||
|
||||
self._total_overload_vms += len(vm_ids)
|
||||
|
||||
def _cpu_utilization_to_energy_consumption(self, pm_type: dict, cpu_utilization: float) -> float:
|
||||
"""Convert the CPU utilization to energy consumption.
|
||||
|
||||
The formulation refers to https://dl.acm.org/doi/epdf/10.1145/1273440.1250665
|
||||
"""
|
||||
power: float = self._config.PM.POWER_CURVE.CALIBRATION_PARAMETER
|
||||
busy_power = self._config.PM.POWER_CURVE.BUSY_POWER
|
||||
idle_power = self._config.PM.POWER_CURVE.IDLE_POWER
|
||||
power: float = pm_type["power_curve"]["calibration_parameter"]
|
||||
busy_power: int = pm_type["power_curve"]["busy_power"]
|
||||
idle_power: int = pm_type["power_curve"]["idle_power"]
|
||||
|
||||
cpu_utilization /= 100
|
||||
cpu_utilization = min(1, cpu_utilization)
|
||||
|
||||
return idle_power + (busy_power - idle_power) * (2 * cpu_utilization - pow(cpu_utilization, power))
|
||||
|
||||
|
@ -342,19 +411,65 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
# Add failed allocation.
|
||||
self._failed_allocation += 1
|
||||
|
||||
def _get_valid_pms(self, vm_cpu_cores_requirement: int, vm_memory_requirement: int) -> List[int]:
|
||||
def _get_valid_pms(
|
||||
self, vm_cpu_cores_requirement: int, vm_memory_requirement: int, vm_category: VmCategory
|
||||
) -> List[int]:
|
||||
"""Check all valid PMs.
|
||||
|
||||
Args: vm_cpu_cores_requirement (int): The CPU cores requested by the VM.
|
||||
Args:
|
||||
vm_cpu_cores_requirement (int): The CPU cores requested by the VM.
|
||||
vm_memory_requirement (int): The memory requested by the VM.
|
||||
vm_category (VmCategory): The VM category. Delay-insensitive: 0, Interactive: 1, Unknown: 2.
|
||||
"""
|
||||
# NOTE: Should we implement this logic inside the action scope?
|
||||
# TODO: In oversubscribable scenario, we should consider more situations, like
|
||||
# the PM type (oversubscribable and non-oversubscribable).
|
||||
valid_pm_list = []
|
||||
|
||||
# Delay-insensitive: 0, Interactive: 1, and Unknown: 2.
|
||||
if vm_category == VmCategory.INTERACTIVE or vm_category == VmCategory.UNKNOWN:
|
||||
valid_pm_list = self._get_valid_non_oversubscribable_pms(
|
||||
vm_cpu_cores_requirement=vm_cpu_cores_requirement,
|
||||
vm_memory_requirement=vm_memory_requirement
|
||||
)
|
||||
else:
|
||||
valid_pm_list = self._get_valid_oversubscribable_pms(
|
||||
vm_cpu_cores_requirement=vm_cpu_cores_requirement,
|
||||
vm_memory_requirement=vm_memory_requirement
|
||||
)
|
||||
|
||||
return valid_pm_list
|
||||
|
||||
def _get_valid_non_oversubscribable_pms(self, vm_cpu_cores_requirement: int, vm_memory_requirement: int) -> list:
|
||||
valid_pm_list = []
|
||||
for pm in self._machines:
|
||||
if (pm.cpu_cores_capacity - pm.cpu_cores_allocated >= vm_cpu_cores_requirement and
|
||||
pm.memory_capacity - pm.memory_allocated >= vm_memory_requirement):
|
||||
valid_pm_list.append(pm.id)
|
||||
if pm.oversubscribable == PmState.EMPTY or pm.oversubscribable == PmState.NON_OVERSUBSCRIBABLE:
|
||||
# In the condition of non-oversubscription, the valid PMs mean:
|
||||
# PM allocated resource + VM allocated resource <= PM capacity.
|
||||
if (pm.cpu_cores_allocated + vm_cpu_cores_requirement <= pm.cpu_cores_capacity
|
||||
and pm.memory_allocated + vm_memory_requirement <= pm.memory_capacity):
|
||||
valid_pm_list.append(pm.id)
|
||||
|
||||
return valid_pm_list
|
||||
|
||||
def _get_valid_oversubscribable_pms(self, vm_cpu_cores_requirement: int, vm_memory_requirement: int) -> List[int]:
|
||||
valid_pm_list = []
|
||||
for pm in self._machines:
|
||||
if pm.oversubscribable == PmState.EMPTY or pm.oversubscribable == PmState.OVERSUBSCRIBABLE:
|
||||
# In the condition of oversubscription, the valid PMs mean:
|
||||
# 1. PM allocated resource + VM allocated resource <= Max oversubscription rate * PM capacity.
|
||||
# 2. PM CPU usage + VM requirements <= Max utilization rate * PM capacity.
|
||||
if (
|
||||
(
|
||||
pm.cpu_cores_allocated + vm_cpu_cores_requirement
|
||||
<= self._max_cpu_oversubscription_rate * pm.cpu_cores_capacity
|
||||
) and (
|
||||
pm.memory_allocated + vm_memory_requirement
|
||||
<= self._max_memory_oversubscription_rate * pm.memory_capacity
|
||||
) and (
|
||||
pm.cpu_utilization / 100 * pm.cpu_cores_capacity + vm_cpu_cores_requirement
|
||||
<= self._max_utilization_rate * pm.cpu_cores_capacity
|
||||
)
|
||||
):
|
||||
valid_pm_list.append(pm.id)
|
||||
|
||||
return valid_pm_list
|
||||
|
||||
|
@ -369,6 +484,9 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
pm.cpu_cores_allocated -= vm.cpu_cores_requirement
|
||||
pm.memory_allocated -= vm.memory_requirement
|
||||
pm.deallocate_vms(vm_ids=[vm.id])
|
||||
# If the VM list is empty, switch the state to empty.
|
||||
if not pm.live_vms:
|
||||
pm.oversubscribable = PmState.EMPTY
|
||||
|
||||
vm_id_list.append(vm.id)
|
||||
# VM completed task succeed.
|
||||
|
@ -390,12 +508,14 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
# Get valid pm list.
|
||||
valid_pm_list = self._get_valid_pms(
|
||||
vm_cpu_cores_requirement=vm_info.cpu_cores_requirement,
|
||||
vm_memory_requirement=vm_info.memory_requirement
|
||||
vm_memory_requirement=vm_info.memory_requirement,
|
||||
vm_category=vm_info.category
|
||||
)
|
||||
|
||||
if len(valid_pm_list) > 0:
|
||||
# Generate pending decision.
|
||||
decision_payload = DecisionPayload(
|
||||
frame_index=self.frame_index(tick=self._tick),
|
||||
valid_pms=valid_pm_list,
|
||||
vm_id=vm_info.id,
|
||||
vm_cpu_cores_requirement=vm_info.cpu_cores_requirement,
|
||||
|
@ -444,10 +564,18 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
self._pending_vm_request_payload.pop(vm_id)
|
||||
self._live_vms[vm_id] = vm
|
||||
|
||||
# TODO: Current logic can not fulfill the oversubscription case.
|
||||
|
||||
# Update PM resources requested by VM.
|
||||
pm = self._machines[pm_id]
|
||||
|
||||
# Empty pm (init state).
|
||||
if pm.oversubscribable == PmState.EMPTY:
|
||||
# Delay-Insensitive: oversubscribable.
|
||||
if vm.category == VmCategory.DELAY_INSENSITIVE:
|
||||
pm.oversubscribable = PmState.OVERSUBSCRIBABLE
|
||||
# Interactive or Unknown: non-oversubscribable
|
||||
else:
|
||||
pm.oversubscribable = PmState.NON_OVERSUBSCRIBABLE
|
||||
|
||||
pm.allocate_vms(vm_ids=[vm.id])
|
||||
pm.cpu_cores_allocated += vm.cpu_cores_requirement
|
||||
pm.memory_allocated += vm.memory_requirement
|
||||
|
@ -455,7 +583,10 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
vm=vm,
|
||||
cpu_utilization=None
|
||||
)
|
||||
pm.energy_consumption = self._cpu_utilization_to_energy_consumption(cpu_utilization=pm.cpu_utilization)
|
||||
pm.energy_consumption = self._cpu_utilization_to_energy_consumption(
|
||||
pm_type=self._pm_type_dict[pm.pm_type],
|
||||
cpu_utilization=pm.cpu_utilization
|
||||
)
|
||||
self._successful_allocation += 1
|
||||
elif type(action) == PostponeAction:
|
||||
postpone_step = action.postpone_step
|
||||
|
@ -483,16 +614,19 @@ class VmSchedulingBusinessEngine(AbsBusinessEngine):
|
|||
else:
|
||||
logger.info_green("File already exists, skipping download.")
|
||||
|
||||
logger.info_green(f"Unzip {download_file_path} to {build_folder}")
|
||||
# Unzip files.
|
||||
logger.info_green(f"Unzip {download_file_path} to {build_folder}")
|
||||
tar = tarfile.open(download_file_path, "r:gz")
|
||||
tar.extractall(path=build_folder)
|
||||
tar.close()
|
||||
|
||||
# Move to the correct path.
|
||||
unzip_file = os.path.join(build_folder, "build")
|
||||
file_names = os.listdir(unzip_file)
|
||||
for file_name in file_names:
|
||||
shutil.move(os.path.join(unzip_file, file_name), build_folder)
|
||||
for _, directories, _ in os.walk(build_folder):
|
||||
for directory in directories:
|
||||
unzip_file = os.path.join(build_folder, directory)
|
||||
logger.info_green(f"Move files to {build_folder} from {unzip_file}")
|
||||
for file_name in os.listdir(unzip_file):
|
||||
if file_name.endswith(".bin"):
|
||||
shutil.move(os.path.join(unzip_file, file_name), build_folder)
|
||||
|
||||
os.rmdir(unzip_file)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from .virtual_machine import VirtualMachine
|
||||
|
@ -59,22 +58,28 @@ class DecisionPayload:
|
|||
"""Decision event in VM Scheduling scenario that contains information for agent to choose action.
|
||||
|
||||
Args:
|
||||
frame_index (int): The current frame index (converted by tick).
|
||||
valid_pms (List[int]): A list contains pm id of all valid pms.
|
||||
vm_id (int): The id of the VM.
|
||||
vm_cpu_cores_requirement (int): The CPU requested by VM.
|
||||
vm_memory_requirement (int): The memory requested by VM.
|
||||
remaining_buffer_time (int): The remaining buffer time.
|
||||
"""
|
||||
summary_key = ["valid_pms", "vm_id", "vm_cpu_cores_requirement", "vm_memory_requirement", "remaining_buffer_time"]
|
||||
summary_key = [
|
||||
"frame_index", "valid_pms", "vm_id", "vm_cpu_cores_requirement", "vm_memory_requirement",
|
||||
"remaining_buffer_time"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_index: int,
|
||||
valid_pms: List[int],
|
||||
vm_id: int,
|
||||
vm_cpu_cores_requirement: int,
|
||||
vm_memory_requirement: int,
|
||||
remaining_buffer_time: int
|
||||
):
|
||||
self.frame_index = frame_index
|
||||
self.valid_pms = valid_pms
|
||||
self.vm_id = vm_id
|
||||
self.vm_cpu_cores_requirement = vm_cpu_cores_requirement
|
||||
|
@ -82,14 +87,6 @@ class DecisionPayload:
|
|||
self.remaining_buffer_time = remaining_buffer_time
|
||||
|
||||
|
||||
class PostponeType(Enum):
|
||||
"""Postpone type."""
|
||||
# Postpone the VM requirement due to the resource exhaustion.
|
||||
Resource = 'resource'
|
||||
# Postpone the VM requirement due to the agent's decision.
|
||||
Agent = 'agent'
|
||||
|
||||
|
||||
class Latency:
|
||||
"""Accumulative latency.
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
|
||||
class Events(Enum):
|
||||
"""VM-PM pairs related events."""
|
||||
# VM request events.
|
||||
REQUEST = "vm_required"
|
||||
|
||||
|
||||
class PostponeType(Enum):
|
||||
"""Postpone type."""
|
||||
# Postpone the VM requirement due to the resource exhaustion.
|
||||
Resource = "resource"
|
||||
# Postpone the VM requirement due to the agent's decision.
|
||||
Agent = "agent"
|
||||
|
||||
|
||||
class PmState(IntEnum):
|
||||
"""PM oversubscription state, includes empty, oversubscribable, non-oversubscribable."""
|
||||
NON_OVERSUBSCRIBABLE = -1
|
||||
EMPTY = 0
|
||||
OVERSUBSCRIBABLE = 1
|
||||
|
||||
|
||||
class VmCategory(IntEnum):
|
||||
DELAY_INSENSITIVE = 0
|
||||
INTERACTIVE = 1
|
||||
UNKNOWN = 2
|
|
@ -1,10 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Events(Enum):
|
||||
"""VM-PM pairs related events."""
|
||||
# VM request events.
|
||||
REQUEST = "vm_required"
|
|
@ -1,9 +1,4 @@
|
|||
AzurePublicDataset:
|
||||
vm_data:
|
||||
azure.2019.original:
|
||||
remote_url: "https://raw.githubusercontent.com/Azure/AzurePublicDataset/master/AzurePublicDatasetLinksV2.txt"
|
||||
sample: 2695548
|
||||
seed: 9
|
||||
azure.2019.336k:
|
||||
remote_url: "https://raw.githubusercontent.com/Azure/AzurePublicDataset/master/AzurePublicDatasetLinksV2.txt"
|
||||
sample: 336000
|
||||
|
|
|
@ -8,12 +8,24 @@ entity:
|
|||
vm_id:
|
||||
column: 'vmid'
|
||||
dtype: 'i'
|
||||
sub_id:
|
||||
column: 'subscriptionid'
|
||||
dtype: 'i'
|
||||
deploy_id:
|
||||
column: 'deploymentid'
|
||||
dtype: 'i'
|
||||
timestamp:
|
||||
column: 'vmcreated'
|
||||
dtype: 'i'
|
||||
vm_lifetime:
|
||||
column: 'lifetime'
|
||||
dtype: 'i'
|
||||
vm_deleted:
|
||||
column: 'vmdeleted'
|
||||
dtype: 'i'
|
||||
vm_category:
|
||||
column: 'vmcategory'
|
||||
dtype: 'i'
|
||||
vm_cpu_cores:
|
||||
column: 'vmcorecountbucket'
|
||||
dtype: 'i'
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import List, Set
|
|||
|
||||
from maro.backends.frame import NodeAttribute, NodeBase, node
|
||||
|
||||
from .enums import PmState
|
||||
from .virtual_machine import VirtualMachine
|
||||
|
||||
|
||||
|
@ -15,6 +16,7 @@ class PhysicalMachine(NodeBase):
|
|||
id = NodeAttribute("i")
|
||||
cpu_cores_capacity = NodeAttribute("i2")
|
||||
memory_capacity = NodeAttribute("i2")
|
||||
pm_type = NodeAttribute("i2")
|
||||
# Statistical features.
|
||||
cpu_cores_allocated = NodeAttribute("i2")
|
||||
memory_allocated = NodeAttribute("i2")
|
||||
|
@ -22,11 +24,16 @@ class PhysicalMachine(NodeBase):
|
|||
cpu_utilization = NodeAttribute("f")
|
||||
energy_consumption = NodeAttribute("f")
|
||||
|
||||
# PM type: non-oversubscribable is -1, empty: 0, oversubscribable is 1.
|
||||
oversubscribable = NodeAttribute("i2")
|
||||
|
||||
def __init__(self):
|
||||
"""Internal use for reset."""
|
||||
self._id = 0
|
||||
self._init_cpu_cores_capacity = 0
|
||||
self._init_memory_capacity = 0
|
||||
self._init_pm_type = 0
|
||||
self._init_pm_state = 0
|
||||
# PM resource.
|
||||
self._live_vms: Set[int] = set()
|
||||
|
||||
|
@ -42,17 +49,26 @@ class PhysicalMachine(NodeBase):
|
|||
|
||||
self.cpu_utilization = round(max(0, cpu_utilization), 2)
|
||||
|
||||
def set_init_state(self, id: int, cpu_cores_capacity: int, memory_capacity: int):
|
||||
def set_init_state(
|
||||
self, id: int, cpu_cores_capacity: int, memory_capacity: int, pm_type: int, oversubscribable: PmState = 0
|
||||
):
|
||||
"""Set initialize state, that will be used after frame reset.
|
||||
|
||||
Args:
|
||||
id (int): PM id, from 0 to N. N means the amount of PM, which can be set in config.
|
||||
cpu_cores_capacity (int): The capacity of cores of the PM, which can be set in config.
|
||||
memory_capacity (int): The capacity of memory of the PM, which can be set in config.
|
||||
pm_type (int): The type of the PM.
|
||||
oversubscribable (int): The state of the PM:
|
||||
- non-oversubscribable: -1.
|
||||
- empty: 0.
|
||||
- oversubscribable: 1.
|
||||
"""
|
||||
self._id = id
|
||||
self._init_cpu_cores_capacity = cpu_cores_capacity
|
||||
self._init_memory_capacity = memory_capacity
|
||||
self._init_pm_type = pm_type
|
||||
self._init_pm_state = oversubscribable
|
||||
|
||||
self.reset()
|
||||
|
||||
|
@ -62,6 +78,8 @@ class PhysicalMachine(NodeBase):
|
|||
self.id = self._id
|
||||
self.cpu_cores_capacity = self._init_cpu_cores_capacity
|
||||
self.memory_capacity = self._init_memory_capacity
|
||||
self.pm_type = self._init_pm_type
|
||||
self.oversubscribable = self._init_pm_state
|
||||
|
||||
self._live_vms.clear()
|
||||
|
||||
|
|
|
@ -11,17 +11,34 @@ CPU_READINGS: "~/.maro/data/vm_scheduling/.build/azure.2019.10k/vm_cpu_readings-
|
|||
|
||||
PROCESSED_DATA_URL: "https://marodatasource.blob.core.windows.net/vm-scheduling-azure/azure.2019.10k/azure.2019.10k.tar.gz"
|
||||
|
||||
# True means kill all VMs on the overload PM.
|
||||
# False means only count these VMs as failed allocation, but not kill them.
|
||||
KILL_ALL_VMS_IF_OVERLOAD: True
|
||||
|
||||
# Oversubscription configuration.
|
||||
# Max CPU oversubscription rate.
|
||||
MAX_CPU_OVERSUBSCRIPTION_RATE: 1.15
|
||||
# Max memory oversubscription rate.
|
||||
MAX_MEM_OVERSUBSCRIPTION_RATE: 1
|
||||
# Max CPU utilization rate.
|
||||
MAX_UTILIZATION_RATE: 1
|
||||
|
||||
PM:
|
||||
AMOUNT: 100
|
||||
CPU: 32
|
||||
# The unit of the memory is GB.
|
||||
MEMORY: 128
|
||||
# GPU: 4
|
||||
# NOTE: Energy consumption parameters should refer to more research.
|
||||
POWER_CURVE:
|
||||
# The calibration parameter used for the CPU utilization vs energy consumption model.
|
||||
CALIBRATION_PARAMETER: 1.4
|
||||
# The idle power usage of a machine.
|
||||
BUSY_POWER: 10
|
||||
# The busy power usage of a machine.
|
||||
IDLE_POWER: 1
|
||||
- PM_type: 0 # PM type is currently "int" only.
|
||||
amount: 100
|
||||
CPU: 32
|
||||
# GPU: 0
|
||||
memory: 128
|
||||
power_curve:
|
||||
calibration_parameter: 1.4
|
||||
busy_power: 10
|
||||
idle_power: 1
|
||||
# - PM_type: 1
|
||||
# amount: 50
|
||||
# CPU: 32
|
||||
# GPU: 0
|
||||
# memory: 128
|
||||
# power_curve:
|
||||
# calibration_parameter: 1.4
|
||||
# busy_power: 10
|
||||
# idle_power: 1
|
||||
|
|
|
@ -11,16 +11,34 @@ CPU_READINGS: "~/.maro/data/vm_scheduling/.build/azure.2019.336k/vm_cpu_readings
|
|||
|
||||
PROCESSED_DATA_URL: "https://marodatasource.blob.core.windows.net/vm-scheduling-azure/azure.2019.336k/azure.2019.336k.tar.gz"
|
||||
|
||||
# True means kill all VMs on the overload PM.
|
||||
# False means only count these VMs as failed allocation, but not kill them.
|
||||
KILL_ALL_VMS_IF_OVERLOAD: True
|
||||
|
||||
# Oversubscription configuration.
|
||||
# Max CPU oversubscription rate.
|
||||
MAX_CPU_OVERSUBSCRIPTION_RATE: 1.15
|
||||
# Max memory oversubscription rate.
|
||||
MAX_MEM_OVERSUBSCRIPTION_RATE: 1
|
||||
# Max CPU utilization rate.
|
||||
MAX_UTILIZATION_RATE: 1
|
||||
|
||||
PM:
|
||||
AMOUNT: 880
|
||||
CPU: 16
|
||||
MEMORY: 112
|
||||
# GPU: 4
|
||||
# NOTE: Energy consumption parameters should refer to more research.
|
||||
POWER_CURVE:
|
||||
# The calibration parameter used for the CPU utilization vs energy consumption model.
|
||||
CALIBRATION_PARAMETER: 1.4
|
||||
# The idle power usage of a machine.
|
||||
BUSY_POWER: 10
|
||||
# The busy power usage of a machine.
|
||||
IDLE_POWER: 1
|
||||
- PM_type: 0 # PM type is currently "int" only.
|
||||
amount: 880
|
||||
CPU: 16
|
||||
# GPU: 0
|
||||
memory: 112
|
||||
power_curve:
|
||||
calibration_parameter: 1.4
|
||||
busy_power: 10
|
||||
idle_power: 1
|
||||
# - PM_type: 1
|
||||
# amount: 440
|
||||
# CPU: 16
|
||||
# GPU: 0
|
||||
# memory: 112
|
||||
# power_curve:
|
||||
# calibration_parameter: 1.4
|
||||
# busy_power: 10
|
||||
# idle_power: 1
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from .enums import VmCategory
|
||||
|
||||
|
||||
class VirtualMachine:
|
||||
"""VM object.
|
||||
|
@ -19,12 +21,28 @@ class VirtualMachine:
|
|||
memory_requirement (int): The memory requested by VM. The unit is (GBs).
|
||||
lifetime (int): The lifetime of the VM, that is, deletion tick - creation tick.
|
||||
"""
|
||||
def __init__(self, id: int, cpu_cores_requirement: int, memory_requirement: int, lifetime: int):
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
cpu_cores_requirement: int,
|
||||
memory_requirement: int,
|
||||
lifetime: int,
|
||||
sub_id: int,
|
||||
deployment_id: int,
|
||||
category: VmCategory
|
||||
):
|
||||
# VM Requirement parameters.
|
||||
self.id: int = id
|
||||
self.cpu_cores_requirement: int = cpu_cores_requirement
|
||||
self.memory_requirement: int = memory_requirement
|
||||
self.lifetime: int = lifetime
|
||||
# The VM belong to a subscription.
|
||||
self.sub_id: int = sub_id
|
||||
# The region of PM that VM allocated (under a subscription) called a deployment group.
|
||||
self.deployment_id: int = deployment_id
|
||||
# The category of the VM. Now includes Delay-insensitive: 0, Interactive: 1, and Unknown: 2.
|
||||
self.category: VmCategory = category
|
||||
|
||||
# VM utilization list with VM cpu utilization(%) in corresponding tick.
|
||||
self._utilization_series: List[float] = []
|
||||
# The physical machine Id that the VM is assigned.
|
||||
|
|
|
@ -29,10 +29,74 @@ class BackendsSetItemInvalidException(MAROException):
|
|||
|
||||
|
||||
class BackendsArrayAttributeAccessException(MAROException):
|
||||
"""Exception then access attribute that slot number greater than 1.
|
||||
"""Exception when access attribute that slot number greater than 1.
|
||||
|
||||
This exception is caused when using invalid slice interface to access slots.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2102, ERROR_CODE[2102])
|
||||
|
||||
|
||||
class BackendsAppendToNonListAttributeException(MAROException):
|
||||
"""Exception when append value to a non list attribute.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2103, ERROR_CODE[2103])
|
||||
|
||||
|
||||
class BackendsResizeNonListAttributeException(MAROException):
|
||||
"""Exception when try to resize a non list attribute.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2104, ERROR_CODE[2104])
|
||||
|
||||
|
||||
class BackendsClearNonListAttributeException(MAROException):
|
||||
"""Exception when try to clear a non list attribute.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2105, ERROR_CODE[2105])
|
||||
|
||||
|
||||
class BackendsInsertNonListAttributeException(MAROException):
|
||||
"""Exception when try to insert a value to non list attribute.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2106, ERROR_CODE[2106])
|
||||
|
||||
|
||||
class BackendsRemoveFromNonListAttributeException(MAROException):
|
||||
"""Exception when try to from a value to non list attribute.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2107, ERROR_CODE[2107])
|
||||
|
||||
|
||||
class BackendsAccessDeletedNodeException(MAROException):
|
||||
"""Exception when try to access a deleted node.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2108, ERROR_CODE[2108])
|
||||
|
||||
|
||||
class BackendsInvalidNodeException(MAROException):
|
||||
"""Exception when try to access a not exist node type.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2109, ERROR_CODE[2109])
|
||||
|
||||
|
||||
class BackendsInvalidAttributeException(MAROException):
|
||||
"""Exception when try to access a not exist attribute type.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(2110, ERROR_CODE[2110])
|
||||
|
|
|
@ -28,6 +28,14 @@ ERROR_CODE = {
|
|||
2100: "Invalid parameter to get attribute value, please use tuple, list or slice instead",
|
||||
2101: "Invalid parameter to set attribute value",
|
||||
2102: "Cannot set value for frame fields directly if slot number more than 1, please use slice interface instead",
|
||||
2103: "Append method only support for list attribute.",
|
||||
2104: "Resize method only support for list attribute.",
|
||||
2105: "Clear method only fupport for list attribute.",
|
||||
2106: "Insert method only fupport for list attribute.",
|
||||
2107: "Remove method only fupport for list attribute.",
|
||||
2108: "Node already been deleted.",
|
||||
2109: "Node not exist.",
|
||||
2110: "Invalid attribute.",
|
||||
|
||||
# simulator
|
||||
2200: "Cannot find specified business engine",
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
chdir "%~dp0.."
|
||||
|
||||
python code_gen.py
|
||||
python setup.py sdist
|
|
@ -8,6 +8,8 @@ elif [[ "$OSTYPE" == "darwin"* ]]; then
|
|||
cd "$(cd "$(dirname "$0")"; pwd -P)/.."
|
||||
fi
|
||||
|
||||
python code_gen.py
|
||||
|
||||
bash ./scripts/compile_cython.sh
|
||||
|
||||
python setup.py sdist
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
# Generate code for raw backend attribute accessors.
|
||||
raw_backend_path = "maro/backends"
|
||||
|
||||
attr_type_list = [
|
||||
("ATTR_CHAR", "Char"),
|
||||
("ATTR_UCHAR", "UChar"),
|
||||
("ATTR_SHORT", "Short"),
|
||||
("ATTR_USHORT", "UShort"),
|
||||
("ATTR_INT", "Int"),
|
||||
("ATTR_UINT", "UInt"),
|
||||
("ATTR_LONG", "Long"),
|
||||
("ATTR_ULONG", "ULong"),
|
||||
("ATTR_FLOAT", "Float"),
|
||||
("ATTR_DOUBLE", "Double"),
|
||||
]
|
||||
|
||||
# Load template for attribute accessors.
|
||||
attr_acc_template = open(
|
||||
f"{raw_backend_path}/_raw_backend_attr_acc_.pyx.tml").read()
|
||||
|
||||
# Base code of raw backend.
|
||||
raw_backend_code = open(f"{raw_backend_path}/_raw_backend_.pyx").read()
|
||||
|
||||
# Real file we use to build.
|
||||
with open(f"{raw_backend_path}/raw_backend.pyx", "w+") as fp:
|
||||
fp.write(raw_backend_code)
|
||||
|
||||
# Append attribute accessor implementations to the end.
|
||||
for attr_type_pair in attr_type_list:
|
||||
attr_acc_def = attr_acc_template.format(T=attr_type_pair[0], CLSNAME=attr_type_pair[1])
|
||||
|
||||
fp.write(attr_acc_def)
|
|
@ -6,8 +6,11 @@ pip install -r .\maro\requirements.build.txt
|
|||
|
||||
REM delete old .c files
|
||||
|
||||
DEL /F .\maro\backends\*.c
|
||||
DEL /F .\maro\backends\*.cpp
|
||||
|
||||
REM generate code
|
||||
python scripts\code_gen.py
|
||||
|
||||
REM compile pyx into .c files
|
||||
REM use numpy backend, and use a big memory block to hold array
|
||||
cython .\maro\backends\backend.pyx .\maro\backends\np_backend.pyx .\maro\backends\raw_backend.pyx .\maro\backends\frame.pyx -3 -E FRAME_BACKEND=NUMPY,NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
cython .\maro\backends\backend.pyx .\maro\backends\np_backend.pyx .\maro\backends\raw_backend.pyx .\maro\backends\frame.pyx --cplus -3 -E NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
|
|
|
@ -9,8 +9,10 @@ fi
|
|||
pip install -r ./maro/requirements.build.txt
|
||||
|
||||
# delete old .c files
|
||||
rm -f ./maro/backends/*.c
|
||||
rm -f ./maro/backends/*.cpp
|
||||
|
||||
python scripts\code_gen.py
|
||||
|
||||
# compile pyx into .c files
|
||||
# use numpy backend, and use a big memory block to hold array
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx -3 -E FRAME_BACKEND=NUMPY,NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
cython ./maro/backends/backend.pyx ./maro/backends/np_backend.pyx ./maro/backends/raw_backend.pyx ./maro/backends/frame.pyx --cplus -3 -E NODES_MEMORY_LAYOUT=ONE_BLOCK -X embedsignature=True
|
||||
|
|
50
setup.py
50
setup.py
|
@ -3,6 +3,7 @@
|
|||
|
||||
import io
|
||||
import os
|
||||
import numpy
|
||||
|
||||
# NOTE: DO NOT change the import order, as sometimes there is a conflict between setuptools and distutils,
|
||||
# it will cause following error:
|
||||
|
@ -30,10 +31,6 @@ compile_conditions = {}
|
|||
# CURRENTLY we using environment variables to specified compiling conditions
|
||||
# TODO: used command line arguments instead
|
||||
|
||||
# specified frame backend
|
||||
FRAME_BACKEND = os.environ.get("FRAME_BACKEND", "NUMPY") # NUMPY or empty
|
||||
|
||||
|
||||
# include dirs for frame and its backend
|
||||
include_dirs = []
|
||||
|
||||
|
@ -41,37 +38,38 @@ include_dirs = []
|
|||
extensions.append(
|
||||
Extension(
|
||||
f"{BASE_MODULE_NAME}.backend",
|
||||
sources=[f"{BASE_SRC_PATH}/backend.c"])
|
||||
sources=[f"{BASE_SRC_PATH}/backend.cpp"],
|
||||
extra_compile_args=['-std=c++11'])
|
||||
)
|
||||
|
||||
if FRAME_BACKEND == "NUMPY":
|
||||
import numpy
|
||||
|
||||
include_dirs.append(numpy.get_include())
|
||||
include_dirs.append(numpy.get_include())
|
||||
|
||||
extensions.append(
|
||||
Extension(
|
||||
f"{BASE_MODULE_NAME}.np_backend",
|
||||
sources=[f"{BASE_SRC_PATH}/np_backend.c"],
|
||||
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
|
||||
include_dirs=include_dirs)
|
||||
)
|
||||
else:
|
||||
# raw implementation
|
||||
# NOTE: not implemented now
|
||||
extensions.append(
|
||||
Extension(
|
||||
f"{BASE_MODULE_NAME}.raw_backend",
|
||||
sources=[f"{BASE_SRC_PATH}/raw_backend.c"])
|
||||
)
|
||||
extensions.append(
|
||||
Extension(
|
||||
f"{BASE_MODULE_NAME}.np_backend",
|
||||
sources=[f"{BASE_SRC_PATH}/np_backend.cpp"],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=['-std=c++11'])
|
||||
)
|
||||
|
||||
# raw implementation
|
||||
# NOTE: not implemented now
|
||||
extensions.append(
|
||||
Extension(
|
||||
f"{BASE_MODULE_NAME}.raw_backend",
|
||||
sources=[f"{BASE_SRC_PATH}/raw_backend.cpp"],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=['-std=c++11'])
|
||||
)
|
||||
|
||||
# frame
|
||||
extensions.append(
|
||||
Extension(
|
||||
f"{BASE_MODULE_NAME}.frame",
|
||||
sources=[f"{BASE_SRC_PATH}/frame.c"],
|
||||
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
|
||||
include_dirs=include_dirs)
|
||||
sources=[f"{BASE_SRC_PATH}/frame.cpp"],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=['-std=c++11'])
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -2,15 +2,19 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import csv
|
||||
import unittest
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from collections import namedtuple
|
||||
from maro.event_buffer import EventBuffer, EventState
|
||||
from maro.simulator.scenarios.cim.business_engine import CimBusinessEngine, Events
|
||||
from maro.simulator.scenarios.cim.ports_order_export import PortOrderExporter
|
||||
from tests.utils import next_step
|
||||
from .mock_data_container import MockDataContainer
|
||||
|
||||
from tests.utils import next_step, backends_to_test
|
||||
|
||||
MAX_TICK = 20
|
||||
|
||||
|
@ -55,400 +59,440 @@ class TestCimScenarios(unittest.TestCase):
|
|||
pass
|
||||
|
||||
def test_init_state(self):
|
||||
eb: EventBuffer = None
|
||||
be: CimBusinessEngine = None
|
||||
eb, be = setup_case("case_01")
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
# check frame
|
||||
self.assertEqual(3, len(be.frame.ports), "static node number should be same with port number after "
|
||||
"initialization")
|
||||
self.assertEqual(2, len(be.frame.vessels), "dynamic node number should be same with vessel number "
|
||||
"after initialization")
|
||||
eb: EventBuffer = None
|
||||
be: CimBusinessEngine = None
|
||||
eb, be = setup_case("case_01")
|
||||
|
||||
# check snapshot
|
||||
self.assertEqual(MAX_TICK, len(be.snapshots),
|
||||
f"snapshots should be {MAX_TICK} after initialization")
|
||||
# check frame
|
||||
self.assertEqual(3, len(be.frame.ports), "static node number should be same with port number after "
|
||||
"initialization")
|
||||
self.assertEqual(2, len(be.frame.vessels), "dynamic node number should be same with vessel number "
|
||||
"after initialization")
|
||||
|
||||
# check snapshot
|
||||
self.assertEqual(0, len(be.snapshots), f"snapshots should be 0 after initialization")
|
||||
|
||||
def test_vessel_moving_correct(self):
|
||||
eb, be = setup_case("case_01")
|
||||
tick = 0
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
eb, be = setup_case("case_01")
|
||||
tick = 0
|
||||
|
||||
#####################################
|
||||
# STEP : beginning
|
||||
v = be._vessels[0]
|
||||
#####################################
|
||||
# STEP : beginning
|
||||
v = be._vessels[0]
|
||||
|
||||
self.assertEqual(0, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 0 at beginning")
|
||||
self.assertEqual(0, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 0 at beginning")
|
||||
self.assertEqual(
|
||||
0, v.next_loc_idx, "next_loc_idx of vessel 0 should be 0 at beginning")
|
||||
self.assertEqual(
|
||||
0, v.last_loc_idx, "last_loc_idx of vessel 0 should be 0 at beginning")
|
||||
|
||||
stop = be._data_cntr.vessel_stops[0, v.next_loc_idx]
|
||||
stop = be._data_cntr.vessel_stops[0, v.next_loc_idx]
|
||||
|
||||
self.assertEqual(0, stop.port_idx,
|
||||
"vessel 0 should parking at port 0 at beginning")
|
||||
self.assertEqual(0, stop.port_idx,
|
||||
"vessel 0 should parking at port 0 at beginning")
|
||||
|
||||
v = be._vessels[1]
|
||||
v = be._vessels[1]
|
||||
|
||||
self.assertEqual(0, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 0 at beginning")
|
||||
self.assertEqual(0, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 0 at beginning")
|
||||
self.assertEqual(
|
||||
0, v.next_loc_idx, "next_loc_idx of vessel 1 should be 0 at beginning")
|
||||
self.assertEqual(
|
||||
0, v.last_loc_idx, "last_loc_idx of vessel 1 should be 0 at beginning")
|
||||
|
||||
stop = be._data_cntr.vessel_stops[1, v.next_loc_idx]
|
||||
stop = be._data_cntr.vessel_stops[1, v.next_loc_idx]
|
||||
|
||||
self.assertEqual(1, stop.port_idx,
|
||||
"vessel 1 should parking at port 1 at beginning")
|
||||
self.assertEqual(1, stop.port_idx,
|
||||
"vessel 1 should parking at port 1 at beginning")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 2
|
||||
for i in range(3):
|
||||
#####################################
|
||||
# STEP : tick = 2
|
||||
for i in range(3):
|
||||
next_step(eb, be, tick)
|
||||
|
||||
tick += 1
|
||||
|
||||
v = be._vessels[0]
|
||||
|
||||
# if these 2 idx not equal, then means at sailing state
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 1 at tick 2")
|
||||
self.assertEqual(0, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 0 at tick 2")
|
||||
|
||||
v = be._vessels[1]
|
||||
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 1 at tick 2")
|
||||
self.assertEqual(0, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 0 at tick 2")
|
||||
|
||||
v = be.snapshots["matrices"][2::"vessel_plans"].flatten()
|
||||
|
||||
# since we already fixed the vessel plans, we just check the value
|
||||
for i in range(2):
|
||||
self.assertEqual(11, v[i*3+0])
|
||||
self.assertEqual(-1, v[i*3+1])
|
||||
self.assertEqual(13, v[i*3+2])
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 8
|
||||
for i in range(6):
|
||||
next_step(eb, be, tick)
|
||||
|
||||
tick += 1
|
||||
|
||||
v = be._vessels[0]
|
||||
|
||||
# vessel 0 parking
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 1 at tick 8")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 1 at tick 8")
|
||||
|
||||
stop = be._data_cntr.vessel_stops[0, v.next_loc_idx]
|
||||
|
||||
self.assertEqual(1, stop.port_idx,
|
||||
"vessel 0 should parking at port 1 at tick 8")
|
||||
|
||||
v = be._vessels[1]
|
||||
|
||||
# vessel 1 sailing
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 1 at tick 8")
|
||||
self.assertEqual(0, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 0 at tick 8")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 10
|
||||
for i in range(2):
|
||||
next_step(eb, be, tick)
|
||||
|
||||
tick += 1
|
||||
|
||||
v = be._vessels[0]
|
||||
|
||||
# vessel 0 parking
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 1 at tick 10")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 1 at tick 10")
|
||||
|
||||
v = be._vessels[1]
|
||||
|
||||
# vessel 1 parking
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 1 at tick 10")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 1 at tick 10")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 11
|
||||
for i in range(1):
|
||||
next_step(eb, be, tick)
|
||||
|
||||
tick += 1
|
||||
|
||||
v = be._vessels[0]
|
||||
|
||||
# vessel 0 parking
|
||||
self.assertEqual(2, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 2 at tick 11")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 1 at tick 11")
|
||||
|
||||
v = be._vessels[1]
|
||||
|
||||
# vessel 1 parking
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 1 at tick 11")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 1 at tick 11")
|
||||
|
||||
# move the env to next step, so it will take snapshot for current tick 11
|
||||
next_step(eb, be, tick)
|
||||
|
||||
tick += 1
|
||||
# we have hard coded the future stops, here we just check if the value correct at each tick
|
||||
for i in range(tick - 1):
|
||||
# check if the future stop at tick 8 (vessel 0 arrive at port 1)
|
||||
stop_list = be.snapshots["vessels"][i:0:[
|
||||
"past_stop_list", "past_stop_tick_list"]].flatten()
|
||||
|
||||
v = be._vessels[0]
|
||||
self.assertEqual(-1, stop_list[0])
|
||||
self.assertEqual(-1, stop_list[2])
|
||||
|
||||
# if these 2 idx not equal, then means at sailing state
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 1 at tick 2")
|
||||
self.assertEqual(0, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 0 at tick 2")
|
||||
stop_list = be.snapshots["vessels"][i:0:[
|
||||
"future_stop_list", "future_stop_tick_list"]].flatten()
|
||||
|
||||
v = be._vessels[1]
|
||||
self.assertEqual(2, stop_list[0])
|
||||
self.assertEqual(3, stop_list[1])
|
||||
self.assertEqual(4, stop_list[2])
|
||||
self.assertEqual(4, stop_list[3])
|
||||
self.assertEqual(10, stop_list[4])
|
||||
self.assertEqual(20, stop_list[5])
|
||||
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 1 at tick 2")
|
||||
self.assertEqual(0, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 0 at tick 2")
|
||||
# check if statistics data correct
|
||||
order_states = be.snapshots["ports"][i:0:[
|
||||
"shortage", "acc_shortage", "booking", "acc_booking"]].flatten()
|
||||
|
||||
v = be.snapshots["matrices"][2::"vessel_plans"]
|
||||
# all the value should be 0 for this case
|
||||
self.assertEqual(
|
||||
0, order_states[0], f"shortage of port 0 should be 0 at tick {i}")
|
||||
self.assertEqual(
|
||||
0, order_states[1], f"acc_shortage of port 0 should be 0 until tick {i}")
|
||||
self.assertEqual(
|
||||
0, order_states[2], f"booking of port 0 should be 0 at tick {i}")
|
||||
self.assertEqual(
|
||||
0, order_states[3], f"acc_booking of port 0 should be 0 until tick {i}")
|
||||
|
||||
# since we already fixed the vessel plans, we just check the value
|
||||
for i in range(2):
|
||||
self.assertEqual(11, v[i*3+0])
|
||||
self.assertEqual(-1, v[i*3+1])
|
||||
self.assertEqual(13, v[i*3+2])
|
||||
# check fulfillment
|
||||
fulfill_states = be.snapshots["ports"][i:0:[
|
||||
"fulfillment", "acc_fulfillment"]].flatten()
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 8
|
||||
for i in range(6):
|
||||
self.assertEqual(
|
||||
0, fulfill_states[0], f"fulfillment of port 0 should be 0 at tick {i}")
|
||||
self.assertEqual(
|
||||
0, fulfill_states[1], f"acc_fulfillment of port 0 should be 0 until tick {i}")
|
||||
|
||||
v = be.snapshots["matrices"][2:: "vessel_plans"].flatten()
|
||||
|
||||
# since we already fixed the vessel plans, we just check the value
|
||||
for i in range(2):
|
||||
self.assertEqual(11, v[i*3+0])
|
||||
self.assertEqual(-1, v[i*3+1])
|
||||
self.assertEqual(13, v[i*3+2])
|
||||
|
||||
def test_order_state(self):
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
eb, be = setup_case("case_02")
|
||||
tick = 0
|
||||
|
||||
p = be._ports[0]
|
||||
|
||||
self.assertEqual(
|
||||
0, p.booking, "port 0 have no booking at beginning")
|
||||
self.assertEqual(
|
||||
0, p.shortage, "port 0 have no shortage at beginning")
|
||||
self.assertEqual(
|
||||
100, p.empty, "port 0 have 100 empty containers at beginning")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 0
|
||||
for i in range(1):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# there should be 10 order generated at tick 0
|
||||
self.assertEqual(
|
||||
10, p.booking, "port 0 should have 10 bookings at tick 0")
|
||||
self.assertEqual(
|
||||
0, p.shortage, "port 0 have no shortage at tick 0")
|
||||
self.assertEqual(
|
||||
90, p.empty, "port 0 have 90 empty containers at tick 0")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 1
|
||||
for i in range(1):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# we have 0 booking, so no shortage
|
||||
self.assertEqual(
|
||||
0, p.booking, "port 0 should have 0 bookings at tick 1")
|
||||
self.assertEqual(
|
||||
0, p.shortage, "port 0 have no shortage at tick 1")
|
||||
self.assertEqual(
|
||||
90, p.empty, "port 0 have 90 empty containers at tick 1")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 3
|
||||
for i in range(2):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# there is an order that take 40 containers
|
||||
self.assertEqual(
|
||||
40, p.booking, "port 0 should have 40 booking at tick 3")
|
||||
self.assertEqual(
|
||||
0, p.shortage, "port 0 have no shortage at tick 3")
|
||||
self.assertEqual(
|
||||
50, p.empty, "port 0 have 90 empty containers at tick 3")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 7
|
||||
for i in range(4):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# there is an order that take 51 containers
|
||||
self.assertEqual(
|
||||
51, p.booking, "port 0 should have 51 booking at tick 7")
|
||||
self.assertEqual(1, p.shortage, "port 0 have 1 shortage at tick 7")
|
||||
self.assertEqual(
|
||||
0, p.empty, "port 0 have 0 empty containers at tick 7")
|
||||
|
||||
# push the simulator to next tick to update snapshot
|
||||
next_step(eb, be, tick)
|
||||
|
||||
tick += 1
|
||||
|
||||
v = be._vessels[0]
|
||||
|
||||
# vessel 0 parking
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 1 at tick 8")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 1 at tick 8")
|
||||
|
||||
stop = be._data_cntr.vessel_stops[0, v.next_loc_idx]
|
||||
|
||||
self.assertEqual(1, stop.port_idx,
|
||||
"vessel 0 should parking at port 1 at tick 8")
|
||||
|
||||
v = be._vessels[1]
|
||||
|
||||
# vessel 1 sailing
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 1 at tick 8")
|
||||
self.assertEqual(0, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 0 at tick 8")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 10
|
||||
for i in range(2):
|
||||
next_step(eb, be, tick)
|
||||
|
||||
tick += 1
|
||||
|
||||
v = be._vessels[0]
|
||||
|
||||
# vessel 0 parking
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 1 at tick 10")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 1 at tick 10")
|
||||
|
||||
v = be._vessels[1]
|
||||
|
||||
# vessel 1 parking
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 1 at tick 10")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 1 at tick 10")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 11
|
||||
for i in range(1):
|
||||
next_step(eb, be, tick)
|
||||
|
||||
tick += 1
|
||||
|
||||
v = be._vessels[0]
|
||||
|
||||
# vessel 0 parking
|
||||
self.assertEqual(2, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 0 should be 2 at tick 11")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 0 should be 1 at tick 11")
|
||||
|
||||
v = be._vessels[1]
|
||||
|
||||
# vessel 1 parking
|
||||
self.assertEqual(1, v.next_loc_idx,
|
||||
"next_loc_idx of vessel 1 should be 1 at tick 11")
|
||||
self.assertEqual(1, v.last_loc_idx,
|
||||
"last_loc_idx of vessel 1 should be 1 at tick 11")
|
||||
|
||||
# move the env to next step, so it will take snapshot for current tick 11
|
||||
next_step(eb, be, tick)
|
||||
|
||||
# we have hard coded the future stops, here we just check if the value correct at each tick
|
||||
for i in range(tick - 1):
|
||||
# check if the future stop at tick 8 (vessel 0 arrive at port 1)
|
||||
stop_list = be.snapshots["vessels"][i:0:[
|
||||
"past_stop_list", "past_stop_tick_list"]]
|
||||
|
||||
self.assertEqual(-1, stop_list[0])
|
||||
self.assertEqual(-1, stop_list[2])
|
||||
|
||||
stop_list = be.snapshots["vessels"][i:0:[
|
||||
"future_stop_list", "future_stop_tick_list"]]
|
||||
|
||||
self.assertEqual(2, stop_list[0])
|
||||
self.assertEqual(3, stop_list[1])
|
||||
self.assertEqual(4, stop_list[2])
|
||||
self.assertEqual(4, stop_list[3])
|
||||
self.assertEqual(10, stop_list[4])
|
||||
self.assertEqual(20, stop_list[5])
|
||||
# check if there is any container missing
|
||||
total_cntr_number = sum([port.empty for port in be._ports]) + \
|
||||
sum([vessel.empty for vessel in be._vessels]) + \
|
||||
sum([port.full for port in be._ports]) + \
|
||||
sum([vessel.full for vessel in be._vessels])
|
||||
|
||||
# NOTE: we flatten here, as raw backend query result has 4dim shape
|
||||
# check if statistics data correct
|
||||
order_states = be.snapshots["ports"][i:0:[
|
||||
"shortage", "acc_shortage", "booking", "acc_booking"]]
|
||||
order_states = be.snapshots["ports"][7:0:[
|
||||
"shortage", "acc_shortage", "booking", "acc_booking"]].flatten()
|
||||
|
||||
# all the value should be 0 for this case
|
||||
self.assertEqual(
|
||||
0, order_states[0], f"shortage of port 0 should be 0 at tick {i}")
|
||||
1, order_states[0], f"shortage of port 0 should be 0 at tick {i}")
|
||||
self.assertEqual(
|
||||
0, order_states[1], f"acc_shortage of port 0 should be 0 until tick {i}")
|
||||
1, order_states[1], f"acc_shortage of port 0 should be 0 until tick {i}")
|
||||
self.assertEqual(
|
||||
0, order_states[2], f"booking of port 0 should be 0 at tick {i}")
|
||||
51, order_states[2], f"booking of port 0 should be 0 at tick {i}")
|
||||
self.assertEqual(
|
||||
0, order_states[3], f"acc_booking of port 0 should be 0 until tick {i}")
|
||||
101, order_states[3], f"acc_booking of port 0 should be 0 until tick {i}")
|
||||
|
||||
# check fulfillment
|
||||
fulfill_states = be.snapshots["ports"][i:0:[
|
||||
"fulfillment", "acc_fulfillment"]]
|
||||
fulfill_states = be.snapshots["ports"][7:0:[
|
||||
"fulfillment", "acc_fulfillment"]].flatten()
|
||||
|
||||
self.assertEqual(
|
||||
0, fulfill_states[0], f"fulfillment of port 0 should be 0 at tick {i}")
|
||||
50, fulfill_states[0], f"fulfillment of port 0 should be 50 at tick {i}")
|
||||
self.assertEqual(
|
||||
0, fulfill_states[1], f"acc_fulfillment of port 0 should be 0 until tick {i}")
|
||||
|
||||
v = be.snapshots["matrices"][2:: "vessel_plans"]
|
||||
|
||||
# since we already fixed the vessel plans, we just check the value
|
||||
for i in range(2):
|
||||
self.assertEqual(11, v[i*3+0])
|
||||
self.assertEqual(-1, v[i*3+1])
|
||||
self.assertEqual(13, v[i*3+2])
|
||||
|
||||
def test_order_state(self):
|
||||
eb, be = setup_case("case_02")
|
||||
tick = 0
|
||||
|
||||
p = be._ports[0]
|
||||
|
||||
self.assertEqual(0, p.booking, "port 0 have no booking at beginning")
|
||||
self.assertEqual(0, p.shortage, "port 0 have no shortage at beginning")
|
||||
self.assertEqual(
|
||||
100, p.empty, "port 0 have 100 empty containers at beginning")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 0
|
||||
for i in range(1):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# there should be 10 order generated at tick 0
|
||||
self.assertEqual(
|
||||
10, p.booking, "port 0 should have 10 bookings at tick 0")
|
||||
self.assertEqual(0, p.shortage, "port 0 have no shortage at tick 0")
|
||||
self.assertEqual(
|
||||
90, p.empty, "port 0 have 90 empty containers at tick 0")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 1
|
||||
for i in range(1):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# we have 0 booking, so no shortage
|
||||
self.assertEqual(
|
||||
0, p.booking, "port 0 should have 0 bookings at tick 1")
|
||||
self.assertEqual(0, p.shortage, "port 0 have no shortage at tick 1")
|
||||
self.assertEqual(
|
||||
90, p.empty, "port 0 have 90 empty containers at tick 1")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 3
|
||||
for i in range(2):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# there is an order that take 40 containers
|
||||
self.assertEqual(
|
||||
40, p.booking, "port 0 should have 40 booking at tick 3")
|
||||
self.assertEqual(0, p.shortage, "port 0 have no shortage at tick 3")
|
||||
self.assertEqual(
|
||||
50, p.empty, "port 0 have 90 empty containers at tick 3")
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 7
|
||||
for i in range(4):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# there is an order that take 51 containers
|
||||
self.assertEqual(
|
||||
51, p.booking, "port 0 should have 51 booking at tick 7")
|
||||
self.assertEqual(1, p.shortage, "port 0 have 1 shortage at tick 7")
|
||||
self.assertEqual(
|
||||
0, p.empty, "port 0 have 0 empty containers at tick 7")
|
||||
|
||||
# push the simulator to next tick to update snapshot
|
||||
next_step(eb, be, tick)
|
||||
|
||||
# check if there is any container missing
|
||||
total_cntr_number = sum([port.empty for port in be._ports]) + \
|
||||
sum([vessel.empty for vessel in be._vessels]) + \
|
||||
sum([port.full for port in be._ports]) + \
|
||||
sum([vessel.full for vessel in be._vessels])
|
||||
|
||||
# check if statistics data correct
|
||||
order_states = be.snapshots["ports"][7:0:[
|
||||
"shortage", "acc_shortage", "booking", "acc_booking"]]
|
||||
|
||||
# all the value should be 0 for this case
|
||||
self.assertEqual(
|
||||
1, order_states[0], f"shortage of port 0 should be 0 at tick {i}")
|
||||
self.assertEqual(
|
||||
1, order_states[1], f"acc_shortage of port 0 should be 0 until tick {i}")
|
||||
self.assertEqual(
|
||||
51, order_states[2], f"booking of port 0 should be 0 at tick {i}")
|
||||
self.assertEqual(
|
||||
101, order_states[3], f"acc_booking of port 0 should be 0 until tick {i}")
|
||||
|
||||
# check fulfillment
|
||||
fulfill_states = be.snapshots["ports"][7:0:[
|
||||
"fulfillment", "acc_fulfillment"]]
|
||||
|
||||
self.assertEqual(
|
||||
50, fulfill_states[0], f"fulfillment of port 0 should be 50 at tick {i}")
|
||||
self.assertEqual(
|
||||
100, fulfill_states[1], f"acc_fulfillment of port 0 should be 100 until tick {i}")
|
||||
100, fulfill_states[1], f"acc_fulfillment of port 0 should be 100 until tick {i}")
|
||||
|
||||
def test_order_load_discharge_state(self):
|
||||
eb, be = setup_case("case_03")
|
||||
tick = 0
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 5
|
||||
for i in range(6):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
eb, be = setup_case("case_03")
|
||||
tick = 0
|
||||
|
||||
# check if we have load all 50 full container
|
||||
p = be._ports[0]
|
||||
v = be._vessels[0]
|
||||
#####################################
|
||||
# STEP : tick = 5
|
||||
for i in range(6):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
self.assertEqual(0, p.full, "port 0 should have no full at tick 5")
|
||||
self.assertEqual(
|
||||
50, v.full, "all 50 full container should be loaded on vessel 0")
|
||||
self.assertEqual(
|
||||
50, p.empty, "remaining empty should be 50 after order generated at tick 5")
|
||||
self.assertEqual(0, p.shortage, "no shortage at tick 5 for port 0")
|
||||
self.assertEqual(0, p.booking, "no booking at tick 5 for pot 0")
|
||||
# check if we have load all 50 full container
|
||||
p = be._ports[0]
|
||||
v = be._vessels[0]
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 10
|
||||
for i in range(5):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
self.assertEqual(0, p.full, "port 0 should have no full at tick 5")
|
||||
self.assertEqual(
|
||||
50, v.full, "all 50 full container should be loaded on vessel 0")
|
||||
self.assertEqual(
|
||||
50, p.empty, "remaining empty should be 50 after order generated at tick 5")
|
||||
self.assertEqual(0, p.shortage, "no shortage at tick 5 for port 0")
|
||||
self.assertEqual(0, p.booking, "no booking at tick 5 for pot 0")
|
||||
|
||||
# at tick 10 vessel 0 arrive at port 1, it should discharge all the full containers
|
||||
p1 = be._ports[1]
|
||||
#####################################
|
||||
# STEP : tick = 10
|
||||
for i in range(5):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
self.assertEqual(
|
||||
0, v.full, "all 0 full container on vessel 0 after arrive at port 1 at tick 10")
|
||||
self.assertEqual(50, p1.on_consignee,
|
||||
"there should be 50 full containers pending to be empty at tick 10 after discharge")
|
||||
self.assertEqual(0, p1.empty, "no empty for port 1 at tick 10")
|
||||
self.assertEqual(0, p1.full, "no full for port 1 at tick 10")
|
||||
# at tick 10 vessel 0 arrive at port 1, it should discharge all the full containers
|
||||
p1 = be._ports[1]
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 12
|
||||
for i in range(2):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
self.assertEqual(
|
||||
0, v.full, "all 0 full container on vessel 0 after arrive at port 1 at tick 10")
|
||||
self.assertEqual(50, p1.on_consignee,
|
||||
"there should be 50 full containers pending to be empty at tick 10 after discharge")
|
||||
self.assertEqual(0, p1.empty, "no empty for port 1 at tick 10")
|
||||
self.assertEqual(0, p1.full, "no full for port 1 at tick 10")
|
||||
|
||||
# we hard coded the buffer time to 2, so
|
||||
self.assertEqual(0, p1.on_consignee,
|
||||
"all the full become empty at tick 12 for port 1")
|
||||
self.assertEqual(
|
||||
50, p1.empty, "there will be 50 empty at tick 12 for port 1")
|
||||
#####################################
|
||||
# STEP : tick = 12
|
||||
for i in range(2):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
# we hard coded the buffer time to 2, so
|
||||
self.assertEqual(0, p1.on_consignee,
|
||||
"all the full become empty at tick 12 for port 1")
|
||||
self.assertEqual(
|
||||
50, p1.empty, "there will be 50 empty at tick 12 for port 1")
|
||||
|
||||
def test_early_discharge(self):
|
||||
eb, be = setup_case("case_04")
|
||||
tick = 0
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
p0 = be._ports[0]
|
||||
p1 = be._ports[1]
|
||||
p2 = be._ports[2]
|
||||
v = be._vessels[0]
|
||||
eb, be = setup_case("case_04")
|
||||
tick = 0
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 10
|
||||
for i in range(11):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
p0 = be._ports[0]
|
||||
p1 = be._ports[1]
|
||||
p2 = be._ports[2]
|
||||
v = be._vessels[0]
|
||||
|
||||
# at tick 10, vessel 0 arrive port 2, it already loaded 50 full, it need to load 50 at port 2, so it will early dicharge 10 empty
|
||||
self.assertEqual(
|
||||
0, v.empty, "vessel 0 should early discharge all the empty at tick 10")
|
||||
self.assertEqual(
|
||||
100, v.full, "vessel 0 should have 100 full on-board at tick 10")
|
||||
self.assertEqual(
|
||||
10, p2.empty, "port 2 have 10 more empty due to early discharge at tick 10")
|
||||
self.assertEqual(0, p2.full, "no full at port 2 at tick 10")
|
||||
#####################################
|
||||
# STEP : tick = 10
|
||||
for i in range(11):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 18
|
||||
for i in range(8):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
# at tick 10, vessel 0 arrive port 2, it already loaded 50 full, it need to load 50 at port 2, so it will early dicharge 10 empty
|
||||
self.assertEqual(
|
||||
0, v.empty, "vessel 0 should early discharge all the empty at tick 10")
|
||||
self.assertEqual(
|
||||
100, v.full, "vessel 0 should have 100 full on-board at tick 10")
|
||||
self.assertEqual(
|
||||
10, p2.empty, "port 2 have 10 more empty due to early discharge at tick 10")
|
||||
self.assertEqual(0, p2.full, "no full at port 2 at tick 10")
|
||||
|
||||
# at tick 18, vessel 0 arrive at port 1, it will discharge all the full
|
||||
self.assertEqual(
|
||||
0, v.empty, "vessel 0 should have no empty at tick 18")
|
||||
self.assertEqual(
|
||||
0, v.full, "vessel 0 should discharge all full on-board at tick 18")
|
||||
self.assertEqual(100, p1.on_consignee,
|
||||
"100 full pending to become empty at port 1 at tick 18")
|
||||
self.assertEqual(0, p1.empty, "no empty for port 1 at tick 18")
|
||||
#####################################
|
||||
# STEP : tick = 18
|
||||
for i in range(8):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
#####################################
|
||||
# STEP : tick = 20
|
||||
for i in range(2):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
# at tick 18, vessel 0 arrive at port 1, it will discharge all the full
|
||||
self.assertEqual(
|
||||
0, v.empty, "vessel 0 should have no empty at tick 18")
|
||||
self.assertEqual(
|
||||
0, v.full, "vessel 0 should discharge all full on-board at tick 18")
|
||||
self.assertEqual(
|
||||
100, p1.on_consignee, "100 full pending to become empty at port 1 at tick 18")
|
||||
self.assertEqual(0, p1.empty, "no empty for port 1 at tick 18")
|
||||
|
||||
self.assertEqual(
|
||||
100, p1.empty, "there should be 100 empty at tick 20 at port 1")
|
||||
#####################################
|
||||
# STEP : tick = 20
|
||||
for i in range(2):
|
||||
next_step(eb, be, tick)
|
||||
tick += 1
|
||||
|
||||
self.assertEqual(
|
||||
100, p1.empty, "there should be 100 empty at tick 20 at port 1")
|
||||
|
||||
def test_order_export(self):
|
||||
"""order.tick, order.src_port_idx, order.dest_port_idx, order.quantity"""
|
||||
Order = namedtuple("Order", ["tick", "src_port_idx", "dest_port_idx", "quantity"])
|
||||
|
||||
exportor = PortOrderExporter(True)
|
||||
|
||||
for i in range(5):
|
||||
exportor.add(Order(0, 0, 1, i + 1))
|
||||
|
||||
out_folder = tempfile.gettempdir()
|
||||
|
||||
exportor.dump(out_folder)
|
||||
|
||||
with open(f"{out_folder}/orders.csv") as fp:
|
||||
reader = csv.DictReader(fp)
|
||||
|
||||
row = 0
|
||||
for line in reader:
|
||||
self.assertEqual(row+1, int(line["quantity"]))
|
||||
|
||||
row += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -7,9 +7,10 @@ import unittest
|
|||
from maro.data_lib import BinaryConverter
|
||||
from maro.event_buffer import EventBuffer
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.citi_bike.business_engine import CitibikeBusinessEngine
|
||||
from maro.simulator.scenarios.citi_bike.business_engine import \
|
||||
CitibikeBusinessEngine
|
||||
from maro.simulator.scenarios.citi_bike.events import CitiBikeEvents
|
||||
from tests.utils import be_run_to_end, next_step
|
||||
from tests.utils import backends_to_test, be_run_to_end, next_step
|
||||
|
||||
|
||||
def setup_case(case_name: str, max_tick: int):
|
||||
|
@ -21,7 +22,8 @@ def setup_case(case_name: str, max_tick: int):
|
|||
trips_bin = os.path.join(config_path, "trips.bin")
|
||||
|
||||
if not os.path.exists(trips_bin):
|
||||
converter = BinaryConverter(trips_bin, os.path.join("tests/data/citi_bike", "trips.meta.yml"))
|
||||
converter = BinaryConverter(trips_bin, os.path.join(
|
||||
"tests/data/citi_bike", "trips.meta.yml"))
|
||||
|
||||
converter.add_csv(os.path.join(config_path, "trips.csv"))
|
||||
converter.flush()
|
||||
|
@ -30,122 +32,139 @@ def setup_case(case_name: str, max_tick: int):
|
|||
weathers_bin = os.path.join("tests/data/citi_bike", "weathers.bin")
|
||||
|
||||
if not os.path.exists(weathers_bin):
|
||||
converter = BinaryConverter(weathers_bin, os.path.join("tests/data/citi_bike", "weather.meta.yml"))
|
||||
converter = BinaryConverter(weathers_bin, os.path.join(
|
||||
"tests/data/citi_bike", "weather.meta.yml"))
|
||||
|
||||
converter.add_csv(os.path.join("tests/data/citi_bike", "weather.csv"))
|
||||
converter.flush()
|
||||
|
||||
eb = EventBuffer()
|
||||
be = CitibikeBusinessEngine(event_buffer=eb, topology=config_path, start_tick=0, max_tick=max_tick, snapshot_resolution=1, max_snapshots=None, additional_options={})
|
||||
be = CitibikeBusinessEngine(event_buffer=eb, topology=config_path, start_tick=0,
|
||||
max_tick=max_tick, snapshot_resolution=1, max_snapshots=None, additional_options={})
|
||||
|
||||
return eb, be
|
||||
|
||||
|
||||
class TestCitibike(unittest.TestCase):
|
||||
def test_trips_without_shortage(self):
|
||||
"""Normal case without shortage, case_1"""
|
||||
eb, be = setup_case("case_1", max_tick=10)
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
next_step(eb, be, 0)
|
||||
eb, be = setup_case("case_1", max_tick=10)
|
||||
|
||||
station_num = len(be.frame.stations)
|
||||
next_step(eb, be, 0)
|
||||
|
||||
station_0 = be.frame.stations[0]
|
||||
station_1 = be.frame.stations[1]
|
||||
station_num = len(be.frame.stations)
|
||||
|
||||
# check bikes at station 0, 1 should be moved
|
||||
self.assertEqual(4, station_0.bikes)
|
||||
self.assertEqual(10, station_1.bikes)
|
||||
station_0 = be.frame.stations[0]
|
||||
station_1 = be.frame.stations[1]
|
||||
|
||||
pending_evts = eb.get_pending_events(5)
|
||||
# check bikes at station 0, 1 should be moved
|
||||
self.assertEqual(4, station_0.bikes)
|
||||
self.assertEqual(10, station_1.bikes)
|
||||
|
||||
# check event in pending pool, there should be 1 returned event
|
||||
self.assertEqual(1, len(pending_evts))
|
||||
self.assertEqual(CitiBikeEvents.ReturnBike, pending_evts[0].event_type)
|
||||
pending_evts = eb.get_pending_events(5)
|
||||
|
||||
next_step(eb, be, 1)
|
||||
# check event in pending pool, there should be 1 returned event
|
||||
self.assertEqual(1, len(pending_evts))
|
||||
self.assertEqual(CitiBikeEvents.ReturnBike,
|
||||
pending_evts[0].event_type)
|
||||
|
||||
# station 0 and 1 have 1 trip
|
||||
self.assertEqual(3, station_0.bikes)
|
||||
self.assertEqual(9, station_1.bikes)
|
||||
next_step(eb, be, 1)
|
||||
|
||||
# no shortage
|
||||
self.assertEqual(0, station_0.shortage)
|
||||
self.assertEqual(0, station_1.shortage)
|
||||
# station 0 and 1 have 1 trip
|
||||
self.assertEqual(3, station_0.bikes)
|
||||
self.assertEqual(9, station_1.bikes)
|
||||
|
||||
# check if snapshot correct
|
||||
states = be.snapshots["stations"][::["shortage", "bikes", "fulfillment", "trip_requirement"]]
|
||||
# no shortage
|
||||
self.assertEqual(0, station_0.shortage)
|
||||
self.assertEqual(0, station_1.shortage)
|
||||
|
||||
# reshape by tick, attribute numbr and station number
|
||||
states = states.reshape(-1, station_num, 4)
|
||||
# check if snapshot correct
|
||||
states = be.snapshots["stations"][::[
|
||||
"shortage", "bikes", "fulfillment", "trip_requirement"]]
|
||||
|
||||
self.assertEqual(2, len(states))
|
||||
# reshape by tick, attribute numbr and station number
|
||||
states = states.reshape(-1, station_num, 4)
|
||||
|
||||
states_at_tick_0 = states[0]
|
||||
states_at_tick_1 = states[1]
|
||||
self.assertEqual(2, len(states))
|
||||
|
||||
# no shortage
|
||||
self.assertEqual(0, states_at_tick_0[:,0].sum())
|
||||
self.assertEqual(4+10, states_at_tick_0[:,1].sum())
|
||||
states_at_tick_0 = states[0]
|
||||
states_at_tick_1 = states[1]
|
||||
|
||||
# since no shortage, trips == fulfillments
|
||||
self.assertEqual(states_at_tick_0[:,2].sum(), states_at_tick_0[:,3].sum())
|
||||
# no shortage
|
||||
self.assertEqual(0, states_at_tick_0[:, 0].sum())
|
||||
self.assertEqual(4+10, states_at_tick_0[:, 1].sum())
|
||||
|
||||
#
|
||||
self.assertEqual(0, states_at_tick_1[:,0].sum())
|
||||
self.assertEqual(3+9, states_at_tick_1[:,1].sum())
|
||||
# since no shortage, trips == fulfillments
|
||||
self.assertEqual(
|
||||
states_at_tick_0[:, 2].sum(), states_at_tick_0[:, 3].sum())
|
||||
|
||||
self.assertEqual(states_at_tick_1[:,2].sum(), states_at_tick_1[:,3].sum())
|
||||
#
|
||||
self.assertEqual(0, states_at_tick_1[:, 0].sum())
|
||||
self.assertEqual(3+9, states_at_tick_1[:, 1].sum())
|
||||
|
||||
self.assertEqual(
|
||||
states_at_tick_1[:, 2].sum(), states_at_tick_1[:, 3].sum())
|
||||
|
||||
def test_trips_on_multiple_epsiode(self):
|
||||
"""Test if total trips of multiple episodes with same config are same"""
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
max_ep = 100
|
||||
max_ep = 100
|
||||
|
||||
eb, be = setup_case("case_1", max_tick=100)
|
||||
eb, be = setup_case("case_1", max_tick=100)
|
||||
|
||||
total_trips_list = []
|
||||
total_trips_list = []
|
||||
|
||||
for ep in range(max_ep):
|
||||
eb.reset()
|
||||
be.reset()
|
||||
for ep in range(max_ep):
|
||||
eb.reset()
|
||||
be.reset()
|
||||
|
||||
be_run_to_end(eb, be)
|
||||
be_run_to_end(eb, be)
|
||||
|
||||
total_trips = be.snapshots["stations"][::"trip_requirement"].sum()
|
||||
shortage_and_fulfillment = be.snapshots["stations"][::["shortage", "fulfillment"]].sum()
|
||||
total_trips = be.snapshots["stations"][::"trip_requirement"].sum(
|
||||
)
|
||||
shortage_and_fulfillment = be.snapshots["stations"][::[
|
||||
"shortage", "fulfillment"]].sum()
|
||||
|
||||
self.assertEqual(total_trips, shortage_and_fulfillment)
|
||||
self.assertEqual(total_trips, shortage_and_fulfillment)
|
||||
|
||||
total_trips_list.append(total_trips)
|
||||
|
||||
# if same with previous episodes
|
||||
self.assertEqual(total_trips_list[0], total_trips)
|
||||
total_trips_list.append(total_trips)
|
||||
|
||||
# if same with previous episodes
|
||||
self.assertEqual(total_trips_list[0], total_trips)
|
||||
|
||||
def test_trips_with_shortage(self):
|
||||
"""Test if shortage states correct"""
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
eb, be = setup_case("case_2", max_tick=5)
|
||||
eb, be = setup_case("case_2", max_tick=5)
|
||||
|
||||
stations_snapshots = be.snapshots["stations"]
|
||||
stations_snapshots = be.snapshots["stations"]
|
||||
|
||||
be_run_to_end(eb, be)
|
||||
be_run_to_end(eb, be)
|
||||
|
||||
states_at_tick_0 = stations_snapshots[0:0:["shortage", "bikes"]]
|
||||
states_at_tick_0 = stations_snapshots[0:0:[
|
||||
"shortage", "bikes"]].flatten()
|
||||
|
||||
shortage_at_tick_0 = states_at_tick_0[0]
|
||||
bikes_at_tick_0 = states_at_tick_0[1]
|
||||
shortage_at_tick_0 = states_at_tick_0[0]
|
||||
bikes_at_tick_0 = states_at_tick_0[1]
|
||||
|
||||
# there should be no shortage, and 4 left
|
||||
self.assertEqual(0, shortage_at_tick_0)
|
||||
self.assertEqual(4, bikes_at_tick_0)
|
||||
# there should be no shortage, and 4 left
|
||||
self.assertEqual(0, shortage_at_tick_0)
|
||||
self.assertEqual(4, bikes_at_tick_0)
|
||||
|
||||
# there should be 6 trips from 1st station, so there will be 2 shortage
|
||||
states_at_tick_1 = stations_snapshots[1:0:["shortage", "bikes", "trip_requirement"]]
|
||||
# there should be 6 trips from 1st station, so there will be 2 shortage
|
||||
states_at_tick_1 = stations_snapshots[1:0:[
|
||||
"shortage", "bikes", "trip_requirement"]].flatten()
|
||||
|
||||
self.assertEqual(2, states_at_tick_1[0])
|
||||
self.assertEqual(0, states_at_tick_1[1])
|
||||
self.assertEqual(6, states_at_tick_1[2])
|
||||
self.assertEqual(2, states_at_tick_1[0])
|
||||
self.assertEqual(0, states_at_tick_1[1])
|
||||
self.assertEqual(6, states_at_tick_1[2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,98 +1,120 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import os
|
||||
from time import time
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.cim.frame_builder import gen_cim_frame
|
||||
from termgraph import termgraph as tg
|
||||
|
||||
"""
|
||||
In this file we will test performance for frame, snapshotlist, and cim scenario, with following config
|
||||
from maro.backends.frame import (FrameBase, FrameNode, NodeAttribute, NodeBase,
|
||||
node)
|
||||
|
||||
1. ports: 100
|
||||
2. vessels: 100
|
||||
3. max_tick: 10000
|
||||
|
||||
"""
|
||||
|
||||
PORTS_NUMBER = 100
|
||||
VESSELS_NUMBER = 100
|
||||
NODE1_NUMBER = 100
|
||||
NODE2_NUMBER = 100
|
||||
MAX_TICK = 10000
|
||||
STOP_NUMBER = (6, 6)
|
||||
|
||||
READ_WRITE_NUMBER = 1000000
|
||||
STATES_QURING_TIME = 100000
|
||||
|
||||
def test_frame_only():
|
||||
READ_WRITE_NUMBER = 10000000
|
||||
STATES_QURING_TIME = 10000
|
||||
TAKE_SNAPSHOT_TIME = 10000
|
||||
|
||||
AVG_TIME = 4
|
||||
|
||||
|
||||
@node("node1")
|
||||
class TestNode1(NodeBase):
|
||||
a = NodeAttribute("i")
|
||||
b = NodeAttribute("i")
|
||||
c = NodeAttribute("i")
|
||||
d = NodeAttribute("i")
|
||||
e = NodeAttribute("i", 16)
|
||||
|
||||
|
||||
@node("node2")
|
||||
class TestNode2(NodeBase):
|
||||
b = NodeAttribute("i", 20)
|
||||
|
||||
|
||||
class TestFrame(FrameBase):
|
||||
node1 = FrameNode(TestNode1, NODE1_NUMBER)
|
||||
node2 = FrameNode(TestNode2, NODE2_NUMBER)
|
||||
|
||||
def __init__(self, backend_name):
|
||||
super().__init__(enable_snapshot=True,
|
||||
total_snapshot=TAKE_SNAPSHOT_TIME, backend_name=backend_name)
|
||||
|
||||
|
||||
def build_frame(backend_name: str):
|
||||
return TestFrame(backend_name)
|
||||
|
||||
|
||||
def attribute_access(frame, times: int):
|
||||
"""Return time cost (in seconds) for attribute acceesing test"""
|
||||
start_time = time()
|
||||
|
||||
frm = gen_cim_frame(PORTS_NUMBER, VESSELS_NUMBER, STOP_NUMBER, MAX_TICK)
|
||||
n1 = frame.node1[0]
|
||||
|
||||
static_node = frm.ports[0]
|
||||
for _ in range(times):
|
||||
a = n1.a
|
||||
n1.a = 12
|
||||
|
||||
# read & write one attribute N times with simplified interface
|
||||
for _ in range(READ_WRITE_NUMBER):
|
||||
static_node.a2 = 10
|
||||
a = static_node.a2
|
||||
|
||||
end_time = time()
|
||||
|
||||
print(f"node read & write {READ_WRITE_NUMBER} times: {end_time - start_time}")
|
||||
return time() - start_time
|
||||
|
||||
|
||||
def test_snapshot_list_only():
|
||||
frm = gen_cim_frame(PORTS_NUMBER, VESSELS_NUMBER, STOP_NUMBER, MAX_TICK)
|
||||
def take_snapshot(frame, times: int):
|
||||
"""Return times cost (in seconds) for take_snapshot operation"""
|
||||
|
||||
start_time = time()
|
||||
|
||||
# 1. take snapshot
|
||||
for i in range(MAX_TICK):
|
||||
frm.take_snapshot(i)
|
||||
for i in range(times):
|
||||
frame.take_snapshot(i)
|
||||
|
||||
end_time = time()
|
||||
|
||||
print(f"take {MAX_TICK} snapshot: {end_time - start_time}")
|
||||
return time() - start_time
|
||||
|
||||
|
||||
def test_states_quering():
|
||||
frm = gen_cim_frame(PORTS_NUMBER, VESSELS_NUMBER, STOP_NUMBER, MAX_TICK)
|
||||
frm.take_snapshot(0)
|
||||
def snapshot_query(frame, times: int):
|
||||
"""Return time cost (in seconds) for snapshot querying"""
|
||||
|
||||
start_time = time()
|
||||
|
||||
static_ss = frm.snapshots["ports"]
|
||||
for i in range(times):
|
||||
states = frame.snapshots["node1"][i::"a"]
|
||||
|
||||
for i in range(STATES_QURING_TIME):
|
||||
states = static_ss[::"empty"]
|
||||
|
||||
end_time = time()
|
||||
|
||||
print(f"Single state quering {STATES_QURING_TIME} times: {end_time - start_time}")
|
||||
|
||||
|
||||
def test_cim():
|
||||
eps = 4
|
||||
|
||||
env = Env("cim", "toy.5p_ssddd_l0.0", durations=MAX_TICK)
|
||||
|
||||
start_time = time()
|
||||
|
||||
for _ in range(eps):
|
||||
_, _, is_done = env.step(None)
|
||||
|
||||
while not is_done:
|
||||
_, _, is_done = env.step(None)
|
||||
|
||||
env.reset()
|
||||
|
||||
end_time = time()
|
||||
|
||||
print(f"cim 5p toplogy with {MAX_TICK} total time cost: {(end_time - start_time)/eps}")
|
||||
return time() - start_time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_frame_only()
|
||||
test_snapshot_list_only()
|
||||
test_states_quering()
|
||||
test_cim()
|
||||
chart_colors = [91, 94]
|
||||
|
||||
chart_args = {'filename': '-', 'title': "Performance comparison between cpp and np backends", 'width': 40,
|
||||
'format': '{:<5.2f}', 'suffix': '', 'no_labels': False,
|
||||
'color': None, 'vertical': False, 'stacked': False,
|
||||
'different_scale': False, 'calendar': False,
|
||||
'start_dt': None, 'custom_tick': '', 'delim': '',
|
||||
'verbose': False, 'version': False,
|
||||
'histogram': False, 'no_values': False}
|
||||
|
||||
chart_labels = [f'attribute accessing ({READ_WRITE_NUMBER})',
|
||||
f'take snapshot ({STATES_QURING_TIME})', f'states querying ({STATES_QURING_TIME})']
|
||||
|
||||
chart_data = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]
|
||||
|
||||
i = 0
|
||||
j = 0
|
||||
|
||||
for backend_name in ["static", "dynamic"]:
|
||||
frame = build_frame(backend_name)
|
||||
|
||||
j = 0
|
||||
|
||||
for func, args in [(attribute_access, READ_WRITE_NUMBER), (take_snapshot, TAKE_SNAPSHOT_TIME), (snapshot_query, STATES_QURING_TIME)]:
|
||||
t = func(frame, args)
|
||||
|
||||
chart_data[j][i] = t
|
||||
|
||||
j += 1
|
||||
|
||||
i += 1
|
||||
|
||||
tg.print_categories(['static', 'dynamic'], chart_colors)
|
||||
tg.chart(chart_colors, chart_data, chart_args, chart_labels)
|
||||
|
|
|
@ -16,6 +16,7 @@ azure-storage-common
|
|||
torch
|
||||
pytest
|
||||
coverage
|
||||
termgraph
|
||||
paramiko==2.7.2
|
||||
pytz==2019.3
|
||||
aria2p==0.9.1
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from dummy.dummy_business_engine import DummyEngine
|
||||
|
||||
from maro.simulator.utils import get_available_envs, get_scenarios, get_topologies
|
||||
from maro.simulator.utils.common import frame_index_to_ticks
|
||||
from maro.simulator.core import BusinessEngineNotFoundError, Env
|
||||
from tests.utils import backends_to_test
|
||||
|
||||
|
||||
def run_to_end(env: Env):
|
||||
|
@ -23,184 +28,234 @@ class TestEnv(unittest.TestCase):
|
|||
|
||||
def test_builtin_scenario_with_default_parameters(self):
|
||||
"""Test if the env with built-in scenario initializing correct"""
|
||||
max_tick = 10
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
env = Env(scenario="cim", topology="toy.5p_ssddd_l0.0", durations=max_tick)
|
||||
max_tick = 10
|
||||
|
||||
run_to_end(env)
|
||||
env = Env(scenario="cim", topology="toy.5p_ssddd_l0.0",
|
||||
durations=max_tick)
|
||||
|
||||
# check port number
|
||||
ports_number = len(env.snapshot_list["ports"])
|
||||
run_to_end(env)
|
||||
|
||||
self.assertEqual(ports_number, 5, msg=f"5pssddd topology should contains 5 ports, got {ports_number}")
|
||||
# check port number
|
||||
ports_number = len(env.snapshot_list["ports"])
|
||||
|
||||
self.assertEqual(
|
||||
ports_number, 5, msg=f"5pssddd topology should contains 5 ports, got {ports_number}")
|
||||
|
||||
def test_env_interfaces_with_specified_business_engine_cls(self):
|
||||
"""Test if env interfaces works as expect"""
|
||||
max_tick = 5
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
max_tick = 5
|
||||
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0, durations=max_tick)
|
||||
env = Env(business_engine_cls=DummyEngine,
|
||||
start_tick=0, durations=max_tick)
|
||||
|
||||
run_to_end(env)
|
||||
run_to_end(env)
|
||||
|
||||
# check if the snapshot number equals with max_tick
|
||||
# NOTE: the snapshot_resolution defaults to 1, so the number of snapshots is same with max_tick
|
||||
num_of_snapshots = len(env.snapshot_list)
|
||||
# check if the snapshot number equals with max_tick
|
||||
# NOTE: the snapshot_resolution defaults to 1, so the number of snapshots is same with max_tick
|
||||
num_of_snapshots = len(env.snapshot_list)
|
||||
|
||||
self.assertEqual(max_tick, len(env.snapshot_list), msg=f"number of snapshots ({num_of_snapshots}) should be same "
|
||||
f"with max tick ({max_tick}) without specified snapshot_resolution and max_snapshots")
|
||||
self.assertEqual(max_tick, len(env.snapshot_list), msg=f"number of snapshots ({num_of_snapshots}) should be same "
|
||||
f"with max tick ({max_tick}) without specified snapshot_resolution and max_snapshots")
|
||||
|
||||
# check if we can reach to the end [start_tick, max_tick)
|
||||
self.assertEqual(max_tick-1, env.tick)
|
||||
# check if we can reach to the end [start_tick, max_tick)
|
||||
self.assertEqual(max_tick-1, env.tick)
|
||||
|
||||
# check if frame_index
|
||||
# NOTE: since we have not specified snapshot_resolution, frame_index should same with tick
|
||||
self.assertEqual(env.tick, env.frame_index)
|
||||
# check if frame_index
|
||||
# NOTE: since we have not specified snapshot_resolution, frame_index should same with tick
|
||||
self.assertEqual(env.tick, env.frame_index)
|
||||
|
||||
# check if config is same as we defined
|
||||
self.assertDictEqual(env.configs, {"name":"dummy"}, msg="configs should same as defined")
|
||||
# check if config is same as we defined
|
||||
self.assertDictEqual(
|
||||
env.configs, {"name": "dummy"}, msg="configs should same as defined")
|
||||
|
||||
# check node information
|
||||
node_info = env.summary["node_detail"]
|
||||
# check node information
|
||||
node_info = env.summary["node_detail"]
|
||||
|
||||
# check node exist
|
||||
self.assertTrue("dummies" in node_info, msg="dummy engine should contains dummy node")
|
||||
# check node exist
|
||||
self.assertTrue("dummies" in node_info,
|
||||
msg="dummy engine should contains dummy node")
|
||||
|
||||
# check node number
|
||||
dummy_number = node_info["dummies"]["number"]
|
||||
# check node number
|
||||
dummy_number = node_info["dummies"]["number"]
|
||||
|
||||
self.assertEqual(10, dummy_number, msg=f"dummy should contains 10 nodes, got {dummy_number}")
|
||||
self.assertEqual(
|
||||
10, dummy_number, msg=f"dummy should contains 10 nodes, got {dummy_number}")
|
||||
|
||||
attributes = node_info["dummies"]["attributes"]
|
||||
attributes = node_info["dummies"]["attributes"]
|
||||
|
||||
# it will contains one attribute
|
||||
self.assertEqual(1, len(attributes), msg=f"dummy node should only contains 1 attribute, got {len(attributes)}")
|
||||
# it will contains one attribute
|
||||
self.assertEqual(1, len(
|
||||
attributes), msg=f"dummy node should only contains 1 attribute, got {len(attributes)}")
|
||||
|
||||
# and the attribute name is val
|
||||
self.assertTrue("val" in attributes)
|
||||
# and the attribute name is val
|
||||
self.assertTrue("val" in attributes)
|
||||
|
||||
# attribute type should be i
|
||||
val_dtype = attributes['val']["type"]
|
||||
# attribute type should be i
|
||||
val_dtype = attributes['val']["type"]
|
||||
|
||||
self.assertEqual("i", val_dtype, msg=f"dummy's val attribute should be int type, got {val_dtype}")
|
||||
self.assertEqual(
|
||||
"int", val_dtype, msg=f"dummy's val attribute should be int type, got {val_dtype}")
|
||||
|
||||
# val should have only one slot (default)
|
||||
val_slots = attributes['val']["slots"]
|
||||
# val should have only one slot (default)
|
||||
val_slots = attributes['val']["slots"]
|
||||
|
||||
self.assertEqual(1, val_slots, msg=f"dummy's val attribute should be int type, got {val_slots}")
|
||||
self.assertEqual(
|
||||
1, val_slots, msg=f"dummy's val attribute should be int type, got {val_slots}")
|
||||
|
||||
# agent list should be [0, dummy_number)
|
||||
self.assertListEqual(list(range(0, dummy_number)), env.agent_idx_list, msg=f"dummy engine should have {dummy_number} agents")
|
||||
# agent list should be [0, dummy_number)
|
||||
self.assertListEqual(list(range(0, dummy_number)), env.agent_idx_list,
|
||||
msg=f"dummy engine should have {dummy_number} agents")
|
||||
|
||||
# check if snapshot list available
|
||||
self.assertIsNotNone(
|
||||
env.snapshot_list, msg="snapshot list should be None")
|
||||
|
||||
# check if snapshot list available
|
||||
self.assertIsNotNone(env.snapshot_list, msg="snapshot list should be None")
|
||||
# reset should work
|
||||
|
||||
# reset should work
|
||||
dummies_ss = env.snapshot_list["dummies"]
|
||||
vals_before_reset = dummies_ss[env.frame_index::"val"]
|
||||
|
||||
dummies_ss = env.snapshot_list["dummies"]
|
||||
vals_before_reset = dummies_ss[env.frame_index::"val"]
|
||||
# before reset, snapshot should have value
|
||||
self.assertListEqual(list(vals_before_reset.flatten()), [
|
||||
env.tick]*dummy_number, msg=f"we should have val value same as last tick, got {vals_before_reset}")
|
||||
|
||||
# before reset, snapshot should have value
|
||||
self.assertListEqual(list(vals_before_reset), [env.tick]*dummy_number, msg=f"we should have val value same as last tick, got {vals_before_reset}")
|
||||
env.reset()
|
||||
|
||||
env.reset()
|
||||
# after reset, it should 0
|
||||
vals_after_reset = dummies_ss[env.frame_index::"val"]
|
||||
|
||||
# after reset, it should 0
|
||||
vals_after_reset = dummies_ss[env.frame_index::"val"]
|
||||
|
||||
self.assertListEqual(list(vals_after_reset), [0]*dummy_number, msg=f"we should have val value same as last tick, got {vals_after_reset}")
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(np.isnan(vals_after_reset).all())
|
||||
else:
|
||||
self.assertListEqual(list(vals_after_reset.flatten()), [
|
||||
0]*dummy_number, msg=f"we should have padding values")
|
||||
|
||||
def test_snapshot_resolution(self):
|
||||
"""Test env with snapshot_resolution, it should take snapshot every snapshot_resolution ticks"""
|
||||
max_tick = 10
|
||||
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0, durations=max_tick, snapshot_resolution=3)
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
run_to_end(env)
|
||||
max_tick = 10
|
||||
|
||||
# we should have 4 snapshots totally without max_snapshots speified
|
||||
self.assertEqual(4, len(env.snapshot_list), msg="We should have 4 snapshots in memory")
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0,
|
||||
durations=max_tick, snapshot_resolution=3)
|
||||
|
||||
# snapshot at 2, 5, 8, 9 ticks
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
run_to_end(env)
|
||||
|
||||
# NOTE: frame_index is the index of frame in snapshot list, it is 0 based, so snapshot resolution will make tick not equals to frame_index
|
||||
#
|
||||
for frame_index, tick in enumerate((2, 5, 8, 9)):
|
||||
self.assertListEqual(list(states[frame_index]), [tick] * 10, msg=f"states should be {tick}")
|
||||
# we should have 4 snapshots totally without max_snapshots speified
|
||||
self.assertEqual(4, len(env.snapshot_list),
|
||||
msg="We should have 4 snapshots in memory")
|
||||
|
||||
# snapshot at 2, 5, 8, 9 ticks
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
# NOTE: frame_index is the index of frame in snapshot list, it is 0 based, so snapshot resolution will make tick not equals to frame_index
|
||||
#
|
||||
for frame_index, tick in enumerate((2, 5, 8, 9)):
|
||||
self.assertListEqual(list(states[frame_index]), [
|
||||
tick] * 10, msg=f"states should be {tick}")
|
||||
|
||||
def test_max_snapshots(self):
|
||||
"""Test env with max_snapshots, it should take snapshot every tick, but should last N kept"""
|
||||
max_tick = 10
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0, durations=max_tick, max_snapshots=2)
|
||||
max_tick = 10
|
||||
|
||||
run_to_end(env)
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0,
|
||||
durations=max_tick, max_snapshots=2)
|
||||
|
||||
# we should have 2 snapshots totally with max_snapshots speified
|
||||
self.assertEqual(2, len(env.snapshot_list), msg="We should have 2 snapshots in memory")
|
||||
run_to_end(env)
|
||||
|
||||
# and only 87 and 9 in snapshot
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
# we should have 2 snapshots totally with max_snapshots speified
|
||||
self.assertEqual(2, len(env.snapshot_list),
|
||||
msg="We should have 2 snapshots in memory")
|
||||
|
||||
# 1st should states at tick 7
|
||||
self.assertListEqual(list(states[0]), [8] * 10, msg="1st snapshot should be at tick 8")
|
||||
# and only 87 and 9 in snapshot
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
# 2nd should states at tick 9
|
||||
self.assertListEqual(list(states[1]), [9] * 10, msg="2nd snapshot should be at tick 9")
|
||||
# 1st should states at tick 7
|
||||
self.assertListEqual(
|
||||
list(states[0]), [8] * 10, msg="1st snapshot should be at tick 8")
|
||||
|
||||
# 2nd should states at tick 9
|
||||
self.assertListEqual(
|
||||
list(states[1]), [9] * 10, msg="2nd snapshot should be at tick 9")
|
||||
|
||||
def test_snapshot_resolution_with_max_snapshots(self):
|
||||
"""Test env with both snapshot_resolution and max_snapshots parameters, and it should work as expected"""
|
||||
max_tick = 10
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0, durations=max_tick, snapshot_resolution=2, max_snapshots=2)
|
||||
max_tick = 10
|
||||
|
||||
run_to_end(env)
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0,
|
||||
durations=max_tick, snapshot_resolution=2, max_snapshots=2)
|
||||
|
||||
# we should have snapshot same as max_snapshots
|
||||
self.assertEqual(2, len(env.snapshot_list), msg="We should have 2 snapshots in memory")
|
||||
run_to_end(env)
|
||||
|
||||
# and only 7 and 9 in snapshot
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
# we should have snapshot same as max_snapshots
|
||||
self.assertEqual(2, len(env.snapshot_list),
|
||||
msg="We should have 2 snapshots in memory")
|
||||
|
||||
# 1st should states at tick 7
|
||||
self.assertListEqual(list(states[0]), [7] * 10, msg="1st snapshot should be at tick 7")
|
||||
# and only 7 and 9 in snapshot
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
# 2nd should states at tick 9
|
||||
self.assertListEqual(list(states[1]), [9] * 10, msg="2nd snapshot should be at tick 9")
|
||||
# 1st should states at tick 7
|
||||
self.assertListEqual(
|
||||
list(states[0]), [7] * 10, msg="1st snapshot should be at tick 7")
|
||||
|
||||
# 2nd should states at tick 9
|
||||
self.assertListEqual(
|
||||
list(states[1]), [9] * 10, msg="2nd snapshot should be at tick 9")
|
||||
|
||||
def test_early_stop(self):
|
||||
"""Test if we can stop at specified tick with early stop at post_step function"""
|
||||
max_tick = 10
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0, durations=max_tick,
|
||||
options={"post_step_early_stop": 6}) # early stop at tick 6, NOTE: simulator still
|
||||
max_tick = 10
|
||||
|
||||
run_to_end(env)
|
||||
env = Env(business_engine_cls=DummyEngine, start_tick=0, durations=max_tick,
|
||||
options={"post_step_early_stop": 6}) # early stop at tick 6, NOTE: simulator still
|
||||
|
||||
# the end tick of env should be 6 as specified
|
||||
self.assertEqual(6, env.tick, msg=f"env should stop at tick 6, but {env.tick}")
|
||||
run_to_end(env)
|
||||
|
||||
# avaiable snapshot should be 7 (0-6)
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
# the end tick of env should be 6 as specified
|
||||
self.assertEqual(
|
||||
6, env.tick, msg=f"env should stop at tick 6, but {env.tick}")
|
||||
|
||||
self.assertEqual(7, len(states), msg=f"available snapshot number should be 7, but {len(states)}")
|
||||
# avaiable snapshot should be 7 (0-6)
|
||||
states = env.snapshot_list["dummies"][::"val"].reshape(-1, 10)
|
||||
|
||||
# and last one should be 6
|
||||
self.assertListEqual(list(states[-1]), [6]*10, msg="last states should be 6")
|
||||
self.assertEqual(
|
||||
7, len(states), msg=f"available snapshot number should be 7, but {len(states)}")
|
||||
|
||||
# and last one should be 6
|
||||
self.assertListEqual(
|
||||
list(states[-1]), [6]*10, msg="last states should be 6")
|
||||
|
||||
def test_builtin_scenario_with_customized_topology(self):
|
||||
"""Test using built-in scenario with customized topology"""
|
||||
for backend_name in backends_to_test:
|
||||
os.environ["DEFAULT_BACKEND_NAME"] = backend_name
|
||||
|
||||
max_tick = 10
|
||||
max_tick = 10
|
||||
|
||||
env = Env(scenario="cim", topology="tests/data/cim/customized_config", start_tick=0, durations=max_tick)
|
||||
env = Env(scenario="cim", topology="tests/data/cim/customized_config",
|
||||
start_tick=0, durations=max_tick)
|
||||
|
||||
run_to_end(env)
|
||||
run_to_end(env)
|
||||
|
||||
# check if the config same as ours
|
||||
self.assertEqual([2], env.configs["container_volumes"], msg="customized container_volumes should be 2")
|
||||
# check if the config same as ours
|
||||
self.assertEqual([2], env.configs["container_volumes"],
|
||||
msg="customized container_volumes should be 2")
|
||||
|
||||
def test_invalid_scenario(self):
|
||||
"""Test specified invalid scenario"""
|
||||
|
@ -213,7 +268,30 @@ class TestEnv(unittest.TestCase):
|
|||
with self.assertRaises(FileNotFoundError) as ctx:
|
||||
env = Env("cim", "None", 100)
|
||||
|
||||
def test_get_avaiable_envs(self):
|
||||
scenario_names = get_scenarios()
|
||||
|
||||
# we have 2 built-in scenarios
|
||||
self.assertEqual(3, len(scenario_names))
|
||||
|
||||
self.assertTrue("cim" in scenario_names)
|
||||
self.assertTrue("citi_bike" in scenario_names)
|
||||
|
||||
cim_topoloies = get_topologies("cim")
|
||||
citi_bike_topologies = get_topologies("citi_bike")
|
||||
vm_topoloties = get_topologies("vm_scheduling")
|
||||
|
||||
env_list = get_available_envs()
|
||||
|
||||
self.assertEqual(len(env_list), len(cim_topoloies) + len(citi_bike_topologies) + len(vm_topoloties))
|
||||
|
||||
def test_frame_index_to_ticks(self):
|
||||
ticks = frame_index_to_ticks(0, 10, 2)
|
||||
|
||||
self.assertEqual(5, len(ticks))
|
||||
|
||||
self.assertListEqual([0, 1], ticks[0])
|
||||
self.assertListEqual([8, 9], ticks[4])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,160 +1,593 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import unittest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from maro.backends.frame import FrameBase, FrameNode, NodeAttribute, NodeBase, node
|
||||
from math import isnan
|
||||
|
||||
from maro.backends.backend import AttributeType
|
||||
from maro.backends.frame import (FrameBase, FrameNode, NodeAttribute, NodeBase,
|
||||
node)
|
||||
from maro.utils.exception.backends_exception import (
|
||||
BackendsArrayAttributeAccessException, BackendsGetItemInvalidException, BackendsSetItemInvalidException
|
||||
)
|
||||
BackendsArrayAttributeAccessException, BackendsGetItemInvalidException,
|
||||
BackendsSetItemInvalidException)
|
||||
from tests.utils import backends_to_test
|
||||
|
||||
STATIC_NODE_NUM = 5
|
||||
DYNAMIC_NODE_NUM = 10
|
||||
|
||||
|
||||
@node("static")
|
||||
class StaticNode(NodeBase):
|
||||
a1 = NodeAttribute("i", 2)
|
||||
a2 = NodeAttribute("i2")
|
||||
a3 = NodeAttribute("i8")
|
||||
a2 = NodeAttribute(AttributeType.Short)
|
||||
a3 = NodeAttribute(AttributeType.Long)
|
||||
|
||||
|
||||
@node("dynamic")
|
||||
class DynamicNode(NodeBase):
|
||||
b1 = NodeAttribute("f")
|
||||
b2 = NodeAttribute("d")
|
||||
b1 = NodeAttribute(AttributeType.Float)
|
||||
b2 = NodeAttribute(AttributeType.Double)
|
||||
|
||||
def build_frame(enable_snapshot:bool=False, total_snapshot:int=10):
|
||||
|
||||
def build_frame(enable_snapshot: bool = False, total_snapshot: int = 10, backend_name="static"):
|
||||
|
||||
class MyFrame(FrameBase):
|
||||
static_nodes = FrameNode(StaticNode, STATIC_NODE_NUM)
|
||||
dynamic_nodes = FrameNode(DynamicNode, DYNAMIC_NODE_NUM)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(enable_snapshot=enable_snapshot, total_snapshot=total_snapshot)
|
||||
super().__init__(enable_snapshot=enable_snapshot,
|
||||
total_snapshot=total_snapshot, backend_name=backend_name)
|
||||
|
||||
return MyFrame()
|
||||
|
||||
|
||||
class TestFrame(unittest.TestCase):
|
||||
def test_node_number(self):
|
||||
"""Test if node number same as defined"""
|
||||
frame = build_frame()
|
||||
for backend_name in backends_to_test:
|
||||
frame = build_frame(backend_name=backend_name)
|
||||
|
||||
self.assertEqual(STATIC_NODE_NUM, len(
|
||||
frame.static_nodes), backend_name)
|
||||
self.assertEqual(DYNAMIC_NODE_NUM, len(
|
||||
frame.dynamic_nodes), backend_name)
|
||||
|
||||
self.assertEqual(STATIC_NODE_NUM, len(frame.static_nodes))
|
||||
self.assertEqual(DYNAMIC_NODE_NUM, len(frame.dynamic_nodes))
|
||||
|
||||
def test_node_accessing(self):
|
||||
"""Test node accessing correct"""
|
||||
frame = build_frame()
|
||||
for backend_name in backends_to_test:
|
||||
|
||||
# accessing for 1st node for both static and dynamic node
|
||||
static_node: StaticNode = frame.static_nodes[0]
|
||||
dynamic_node: DynamicNode = frame.dynamic_nodes[0]
|
||||
frame = build_frame(backend_name=backend_name)
|
||||
|
||||
static_node.a2 = 10
|
||||
dynamic_node.b1 = 12.34
|
||||
# accessing for 1st node for both static and dynamic node
|
||||
static_node: StaticNode = frame.static_nodes[0]
|
||||
dynamic_node: DynamicNode = frame.dynamic_nodes[0]
|
||||
|
||||
self.assertEqual(10, static_node.a2, msg="a2 attribute should be 10 for 1st static node")
|
||||
self.assertAlmostEqual(12.34, dynamic_node.b1, 2, msg="b1 attribute should be 12.34 for 1st dynamic node")
|
||||
static_node.a2 = 10
|
||||
dynamic_node.b1 = 12.34
|
||||
|
||||
# check if values correct for multiple nodes
|
||||
for node in frame.static_nodes:
|
||||
node.a2 = node.index
|
||||
self.assertEqual(
|
||||
10, static_node.a2, msg="a2 attribute should be 10 for 1st static node")
|
||||
self.assertAlmostEqual(
|
||||
12.34, dynamic_node.b1, 2, msg="b1 attribute should be 12.34 for 1st dynamic node")
|
||||
|
||||
# check if the value correct
|
||||
for node in frame.static_nodes:
|
||||
self.assertEqual(node.index, node.a2, msg=f"static node.a2 should be {node.index}")
|
||||
# check if values correct for multiple nodes
|
||||
for node in frame.static_nodes:
|
||||
node.a2 = node.index
|
||||
|
||||
# check slice accessing
|
||||
static_node.a1[1] = 12
|
||||
static_node.a1[0] = 20
|
||||
# check if the value correct
|
||||
for node in frame.static_nodes:
|
||||
self.assertEqual(node.index, node.a2,
|
||||
msg=f"static node.a2 should be {node.index}")
|
||||
|
||||
self.assertListEqual([20, 12], list(static_node.a1[:]), msg="static node's a1 should be [20, 12]")
|
||||
self.assertEqual(20, static_node.a1[0], msg="1st slot of a1 should be 20")
|
||||
self.assertEqual(12, static_node.a1[1], msg="2nd slot of a1 should be 12")
|
||||
# check slice accessing
|
||||
static_node.a1[1] = 12
|
||||
static_node.a1[0] = 20
|
||||
|
||||
# set again with another way
|
||||
static_node.a1[(1, 0)] = (22, 11)
|
||||
self.assertListEqual([20, 12], list(
|
||||
static_node.a1[:]), msg="static node's a1 should be [20, 12]")
|
||||
self.assertEqual(
|
||||
20, static_node.a1[0], msg="1st slot of a1 should be 20")
|
||||
self.assertEqual(
|
||||
12, static_node.a1[1], msg="2nd slot of a1 should be 12")
|
||||
|
||||
self.assertListEqual([11, 22], list(static_node.a1[:]), msg="static node a1 should be [11, 22]")
|
||||
# set again with another way
|
||||
static_node.a1[(1, 0)] = (22, 11)
|
||||
|
||||
# another way
|
||||
# NOTE: additional value will be ignored
|
||||
static_node.a1[:] = (1, 2, 3)
|
||||
self.assertListEqual([11, 22], list(
|
||||
static_node.a1[:]), msg="static node a1 should be [11, 22]")
|
||||
|
||||
self.assertListEqual([1, 2], list(static_node.a1[:]), msg="static node a1 should be [1, 2")
|
||||
# another way
|
||||
# NOTE: additional value will be ignored
|
||||
static_node.a1[:] = (1, 2, 3)
|
||||
|
||||
self.assertListEqual([1, 2], list(
|
||||
static_node.a1[:]), msg="static node a1 should be [1, 2")
|
||||
|
||||
def test_invalid_node_accessing(self):
|
||||
for backend_name in backends_to_test:
|
||||
frm = build_frame(backend_name=backend_name)
|
||||
|
||||
frm = build_frame()
|
||||
static_node: StaticNode = frm.static_nodes[0]
|
||||
|
||||
static_node: StaticNode = frm.static_nodes[0]
|
||||
# get attribute value with not supported parameter
|
||||
with self.assertRaises(BackendsGetItemInvalidException) as ctx:
|
||||
a = static_node.a1["a"]
|
||||
|
||||
# get attribute value with not supported parameter
|
||||
with self.assertRaises(BackendsGetItemInvalidException) as ctx:
|
||||
a = static_node.a1["a"]
|
||||
with self.assertRaises(BackendsSetItemInvalidException) as ctx:
|
||||
static_node.a1["a"] = 1
|
||||
|
||||
with self.assertRaises(BackendsSetItemInvalidException) as ctx:
|
||||
static_node.a1["a"] = 1
|
||||
|
||||
with self.assertRaises(BackendsArrayAttributeAccessException) as ctx:
|
||||
static_node.a1 = 1
|
||||
with self.assertRaises(BackendsArrayAttributeAccessException) as ctx:
|
||||
static_node.a1 = 1
|
||||
|
||||
def test_get_node_info(self):
|
||||
"""Test if node information correct"""
|
||||
frm = build_frame()
|
||||
for backend_name in backends_to_test:
|
||||
"""Test if node information correct"""
|
||||
frm = build_frame(backend_name=backend_name)
|
||||
|
||||
node_info = frm.get_node_info()
|
||||
node_info = frm.get_node_info()
|
||||
|
||||
# if should contains 2 nodes
|
||||
self.assertTrue("static" in node_info)
|
||||
self.assertTrue("dynamic" in node_info)
|
||||
# if should contains 2 nodes
|
||||
self.assertTrue("static" in node_info)
|
||||
self.assertTrue("dynamic" in node_info)
|
||||
|
||||
# node number
|
||||
self.assertEqual(STATIC_NODE_NUM, node_info["static"]["number"])
|
||||
self.assertEqual(DYNAMIC_NODE_NUM, node_info["dynamic"]["number"])
|
||||
# node number
|
||||
self.assertEqual(STATIC_NODE_NUM, node_info["static"]["number"])
|
||||
self.assertEqual(DYNAMIC_NODE_NUM, node_info["dynamic"]["number"])
|
||||
|
||||
# check attributes
|
||||
self.assertTrue("a1" in node_info["static"]["attributes"])
|
||||
self.assertTrue("a2" in node_info["static"]["attributes"])
|
||||
self.assertTrue("a3" in node_info["static"]["attributes"])
|
||||
self.assertTrue("b1" in node_info["dynamic"]["attributes"])
|
||||
self.assertTrue("b2" in node_info["dynamic"]["attributes"])
|
||||
|
||||
# check slot number
|
||||
self.assertEqual(2, node_info["static"]["attributes"]["a1"]["slots"])
|
||||
self.assertEqual(1, node_info["static"]["attributes"]["a2"]["slots"])
|
||||
# check attributes
|
||||
self.assertTrue("a1" in node_info["static"]["attributes"])
|
||||
self.assertTrue("a2" in node_info["static"]["attributes"])
|
||||
self.assertTrue("a3" in node_info["static"]["attributes"])
|
||||
self.assertTrue("b1" in node_info["dynamic"]["attributes"])
|
||||
self.assertTrue("b2" in node_info["dynamic"]["attributes"])
|
||||
|
||||
# check slot number
|
||||
self.assertEqual(2, node_info["static"]
|
||||
["attributes"]["a1"]["slots"])
|
||||
self.assertEqual(1, node_info["static"]
|
||||
["attributes"]["a2"]["slots"])
|
||||
|
||||
def test_enable_snapshots(self):
|
||||
"""Test if snapshot enabled"""
|
||||
frame = build_frame(enable_snapshot=True)
|
||||
for backend_name in backends_to_test:
|
||||
"""Test if snapshot enabled"""
|
||||
frame = build_frame(enable_snapshot=True,
|
||||
backend_name=backend_name)
|
||||
|
||||
# snapshots should not be None
|
||||
self.assertIsNotNone(frame)
|
||||
# snapshots should not be None
|
||||
self.assertIsNotNone(frame)
|
||||
|
||||
# length should be 10
|
||||
self.assertEqual(10, len(frame.snapshots), msg="snapshot length should be 10")
|
||||
# length should be 0 before taking snapshot
|
||||
self.assertEqual(0, len(frame.snapshots),
|
||||
msg="snapshot length should be 0")
|
||||
|
||||
# another frame without snapshots enabled
|
||||
frame1 = build_frame()
|
||||
|
||||
self.assertIsNone(frame1.snapshots)
|
||||
# another frame without snapshots enabled
|
||||
frame1 = build_frame(backend_name=backend_name)
|
||||
|
||||
self.assertIsNone(frame1.snapshots)
|
||||
|
||||
def test_reset(self):
|
||||
"""Test reset work as expected, reset all attributes to 0"""
|
||||
frame = build_frame()
|
||||
for backend_name in backends_to_test:
|
||||
"""Test reset work as expected, reset all attributes to 0"""
|
||||
frame = build_frame(backend_name=backend_name)
|
||||
|
||||
frame.static_nodes[0].a1[:] = (1, 234)
|
||||
frame.static_nodes[0].a1[:] = (1, 234)
|
||||
|
||||
# before reset
|
||||
self.assertListEqual([1, 234], list(frame.static_nodes[0].a1[:]), msg="static node's a1 should be [1, 234] before reset")
|
||||
# before reset
|
||||
self.assertListEqual([1, 234], list(
|
||||
frame.static_nodes[0].a1[:]), msg="static node's a1 should be [1, 234] before reset")
|
||||
|
||||
frame.reset()
|
||||
|
||||
# after reset
|
||||
self.assertListEqual([0, 0], list(
|
||||
frame.static_nodes[0].a1[:]), msg="static node's a1 should be [0, 0] after reset")
|
||||
|
||||
def test_append_nodes(self):
|
||||
# NOTE: this case only support raw backend
|
||||
frame = build_frame(enable_snapshot=True,
|
||||
total_snapshot=10, backend_name="dynamic")
|
||||
|
||||
# set value for last static node
|
||||
last_static_node = frame.static_nodes[-1]
|
||||
|
||||
self.assertEqual(STATIC_NODE_NUM, len(frame.static_nodes))
|
||||
|
||||
last_static_node.a2 = 2
|
||||
last_static_node.a3 = 9
|
||||
|
||||
# this snapshot should keep 5 static nodes
|
||||
frame.take_snapshot(0)
|
||||
|
||||
# append 2 new node
|
||||
frame.append_node("static", 2)
|
||||
|
||||
# then there should be 2 new node instance
|
||||
self.assertEqual(STATIC_NODE_NUM + 2, len(frame.static_nodes))
|
||||
|
||||
# then index should keep sequentially
|
||||
for i in range(len(frame.static_nodes)):
|
||||
self.assertEqual(i, frame.static_nodes[i].index)
|
||||
|
||||
# value should be zero
|
||||
for node in frame.static_nodes[-2:]:
|
||||
self.assertEqual(0, node.a3)
|
||||
self.assertEqual(0, node.a2)
|
||||
self.assertEqual(0, node.a1[0])
|
||||
self.assertEqual(0, node.a1[1])
|
||||
|
||||
last_static_node.a3 = 12
|
||||
|
||||
# this snapshot should contains 7 static node
|
||||
frame.take_snapshot(1)
|
||||
|
||||
static_snapshot = frame.snapshots["static"]
|
||||
|
||||
# snapshot only provide current number (include delete ones)
|
||||
self.assertEqual(7, len(static_snapshot))
|
||||
|
||||
# query for 1st tick
|
||||
states = static_snapshot[0::"a3"]
|
||||
|
||||
# the query result of raw snapshotlist has 4 dim shape
|
||||
# (ticks, max nodes, attributes, max slots)
|
||||
self.assertTupleEqual(states.shape, (1, 7, 1, 1))
|
||||
|
||||
states = states.flatten()
|
||||
|
||||
# there should be 7 items, 5 for 5 nodes, 2 for padding as we do not provide node index to query,
|
||||
# snapshotlist will padding to max_number fo node
|
||||
self.assertEqual(7, len(states))
|
||||
self.assertListEqual([0.0, 0.0, 0.0, 0.0, 9.0], list(states)[0:5])
|
||||
|
||||
# 2 padding (NAN) in the end
|
||||
self.assertTrue(np.isnan(states[-2:]).all())
|
||||
|
||||
states = static_snapshot[1::"a3"]
|
||||
|
||||
self.assertTupleEqual(states.shape, (1, 7, 1, 1))
|
||||
|
||||
states = states.flatten()
|
||||
|
||||
self.assertEqual(7, len(states))
|
||||
|
||||
# no padding value
|
||||
self.assertListEqual(
|
||||
[0.0, 0.0, 0.0, 0.0, 12.0, 0.0, 0.0], list(states))
|
||||
|
||||
# with specify node indices, will not padding to max node number
|
||||
states = static_snapshot[0:[0, 1, 2, 3, 4]:"a3"]
|
||||
|
||||
self.assertTupleEqual(states.shape, (1, 5, 1, 1))
|
||||
|
||||
self.assertListEqual([0.0, 0.0, 0.0, 0.0, 9.0],
|
||||
list(states.flatten()[0:5]))
|
||||
|
||||
frame.snapshots.reset()
|
||||
frame.reset()
|
||||
|
||||
# after reset
|
||||
self.assertListEqual([0, 0], list(frame.static_nodes[0].a1[:]), msg="static node's a1 should be [0, 0] after reset")
|
||||
# node number will resume to origin one after reset
|
||||
self.assertEqual(STATIC_NODE_NUM, len(frame.static_nodes))
|
||||
|
||||
def test_delete_node(self):
|
||||
frame = build_frame(enable_snapshot=True,
|
||||
total_snapshot=10, backend_name="dynamic")
|
||||
|
||||
# set value for last static node
|
||||
last_static_node = frame.static_nodes[-1]
|
||||
second_static_node = frame.static_nodes[1]
|
||||
|
||||
self.assertEqual(STATIC_NODE_NUM, len(frame.static_nodes))
|
||||
|
||||
second_static_node.a3 = 444
|
||||
last_static_node.a2 = 2
|
||||
last_static_node.a3 = 9
|
||||
|
||||
# this snapshot should keep 5 static nodes
|
||||
frame.take_snapshot(0)
|
||||
|
||||
# delete 2nd node
|
||||
frame.delete_node(second_static_node)
|
||||
|
||||
last_static_node.a3 = 123
|
||||
|
||||
frame.take_snapshot(1)
|
||||
|
||||
# deleted node's instance will not be removed, just mark as deleted
|
||||
self.assertTrue(second_static_node.is_deleted)
|
||||
|
||||
# future setter will cause exception
|
||||
with self.assertRaises(Exception) as ctx:
|
||||
second_static_node.a3 = 11
|
||||
|
||||
# attribute getter failed too
|
||||
with self.assertRaises(Exception) as ctx:
|
||||
a = second_static_node.a3
|
||||
|
||||
static_snapshots = frame.snapshots["static"]
|
||||
|
||||
# snapshot will try to padding to max node number if not specify node indices
|
||||
states = static_snapshots[0::"a3"]
|
||||
|
||||
self.assertTupleEqual(states.shape, (1, 5, 1, 1))
|
||||
|
||||
states = states.flatten()
|
||||
|
||||
# no nan for 1st snapshot
|
||||
self.assertFalse(np.isnan(states).all())
|
||||
self.assertListEqual([0.0, 444.0, 0.0, 0.0, 9.0], list(states))
|
||||
|
||||
states = static_snapshots[1::"a3"]
|
||||
|
||||
self.assertTupleEqual(states.shape, (1, 5, 1, 1))
|
||||
|
||||
states = states.flatten()
|
||||
|
||||
# 2nd is padding value
|
||||
self.assertTrue(np.isnan(states[1]))
|
||||
|
||||
self.assertListEqual([0.0, 0.0, 0.0, 123.0],
|
||||
list(states[[0, 2, 3, 4]]))
|
||||
|
||||
# then we resume the deleted node, this mark it as not deleted, but values will be reset to 0
|
||||
frame.resume_node(second_static_node)
|
||||
|
||||
# DELETE node's value will not be reset after deleted
|
||||
self.assertEqual(444, second_static_node.a3)
|
||||
|
||||
second_static_node.a3 = 222
|
||||
|
||||
frame.take_snapshot(2)
|
||||
|
||||
states = static_snapshots[2::"a3"]
|
||||
|
||||
self.assertTupleEqual(states.shape, (1, 5, 1, 1))
|
||||
|
||||
states = states.flatten()
|
||||
|
||||
self.assertListEqual([0.0, 222.0, 0.0, 0.0, 123.0], list(states))
|
||||
|
||||
frame.snapshots.reset()
|
||||
frame.reset()
|
||||
|
||||
# node number will resume to origin one after reset
|
||||
self.assertEqual(STATIC_NODE_NUM, len(frame.static_nodes))
|
||||
|
||||
# and no nodes marked as deleted
|
||||
for node in frame.static_nodes:
|
||||
self.assertTrue(node.is_deleted == False)
|
||||
|
||||
def test_invalid_attribute_description(self):
|
||||
# we do not support const list attribute
|
||||
|
||||
@node("test")
|
||||
class TestNode(NodeBase):
|
||||
a1 = NodeAttribute("i", 2, is_const=True, is_list=True)
|
||||
|
||||
class TestFrame(FrameBase):
|
||||
test_nodes = FrameNode(TestNode, 1)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(enable_snapshot=True, total_snapshot=10, backend_name="dynamic")
|
||||
|
||||
with self.assertRaises(RuntimeError) as ctx:
|
||||
frame = TestFrame()
|
||||
|
||||
def test_query_const_attribute_without_taking_snapshot(self):
|
||||
@node("test")
|
||||
class TestNode(NodeBase):
|
||||
a1 = NodeAttribute("i", 2, is_const=True)
|
||||
|
||||
class TestFrame(FrameBase):
|
||||
test_nodes = FrameNode(TestNode, 2)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(enable_snapshot=True, total_snapshot=10, backend_name="dynamic")
|
||||
|
||||
frame = TestFrame()
|
||||
|
||||
t1 = frame.test_nodes[0]
|
||||
|
||||
t1.a1[0] = 10
|
||||
|
||||
t1_ss = frame.snapshots["test"]
|
||||
|
||||
# default snapshot length is 0
|
||||
self.assertEqual(0, len(frame.snapshots))
|
||||
|
||||
# we DO have to provide a tick to it for padding, as there is no snapshots there
|
||||
states = t1_ss[0::"a1"]
|
||||
|
||||
states = states.flatten()
|
||||
|
||||
self.assertListEqual([10.0, 0.0, 0.0, 0.0], list(states))
|
||||
|
||||
def test_list_attribute(self):
|
||||
@node("test")
|
||||
class TestNode(NodeBase):
|
||||
a1 = NodeAttribute("i", 1, is_list=True)
|
||||
a2 = NodeAttribute("i", 2, is_const=True)
|
||||
a3 = NodeAttribute("i")
|
||||
|
||||
class TestFrame(FrameBase):
|
||||
test_nodes = FrameNode(TestNode, 2)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(enable_snapshot=True, total_snapshot=10, backend_name="dynamic")
|
||||
|
||||
frame = TestFrame()
|
||||
|
||||
frame.take_snapshot(0)
|
||||
|
||||
n1 = frame.test_nodes[0]
|
||||
|
||||
n1.a2[:] = (2221, 2222)
|
||||
n1.a3 = 333
|
||||
|
||||
# slot number of list attribute is 0 by default
|
||||
# so get/set value by index will cause error
|
||||
|
||||
# append value to it
|
||||
n1.a1.append(10)
|
||||
n1.a1.append(11)
|
||||
n1.a1.append(12)
|
||||
|
||||
expected_value = [10, 11, 12]
|
||||
|
||||
# check if value set append correct
|
||||
self.assertListEqual(expected_value, n1.a1[:])
|
||||
|
||||
# Check if length correct
|
||||
self.assertEqual(3, len(n1.a1))
|
||||
|
||||
# For loop to go through all items in list
|
||||
for i, a_value in enumerate(n1.a1):
|
||||
self.assertEqual(expected_value[i], a_value)
|
||||
|
||||
frame.take_snapshot(1)
|
||||
|
||||
# resize it to 2
|
||||
n1.a1.resize(2)
|
||||
|
||||
# this will cause last value to be removed
|
||||
self.assertEqual(2, len(n1.a1))
|
||||
|
||||
self.assertListEqual([10, 11], n1.a1[:])
|
||||
|
||||
# exterd its size, then default value should be 0
|
||||
n1.a1.resize(5)
|
||||
|
||||
self.assertEqual(5, len(n1.a1))
|
||||
self.assertListEqual([10, 11, 0, 0, 0], n1.a1[:])
|
||||
|
||||
# clear will cause length be 0
|
||||
n1.a1.clear()
|
||||
|
||||
self.assertEqual(0, len(n1.a1))
|
||||
|
||||
# insert a new value to 0, as it is empty now
|
||||
n1.a1.insert(0, 10)
|
||||
|
||||
self.assertEqual(1, len(n1.a1))
|
||||
|
||||
self.assertEqual(10, n1.a1[0])
|
||||
|
||||
# [11, 10] after insert
|
||||
n1.a1.insert(0, 11)
|
||||
|
||||
# remove 2nd one
|
||||
n1.a1.remove(1)
|
||||
|
||||
self.assertEqual(1, len(n1.a1))
|
||||
|
||||
self.assertEqual(11, n1.a1[0])
|
||||
|
||||
# test if snapshot correct
|
||||
# NOTE: list attribute querying need to provide 1 attribute and 1 node index
|
||||
states = frame.snapshots["test"][0:0:"a1"]
|
||||
|
||||
# first tick a1 has no value, so states will be None
|
||||
self.assertIsNone(states)
|
||||
|
||||
states = frame.snapshots['test'][1:0:"a1"]
|
||||
states = states.flatten()
|
||||
|
||||
# a1 has 3 value at tick 1
|
||||
self.assertEqual(3, len(states))
|
||||
|
||||
self.assertListEqual([10, 11, 12], list(states))
|
||||
|
||||
# tick can be empty, then means get state for latest snapshot
|
||||
states = frame.snapshots["test"][:0:"a1"].flatten()
|
||||
|
||||
self.assertEqual(3, len(states))
|
||||
self.assertListEqual([10, 11, 12], list(states))
|
||||
|
||||
def test_list_attribute_with_large_size(self):
|
||||
@node("test")
|
||||
class TestNode(NodeBase):
|
||||
a1 = NodeAttribute("i", 1, is_list=True)
|
||||
|
||||
class TestFrame(FrameBase):
|
||||
test_nodes = FrameNode(TestNode, 2)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(backend_name="dynamic")
|
||||
|
||||
frame = TestFrame()
|
||||
|
||||
n1a1 = frame.test_nodes[0].a1
|
||||
|
||||
max_size = 200*10000
|
||||
|
||||
for i in range(max_size):
|
||||
n1a1.append(1)
|
||||
|
||||
print(len(n1a1))
|
||||
self.assertEqual(max_size, len(n1a1))
|
||||
|
||||
def test_list_attribute_invalid_index_access(self):
|
||||
@node("test")
|
||||
class TestNode(NodeBase):
|
||||
a1 = NodeAttribute("i", 1, is_list=True)
|
||||
|
||||
class TestFrame(FrameBase):
|
||||
test_nodes = FrameNode(TestNode, 2)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(backend_name="dynamic")
|
||||
|
||||
frame = TestFrame()
|
||||
|
||||
n1a1 = frame.test_nodes[0].a1
|
||||
|
||||
# default list attribute's size is 0, so index accessing will out of range
|
||||
with self.assertRaises(RuntimeError) as ctx:
|
||||
a = n1a1[0]
|
||||
|
||||
with self.assertRaises(RuntimeError) as ctx:
|
||||
n1a1.remove(0)
|
||||
|
||||
def test_frame_dump(self):
|
||||
frame = build_frame(enable_snapshot=True, total_snapshot=10, backend_name="dynamic")
|
||||
|
||||
frame.dump(".")
|
||||
|
||||
# there should be 2 output files
|
||||
self.assertTrue(os.path.exists("node_static.csv"))
|
||||
self.assertTrue(os.path.exists("node_dynamic.csv"))
|
||||
list_parser = lambda c: c if not c.startswith("[") else [float(n) for n in c.strip('[] ,').split(",")]
|
||||
|
||||
# a1 is a list
|
||||
static_df = pd.read_csv("node_static.csv", converters={"a1": list_parser})
|
||||
|
||||
# all value should be 0
|
||||
for i in range(STATIC_NODE_NUM):
|
||||
row = static_df.loc[i]
|
||||
|
||||
a1 = row["a1"]
|
||||
a2 = row["a2"]
|
||||
a3 = row["a3"]
|
||||
|
||||
self.assertEqual(2, len(a1))
|
||||
|
||||
self.assertListEqual([0.0, 0.0], a1)
|
||||
self.assertEqual(0, a2)
|
||||
self.assertEqual(0, a3)
|
||||
|
||||
frame.take_snapshot(0)
|
||||
|
||||
frame.take_snapshot(1)
|
||||
|
||||
frame.snapshots.dump(".")
|
||||
|
||||
self.assertTrue(os.path.exists("snapshots_dynamic.csv"))
|
||||
self.assertTrue(os.path.exists("snapshots_static.csv"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -4,167 +4,240 @@
|
|||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from test_frame import DYNAMIC_NODE_NUM, STATIC_NODE_NUM, build_frame
|
||||
|
||||
from tests.utils import backends_to_test
|
||||
|
||||
|
||||
class TestFrame(unittest.TestCase):
|
||||
def test_take_snapshot(self):
|
||||
"""Test if take_stapshot work"""
|
||||
frame = build_frame(True)
|
||||
for backend_name in backends_to_test:
|
||||
"""Test if take_stapshot work"""
|
||||
frame = build_frame(True, backend_name=backend_name)
|
||||
|
||||
# 1st static node
|
||||
static_node = frame.static_nodes[0]
|
||||
# 1st static node
|
||||
static_node = frame.static_nodes[0]
|
||||
|
||||
static_node.a1[:] = [1, 23]
|
||||
static_node.a1[:] = [1, 23]
|
||||
|
||||
frame.take_snapshot(0)
|
||||
frame.take_snapshot(0)
|
||||
|
||||
a1_at_tick_0 = frame.snapshots["static"][:0:"a1"]
|
||||
frame_index_list = frame.snapshots.get_frame_index_list()
|
||||
|
||||
# the value should be same with current
|
||||
self.assertListEqual(list(a1_at_tick_0.astype("i")), [1, 23], msg="1st static node's a1 should be [1, 23] at tick 0")
|
||||
# check if frame list correct
|
||||
# NOTE: since our resolution is 1 here, so tick==frame_index
|
||||
self.assertListEqual(frame_index_list, [0])
|
||||
|
||||
# test if the value in snapshot will be changed after change frame
|
||||
static_node.a1[1] = 123
|
||||
a1_at_tick_0 = frame.snapshots["static"][:0:"a1"]
|
||||
|
||||
a1_at_tick_0 = frame.snapshots["static"][:0:"a1"]
|
||||
# NOTE: we use flatten here, as raw backend's snapshotlist will have 4 dim result
|
||||
# the value should be same with current
|
||||
self.assertListEqual(list(a1_at_tick_0.flatten().astype("i")), [
|
||||
1, 23], msg="1st static node's a1 should be [1, 23] at tick 0")
|
||||
|
||||
self.assertListEqual(list(a1_at_tick_0.astype("i")), [1, 23], msg="1st static node's a1 should be [1, 23] at tick 0 even static node value changed")
|
||||
# test if the value in snapshot will be changed after change frame
|
||||
static_node.a1[1] = 123
|
||||
|
||||
a1_at_tick_0 = frame.snapshots["static"][:0:"a1"]
|
||||
|
||||
self.assertListEqual(list(a1_at_tick_0.flatten().astype("i")), [
|
||||
1, 23], msg="1st static node's a1 should be [1, 23] at tick 0 even static node value changed")
|
||||
|
||||
frame.take_snapshot(1)
|
||||
|
||||
frame_index_list = frame.snapshots.get_frame_index_list()
|
||||
|
||||
self.assertListEqual(frame_index_list, [0, 1])
|
||||
|
||||
def test_slice_quering(self):
|
||||
"""Test if states quering result correct"""
|
||||
frame = build_frame(True, total_snapshot=2)
|
||||
for backend_name in backends_to_test:
|
||||
"""Test if states quering result correct"""
|
||||
frame = build_frame(True, total_snapshot=2,
|
||||
backend_name=backend_name)
|
||||
|
||||
# one node changes
|
||||
static_node = frame.static_nodes[0]
|
||||
# one node changes
|
||||
static_node = frame.static_nodes[0]
|
||||
|
||||
static_node.a2 = 1
|
||||
static_node.a2 = 1
|
||||
|
||||
# before takeing snapshot, states should be 0
|
||||
static_node_a2_states = frame.snapshots["static"][0:0:"a2"]
|
||||
# before takeing snapshot, states should be 0
|
||||
static_node_a2_states = frame.snapshots["static"][0:0:"a2"]
|
||||
|
||||
self.assertEqual(1, len(static_node_a2_states), msg="slicing with 1 tick, 1 node and 1 attr, should return array with 1 result")
|
||||
self.assertEqual(0, static_node_a2_states.astype("i")[0], msg="states before taking snashot should be 0")
|
||||
self.assertEqual(1, len(static_node_a2_states),
|
||||
msg="slicing with 1 tick, 1 node and 1 attr, should return array with 1 result")
|
||||
|
||||
frame.take_snapshot(0)
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(np.isnan(static_node_a2_states).all())
|
||||
else:
|
||||
self.assertEqual(0, static_node_a2_states.astype(
|
||||
"i")[0], msg="states before taking snapshot should be 0")
|
||||
|
||||
# set a2 and a3 for all static nodes
|
||||
for i, node in enumerate(frame.static_nodes):
|
||||
node.a3 = 100 * i
|
||||
node.a2 = 100 * i + 1
|
||||
frame.take_snapshot(0)
|
||||
|
||||
# take snapshot
|
||||
frame.take_snapshot(1)
|
||||
# set a2 and a3 for all static nodes
|
||||
for i, node in enumerate(frame.static_nodes):
|
||||
node.a3 = 100 * i
|
||||
node.a2 = 100 * i + 1
|
||||
|
||||
# query with 2 attributes
|
||||
states = frame.snapshots["static"][1::["a3", "a2"]]
|
||||
# take snapshot
|
||||
frame.take_snapshot(1)
|
||||
|
||||
# with this quering, the result should be like
|
||||
# so we can reshape it as
|
||||
states = states.reshape(len(frame.static_nodes), 2)
|
||||
# query with 2 attributes
|
||||
states = frame.snapshots["static"][1::["a3", "a2"]]
|
||||
|
||||
# then 1st should a3 value for all static node
|
||||
# 2nd should be a2 value for all static node
|
||||
self.assertListEqual(list(states[:, 0].astype("i")), [100 * i for i in range(len(frame.static_nodes))], msg="1st row of states should be a3 value")
|
||||
self.assertListEqual(list(states[:, 1].astype("i")), [100 * i + 1 for i in range(len(frame.static_nodes))], msg="2nd row of states should be a2 value")
|
||||
# with this quering, the result should be like
|
||||
# so we can reshape it as
|
||||
states = states.reshape(len(frame.static_nodes), 2)
|
||||
|
||||
# quering without tick, means return all ticks in snapshot
|
||||
states = frame.snapshots["static"][:0:"a2"]
|
||||
# then 1st should a3 value for all static node
|
||||
# 2nd should be a2 value for all static node
|
||||
self.assertListEqual(list(states[:, 0].astype("i")), [
|
||||
100 * i for i in range(len(frame.static_nodes))], msg="1st row of states should be a3 value")
|
||||
self.assertListEqual(list(states[:, 1].astype("i")), [100 * i + 1 for i in range(
|
||||
len(frame.static_nodes))], msg="2nd row of states should be a2 value")
|
||||
|
||||
# reshape it as 2-dim, so row is tick
|
||||
states = states.reshape(2, -1).astype("i")
|
||||
# quering without tick, means return all ticks in snapshot
|
||||
states = frame.snapshots["static"][:0:"a2"]
|
||||
|
||||
# then each row is a2 value for 1st static node at that tick
|
||||
self.assertEqual(1, len(states[0]), msg="1st static should contains 1 a2 value at tick 0")
|
||||
# reshape it as 2-dim, so row is tick
|
||||
states = states.reshape(2, -1).astype("i")
|
||||
|
||||
self.assertEqual(1, states[0], msg="1st static node a2 value should be 1 at tick 0")
|
||||
# then each row is a2 value for 1st static node at that tick
|
||||
self.assertEqual(
|
||||
1, len(states[0]), msg="1st static should contains 1 a2 value at tick 0")
|
||||
|
||||
self.assertEqual(1, states[1], msg="1st staic node a2 value should be 1 at tick 1")
|
||||
self.assertEqual(
|
||||
1, states[0], msg="1st static node a2 value should be 1 at tick 0")
|
||||
|
||||
# quering without node index, means return attributes of all the nodes
|
||||
states = frame.snapshots["static"][1::"a2"]
|
||||
self.assertEqual(
|
||||
1, states[1], msg="1st staic node a2 value should be 1 at tick 1")
|
||||
|
||||
self.assertEqual(len(frame.static_nodes), len(states), msg="1 tick 1 attribute and not specified ticks, should return array length same as node number")
|
||||
self.assertListEqual(list(states.astype("i")), [100 * i + 1 for i in range(len(frame.static_nodes))], msg="a2 at 1st row should be values at tick 1")
|
||||
# quering without node index, means return attributes of all the nodes
|
||||
states = frame.snapshots["static"][1::"a2"]
|
||||
|
||||
# when reach the max size of snapshot, oldest one will be overwrite
|
||||
static_node.a2 = 1000
|
||||
# NOTE: dynamic backend have shape
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(len(states), len(frame.static_nodes))
|
||||
else:
|
||||
self.assertEqual(len(frame.static_nodes),
|
||||
len(states), msg="1 tick 1 attribute and not specified ticks, should return array length same as node number")
|
||||
|
||||
frame.take_snapshot(2)
|
||||
self.assertListEqual(list(states.flatten().astype("i")), [100 * i + 1 for i in range(
|
||||
len(frame.static_nodes))], msg="a2 at 1st row should be values at tick 1")
|
||||
|
||||
# check if current snapshot max size correct
|
||||
self.assertEqual(2, len(frame.snapshots), msg="snapshot list max size should be 2")
|
||||
# when reach the max size of snapshot, oldest one will be overwrite
|
||||
static_node.a2 = 1000
|
||||
|
||||
# and result without ticks should return 2 row: 2*len(static_nodes)
|
||||
states = frame.snapshots["static"][::"a2"]
|
||||
states = states.reshape(-1, len(frame.static_nodes))
|
||||
frame.take_snapshot(2)
|
||||
|
||||
self.assertEqual(2, len(states), msg="states should contains 2 row")
|
||||
# check if current snapshot max size correct
|
||||
self.assertEqual(2, len(frame.snapshots),
|
||||
msg="snapshot list max size should be 2")
|
||||
|
||||
# 1st row should be values at tick 1
|
||||
self.assertListEqual(list(states[0].astype("i")), [100 * i + 1 for i in range(len(frame.static_nodes))], msg="a2 at tick 1 for all nodes should be correct")
|
||||
|
||||
# 2nd row should be lastest one
|
||||
self.assertEqual(1000, states[1][0], msg="a2 for 1st static node for 2nd row should be 1000")
|
||||
# and result without ticks should return 2 row: 2*len(static_nodes)
|
||||
states = frame.snapshots["static"][::"a2"]
|
||||
|
||||
# quering with ticks that being over-wrote, should return 0 for that tick
|
||||
states = frame.snapshots["static"][(0, 1, 2)::"a2"]
|
||||
states = states.reshape(-1, len(frame.static_nodes))
|
||||
states = states.reshape(-1, len(frame.static_nodes))
|
||||
|
||||
self.assertEqual(3, len(states), msg="states should contains 3 row")
|
||||
self.assertListEqual([0]*len(frame.static_nodes), list(states[0].astype("i")), msg="over-wrote tick should return 0")
|
||||
self.assertListEqual(list(states[1].astype("i")), [100 * i + 1 for i in range(len(frame.static_nodes))], msg="a2 at tick 1 for all nodes should be correct")
|
||||
self.assertEqual(1000, states[2][0], msg="a2 for 1st static node for 2nd row should be 1000")
|
||||
self.assertEqual(
|
||||
2, len(states), msg="states should contains 2 row")
|
||||
|
||||
# 1st row should be values at tick 1
|
||||
self.assertListEqual(list(states[0].astype("i")), [100 * i + 1 for i in range(
|
||||
len(frame.static_nodes))], msg="a2 at tick 1 for all nodes should be correct")
|
||||
|
||||
# 2nd row should be lastest one
|
||||
self.assertEqual(
|
||||
1000, states[1][0], msg="a2 for 1st static node for 2nd row should be 1000")
|
||||
|
||||
# quering with ticks that being over-wrote, should return 0 for that tick
|
||||
states = frame.snapshots["static"][(0, 1, 2)::"a2"]
|
||||
states = states.reshape(-1, len(frame.static_nodes))
|
||||
|
||||
self.assertEqual(
|
||||
3, len(states), msg="states should contains 3 row")
|
||||
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(np.isnan(states[0]).all())
|
||||
else:
|
||||
self.assertListEqual([0]*len(frame.static_nodes),
|
||||
list(states[0].astype("i")), msg="over-wrote tick should return 0")
|
||||
|
||||
self.assertListEqual(list(states[1].astype("i")), [100 * i + 1 for i in range(
|
||||
len(frame.static_nodes))], msg="a2 at tick 1 for all nodes should be correct")
|
||||
|
||||
self.assertEqual(
|
||||
1000, states[2][0], msg="a2 for 1st static node for 2nd row should be 1000")
|
||||
|
||||
frame_index_list = frame.snapshots.get_frame_index_list()
|
||||
|
||||
self.assertListEqual(frame_index_list, [1, 2])
|
||||
|
||||
def test_snapshot_length(self):
|
||||
"""Test __len__ function result"""
|
||||
for backend_name in backends_to_test:
|
||||
frm = build_frame(True, total_snapshot=10,
|
||||
backend_name=backend_name)
|
||||
|
||||
frm = build_frame(True, total_snapshot=10)
|
||||
|
||||
self.assertEqual(10, len(frm.snapshots))
|
||||
|
||||
self.assertEqual(0, len(frm.snapshots))
|
||||
|
||||
def test_snapshot_node_length(self):
|
||||
"""Test if node number in snapshot correct"""
|
||||
for backend_name in backends_to_test:
|
||||
frm = build_frame(True, backend_name=backend_name)
|
||||
|
||||
frm = build_frame(True)
|
||||
|
||||
self.assertEqual(STATIC_NODE_NUM, len(frm.snapshots["static"]))
|
||||
self.assertEqual(DYNAMIC_NODE_NUM, len(frm.snapshots["dynamic"]))
|
||||
self.assertEqual(STATIC_NODE_NUM, len(frm.snapshots["static"]))
|
||||
self.assertEqual(DYNAMIC_NODE_NUM, len(frm.snapshots["dynamic"]))
|
||||
|
||||
def test_quering_with_not_exist_indices(self):
|
||||
# NOTE: when quering with not exist indices, snapshot will try to fill the result of that index with 0
|
||||
for backend_name in backends_to_test:
|
||||
frm = build_frame(True, backend_name=backend_name)
|
||||
|
||||
frm = build_frame(True)
|
||||
for node in frm.static_nodes:
|
||||
node.a2 = node.index
|
||||
|
||||
for node in frm.static_nodes:
|
||||
node.a2 = node.index
|
||||
frm.take_snapshot(0)
|
||||
|
||||
frm.take_snapshot(0)
|
||||
# with 1 invalid index, all should be 0
|
||||
states = frm.snapshots["static"][1::"a2"]
|
||||
|
||||
# with 1 invalid index, all should be 0
|
||||
states = frm.snapshots["static"][1::"a2"]
|
||||
# NOTE: raw backend will padding with nan while numpy padding with 0
|
||||
if backend_name == "dynamic":
|
||||
# all should be nan
|
||||
self.assertTrue(np.isnan(states).all())
|
||||
else:
|
||||
self.assertListEqual(list(states.astype("I")), [
|
||||
0]*STATIC_NODE_NUM)
|
||||
|
||||
self.assertListEqual(list(states.astype("I")), [0]*STATIC_NODE_NUM)
|
||||
# with 1 invalid index, one valid index
|
||||
states = frm.snapshots["static"][(0, 1)::"a2"]
|
||||
|
||||
# with 1 invalid index, one valid index
|
||||
states = frm.snapshots["static"][(0, 1)::"a2"]
|
||||
states = states.reshape(-1, STATIC_NODE_NUM)
|
||||
# NOTE: this reshape for raw backend will get 2 dim array, each for one tick
|
||||
states = states.reshape(-1, STATIC_NODE_NUM)
|
||||
|
||||
# index 0 should be same with out current value
|
||||
self.assertListEqual(list(states[0].astype("i")), [i for i in range(STATIC_NODE_NUM)])
|
||||
self.assertListEqual(list(states[1].astype("i")), [0]*STATIC_NODE_NUM)
|
||||
# index 0 should be same with current value
|
||||
self.assertListEqual(list(states[0].astype("i")), [
|
||||
i for i in range(STATIC_NODE_NUM)])
|
||||
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(np.isnan(states[1]).all())
|
||||
else:
|
||||
self.assertListEqual(list(states[1].astype("i")), [
|
||||
0]*STATIC_NODE_NUM)
|
||||
|
||||
def test_get_attribute_with_undefined_attribute(self):
|
||||
frm = build_frame(True)
|
||||
frm.take_snapshot(0)
|
||||
for backend_name in backends_to_test:
|
||||
frm = build_frame(True, backend_name=backend_name)
|
||||
frm.take_snapshot(0)
|
||||
|
||||
# not exist attribute name
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
states = frm.snapshots["static"][::"a8"]
|
||||
# not exist attribute name
|
||||
with self.assertRaises(Exception) as ctx:
|
||||
states = frm.snapshots["static"][::"a8"]
|
||||
|
||||
# not exist node name
|
||||
self.assertIsNone(frm.snapshots["hehe"])
|
||||
# not exist node name
|
||||
self.assertIsNone(frm.snapshots["hehe"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from maro.rl.utils.trajectory_utils import get_k_step_returns, get_lambda_returns
|
||||
|
||||
|
||||
class TestTrajectoryUtils(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.rewards = np.asarray([3, 2, 4, 1, 5])
|
||||
self.values = np.asarray([4, 7, 1, 3, 6])
|
||||
self.lam = 0.6
|
||||
self.discount = 0.8
|
||||
self.k = 4
|
||||
|
||||
def test_k_step_return(self):
|
||||
returns = get_k_step_returns(self.rewards, self.values, self.discount, k=self.k)
|
||||
expected = np.asarray([10.1296, 8.912, 8.64, 5.8, 6.0])
|
||||
np.testing.assert_allclose(returns, expected, rtol=1e-4)
|
||||
|
||||
def test_lambda_return(self):
|
||||
returns = get_lambda_returns(self.rewards, self.values, self.discount, self.lam, k=self.k)
|
||||
expected = np.asarray([8.1378176, 6.03712, 7.744, 5.8, 6.0])
|
||||
np.testing.assert_allclose(returns, expected, rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -5,6 +5,8 @@ from maro.event_buffer import EventBuffer, EventState
|
|||
from maro.simulator.scenarios import AbsBusinessEngine
|
||||
|
||||
|
||||
backends_to_test = ["static", "dynamic"]
|
||||
|
||||
def next_step(eb: EventBuffer, be: AbsBusinessEngine, tick: int):
|
||||
if tick > 0:
|
||||
# lets post process last tick first before start a new tick
|
||||
|
|
Загрузка…
Ссылка в новой задаче