Format code on `master` (black + isort) (#538)
* Config files * Add autoflake * Update isort exclude; add pre-commit to requirements * Manually fix a few bad cases
This commit is contained in:
Родитель
17ad48928c
Коммит
7e3c1d5893
|
@ -0,0 +1,7 @@
|
|||
[tool.black]
|
||||
line-length = 120
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 120
|
||||
known_first_party = "maro"
|
|
@ -5,7 +5,9 @@ ignore =
|
|||
# line break after binary operator
|
||||
W504,
|
||||
# line break before binary operator
|
||||
W503
|
||||
W503,
|
||||
# whitespace before ':'
|
||||
E203
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
|
@ -27,14 +29,5 @@ max-line-length = 120
|
|||
per-file-ignores =
|
||||
# import not used: ignore in __init__.py files
|
||||
__init__.py:F401
|
||||
# igore invalid escape sequence in cli main script to show banner
|
||||
# ignore invalid escape sequence in cli main script to show banner
|
||||
maro.py:W605
|
||||
|
||||
[isort]
|
||||
indent = " "
|
||||
line_length = 120
|
||||
use_parentheses = True
|
||||
multi_line_output = 6
|
||||
known_first_party = maro
|
||||
filter_files = True
|
||||
skip_glob = maro/__init__.py, tests/*, examples/*, setup.py
|
||||
|
|
|
@ -45,12 +45,12 @@ jobs:
|
|||
uses: github/super-linter@latest
|
||||
env:
|
||||
VALIDATE_ALL_CODEBASE: false
|
||||
VALIDATE_PYTHON_PYLINT: false # disable pylint, as we have not configure it
|
||||
VALIDATE_PYTHON_BLACK: false # same as above
|
||||
VALIDATE_PYTHON_PYLINT: false # disable pylint, as we have not configured it
|
||||
VALIDATE_PYTHON_MYPY: false # same as above
|
||||
VALIDATE_JSCPD: false # Can not exclude specific file: https://github.com/kucherenko/jscpd/issues/215
|
||||
PYTHON_FLAKE8_CONFIG_FILE: tox.ini
|
||||
PYTHON_ISORT_CONFIG_FILE: tox.ini
|
||||
PYTHON_BLACK_CONFIG_FILE: pyproject.toml
|
||||
PYTHON_ISORT_CONFIG_FILE: pyproject.toml
|
||||
EDITORCONFIG_FILE_NAME: ../../.editorconfig
|
||||
FILTER_REGEX_INCLUDE: maro/.*
|
||||
FILTER_REGEX_EXCLUDE: tests/.*
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.4
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args:
|
||||
- --in-place
|
||||
- --remove-unused-variables
|
||||
- --remove-all-unused-imports
|
||||
exclude: .*/__init__\.py|setup\.py
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
- --settings-path=.github/linters/pyproject.toml
|
||||
- --check
|
||||
- repo: https://github.com/asottile/add-trailing-comma
|
||||
rev: v2.2.3
|
||||
hooks:
|
||||
- id: add-trailing-comma
|
||||
name: add-trailing-comma (1st round)
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
name: black (1st round)
|
||||
args:
|
||||
- --config=.github/linters/pyproject.toml
|
||||
- repo: https://github.com/asottile/add-trailing-comma
|
||||
rev: v2.2.3
|
||||
hooks:
|
||||
- id: add-trailing-comma
|
||||
name: add-trailing-comma (2nd round)
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
name: black (2nd round)
|
||||
args:
|
||||
- --config=.github/linters/pyproject.toml
|
||||
- repo: https://gitlab.com/pycqa/flake8
|
||||
rev: 3.7.9
|
||||
hooks:
|
||||
- id: flake8
|
||||
args:
|
||||
- --config=.github/linters/tox.ini
|
||||
exclude: \.git|__pycache__|docs|build|dist|.*\.egg-info|docker_files|\.vscode|\.github|scripts|tests|maro\/backends\/.*.cp|setup.py
|
|
@ -65,7 +65,7 @@ else:
|
|||
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
pygments_style = "sphinx"
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = False
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
|
||||
from maro.rl.policy import DiscretePolicyGradient
|
||||
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
|
||||
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
|
@ -58,10 +56,10 @@ def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
|
|||
name=name,
|
||||
params=ActorCriticParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
|
||||
reward_discount=.0,
|
||||
reward_discount=0.0,
|
||||
grad_iters=10,
|
||||
critic_loss_cls=torch.nn.SmoothL1Loss,
|
||||
min_logp=None,
|
||||
lam=.0,
|
||||
lam=0.0,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.optim import RMSprop
|
||||
|
||||
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
|
||||
from maro.rl.model import DiscreteQNet, FullyConnected
|
||||
from maro.rl.policy import ValueBasedPolicy
|
||||
from maro.rl.training.algorithms import DQNTrainer, DQNParams
|
||||
from maro.rl.training.algorithms import DQNParams, DQNTrainer
|
||||
|
||||
q_net_conf = {
|
||||
"hidden_dims": [256, 128, 64, 32],
|
||||
|
@ -38,14 +36,18 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli
|
|||
name=name,
|
||||
q_net=MyQNet(state_dim, action_num),
|
||||
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
|
||||
exploration_scheduling_options=[(
|
||||
"epsilon", MultiLinearExplorationScheduler, {
|
||||
"splits": [(2, 0.32)],
|
||||
"initial_value": 0.4,
|
||||
"last_ep": 5,
|
||||
"final_value": 0.0,
|
||||
}
|
||||
)],
|
||||
exploration_scheduling_options=[
|
||||
(
|
||||
"epsilon",
|
||||
MultiLinearExplorationScheduler,
|
||||
{
|
||||
"splits": [(2, 0.32)],
|
||||
"initial_value": 0.4,
|
||||
"last_ep": 5,
|
||||
"final_value": 0.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
warmup=100,
|
||||
)
|
||||
|
||||
|
@ -54,7 +56,7 @@ def get_dqn(name: str) -> DQNTrainer:
|
|||
return DQNTrainer(
|
||||
name=name,
|
||||
params=DQNParams(
|
||||
reward_discount=.0,
|
||||
reward_discount=0.0,
|
||||
update_target_every=5,
|
||||
num_epochs=10,
|
||||
soft_update_coef=0.1,
|
||||
|
|
|
@ -2,22 +2,21 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from maro.rl.model import DiscreteACBasedNet, FullyConnected, MultiQNet
|
||||
from maro.rl.policy import DiscretePolicyGradient
|
||||
from maro.rl.training.algorithms import DiscreteMADDPGTrainer, DiscreteMADDPGParams
|
||||
|
||||
from maro.rl.training.algorithms import DiscreteMADDPGParams, DiscreteMADDPGTrainer
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
"activation": torch.nn.Tanh,
|
||||
"softmax": True,
|
||||
"batch_norm": False,
|
||||
"head": True
|
||||
"head": True,
|
||||
}
|
||||
critic_net_conf = {
|
||||
"hidden_dims": [256, 128, 64],
|
||||
|
@ -25,7 +24,7 @@ critic_net_conf = {
|
|||
"activation": torch.nn.LeakyReLU,
|
||||
"softmax": False,
|
||||
"batch_norm": True,
|
||||
"head": True
|
||||
"head": True,
|
||||
}
|
||||
actor_learning_rate = 0.001
|
||||
critic_learning_rate = 0.001
|
||||
|
@ -64,9 +63,9 @@ def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMAD
|
|||
return DiscreteMADDPGTrainer(
|
||||
name=name,
|
||||
params=DiscreteMADDPGParams(
|
||||
reward_discount=.0,
|
||||
reward_discount=0.0,
|
||||
num_epoch=10,
|
||||
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
|
||||
shared_critic=False
|
||||
)
|
||||
shared_critic=False,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -15,11 +15,11 @@ def get_ppo(state_dim: int, name: str) -> PPOTrainer:
|
|||
name=name,
|
||||
params=PPOParams(
|
||||
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
|
||||
reward_discount=.0,
|
||||
reward_discount=0.0,
|
||||
grad_iters=10,
|
||||
critic_loss_cls=torch.nn.SmoothL1Loss,
|
||||
min_logp=None,
|
||||
lam=.0,
|
||||
lam=0.0,
|
||||
clip_ratio=0.1,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
env_conf = {
|
||||
"scenario": "cim",
|
||||
"topology": "toy.4p_ssdd_l0.0",
|
||||
"durations": 560
|
||||
"durations": 560,
|
||||
}
|
||||
|
||||
if env_conf["topology"].startswith("toy"):
|
||||
|
@ -17,27 +17,26 @@ vessel_attributes = ["empty", "full", "remaining_space"]
|
|||
|
||||
state_shaping_conf = {
|
||||
"look_back": 7,
|
||||
"max_ports_downstream": 2
|
||||
"max_ports_downstream": 2,
|
||||
}
|
||||
|
||||
action_shaping_conf = {
|
||||
"action_space": [(i - 10) / 10 for i in range(21)],
|
||||
"finite_vessel_space": True,
|
||||
"has_early_discharge": True
|
||||
"has_early_discharge": True,
|
||||
}
|
||||
|
||||
reward_shaping_conf = {
|
||||
"time_window": 99,
|
||||
"fulfillment_factor": 1.0,
|
||||
"shortage_factor": 1.0,
|
||||
"time_decay": 0.97
|
||||
"time_decay": 0.97,
|
||||
}
|
||||
|
||||
# obtain state dimension from a temporary env_wrapper instance
|
||||
state_dim = (
|
||||
(state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
|
||||
+ len(vessel_attributes)
|
||||
)
|
||||
state_dim = (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(
|
||||
port_attributes,
|
||||
) + len(vessel_attributes)
|
||||
|
||||
action_num = len(action_shaping_conf["action_space"])
|
||||
|
||||
|
|
|
@ -8,29 +8,32 @@ import numpy as np
|
|||
from maro.rl.rollout import AbsEnvSampler, CacheElement
|
||||
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent
|
||||
|
||||
from .config import (
|
||||
action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf,
|
||||
vessel_attributes,
|
||||
)
|
||||
from .config import action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf, vessel_attributes
|
||||
|
||||
|
||||
class CIMEnvSampler(AbsEnvSampler):
|
||||
def _get_global_and_agent_state_impl(
|
||||
self, event: DecisionEvent, tick: int = None,
|
||||
self,
|
||||
event: DecisionEvent,
|
||||
tick: int = None,
|
||||
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
|
||||
tick = self._env.tick
|
||||
vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
|
||||
port_idx, vessel_idx = event.port_idx, event.vessel_idx
|
||||
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
|
||||
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
|
||||
state = np.concatenate([
|
||||
port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes],
|
||||
vessel_snapshots[tick: vessel_idx: vessel_attributes]
|
||||
])
|
||||
future_port_list = vessel_snapshots[tick:vessel_idx:"future_stop_list"].astype("int")
|
||||
state = np.concatenate(
|
||||
[
|
||||
port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
|
||||
vessel_snapshots[tick:vessel_idx:vessel_attributes],
|
||||
],
|
||||
)
|
||||
return state, {port_idx: state}
|
||||
|
||||
def _translate_to_env_action(
|
||||
self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionEvent,
|
||||
self,
|
||||
action_dict: Dict[Any, Union[np.ndarray, List[object]]],
|
||||
event: DecisionEvent,
|
||||
) -> Dict[Any, object]:
|
||||
action_space = action_shaping_conf["action_space"]
|
||||
finite_vsl_space = action_shaping_conf["finite_vessel_space"]
|
||||
|
@ -40,7 +43,7 @@ class CIMEnvSampler(AbsEnvSampler):
|
|||
|
||||
vsl_idx, action_scope = event.vessel_idx, event.action_scope
|
||||
vsl_snapshots = self._env.snapshot_list["vessels"]
|
||||
vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")
|
||||
vsl_space = vsl_snapshots[self._env.tick : vsl_idx : vessel_attributes][2] if finite_vsl_space else float("inf")
|
||||
|
||||
percent = abs(action_space[model_action[0]])
|
||||
zero_action_idx = len(action_space) / 2 # index corresponding to value zero.
|
||||
|
@ -49,7 +52,9 @@ class CIMEnvSampler(AbsEnvSampler):
|
|||
actual_action = min(round(percent * action_scope.load), vsl_space)
|
||||
elif model_action > zero_action_idx:
|
||||
action_type = ActionType.DISCHARGE
|
||||
early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
|
||||
early_discharge = (
|
||||
vsl_snapshots[self._env.tick : vsl_idx : "early_discharge"][0] if has_early_discharge else 0
|
||||
)
|
||||
plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
|
||||
actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
|
||||
else:
|
||||
|
@ -70,7 +75,7 @@ class CIMEnvSampler(AbsEnvSampler):
|
|||
decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
|
||||
rewards = np.float32(
|
||||
reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
|
||||
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
|
||||
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list),
|
||||
)
|
||||
return {agent_id: reward for agent_id, reward in zip(ports, rewards)}
|
||||
|
||||
|
|
|
@ -3,19 +3,16 @@ from typing import Any, Callable, Dict, Optional
|
|||
|
||||
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
|
||||
from examples.cim.rl.env_sampler import CIMEnvSampler
|
||||
|
||||
from maro.rl.policy import AbsPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import AbsEnvSampler
|
||||
from maro.rl.training import AbsTrainer
|
||||
|
||||
from .algorithms.ac import get_ac_policy
|
||||
from .algorithms.dqn import get_dqn_policy
|
||||
from .algorithms.maddpg import get_maddpg_policy
|
||||
from .algorithms.ppo import get_ppo_policy
|
||||
from .algorithms.ac import get_ac
|
||||
from .algorithms.ppo import get_ppo
|
||||
from .algorithms.dqn import get_dqn
|
||||
from .algorithms.maddpg import get_maddpg
|
||||
from .algorithms.ac import get_ac, get_ac_policy
|
||||
from .algorithms.dqn import get_dqn, get_dqn_policy
|
||||
from .algorithms.maddpg import get_maddpg, get_maddpg_policy
|
||||
from .algorithms.ppo import get_ppo, get_ppo_policy
|
||||
|
||||
|
||||
class CIMBundle(RLComponentBundle):
|
||||
|
@ -29,7 +26,7 @@ class CIMBundle(RLComponentBundle):
|
|||
return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"])
|
||||
|
||||
def get_agent2policy(self) -> Dict[Any, str]:
|
||||
return {agent: f"{algorithm}_{agent}.policy"for agent in self.env.agent_idx_list}
|
||||
return {agent: f"{algorithm}_{agent}.policy" for agent in self.env.agent_idx_list}
|
||||
|
||||
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
|
||||
if algorithm == "ac":
|
||||
|
@ -60,23 +57,17 @@ class CIMBundle(RLComponentBundle):
|
|||
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
|
||||
if algorithm == "ac":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}")
|
||||
for i in range(num_agents)
|
||||
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "ppo":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}")
|
||||
for i in range(num_agents)
|
||||
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
|
||||
}
|
||||
elif algorithm == "dqn":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}")
|
||||
for i in range(num_agents)
|
||||
}
|
||||
trainer_creator = {f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}") for i in range(num_agents)}
|
||||
elif algorithm == "discrete_maddpg":
|
||||
trainer_creator = {
|
||||
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}")
|
||||
for i in range(num_agents)
|
||||
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
|
|
@ -63,8 +63,13 @@ class GreedyPolicy:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = Env(scenario=config.env.scenario, topology=config.env.topology, start_tick=config.env.start_tick,
|
||||
durations=config.env.durations, snapshot_resolution=config.env.resolution)
|
||||
env = Env(
|
||||
scenario=config.env.scenario,
|
||||
topology=config.env.topology,
|
||||
start_tick=config.env.start_tick,
|
||||
durations=config.env.durations,
|
||||
snapshot_resolution=config.env.resolution,
|
||||
)
|
||||
|
||||
if config.env.seed is not None:
|
||||
env.set_seed(config.env.seed)
|
||||
|
|
|
@ -7,11 +7,15 @@ from pulp import PULP_CBC_CMD, LpInteger, LpMaximize, LpProblem, LpVariable, lpS
|
|||
from maro.utils import DottableDict
|
||||
|
||||
|
||||
class CitiBikeILP():
|
||||
class CitiBikeILP:
|
||||
def __init__(
|
||||
self, num_station: int, num_neighbor: int,
|
||||
station_capacity: List[int], station_neighbor_list: List[List[int]],
|
||||
decision_interval: int, config: DottableDict
|
||||
self,
|
||||
num_station: int,
|
||||
num_neighbor: int,
|
||||
station_capacity: List[int],
|
||||
station_neighbor_list: List[List[int]],
|
||||
decision_interval: int,
|
||||
config: DottableDict,
|
||||
):
|
||||
"""A simple Linear Programming formulation for solving the bike repositioning problem.
|
||||
|
||||
|
@ -68,37 +72,37 @@ class CitiBikeILP():
|
|||
name=f"T{decision_point}_S{station}_Inv",
|
||||
lowBound=0,
|
||||
upBound=self._station_capacity[station],
|
||||
cat=LpInteger
|
||||
cat=LpInteger,
|
||||
)
|
||||
self._safety_inventory[decision_point][station] = LpVariable(
|
||||
name=f"T{decision_point}_S{station}_SafetyInv",
|
||||
lowBound=0,
|
||||
upBound=round(self._safety_inventory_limit * self._station_capacity[station]),
|
||||
cat=LpInteger
|
||||
cat=LpInteger,
|
||||
)
|
||||
self._fulfillment[decision_point][station] = LpVariable(
|
||||
name=f"T{decision_point}_S{station}_Fulfillment",
|
||||
lowBound=0,
|
||||
cat=LpInteger
|
||||
cat=LpInteger,
|
||||
)
|
||||
|
||||
# For intermediate variables.
|
||||
self._transfer_from[decision_point][station] = LpVariable(
|
||||
name=f"T{decision_point}_TransferFrom{station}",
|
||||
lowBound=0,
|
||||
cat=LpInteger
|
||||
cat=LpInteger,
|
||||
)
|
||||
self._transfer_to[decision_point][station] = LpVariable(
|
||||
name=f"T{decision_point}_TransferTo{station}",
|
||||
lowBound=0,
|
||||
cat=LpInteger
|
||||
cat=LpInteger,
|
||||
)
|
||||
|
||||
for neighbor_idx in range(self._num_neighbor):
|
||||
self._transfer[decision_point][station][neighbor_idx] = LpVariable(
|
||||
name=f"T{decision_point}_Transfer_from{station}_to{neighbor_idx}th",
|
||||
lowBound=0,
|
||||
cat=LpInteger
|
||||
cat=LpInteger,
|
||||
)
|
||||
|
||||
# Initialize inventory of the first decision point with the environment's current inventory.
|
||||
|
@ -113,23 +117,26 @@ class CitiBikeILP():
|
|||
), f"Fulfillment_Limit_T{decision_point}_S{station}"
|
||||
# For intermediate variables.
|
||||
problem += (
|
||||
self._transfer_from[decision_point][station] == lpSum(
|
||||
self._transfer_from[decision_point][station]
|
||||
== lpSum(
|
||||
self._transfer[decision_point][station][neighbor_idx]
|
||||
for neighbor_idx in range(self._num_neighbor)
|
||||
)
|
||||
), f"TotalTransferFrom_T{decision_point}_S{station}"
|
||||
problem += (
|
||||
self._transfer_to[decision_point][station] == lpSum(
|
||||
self._transfer_to[decision_point][station]
|
||||
== lpSum(
|
||||
self._transfer[decision_point][neighbor][self._station_neighbor_list[neighbor].index(station)]
|
||||
for neighbor in range(self._num_station)
|
||||
if station in self._station_neighbor_list[neighbor][:self._num_neighbor]
|
||||
if station in self._station_neighbor_list[neighbor][: self._num_neighbor]
|
||||
)
|
||||
), f"TotalTransferTo_T{decision_point}_S{station}"
|
||||
|
||||
for decision_point in range(1, self._num_decision_point):
|
||||
for station in range(self._num_station):
|
||||
problem += (
|
||||
self._inventory[decision_point][station] == (
|
||||
self._inventory[decision_point][station]
|
||||
== (
|
||||
self._inventory[decision_point - 1][station]
|
||||
+ supply[decision_point - 1, station]
|
||||
- self._fulfillment[decision_point - 1][station]
|
||||
|
@ -138,7 +145,8 @@ class CitiBikeILP():
|
|||
)
|
||||
), f"Inventory_T{decision_point}_S{station}"
|
||||
problem += (
|
||||
self._safety_inventory[decision_point][station] <= (
|
||||
self._safety_inventory[decision_point][station]
|
||||
<= (
|
||||
self._inventory[decision_point - 1][station]
|
||||
+ supply[decision_point - 1, station]
|
||||
- self._fulfillment[decision_point - 1][station]
|
||||
|
@ -148,26 +156,31 @@ class CitiBikeILP():
|
|||
|
||||
def _set_objective(self, problem: LpProblem):
|
||||
fulfillment_gain = lpSum(
|
||||
math.pow(self._fulfillment_time_decay_factor, decision_point) * lpSum(
|
||||
self._fulfillment[decision_point][station] for station in range(self._num_station)
|
||||
) for decision_point in range(self._num_decision_point)
|
||||
)
|
||||
|
||||
safety_inventory_reward = self._safety_inventory_reward_factor * lpSum(
|
||||
math.pow(self._safety_inventory_reward_time_decay_factor, decision_point) * lpSum(
|
||||
self._safety_inventory[decision_point][station] for station in range(self._num_station)
|
||||
) for decision_point in range(self._num_decision_point)
|
||||
)
|
||||
|
||||
transfer_cost = self._transfer_cost_factor * lpSum(
|
||||
self._transfer_to[decision_point][station] for station in range(self._num_station)
|
||||
math.pow(self._fulfillment_time_decay_factor, decision_point)
|
||||
* lpSum(self._fulfillment[decision_point][station] for station in range(self._num_station))
|
||||
for decision_point in range(self._num_decision_point)
|
||||
)
|
||||
|
||||
problem += (fulfillment_gain + safety_inventory_reward - transfer_cost)
|
||||
safety_inventory_reward = self._safety_inventory_reward_factor * lpSum(
|
||||
math.pow(self._safety_inventory_reward_time_decay_factor, decision_point)
|
||||
* lpSum(self._safety_inventory[decision_point][station] for station in range(self._num_station))
|
||||
for decision_point in range(self._num_decision_point)
|
||||
)
|
||||
|
||||
transfer_cost = self._transfer_cost_factor * lpSum(
|
||||
self._transfer_to[decision_point][station]
|
||||
for station in range(self._num_station)
|
||||
for decision_point in range(self._num_decision_point)
|
||||
)
|
||||
|
||||
problem += fulfillment_gain + safety_inventory_reward - transfer_cost
|
||||
|
||||
def _formulate_and_solve(
|
||||
self, env_tick: int, init_inventory: np.ndarray, demand: np.ndarray, supply: np.ndarray
|
||||
self,
|
||||
env_tick: int,
|
||||
init_inventory: np.ndarray,
|
||||
demand: np.ndarray,
|
||||
supply: np.ndarray,
|
||||
):
|
||||
problem = LpProblem(
|
||||
name=f"Citi_Bike_Repositioning_from_tick_{env_tick}",
|
||||
|
@ -181,7 +194,11 @@ class CitiBikeILP():
|
|||
# ============================= private end =============================
|
||||
|
||||
def get_transfer_list(
|
||||
self, env_tick: int, init_inventory: np.ndarray, demand: np.ndarray, supply: np.ndarray
|
||||
self,
|
||||
env_tick: int,
|
||||
init_inventory: np.ndarray,
|
||||
demand: np.ndarray,
|
||||
supply: np.ndarray,
|
||||
) -> List[Tuple[int, int, int]]:
|
||||
"""Get the transfer list for the given env_tick.
|
||||
|
||||
|
@ -201,7 +218,10 @@ class CitiBikeILP():
|
|||
if env_tick >= self._last_start_tick + self._apply_buffer_size:
|
||||
self._last_start_tick = env_tick
|
||||
self._formulate_and_solve(
|
||||
env_tick=env_tick, init_inventory=init_inventory, demand=demand, supply=supply
|
||||
env_tick=env_tick,
|
||||
init_inventory=init_inventory,
|
||||
demand=demand,
|
||||
supply=supply,
|
||||
)
|
||||
|
||||
decision_point = (env_tick - self._last_start_tick) // self._decision_interval
|
||||
|
|
|
@ -5,8 +5,8 @@ from typing import List, Tuple
|
|||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from citi_bike_ilp import CitiBikeILP
|
||||
|
||||
from maro.data_lib import BinaryReader, ItemTickPicker
|
||||
from maro.event_buffer import AbsEvent
|
||||
from maro.forecasting import OneStepFixWindowMA as Forecaster
|
||||
|
@ -22,9 +22,14 @@ ENV: Env = None
|
|||
TRIP_PICKER: ItemTickPicker = None
|
||||
|
||||
|
||||
class MaIlpAgent():
|
||||
class MaIlpAgent:
|
||||
def __init__(
|
||||
self, ilp: CitiBikeILP, num_station: int, num_time_interval: int, ticks_per_interval: int, ma_window_size: int
|
||||
self,
|
||||
ilp: CitiBikeILP,
|
||||
num_station: int,
|
||||
num_time_interval: int,
|
||||
ticks_per_interval: int,
|
||||
ma_window_size: int,
|
||||
):
|
||||
"""An agent that make decisions by ILP in Citi Bike scenario.
|
||||
|
||||
|
@ -100,15 +105,23 @@ class MaIlpAgent():
|
|||
The second item indicates the forecasting supply for each station in each time interval,
|
||||
with shape: (num_time_interval, num_station).
|
||||
"""
|
||||
demand = np.array(
|
||||
[round(self._demand_forecaster[i].forecast()) for i in range(self._num_station)],
|
||||
dtype=np.int16
|
||||
).reshape((1, -1)).repeat(self._num_time_interval, axis=0)
|
||||
demand = (
|
||||
np.array(
|
||||
[round(self._demand_forecaster[i].forecast()) for i in range(self._num_station)],
|
||||
dtype=np.int16,
|
||||
)
|
||||
.reshape((1, -1))
|
||||
.repeat(self._num_time_interval, axis=0)
|
||||
)
|
||||
|
||||
supply = np.array(
|
||||
[round(self._supply_forecaster[i].forecast()) for i in range(self._num_station)],
|
||||
dtype=np.int16
|
||||
).reshape((1, -1)).repeat(self._num_time_interval, axis=0)
|
||||
supply = (
|
||||
np.array(
|
||||
[round(self._supply_forecaster[i].forecast()) for i in range(self._num_station)],
|
||||
dtype=np.int16,
|
||||
)
|
||||
.reshape((1, -1))
|
||||
.repeat(self._num_time_interval, axis=0)
|
||||
)
|
||||
|
||||
return demand, supply
|
||||
|
||||
|
@ -147,7 +160,7 @@ class MaIlpAgent():
|
|||
env_tick=env_tick,
|
||||
init_inventory=init_inventory,
|
||||
demand=demand,
|
||||
supply=supply
|
||||
supply=supply,
|
||||
)
|
||||
|
||||
action_list = [
|
||||
|
@ -160,20 +173,28 @@ class MaIlpAgent():
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--peep", action='store_true',
|
||||
help="If set, peep the future demand and supply of bikes for each station directly from the log data."
|
||||
"--peep",
|
||||
action="store_true",
|
||||
help="If set, peep the future demand and supply of bikes for each station directly from the log data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--config", type=str, default="examples/citi_bike/online_lp/config.yml",
|
||||
help="The path of the config file."
|
||||
"-c",
|
||||
"--config",
|
||||
type=str,
|
||||
default="examples/citi_bike/online_lp/config.yml",
|
||||
help="The path of the config file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--topology", type=str,
|
||||
help="Which topology to use. If set, it will over-write the topology set in the config file."
|
||||
"-t",
|
||||
"--topology",
|
||||
type=str,
|
||||
help="Which topology to use. If set, it will over-write the topology set in the config file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r", "--seed", type=int,
|
||||
help="The random seed for the environment. If set, it will over-write the seed set in the config file."
|
||||
"-r",
|
||||
"--seed",
|
||||
type=int,
|
||||
help="The random seed for the environment. If set, it will over-write the seed set in the config file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -205,7 +226,7 @@ if __name__ == "__main__":
|
|||
TRIP_PICKER = BinaryReader(env.configs["trip_data"]).items_tick_picker(
|
||||
start_time_offset=config.env.start_tick,
|
||||
end_time_offset=(config.env.start_tick + config.env.durations),
|
||||
time_unit="m"
|
||||
time_unit="m",
|
||||
)
|
||||
|
||||
if config.env.seed is not None:
|
||||
|
@ -220,29 +241,26 @@ if __name__ == "__main__":
|
|||
# TODO: Update the Env interface.
|
||||
num_station = len(env.agent_idx_list)
|
||||
station_distance_adj = np.array(
|
||||
load_adj_from_csv(env.configs["distance_adj_data"], skiprows=1)
|
||||
load_adj_from_csv(env.configs["distance_adj_data"], skiprows=1),
|
||||
).reshape(num_station, num_station)
|
||||
station_neighbor_list = [
|
||||
neighbor_list[1:]
|
||||
for neighbor_list in np.argsort(station_distance_adj, axis=1).tolist()
|
||||
]
|
||||
station_neighbor_list = [neighbor_list[1:] for neighbor_list in np.argsort(station_distance_adj, axis=1).tolist()]
|
||||
|
||||
# Init a Moving-Average based ILP agent.
|
||||
decision_interval = env.configs["decision"]["resolution"]
|
||||
ilp = CitiBikeILP(
|
||||
num_station=num_station,
|
||||
num_neighbor=min(config.ilp.num_neighbor, num_station - 1),
|
||||
station_capacity=env.snapshot_list["stations"][env.frame_index:env.agent_idx_list:"capacity"],
|
||||
station_capacity=env.snapshot_list["stations"][env.frame_index : env.agent_idx_list : "capacity"],
|
||||
station_neighbor_list=station_neighbor_list,
|
||||
decision_interval=decision_interval,
|
||||
config=config.ilp
|
||||
config=config.ilp,
|
||||
)
|
||||
agent = MaIlpAgent(
|
||||
ilp=ilp,
|
||||
num_station=num_station,
|
||||
num_time_interval=math.ceil(config.ilp.plan_window_size / decision_interval),
|
||||
ticks_per_interval=decision_interval,
|
||||
ma_window_size=config.forecasting.ma_window_size
|
||||
ma_window_size=config.forecasting.ma_window_size,
|
||||
)
|
||||
|
||||
pre_decision_tick: int = -1
|
||||
|
@ -252,15 +270,15 @@ if __name__ == "__main__":
|
|||
else:
|
||||
action = agent.get_action_list(
|
||||
env_tick=env.tick,
|
||||
init_inventory=env.snapshot_list["stations"][
|
||||
env.frame_index:env.agent_idx_list:"bikes"
|
||||
].astype(np.int16),
|
||||
finished_events=env.get_finished_events()
|
||||
init_inventory=env.snapshot_list["stations"][env.frame_index : env.agent_idx_list : "bikes"].astype(
|
||||
np.int16,
|
||||
),
|
||||
finished_events=env.get_finished_events(),
|
||||
)
|
||||
pre_decision_tick = decision_event.tick
|
||||
_, decision_event, is_done = env.step(action=action)
|
||||
|
||||
print(
|
||||
f"[{'De' if PEEP_AND_USE_REAL_DATA else 'MA'}] "
|
||||
f"Topology {config.env.topology} with seed {config.env.seed}: {env.metrics}"
|
||||
f"Topology {config.env.topology} with seed {config.env.seed}: {env.metrics}",
|
||||
)
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from random import seed, randint
|
||||
from random import randint
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.cim.common import Action, ActionType
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize an environment with a specific scenario, related topology.
|
||||
env = Env(scenario="cim", topology="global_trade.22p_l0.1", start_tick=0, durations=100)
|
||||
|
@ -29,7 +28,7 @@ if __name__ == "__main__":
|
|||
decision_event.vessel_idx,
|
||||
decision_event.port_idx,
|
||||
randint(0, action_scope.discharge if to_discharge else action_scope.load),
|
||||
ActionType.DISCHARGE if to_discharge else ActionType.LOAD
|
||||
ActionType.DISCHARGE if to_discharge else ActionType.LOAD,
|
||||
)
|
||||
|
||||
# Drive environment with dummy action (no repositioning)
|
||||
|
|
|
@ -9,13 +9,12 @@ os.environ["MARO_STREAMIT_ENABLED"] = "true"
|
|||
os.environ["MARO_STREAMIT_EXPERIMENT_NAME"] = "experiment_example"
|
||||
|
||||
|
||||
from random import seed, randint
|
||||
from random import randint, seed
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.cim.common import Action, ActionScope, ActionType
|
||||
from maro.simulator.scenarios.cim.common import Action, ActionType
|
||||
from maro.streamit import streamit
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed(0)
|
||||
NUM_EPISODE = 2
|
||||
|
@ -42,7 +41,7 @@ if __name__ == "__main__":
|
|||
decision_event.vessel_idx,
|
||||
decision_event.port_idx,
|
||||
randint(0, action_scope.discharge if to_discharge else action_scope.load),
|
||||
ActionType.DISCHARGE if to_discharge else ActionType.LOAD
|
||||
ActionType.DISCHARGE if to_discharge else ActionType.LOAD,
|
||||
)
|
||||
|
||||
# Drive environment with dummy action (no repositioning)
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from random import seed, randint
|
||||
from random import randint, seed
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.cim.common import Action, ActionType
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed(0)
|
||||
NUM_EPISODE = 2
|
||||
|
@ -17,12 +16,15 @@ if __name__ == "__main__":
|
|||
If you leave value to empty string, it will dump to current folder.
|
||||
For getting dump data, please uncomment below line and specify dump destination folder.
|
||||
"""
|
||||
opts['enable-dump-snapshot'] = 'YOUR_FOLDER_NAME'
|
||||
opts["enable-dump-snapshot"] = "YOUR_FOLDER_NAME"
|
||||
|
||||
# Initialize an environment with a specific scenario, related topology.
|
||||
env = Env(
|
||||
scenario="cim", topology="global_trade.22p_l0.1",
|
||||
start_tick=0, durations=100, options=opts
|
||||
scenario="cim",
|
||||
topology="global_trade.22p_l0.1",
|
||||
start_tick=0,
|
||||
durations=100,
|
||||
options=opts,
|
||||
)
|
||||
|
||||
# To reset environmental data before starting a new experiment.
|
||||
|
@ -40,7 +42,7 @@ if __name__ == "__main__":
|
|||
decision_event.vessel_idx,
|
||||
decision_event.port_idx,
|
||||
randint(0, action_scope.discharge if to_discharge else action_scope.load),
|
||||
ActionType.DISCHARGE if to_discharge else ActionType.LOAD
|
||||
ActionType.DISCHARGE if to_discharge else ActionType.LOAD,
|
||||
)
|
||||
|
||||
# Drive environment with dummy action (no repositioning)
|
||||
|
|
|
@ -17,8 +17,14 @@ For getting dump data, please uncomment below line and specify dump destination
|
|||
"""
|
||||
# opts['enable-dump-snapshot'] = ''
|
||||
|
||||
env = Env(scenario="citi_bike", topology="toy.4s_4t", start_tick=start_tick,
|
||||
durations=durations, snapshot_resolution=60, options=opts)
|
||||
env = Env(
|
||||
scenario="citi_bike",
|
||||
topology="toy.4s_4t",
|
||||
start_tick=start_tick,
|
||||
durations=durations,
|
||||
snapshot_resolution=60,
|
||||
options=opts,
|
||||
)
|
||||
|
||||
print(env.summary)
|
||||
|
||||
|
@ -35,11 +41,11 @@ for ep in range(max_ep):
|
|||
if decision_evt is not None:
|
||||
action = Action(decision_evt.station_idx, 0, 10)
|
||||
|
||||
station_ss = env.snapshot_list['stations']
|
||||
shortage_states = station_ss[::'shortage']
|
||||
station_ss = env.snapshot_list["stations"]
|
||||
shortage_states = station_ss[::"shortage"]
|
||||
print("total shortage", shortage_states.sum())
|
||||
|
||||
trips_states = station_ss[::'trip_requirement']
|
||||
trips_states = station_ss[::"trip_requirement"]
|
||||
print("total trip", trips_states.sum())
|
||||
|
||||
cost_states = station_ss[::["extra_cost", "transfer_cost"]]
|
||||
|
@ -54,7 +60,7 @@ for ep in range(max_ep):
|
|||
|
||||
# NOTE: We have not clear the trip adj at each tick so it is an accumulative value,
|
||||
# then we can just query last snapshot to calc total trips.
|
||||
trips_adj = matrix_ss[last_snapshot_index::'trips_adj']
|
||||
trips_adj = matrix_ss[last_snapshot_index::"trips_adj"]
|
||||
|
||||
# Reshape it we need an easy way to access.
|
||||
# trips_adj = trips_adj.reshape((-1, len(station_ss)))
|
||||
|
|
|
@ -14,9 +14,11 @@ def worker(group_name):
|
|||
Args:
|
||||
group_name (str): Identifier for the group of all communication components.
|
||||
"""
|
||||
proxy = Proxy(group_name=group_name,
|
||||
component_type="worker",
|
||||
expected_peers={"master": 1})
|
||||
proxy = Proxy(
|
||||
group_name=group_name,
|
||||
component_type="worker",
|
||||
expected_peers={"master": 1},
|
||||
)
|
||||
counter = 0
|
||||
print(f"{proxy.name}'s counter is {counter}.")
|
||||
|
||||
|
@ -45,14 +47,14 @@ def master(group_name: str, worker_num: int, is_immediate: bool = False):
|
|||
proxy = Proxy(
|
||||
group_name=group_name,
|
||||
component_type="master",
|
||||
expected_peers={"worker": worker_num}
|
||||
expected_peers={"worker": worker_num},
|
||||
)
|
||||
|
||||
if is_immediate:
|
||||
session_ids = proxy.ibroadcast(
|
||||
component_type="worker",
|
||||
tag="INC",
|
||||
session_type=SessionType.NOTIFICATION
|
||||
session_type=SessionType.NOTIFICATION,
|
||||
)
|
||||
# Do some tasks with higher priority here.
|
||||
replied_msgs = proxy.receive_by_id(session_ids, timeout=-1)
|
||||
|
@ -61,13 +63,13 @@ def master(group_name: str, worker_num: int, is_immediate: bool = False):
|
|||
component_type="worker",
|
||||
tag="INC",
|
||||
session_type=SessionType.NOTIFICATION,
|
||||
timeout=-1
|
||||
timeout=-1,
|
||||
)
|
||||
|
||||
for msg in replied_msgs:
|
||||
print(
|
||||
f"{proxy.name} get receive notification from {msg.source} with "
|
||||
f"message session stage {msg.session_stage}."
|
||||
f"message session stage {msg.session_stage}.",
|
||||
)
|
||||
|
||||
|
||||
|
@ -84,7 +86,7 @@ if __name__ == "__main__":
|
|||
|
||||
workers = mp.Pool(worker_number)
|
||||
|
||||
master_process = mp.Process(target=master, args=(group_name, worker_number, is_immediate,))
|
||||
master_process = mp.Process(target=master, args=(group_name, worker_number, is_immediate))
|
||||
master_process.start()
|
||||
|
||||
workers.map(worker, [group_name] * worker_number)
|
||||
|
|
|
@ -16,9 +16,11 @@ def summation_worker(group_name):
|
|||
Args:
|
||||
group_name (str): Identifier for the group of all communication components.
|
||||
"""
|
||||
proxy = Proxy(group_name=group_name,
|
||||
component_type="sum_worker",
|
||||
expected_peers={"master": 1})
|
||||
proxy = Proxy(
|
||||
group_name=group_name,
|
||||
component_type="sum_worker",
|
||||
expected_peers={"master": 1},
|
||||
)
|
||||
|
||||
# Nonrecurring receive the message from the proxy.
|
||||
msg = proxy.receive_once()
|
||||
|
@ -36,9 +38,11 @@ def multiplication_worker(group_name):
|
|||
Args:
|
||||
group_name (str): Identifier for the group of all communication components.
|
||||
"""
|
||||
proxy = Proxy(group_name=group_name,
|
||||
component_type="multiply_worker",
|
||||
expected_peers={"master": 1})
|
||||
proxy = Proxy(
|
||||
group_name=group_name,
|
||||
component_type="multiply_worker",
|
||||
expected_peers={"master": 1},
|
||||
)
|
||||
|
||||
# Nonrecurring receive the message from the proxy.
|
||||
msg = proxy.receive_once()
|
||||
|
@ -62,10 +66,14 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
|
|||
you can do something with high priority before receiving replied messages from peers.
|
||||
Sync Mode: It will block until the proxy returns all the replied messages.
|
||||
"""
|
||||
proxy = Proxy(group_name=group_name,
|
||||
component_type="master",
|
||||
expected_peers={"sum_worker": sum_worker_number,
|
||||
"multiply_worker": multiply_worker_number})
|
||||
proxy = Proxy(
|
||||
group_name=group_name,
|
||||
component_type="master",
|
||||
expected_peers={
|
||||
"sum_worker": sum_worker_number,
|
||||
"multiply_worker": multiply_worker_number,
|
||||
},
|
||||
)
|
||||
|
||||
sum_list = np.random.randint(0, 10, 100)
|
||||
multiple_list = np.random.randint(1, 10, 20)
|
||||
|
@ -75,25 +83,30 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
|
|||
destination_payload_list = []
|
||||
for idx, peer in enumerate(proxy.peers["sum_worker"]):
|
||||
data_length_per_peer = int(len(sum_list) / len(proxy.peers["sum_worker"]))
|
||||
destination_payload_list.append((peer, sum_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer]))
|
||||
destination_payload_list.append((peer, sum_list[idx * data_length_per_peer : (idx + 1) * data_length_per_peer]))
|
||||
|
||||
# Assign multiply tasks for multiplication workers.
|
||||
for idx, peer in enumerate(proxy.peers["multiply_worker"]):
|
||||
data_length_per_peer = int(len(multiple_list) / len(proxy.peers["multiply_worker"]))
|
||||
destination_payload_list.append(
|
||||
(peer, multiple_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer]))
|
||||
(peer, multiple_list[idx * data_length_per_peer : (idx + 1) * data_length_per_peer]),
|
||||
)
|
||||
|
||||
if is_immediate:
|
||||
session_ids = proxy.iscatter(tag="job",
|
||||
session_type=SessionType.TASK,
|
||||
destination_payload_list=destination_payload_list)
|
||||
session_ids = proxy.iscatter(
|
||||
tag="job",
|
||||
session_type=SessionType.TASK,
|
||||
destination_payload_list=destination_payload_list,
|
||||
)
|
||||
# Do some tasks with higher priority here.
|
||||
replied_msgs = proxy.receive_by_id(session_ids, timeout=-1)
|
||||
else:
|
||||
replied_msgs = proxy.scatter(tag="job",
|
||||
session_type=SessionType.TASK,
|
||||
destination_payload_list=destination_payload_list,
|
||||
timeout=-1)
|
||||
replied_msgs = proxy.scatter(
|
||||
tag="job",
|
||||
session_type=SessionType.TASK,
|
||||
destination_payload_list=destination_payload_list,
|
||||
timeout=-1,
|
||||
)
|
||||
|
||||
sum_result, multiply_result = 0, 1
|
||||
for msg in replied_msgs:
|
||||
|
@ -105,8 +118,8 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
|
|||
multiply_result *= msg.body
|
||||
|
||||
# Check task result correction.
|
||||
assert(sum(sum_list) == sum_result)
|
||||
assert(np.prod(multiple_list) == multiply_result)
|
||||
assert sum(sum_list) == sum_result
|
||||
assert np.prod(multiple_list) == multiply_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -124,8 +137,10 @@ if __name__ == "__main__":
|
|||
# Worker's pool for sum_worker and prod_worker.
|
||||
workers = mp.Pool(sum_worker_number + multiply_worker_number)
|
||||
|
||||
master_process = mp.Process(target=master,
|
||||
args=(group_name, sum_worker_number, multiply_worker_number, is_immediate,))
|
||||
master_process = mp.Process(
|
||||
target=master,
|
||||
args=(group_name, sum_worker_number, multiply_worker_number, is_immediate),
|
||||
)
|
||||
master_process.start()
|
||||
|
||||
for s in range(sum_worker_number):
|
||||
|
|
|
@ -16,9 +16,11 @@ def worker(group_name):
|
|||
Args:
|
||||
group_name (str): Identifier for the group of all communication components
|
||||
"""
|
||||
proxy = Proxy(group_name=group_name,
|
||||
component_type="worker",
|
||||
expected_peers={"master": 1})
|
||||
proxy = Proxy(
|
||||
group_name=group_name,
|
||||
component_type="worker",
|
||||
expected_peers={"master": 1},
|
||||
)
|
||||
|
||||
# Nonrecurring receive the message from the proxy.
|
||||
msg = proxy.receive_once()
|
||||
|
@ -40,19 +42,23 @@ def master(group_name: str, is_immediate: bool = False):
|
|||
you can do something with high priority before receiving replied messages from peers.
|
||||
Sync Mode: It will block until the proxy returns all the replied messages.
|
||||
"""
|
||||
proxy = Proxy(group_name=group_name,
|
||||
component_type="master",
|
||||
expected_peers={"worker": 1})
|
||||
proxy = Proxy(
|
||||
group_name=group_name,
|
||||
component_type="master",
|
||||
expected_peers={"worker": 1},
|
||||
)
|
||||
|
||||
random_integer_list = np.random.randint(0, 100, 5)
|
||||
print(f"generate random integer list: {random_integer_list}.")
|
||||
|
||||
for peer in proxy.peers["worker"]:
|
||||
message = SessionMessage(tag="sum",
|
||||
source=proxy.name,
|
||||
destination=peer,
|
||||
body=random_integer_list,
|
||||
session_type=SessionType.TASK)
|
||||
message = SessionMessage(
|
||||
tag="sum",
|
||||
source=proxy.name,
|
||||
destination=peer,
|
||||
body=random_integer_list,
|
||||
session_type=SessionType.TASK,
|
||||
)
|
||||
if is_immediate:
|
||||
session_id = proxy.isend(message)
|
||||
# Do some tasks with higher priority here.
|
||||
|
@ -74,7 +80,7 @@ if __name__ == "__main__":
|
|||
group_name = "proxy_send_simple_example"
|
||||
is_immediate = False
|
||||
|
||||
master_process = mp.Process(target=master, args=(group_name, is_immediate,))
|
||||
master_process = mp.Process(target=master, args=(group_name, is_immediate))
|
||||
worker_process = mp.Process(target=worker, args=(group_name,))
|
||||
master_process.start()
|
||||
worker_process.start()
|
||||
|
|
|
@ -5,7 +5,7 @@ from maro.cli.local.commands import run
|
|||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("conf_path", help='Path of the job deployment')
|
||||
parser.add_argument("conf_path", help="Path of the job deployment")
|
||||
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
|
@ -80,5 +80,5 @@ retail_frame.stores[0].shortages[:] = [i + 1 for i in range(TOTAL_PRODUCT_CATEGO
|
|||
retail_frame.take_snapshot(1)
|
||||
|
||||
# Query shortage information of first and second store at first and second tick.
|
||||
store_shortage_history = snapshot_list["store"][[0, 1]: [0, 1]: "shortages"].reshape(2, -1)
|
||||
store_shortage_history = snapshot_list["store"][[0, 1]:[0, 1]:"shortages"].reshape(2, -1)
|
||||
print(f"First and second store shortage history at the first and second tick (numpy array): {store_shortage_history}")
|
||||
|
|
|
@ -2,13 +2,16 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent
|
||||
from maro.vector_env import VectorEnv
|
||||
|
||||
|
||||
class VectorEnvUsage(Enum):
|
||||
PUSH_ONE_FORWARD = "push the first environment forward and left others behind"
|
||||
PUSH_ALL_FORWARD = "push all environments forward together"
|
||||
|
||||
|
||||
USAGE = VectorEnvUsage.PUSH_ONE_FORWARD
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -28,7 +31,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Showcase: how to access information from snapshot list in vector env.
|
||||
if env0_dec:
|
||||
ss0 = env.snapshot_list["vessels"][env0_dec.tick:env0_dec.vessel_idx:"remaining_space"]
|
||||
ss0 = env.snapshot_list["vessels"][env0_dec.tick : env0_dec.vessel_idx : "remaining_space"]
|
||||
|
||||
# 1. Only push specified (1st for this example) environment, leave others behind.
|
||||
if USAGE == VectorEnvUsage.PUSH_ONE_FORWARD and env0_dec:
|
||||
|
@ -38,8 +41,8 @@ if __name__ == "__main__":
|
|||
vessel_idx=env0_dec.vessel_idx,
|
||||
port_idx=env0_dec.port_idx,
|
||||
quantity=env0_dec.action_scope.load,
|
||||
action_type=ActionType.LOAD
|
||||
)
|
||||
action_type=ActionType.LOAD,
|
||||
),
|
||||
}
|
||||
|
||||
# 2. Only pass action to 1st environment (give None to other environments),
|
||||
|
@ -51,7 +54,7 @@ if __name__ == "__main__":
|
|||
vessel_idx=env0_dec.vessel_idx,
|
||||
port_idx=env0_dec.port_idx,
|
||||
quantity=env0_dec.action_scope.load,
|
||||
action_type=ActionType.LOAD
|
||||
action_type=ActionType.LOAD,
|
||||
)
|
||||
|
||||
metrics, decision_event, is_done = env.step(action)
|
||||
|
|
|
@ -3,19 +3,21 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class IlpPmCapacity:
|
||||
core: int
|
||||
mem: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class IlpVmInfo:
|
||||
id: int=-1
|
||||
pm_idx: int=-2
|
||||
core: int=-1
|
||||
mem: int=-1
|
||||
lifetime: int=-1
|
||||
arrival_env_tick: int=-1
|
||||
id: int = -1
|
||||
pm_idx: int = -2
|
||||
core: int = -1
|
||||
mem: int = -1
|
||||
lifetime: int = -1
|
||||
arrival_env_tick: int = -1
|
||||
|
||||
def remaining_lifetime(self, env_tick: int):
|
||||
return self.lifetime - (env_tick - self.arrival_env_tick)
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
import timeit
|
||||
from collections import defaultdict, Counter
|
||||
from collections import Counter, defaultdict
|
||||
from typing import List, Set
|
||||
|
||||
from maro.data_lib import BinaryReader
|
||||
from maro.simulator.scenarios.vm_scheduling import PostponeAction, AllocateAction
|
||||
from maro.simulator.scenarios.vm_scheduling.common import Action
|
||||
from maro.utils import DottableDict, Logger
|
||||
|
||||
import numpy as np
|
||||
from common import IlpPmCapacity, IlpVmInfo
|
||||
from vm_scheduling_ilp import NOT_ALLOCATE_NOW, VmSchedulingILP
|
||||
|
||||
class IlpAgent():
|
||||
from maro.data_lib import BinaryReader
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, PostponeAction
|
||||
from maro.simulator.scenarios.vm_scheduling.common import Action
|
||||
from maro.utils import DottableDict, Logger
|
||||
|
||||
|
||||
class IlpAgent:
|
||||
def __init__(
|
||||
self,
|
||||
ilp_config: DottableDict,
|
||||
|
@ -24,11 +24,11 @@ class IlpAgent():
|
|||
env_duration: int,
|
||||
simulation_logger: Logger,
|
||||
ilp_logger: Logger,
|
||||
log_path: str
|
||||
log_path: str,
|
||||
):
|
||||
self._simulation_logger = simulation_logger
|
||||
self._ilp_logger = ilp_logger
|
||||
|
||||
|
||||
self._allocation_counter = Counter()
|
||||
|
||||
pm_capacity: List[IlpPmCapacity] = [IlpPmCapacity(core=pm[0], mem=pm[1]) for pm in pm_capacity]
|
||||
|
@ -41,7 +41,7 @@ class IlpAgent():
|
|||
self.vm_item_picker = self.vm_reader.items_tick_picker(
|
||||
env_start_tick,
|
||||
env_start_tick + env_duration,
|
||||
time_unit="s"
|
||||
time_unit="s",
|
||||
)
|
||||
|
||||
# Used to keep the info already read from the vm_item_picker.
|
||||
|
@ -87,7 +87,7 @@ class IlpAgent():
|
|||
core=vm.vm_cpu_cores,
|
||||
mem=vm.vm_memory,
|
||||
lifetime=vm.vm_lifetime,
|
||||
arrival_env_tick=tick
|
||||
arrival_env_tick=tick,
|
||||
)
|
||||
if tick < env_tick + self.ilp_apply_buffer_size:
|
||||
self.refreshed_allocated_vm_dict[vm.vm_id] = vmInfo
|
||||
|
@ -109,7 +109,13 @@ class IlpAgent():
|
|||
|
||||
# Choose action by ILP, may trigger a new formulation and solution,
|
||||
# may directly return the decision if the cur_vm_id is still in the apply buffer size of last solution.
|
||||
chosen_pm_idx = self.ilp.choose_pm(env_tick, cur_vm_id, self.allocated_vm, self.future_vm_req, self._vm_id_to_idx)
|
||||
chosen_pm_idx = self.ilp.choose_pm(
|
||||
env_tick,
|
||||
cur_vm_id,
|
||||
self.allocated_vm,
|
||||
self.future_vm_req,
|
||||
self._vm_id_to_idx,
|
||||
)
|
||||
self._simulation_logger.info(f"tick: {env_tick}, vm: {cur_vm_id} -> pm: {chosen_pm_idx}")
|
||||
|
||||
if chosen_pm_idx == NOT_ALLOCATE_NOW:
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
import io
|
||||
import os
|
||||
import pprint
|
||||
import shutil
|
||||
import timeit
|
||||
from typing import List, Set
|
||||
|
||||
import pprint
|
||||
import yaml
|
||||
from ilp_agent import IlpAgent
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import DecisionPayload
|
||||
from maro.simulator.scenarios.vm_scheduling.common import Action
|
||||
from maro.utils import convert_dottable, Logger, LogFormat
|
||||
from maro.utils import LogFormat, Logger, convert_dottable
|
||||
|
||||
from ilp_agent import IlpAgent
|
||||
|
||||
os.environ['LOG_LEVEL'] = 'CRITICAL'
|
||||
os.environ["LOG_LEVEL"] = "CRITICAL"
|
||||
FILE_PATH = os.path.split(os.path.realpath(__file__))[0]
|
||||
CONFIG_PATH = os.path.join(FILE_PATH, "config.yml")
|
||||
with io.open(CONFIG_PATH, "r") as in_file:
|
||||
|
@ -32,11 +31,11 @@ if __name__ == "__main__":
|
|||
topology=config.env.topology,
|
||||
start_tick=config.env.start_tick,
|
||||
durations=config.env.durations,
|
||||
snapshot_resolution=config.env.resolution
|
||||
snapshot_resolution=config.env.resolution,
|
||||
)
|
||||
shutil.copy(
|
||||
os.path.join(env._business_engine._config_path, "config.yml"),
|
||||
os.path.join(LOG_PATH, "BEconfig.yml")
|
||||
os.path.join(LOG_PATH, "BEconfig.yml"),
|
||||
)
|
||||
shutil.copy(CONFIG_PATH, os.path.join(LOG_PATH, "config.yml"))
|
||||
|
||||
|
@ -51,9 +50,7 @@ if __name__ == "__main__":
|
|||
metrics, decision_event, is_done = env.step(None)
|
||||
|
||||
# Get the core & memory capacity of all PMs in this environment.
|
||||
pm_capacity = env.snapshot_list["pms"][
|
||||
env.frame_index::["cpu_cores_capacity", "memory_capacity"]
|
||||
].reshape(-1, 2)
|
||||
pm_capacity = env.snapshot_list["pms"][env.frame_index :: ["cpu_cores_capacity", "memory_capacity"]].reshape(-1, 2)
|
||||
pm_num = pm_capacity.shape[0]
|
||||
|
||||
# ILP agent.
|
||||
|
@ -65,7 +62,7 @@ if __name__ == "__main__":
|
|||
env_duration=config.env.durations,
|
||||
simulation_logger=simulation_logger,
|
||||
ilp_logger=ilp_logger,
|
||||
log_path=LOG_PATH
|
||||
log_path=LOG_PATH,
|
||||
)
|
||||
|
||||
while not is_done:
|
||||
|
@ -83,7 +80,7 @@ if __name__ == "__main__":
|
|||
end_time = timeit.default_timer()
|
||||
simulation_logger.info(
|
||||
f"[Offline ILP] Topology: {config.env.topology}. Total ticks: {config.env.durations}."
|
||||
f" Start tick: {config.env.start_tick}."
|
||||
f" Start tick: {config.env.start_tick}.",
|
||||
)
|
||||
simulation_logger.info(f"[Timer] {end_time - start_time:.2f} seconds to finish the simulation.")
|
||||
ilp_agent.report_allocation_summary()
|
||||
|
|
|
@ -1,23 +1,22 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import os
|
||||
import timeit
|
||||
from typing import List
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from pulp import PULP_CBC_CMD, GLPK, GUROBI_CMD, LpInteger, LpMaximize, LpProblem, LpStatus, LpVariable, lpSum
|
||||
from common import IlpPmCapacity, IlpVmInfo
|
||||
from pulp import GLPK, PULP_CBC_CMD, LpInteger, LpMaximize, LpProblem, LpStatus, LpVariable, lpSum
|
||||
|
||||
from maro.utils import DottableDict, Logger
|
||||
|
||||
from common import IlpPmCapacity, IlpVmInfo
|
||||
|
||||
|
||||
# To indicate the decision of not allocate or cannot allocate any PM for current VM request.
|
||||
NOT_ALLOCATE_NOW = -1
|
||||
|
||||
class VmSchedulingILP():
|
||||
|
||||
class VmSchedulingILP:
|
||||
def __init__(self, config: DottableDict, pm_capacity: List[IlpPmCapacity], logger: Logger, log_path: str):
|
||||
self._logger = logger
|
||||
self._log_path = log_path
|
||||
|
@ -76,7 +75,7 @@ class VmSchedulingILP():
|
|||
name=f"Place_VM{vm_idx}_{self._future_vm_req[vm_idx].id}_on_PM{pm_idx}",
|
||||
lowBound=0,
|
||||
upBound=1,
|
||||
cat=LpInteger
|
||||
cat=LpInteger,
|
||||
)
|
||||
|
||||
def _add_constraints(self, problem: LpProblem):
|
||||
|
@ -84,7 +83,7 @@ class VmSchedulingILP():
|
|||
for vm_idx in range(self._vm_num):
|
||||
problem += (
|
||||
lpSum(self._mapping[pm_idx][vm_idx] for pm_idx in range(self._pm_num)) <= 1,
|
||||
f"Mapping_VM{vm_idx}_to_max_1_PM"
|
||||
f"Mapping_VM{vm_idx}_to_max_1_PM",
|
||||
)
|
||||
|
||||
# PM capacity limitation: core + mem.
|
||||
|
@ -95,16 +94,20 @@ class VmSchedulingILP():
|
|||
vm.core * self._mapping[pm_idx][vm_idx]
|
||||
for vm_idx, vm in enumerate(self._future_vm_req)
|
||||
if (vm.arrival_env_tick - self._env_tick <= t and vm.remaining_lifetime(self._env_tick) >= t)
|
||||
) + self._pm_allocated_core[t][pm_idx] <= self._pm_capacity[pm_idx].core * self.core_upper_ratio,
|
||||
f"T{t}_PM{pm_idx}_core_capacity_limit"
|
||||
)
|
||||
+ self._pm_allocated_core[t][pm_idx]
|
||||
<= self._pm_capacity[pm_idx].core * self.core_upper_ratio,
|
||||
f"T{t}_PM{pm_idx}_core_capacity_limit",
|
||||
)
|
||||
problem += (
|
||||
lpSum(
|
||||
vm.mem * self._mapping[pm_idx][vm_idx]
|
||||
for vm_idx, vm in enumerate(self._future_vm_req)
|
||||
if (vm.arrival_env_tick - self._env_tick <= t and vm.remaining_lifetime(self._env_tick) >= t)
|
||||
) + self._pm_allocated_mem[t][pm_idx] <= self._pm_capacity[pm_idx].mem * self.mem_upper_ratio,
|
||||
f"T{t}_PM{pm_idx}_mem_capacity_limit"
|
||||
)
|
||||
+ self._pm_allocated_mem[t][pm_idx]
|
||||
<= self._pm_capacity[pm_idx].mem * self.mem_upper_ratio,
|
||||
f"T{t}_PM{pm_idx}_mem_capacity_limit",
|
||||
)
|
||||
|
||||
def _set_objective(self, problem: LpProblem):
|
||||
|
@ -129,7 +132,7 @@ class VmSchedulingILP():
|
|||
|
||||
problem = LpProblem(
|
||||
name=f"VM_Scheduling_from_tick_{self._env_tick}",
|
||||
sense=LpMaximize
|
||||
sense=LpMaximize,
|
||||
)
|
||||
self._init_variables()
|
||||
self._add_constraints(problem=problem)
|
||||
|
@ -148,8 +151,9 @@ class VmSchedulingILP():
|
|||
if self._mapping[pm_idx][vm_idx].varValue:
|
||||
chosen_pm_idx = pm_idx
|
||||
break
|
||||
self._logger.info(f"Solution tick: {self._env_tick}, vm: {self._future_vm_req[vm_idx].id} -> pm: {chosen_pm_idx}")
|
||||
|
||||
self._logger.info(
|
||||
f"Solution tick: {self._env_tick}, vm: {self._future_vm_req[vm_idx].id} -> pm: {chosen_pm_idx}",
|
||||
)
|
||||
|
||||
def choose_pm(
|
||||
self,
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, SGD
|
||||
from torch.optim import SGD, Adam
|
||||
|
||||
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
|
||||
from maro.rl.policy import DiscretePolicyGradient
|
||||
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
|
||||
|
||||
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
|
||||
|
||||
actor_net_conf = {
|
||||
"hidden_dims": [64, 32, 32],
|
||||
|
@ -39,7 +36,7 @@ class MyActorNet(DiscreteACBasedNet):
|
|||
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)
|
||||
|
||||
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
|
||||
features, masks = states[:, :self._num_features], states[:, self._num_features:]
|
||||
features, masks = states[:, : self._num_features], states[:, self._num_features :]
|
||||
masks += 1e-8 # this is to prevent zero probability and infinite logP.
|
||||
return self._actor(features) * masks
|
||||
|
||||
|
@ -52,7 +49,7 @@ class MyCriticNet(VNet):
|
|||
self._optim = SGD(self._critic.parameters(), lr=critic_learning_rate)
|
||||
|
||||
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
|
||||
features, masks = states[:, :self._num_features], states[:, self._num_features:]
|
||||
features, masks = states[:, : self._num_features], states[:, self._num_features :]
|
||||
masks += 1e-8 # this is to prevent zero probability and infinite logP.
|
||||
return self._critic(features).squeeze(-1)
|
||||
|
||||
|
@ -70,6 +67,6 @@ def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer:
|
|||
grad_iters=100,
|
||||
critic_loss_cls=torch.nn.MSELoss,
|
||||
min_logp=-20,
|
||||
lam=.0,
|
||||
lam=0.0,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -33,8 +33,8 @@ class MyQNet(DiscreteQNet):
|
|||
self._lr_scheduler = CosineAnnealingWarmRestarts(self._optim, **q_net_lr_scheduler_params)
|
||||
|
||||
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
|
||||
masks = states[:, self._num_features:]
|
||||
q_for_all_actions = self._fc(states[:, :self._num_features])
|
||||
masks = states[:, self._num_features :]
|
||||
q_for_all_actions = self._fc(states[:, : self._num_features])
|
||||
return q_for_all_actions + (masks - 1) * 1e8
|
||||
|
||||
|
||||
|
@ -44,11 +44,13 @@ class MaskedEpsGreedy:
|
|||
self._num_features = num_features
|
||||
|
||||
def __call__(self, states, actions, num_actions, *, epsilon):
|
||||
masks = states[:, self._num_features:]
|
||||
return np.array([
|
||||
action if np.random.random() > epsilon else np.random.choice(np.where(mask == 1)[0])
|
||||
for action, mask in zip(actions, masks)
|
||||
])
|
||||
masks = states[:, self._num_features :]
|
||||
return np.array(
|
||||
[
|
||||
action if np.random.random() > epsilon else np.random.choice(np.where(mask == 1)[0])
|
||||
for action, mask in zip(actions, masks)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str) -> ValueBasedPolicy:
|
||||
|
@ -56,14 +58,18 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
|
|||
name=name,
|
||||
q_net=MyQNet(state_dim, action_num, num_features),
|
||||
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
|
||||
exploration_scheduling_options=[(
|
||||
"epsilon", MultiLinearExplorationScheduler, {
|
||||
"splits": [(100, 0.32)],
|
||||
"initial_value": 0.4,
|
||||
"last_ep": 400,
|
||||
"final_value": 0.0,
|
||||
}
|
||||
)],
|
||||
exploration_scheduling_options=[
|
||||
(
|
||||
"epsilon",
|
||||
MultiLinearExplorationScheduler,
|
||||
{
|
||||
"splits": [(100, 0.32)],
|
||||
"initial_value": 0.4,
|
||||
"last_ep": 400,
|
||||
"final_value": 0.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
warmup=100,
|
||||
)
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
from maro.simulator import Env
|
||||
|
||||
|
||||
env_conf = {
|
||||
"scenario": "vm_scheduling",
|
||||
"topology": "azure.2019.10k",
|
||||
|
|
|
@ -15,7 +15,13 @@ from maro.simulator import Env
|
|||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
|
||||
|
||||
from .config import (
|
||||
num_features, pm_attributes, pm_window_size, reward_shaping_conf, seed, test_reward_shaping_conf, test_seed,
|
||||
num_features,
|
||||
pm_attributes,
|
||||
pm_window_size,
|
||||
reward_shaping_conf,
|
||||
seed,
|
||||
test_reward_shaping_conf,
|
||||
test_seed,
|
||||
)
|
||||
|
||||
timestamp = str(time.time())
|
||||
|
@ -31,13 +37,15 @@ class VMEnvSampler(AbsEnvSampler):
|
|||
self._test_env.set_seed(test_seed)
|
||||
|
||||
# adjust the ratio of the success allocation and the total income when computing the reward
|
||||
self.num_pms = self._learn_env.business_engine._pm_amount # the number of pms
|
||||
self.num_pms = self._learn_env.business_engine._pm_amount # the number of pms
|
||||
self._durations = self._learn_env.business_engine._max_tick
|
||||
self._pm_state_history = np.zeros((pm_window_size - 1, self.num_pms, 2))
|
||||
self._legal_pm_mask = None
|
||||
|
||||
def _get_global_and_agent_state_impl(
|
||||
self, event: DecisionPayload, tick: int = None,
|
||||
self,
|
||||
event: DecisionPayload,
|
||||
tick: int = None,
|
||||
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
|
||||
pm_state, vm_state = self._get_pm_state(), self._get_vm_state(event)
|
||||
# get the legal number of PM.
|
||||
|
@ -61,7 +69,9 @@ class VMEnvSampler(AbsEnvSampler):
|
|||
return None, {"AGENT": state}
|
||||
|
||||
def _translate_to_env_action(
|
||||
self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionPayload,
|
||||
self,
|
||||
action_dict: Dict[Any, Union[np.ndarray, List[object]]],
|
||||
event: DecisionPayload,
|
||||
) -> Dict[Any, object]:
|
||||
if action_dict["AGENT"] == self.num_pms:
|
||||
return {"AGENT": PostponeAction(vm_id=event.vm_id, postpone_step=1)}
|
||||
|
@ -71,17 +81,17 @@ class VMEnvSampler(AbsEnvSampler):
|
|||
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionPayload, tick: int) -> Dict[Any, float]:
|
||||
action = env_action_dict["AGENT"]
|
||||
conf = reward_shaping_conf if self._env == self._learn_env else test_reward_shaping_conf
|
||||
if isinstance(action, PostponeAction): # postponement
|
||||
if isinstance(action, PostponeAction): # postponement
|
||||
if np.sum(self._legal_pm_mask) != 1:
|
||||
reward = -0.1 * conf["alpha"] + 0.0 * conf["beta"]
|
||||
else:
|
||||
reward = 0.0 * conf["alpha"] + 0.0 * conf["beta"]
|
||||
else:
|
||||
reward = self._get_allocation_reward(event, conf["alpha"], conf["beta"]) if event else .0
|
||||
reward = self._get_allocation_reward(event, conf["alpha"], conf["beta"]) if event else 0.0
|
||||
return {"AGENT": np.float32(reward)}
|
||||
|
||||
def _get_pm_state(self):
|
||||
total_pm_info = self._env.snapshot_list["pms"][self._env.frame_index::pm_attributes]
|
||||
total_pm_info = self._env.snapshot_list["pms"][self._env.frame_index :: pm_attributes]
|
||||
total_pm_info = total_pm_info.reshape(self.num_pms, len(pm_attributes))
|
||||
|
||||
# normalize the attributes of pms' cpu and memory
|
||||
|
@ -99,21 +109,24 @@ class VMEnvSampler(AbsEnvSampler):
|
|||
|
||||
# get the sequence pms' information
|
||||
self._pm_state_history = np.concatenate((self._pm_state_history, total_pm_info), axis=0)
|
||||
return self._pm_state_history[-pm_window_size:, :, :] # (win_size, num_pms, 2)
|
||||
return self._pm_state_history[-pm_window_size:, :, :] # (win_size, num_pms, 2)
|
||||
|
||||
def _get_vm_state(self, event):
|
||||
return np.array([
|
||||
event.vm_cpu_cores_requirement / self._max_cpu_capacity,
|
||||
event.vm_memory_requirement / self._max_memory_capacity,
|
||||
(self._durations - self._env.tick) * 1.0 / 200, # TODO: CHANGE 200 TO SOMETHING CONFIGURABLE
|
||||
self._env.business_engine._get_unit_price(event.vm_cpu_cores_requirement, event.vm_memory_requirement)
|
||||
])
|
||||
return np.array(
|
||||
[
|
||||
event.vm_cpu_cores_requirement / self._max_cpu_capacity,
|
||||
event.vm_memory_requirement / self._max_memory_capacity,
|
||||
(self._durations - self._env.tick) * 1.0 / 200, # TODO: CHANGE 200 TO SOMETHING CONFIGURABLE
|
||||
self._env.business_engine._get_unit_price(event.vm_cpu_cores_requirement, event.vm_memory_requirement),
|
||||
],
|
||||
)
|
||||
|
||||
def _get_allocation_reward(self, event: DecisionPayload, alpha: float, beta: float):
|
||||
vm_unit_price = self._env.business_engine._get_unit_price(
|
||||
event.vm_cpu_cores_requirement, event.vm_memory_requirement
|
||||
event.vm_cpu_cores_requirement,
|
||||
event.vm_memory_requirement,
|
||||
)
|
||||
return (alpha + beta * vm_unit_price * min(self._durations - event.frame_index, event.remaining_buffer_time))
|
||||
return alpha + beta * vm_unit_price * min(self._durations - event.frame_index, event.remaining_buffer_time)
|
||||
|
||||
def _post_step(self, cache_element: CacheElement) -> None:
|
||||
self._info["env_metric"] = {k: v for k, v in self._env.metrics.items() if k != "total_latency"}
|
||||
|
@ -127,7 +140,9 @@ class VMEnvSampler(AbsEnvSampler):
|
|||
action = cache_element.action_dict["AGENT"]
|
||||
if cache_element.state:
|
||||
mask = cache_element.state[num_features:]
|
||||
self._info["actions_by_core_requirement"][cache_element.event.vm_cpu_cores_requirement].append([action, mask])
|
||||
self._info["actions_by_core_requirement"][cache_element.event.vm_cpu_cores_requirement].append(
|
||||
[action, mask],
|
||||
)
|
||||
self._info["action_sequence"].append(action)
|
||||
|
||||
def _post_eval_step(self, cache_element: CacheElement) -> None:
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from examples.vm_scheduling.rl.algorithms.ac import get_ac_policy
|
||||
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn_policy
|
||||
from examples.vm_scheduling.rl.algorithms.ac import get_ac, get_ac_policy
|
||||
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn, get_dqn_policy
|
||||
from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf
|
||||
from examples.vm_scheduling.rl.env_sampler import VMEnvSampler
|
||||
|
||||
from maro.rl.policy import AbsPolicy
|
||||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
|
||||
from maro.rl.rollout import AbsEnvSampler
|
||||
|
@ -30,14 +31,22 @@ class VMBundle(RLComponentBundle):
|
|||
if algorithm == "ac":
|
||||
policy_creator = {
|
||||
f"{algorithm}.policy": partial(
|
||||
get_ac_policy, state_dim, action_num, num_features, f"{algorithm}.policy",
|
||||
)
|
||||
get_ac_policy,
|
||||
state_dim,
|
||||
action_num,
|
||||
num_features,
|
||||
f"{algorithm}.policy",
|
||||
),
|
||||
}
|
||||
elif algorithm == "dqn":
|
||||
policy_creator = {
|
||||
f"{algorithm}.policy": partial(
|
||||
get_dqn_policy, state_dim, action_num, num_features, f"{algorithm}.policy",
|
||||
)
|
||||
get_dqn_policy,
|
||||
state_dim,
|
||||
action_num,
|
||||
num_features,
|
||||
f"{algorithm}.policy",
|
||||
),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
@ -46,10 +55,8 @@ class VMBundle(RLComponentBundle):
|
|||
|
||||
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
|
||||
if algorithm == "ac":
|
||||
from .algorithms.ac import get_ac, get_ac_policy
|
||||
trainer_creator = {algorithm: partial(get_ac, state_dim, num_features, algorithm)}
|
||||
elif algorithm == "dqn":
|
||||
from .algorithms.dqn import get_dqn, get_dqn_policy
|
||||
trainer_creator = {algorithm: partial(get_dqn, algorithm)}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling.common import Action
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
|
||||
from maro.simulator.scenarios.vm_scheduling.common import Action
|
||||
|
||||
|
||||
class VMSchedulingAgent(object):
|
||||
|
@ -8,15 +8,14 @@ class VMSchedulingAgent(object):
|
|||
self._algorithm = algorithm
|
||||
|
||||
def choose_action(self, decision_event: DecisionPayload, env: Env) -> Action:
|
||||
"""This method will determine whether to postpone the current VM or allocate a PM to the current VM.
|
||||
"""
|
||||
"""This method will determine whether to postpone the current VM or allocate a PM to the current VM."""
|
||||
valid_pm_num: int = len(decision_event.valid_pms)
|
||||
|
||||
if valid_pm_num <= 0:
|
||||
# No valid PM now, postpone.
|
||||
action: PostponeAction = PostponeAction(
|
||||
vm_id=decision_event.vm_id,
|
||||
postpone_step=1
|
||||
postpone_step=1,
|
||||
)
|
||||
else:
|
||||
action: AllocateAction = self._algorithm.allocate_vm(decision_event, env)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import numpy as np
|
||||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
|
||||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
|
||||
class BestFit(RuleBasedAlgorithm):
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -17,7 +16,7 @@ class BestFit(RuleBasedAlgorithm):
|
|||
# Take action to allocate on the chose PM.
|
||||
action: AllocateAction = AllocateAction(
|
||||
vm_id=decision_event.vm_id,
|
||||
pm_id=decision_event.valid_pms[chosen_idx]
|
||||
pm_id=decision_event.valid_pms[chosen_idx],
|
||||
)
|
||||
|
||||
return action
|
||||
|
@ -25,8 +24,12 @@ class BestFit(RuleBasedAlgorithm):
|
|||
def _pick_pm_func(self, decision_event, env) -> int:
|
||||
# Get the capacity and allocated cores from snapshot.
|
||||
valid_pm_info = env.snapshot_list["pms"][
|
||||
env.frame_index:decision_event.valid_pms:[
|
||||
"cpu_cores_capacity", "cpu_cores_allocated", "memory_capacity", "memory_allocated", "energy_consumption"
|
||||
env.frame_index : decision_event.valid_pms : [
|
||||
"cpu_cores_capacity",
|
||||
"cpu_cores_allocated",
|
||||
"memory_capacity",
|
||||
"memory_allocated",
|
||||
"energy_consumption",
|
||||
]
|
||||
].reshape(-1, 5)
|
||||
# Calculate to get the remaining cpu cores.
|
||||
|
@ -37,23 +40,19 @@ class BestFit(RuleBasedAlgorithm):
|
|||
energy_consumption = valid_pm_info[:, 4]
|
||||
# Choose the PM with the preference rule.
|
||||
chosen_idx: int = 0
|
||||
if self._metric_type == 'remaining_cpu_cores':
|
||||
if self._metric_type == "remaining_cpu_cores":
|
||||
chosen_idx = np.argmin(cpu_cores_remaining)
|
||||
elif self._metric_type == 'remaining_memory':
|
||||
elif self._metric_type == "remaining_memory":
|
||||
chosen_idx = np.argmin(memory_remaining)
|
||||
elif self._metric_type == 'energy_consumption':
|
||||
elif self._metric_type == "energy_consumption":
|
||||
chosen_idx = np.argmax(energy_consumption)
|
||||
elif self._metric_type == 'remaining_cpu_cores_and_energy_consumption':
|
||||
elif self._metric_type == "remaining_cpu_cores_and_energy_consumption":
|
||||
maximum_energy_consumption = energy_consumption[0]
|
||||
minimum_remaining_cpu_cores = cpu_cores_remaining[0]
|
||||
for i, remaining in enumerate(cpu_cores_remaining):
|
||||
energy = energy_consumption[i]
|
||||
if (
|
||||
remaining < minimum_remaining_cpu_cores
|
||||
or (
|
||||
remaining == minimum_remaining_cpu_cores
|
||||
and energy > maximum_energy_consumption
|
||||
)
|
||||
if remaining < minimum_remaining_cpu_cores or (
|
||||
remaining == minimum_remaining_cpu_cores and energy > maximum_energy_consumption
|
||||
):
|
||||
chosen_idx = i
|
||||
minimum_remaining_cpu_cores = remaining
|
||||
|
|
|
@ -1,21 +1,19 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
from yaml import safe_load
|
||||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
from maro.utils.utils import convert_dottable
|
||||
|
||||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
|
||||
class BinPacking(RuleBasedAlgorithm):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self._max_cpu_oversubscription_rate: float = kwargs["env"].configs.MAX_CPU_OVERSUBSCRIPTION_RATE
|
||||
total_pm_cpu_info = kwargs["env"].snapshot_list["pms"][
|
||||
kwargs["env"].frame_index::["cpu_cores_capacity"]
|
||||
].reshape(-1)
|
||||
total_pm_cpu_info = (
|
||||
kwargs["env"].snapshot_list["pms"][kwargs["env"].frame_index :: ["cpu_cores_capacity"]].reshape(-1)
|
||||
)
|
||||
self._pm_num: int = total_pm_cpu_info.shape[0]
|
||||
self._pm_cpu_core_num: int = int(np.max(total_pm_cpu_info) * self._max_cpu_oversubscription_rate)
|
||||
|
||||
|
@ -29,7 +27,7 @@ class BinPacking(RuleBasedAlgorithm):
|
|||
|
||||
# Get the number of PM, maximum CPU core and max cpu oversubscription rate.
|
||||
total_pm_info = env.snapshot_list["pms"][
|
||||
env.frame_index::["cpu_cores_capacity", "cpu_cores_allocated"]
|
||||
env.frame_index :: ["cpu_cores_capacity", "cpu_cores_allocated"]
|
||||
].reshape(-1, 2)
|
||||
|
||||
cpu_cores_remaining = total_pm_info[:, 0] * self._max_cpu_oversubscription_rate - total_pm_info[:, 1]
|
||||
|
@ -57,7 +55,7 @@ class BinPacking(RuleBasedAlgorithm):
|
|||
# Take action to allocate on the chosen pm.
|
||||
action: AllocateAction = AllocateAction(
|
||||
vm_id=decision_event.vm_id,
|
||||
pm_id=chosen_idx
|
||||
pm_id=chosen_idx,
|
||||
)
|
||||
|
||||
return action
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
|
||||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
|
||||
class FirstFit(RuleBasedAlgorithm):
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -14,7 +14,7 @@ class FirstFit(RuleBasedAlgorithm):
|
|||
# Take action to allocate on the chose PM.
|
||||
action: AllocateAction = AllocateAction(
|
||||
vm_id=decision_event.vm_id,
|
||||
pm_id=chosen_idx
|
||||
pm_id=chosen_idx,
|
||||
)
|
||||
|
||||
return action
|
||||
|
|
|
@ -1,19 +1,18 @@
|
|||
import importlib
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import timeit
|
||||
|
||||
import yaml
|
||||
import importlib
|
||||
from agent import VMSchedulingAgent
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.utils import convert_dottable
|
||||
|
||||
from agent import VMSchedulingAgent
|
||||
|
||||
|
||||
def import_class(name):
|
||||
components = name.rsplit('.', 1)
|
||||
components = name.rsplit(".", 1)
|
||||
mod = importlib.import_module(components[0])
|
||||
mod = getattr(mod, components[1])
|
||||
return mod
|
||||
|
@ -33,7 +32,7 @@ if __name__ == "__main__":
|
|||
topology=config.env.topology,
|
||||
start_tick=config.env.start_tick,
|
||||
durations=config.env.durations,
|
||||
snapshot_resolution=config.env.resolution
|
||||
snapshot_resolution=config.env.resolution,
|
||||
)
|
||||
|
||||
if config.env.seed is not None:
|
||||
|
@ -57,7 +56,7 @@ if __name__ == "__main__":
|
|||
end_time = timeit.default_timer()
|
||||
print(
|
||||
f"[{config.algorithm.type.split('.')[1]}] Topology: {config.env.topology}. Total ticks: {config.env.durations}."
|
||||
f" Start tick: {config.env.start_tick}."
|
||||
f" Start tick: {config.env.start_tick}.",
|
||||
)
|
||||
print(f"[Timer] {end_time - start_time:.2f} seconds to finish the simulation.")
|
||||
print(metrics)
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
import random
|
||||
|
||||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
|
||||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
|
||||
class RandomPick(RuleBasedAlgorithm):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
valid_pm_num: int = len(decision_event.valid_pms)
|
||||
# Random choose a valid PM.
|
||||
|
@ -17,7 +17,7 @@ class RandomPick(RuleBasedAlgorithm):
|
|||
# Take action to allocate on the chosen PM.
|
||||
action: AllocateAction = AllocateAction(
|
||||
vm_id=decision_event.vm_id,
|
||||
pm_id=decision_event.valid_pms[chosen_idx]
|
||||
pm_id=decision_event.valid_pms[chosen_idx],
|
||||
)
|
||||
|
||||
return action
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
from maro.simulator import Env
|
||||
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
|
||||
|
||||
from rule_based_algorithm import RuleBasedAlgorithm
|
||||
|
||||
|
||||
class RoundRobin(RuleBasedAlgorithm):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self._prev_idx: int = 0
|
||||
self._pm_num: int = kwargs["env"].snapshot_list["pms"][kwargs["env"].frame_index::["cpu_cores_capacity"]].shape[0]
|
||||
self._pm_num: int = (
|
||||
kwargs["env"].snapshot_list["pms"][kwargs["env"].frame_index :: ["cpu_cores_capacity"]].shape[0]
|
||||
)
|
||||
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
# Choose the valid PM which index is next to the previous chose PM's index
|
||||
|
@ -21,7 +23,7 @@ class RoundRobin(RuleBasedAlgorithm):
|
|||
# Take action to allocate on the chosen PM.
|
||||
action: AllocateAction = AllocateAction(
|
||||
vm_id=decision_event.vm_id,
|
||||
pm_id=chosen_idx
|
||||
pm_id=chosen_idx,
|
||||
)
|
||||
|
||||
return action
|
||||
|
|
|
@ -9,6 +9,5 @@ class RuleBasedAlgorithm(object):
|
|||
|
||||
@abc.abstractmethod
|
||||
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
|
||||
"""This method will determine allocate which PM to the current VM.
|
||||
"""
|
||||
pass
|
||||
"""This method will determine allocate which PM to the current VM."""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# isort: skip_file
|
||||
|
||||
from .__misc__ import __data_version__, __version__
|
||||
|
||||
from maro.utils.utils import check_deployment_status, deploy
|
||||
|
|
|
@ -6,21 +6,37 @@
|
|||
#distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
from cpython cimport bool
|
||||
from cython cimport view
|
||||
from cython.operator cimport dereference as deref
|
||||
|
||||
from cpython cimport bool
|
||||
from libcpp cimport bool as cppbool
|
||||
from libcpp.map cimport map
|
||||
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX,
|
||||
ATTR_CHAR, ATTR_UCHAR, ATTR_SHORT, ATTR_USHORT, ATTR_INT, ATTR_UINT,
|
||||
ATTR_LONG, ATTR_ULONG, ATTR_FLOAT, ATTR_DOUBLE)
|
||||
|
||||
from maro.backends.backend cimport (
|
||||
ATTR_CHAR,
|
||||
ATTR_DOUBLE,
|
||||
ATTR_FLOAT,
|
||||
ATTR_INT,
|
||||
ATTR_LONG,
|
||||
ATTR_SHORT,
|
||||
ATTR_TYPE,
|
||||
ATTR_UCHAR,
|
||||
ATTR_UINT,
|
||||
ATTR_ULONG,
|
||||
ATTR_USHORT,
|
||||
INT,
|
||||
NODE_INDEX,
|
||||
NODE_TYPE,
|
||||
SLOT_INDEX,
|
||||
UINT,
|
||||
ULONG,
|
||||
AttributeType,
|
||||
BackendAbc,
|
||||
SnapshotListAbc,
|
||||
)
|
||||
|
||||
# Ensure numpy will not crash, as we use numpy as query result
|
||||
np.import_array()
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
#distutils: language = c++
|
||||
|
||||
from cpython cimport bool
|
||||
|
||||
from libc.stdint cimport int32_t, int64_t, int16_t, int8_t, uint32_t, uint64_t
|
||||
from libc.stdint cimport int8_t, int16_t, int32_t, int64_t, uint32_t, uint64_t
|
||||
|
||||
# common types
|
||||
|
||||
|
|
|
@ -5,8 +5,10 @@
|
|||
#distutils: language = c++
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from cpython cimport bool
|
||||
|
||||
|
||||
cdef class AttributeType:
|
||||
Byte = b"byte"
|
||||
UByte = b"ubyte"
|
||||
|
|
|
@ -6,8 +6,19 @@
|
|||
|
||||
from cpython cimport bool
|
||||
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, USHORT, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
|
||||
from maro.backends.backend cimport (
|
||||
ATTR_TYPE,
|
||||
INT,
|
||||
NODE_INDEX,
|
||||
NODE_TYPE,
|
||||
SLOT_INDEX,
|
||||
UINT,
|
||||
ULONG,
|
||||
USHORT,
|
||||
AttributeType,
|
||||
BackendAbc,
|
||||
SnapshotListAbc,
|
||||
)
|
||||
|
||||
|
||||
cdef class SnapshotList:
|
||||
|
|
|
@ -8,31 +8,42 @@
|
|||
import os
|
||||
|
||||
cimport cython
|
||||
|
||||
cimport numpy as np
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cpython cimport bool
|
||||
|
||||
from typing import Union
|
||||
|
||||
from maro.backends.backend cimport (
|
||||
ATTR_TYPE,
|
||||
INT,
|
||||
NODE_INDEX,
|
||||
NODE_TYPE,
|
||||
SLOT_INDEX,
|
||||
UINT,
|
||||
ULONG,
|
||||
USHORT,
|
||||
AttributeType,
|
||||
BackendAbc,
|
||||
SnapshotListAbc,
|
||||
)
|
||||
from maro.backends.np_backend cimport NumpyBackend
|
||||
from maro.backends.raw_backend cimport RawBackend
|
||||
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, UINT, ULONG, USHORT, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
|
||||
|
||||
from maro.utils.exception.backends_exception import (
|
||||
BackendsGetItemInvalidException,
|
||||
BackendsSetItemInvalidException,
|
||||
BackendsArrayAttributeAccessException,
|
||||
BackendsAppendToNonListAttributeException,
|
||||
BackendsResizeNonListAttributeException,
|
||||
BackendsClearNonListAttributeException,
|
||||
BackendsInsertNonListAttributeException,
|
||||
BackendsRemoveFromNonListAttributeException,
|
||||
BackendsAccessDeletedNodeException,
|
||||
BackendsAppendToNonListAttributeException,
|
||||
BackendsArrayAttributeAccessException,
|
||||
BackendsClearNonListAttributeException,
|
||||
BackendsGetItemInvalidException,
|
||||
BackendsInsertNonListAttributeException,
|
||||
BackendsInvalidAttributeException,
|
||||
BackendsInvalidNodeException,
|
||||
BackendsInvalidAttributeException
|
||||
BackendsRemoveFromNonListAttributeException,
|
||||
BackendsResizeNonListAttributeException,
|
||||
BackendsSetItemInvalidException,
|
||||
)
|
||||
|
||||
# Old type definition mapping.
|
||||
|
|
|
@ -5,11 +5,21 @@
|
|||
#distutils: language = c++
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
from cpython cimport bool
|
||||
from maro.backends.backend cimport BackendAbc, SnapshotListAbc, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX
|
||||
|
||||
from maro.backends.backend cimport (
|
||||
ATTR_TYPE,
|
||||
NODE_INDEX,
|
||||
NODE_TYPE,
|
||||
SLOT_INDEX,
|
||||
UINT,
|
||||
ULONG,
|
||||
BackendAbc,
|
||||
SnapshotListAbc,
|
||||
)
|
||||
|
||||
|
||||
cdef class NumpyBackend(BackendAbc):
|
||||
|
|
|
@ -8,13 +8,24 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
cimport cython
|
||||
|
||||
cimport numpy as np
|
||||
from cpython cimport bool
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, UINT, ULONG, USHORT, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
|
||||
|
||||
from maro.backends.backend cimport (
|
||||
ATTR_TYPE,
|
||||
INT,
|
||||
NODE_INDEX,
|
||||
NODE_TYPE,
|
||||
SLOT_INDEX,
|
||||
UINT,
|
||||
ULONG,
|
||||
USHORT,
|
||||
AttributeType,
|
||||
BackendAbc,
|
||||
SnapshotListAbc,
|
||||
)
|
||||
|
||||
# Attribute data type mapping.
|
||||
attribute_type_mapping = {
|
||||
|
@ -44,11 +55,10 @@ attribute_type_range = {
|
|||
|
||||
IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
|
||||
# with this flag, we will allocate a big enough memory for all node types, then use this block construct numpy array
|
||||
from cpython cimport Py_INCREF, PyObject, PyTypeObject
|
||||
from cpython.mem cimport PyMem_Free, PyMem_Malloc
|
||||
from libc.string cimport memset
|
||||
|
||||
from cpython cimport PyObject, Py_INCREF, PyTypeObject
|
||||
from cpython.mem cimport PyMem_Malloc, PyMem_Free
|
||||
|
||||
# we need this to avoid seg fault
|
||||
np.import_array()
|
||||
|
||||
|
|
|
@ -5,14 +5,29 @@
|
|||
#distutils: language = c++
|
||||
|
||||
cimport cython
|
||||
|
||||
from cpython cimport bool
|
||||
from libcpp cimport bool as cppbool
|
||||
from libcpp.string cimport string
|
||||
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, INT, UINT, ULONG, USHORT, NODE_INDEX, SLOT_INDEX,
|
||||
ATTR_CHAR, ATTR_SHORT, ATTR_INT, ATTR_LONG, ATTR_FLOAT, ATTR_DOUBLE, QUERY_FLOAT, ATTR_TYPE, NODE_TYPE)
|
||||
|
||||
from maro.backends.backend cimport (
|
||||
ATTR_CHAR,
|
||||
ATTR_DOUBLE,
|
||||
ATTR_FLOAT,
|
||||
ATTR_INT,
|
||||
ATTR_LONG,
|
||||
ATTR_SHORT,
|
||||
ATTR_TYPE,
|
||||
INT,
|
||||
NODE_INDEX,
|
||||
NODE_TYPE,
|
||||
QUERY_FLOAT,
|
||||
SLOT_INDEX,
|
||||
UINT,
|
||||
ULONG,
|
||||
USHORT,
|
||||
BackendAbc,
|
||||
SnapshotListAbc,
|
||||
)
|
||||
|
||||
|
||||
cdef extern from "raw/common.h" namespace "maro::backends::raw":
|
||||
|
|
|
@ -8,21 +8,37 @@
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
from cpython cimport bool
|
||||
from cython cimport view
|
||||
from cython.operator cimport dereference as deref
|
||||
|
||||
from cpython cimport bool
|
||||
from libcpp cimport bool as cppbool
|
||||
from libcpp.map cimport map
|
||||
|
||||
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
|
||||
INT, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX,
|
||||
ATTR_CHAR, ATTR_UCHAR, ATTR_SHORT, ATTR_USHORT, ATTR_INT, ATTR_UINT,
|
||||
ATTR_LONG, ATTR_ULONG, ATTR_FLOAT, ATTR_DOUBLE)
|
||||
|
||||
from maro.backends.backend cimport (
|
||||
ATTR_CHAR,
|
||||
ATTR_DOUBLE,
|
||||
ATTR_FLOAT,
|
||||
ATTR_INT,
|
||||
ATTR_LONG,
|
||||
ATTR_SHORT,
|
||||
ATTR_TYPE,
|
||||
ATTR_UCHAR,
|
||||
ATTR_UINT,
|
||||
ATTR_ULONG,
|
||||
ATTR_USHORT,
|
||||
INT,
|
||||
NODE_INDEX,
|
||||
NODE_TYPE,
|
||||
SLOT_INDEX,
|
||||
UINT,
|
||||
ULONG,
|
||||
AttributeType,
|
||||
BackendAbc,
|
||||
SnapshotListAbc,
|
||||
)
|
||||
|
||||
# Ensure numpy will not crash, as we use numpy as query result
|
||||
np.import_array()
|
||||
|
|
|
@ -83,9 +83,8 @@ class CitiBikePipeline(DataPipeline):
|
|||
with zipfile.ZipFile(self._download_file, "r") as zip_ref:
|
||||
for filename in zip_ref.namelist():
|
||||
# Only one csv file is expected.
|
||||
if (
|
||||
filename.endswith(".csv") and
|
||||
(not (filename.startswith("__MACOSX") or filename.startswith(".")))
|
||||
if filename.endswith(".csv") and (
|
||||
not (filename.startswith("__MACOSX") or filename.startswith("."))
|
||||
):
|
||||
|
||||
logger.info_green(f"Unzip {filename} from {self._download_file}.")
|
||||
|
@ -106,10 +105,13 @@ class CitiBikePipeline(DataPipeline):
|
|||
with open(self._station_info_file, mode="r", encoding="utf-8") as station_file:
|
||||
# read station to station file
|
||||
raw_station_data = pd.DataFrame.from_dict(pd.read_json(station_file)["data"]["stations"])
|
||||
station_data = raw_station_data.rename(columns={
|
||||
"lon": "station_longitude",
|
||||
"lat": "station_latitude",
|
||||
"region_id": "region"})
|
||||
station_data = raw_station_data.rename(
|
||||
columns={
|
||||
"lon": "station_longitude",
|
||||
"lat": "station_latitude",
|
||||
"region_id": "region",
|
||||
},
|
||||
)
|
||||
|
||||
# group by station to generate station init info
|
||||
full_stations = station_data[
|
||||
|
@ -124,7 +126,8 @@ class CitiBikePipeline(DataPipeline):
|
|||
full_stations["station_latitude"] = pd.to_numeric(full_stations["station_latitude"], downcast="float")
|
||||
full_stations.drop(full_stations[full_stations["capacity"] == 0].index, axis=0, inplace=True)
|
||||
full_stations.dropna(
|
||||
subset=["station_id", "capacity", "station_longitude", "station_latitude"], inplace=True
|
||||
subset=["station_id", "capacity", "station_longitude", "station_latitude"],
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
self._common_data["full_stations"] = full_stations
|
||||
|
@ -141,13 +144,24 @@ class CitiBikePipeline(DataPipeline):
|
|||
with open(file, "r", encoding="utf-8", errors="ignore") as fp:
|
||||
|
||||
ret = pd.read_csv(fp)
|
||||
ret = ret[[
|
||||
"tripduration", "starttime", "start station id", "end station id", "start station latitude",
|
||||
"start station longitude", "end station latitude", "end station longitude", "gender", "usertype",
|
||||
"bikeid"
|
||||
]]
|
||||
ret = ret[
|
||||
[
|
||||
"tripduration",
|
||||
"starttime",
|
||||
"start station id",
|
||||
"end station id",
|
||||
"start station latitude",
|
||||
"start station longitude",
|
||||
"end station latitude",
|
||||
"end station longitude",
|
||||
"gender",
|
||||
"usertype",
|
||||
"bikeid",
|
||||
]
|
||||
]
|
||||
ret["tripduration"] = pd.to_numeric(
|
||||
pd.to_numeric(ret["tripduration"], downcast="integer") / 60, downcast="integer"
|
||||
pd.to_numeric(ret["tripduration"], downcast="integer") / 60,
|
||||
downcast="integer",
|
||||
)
|
||||
ret["starttime"] = pd.to_datetime(ret["starttime"])
|
||||
ret["start station id"] = pd.to_numeric(ret["start station id"], errors="coerce", downcast="integer")
|
||||
|
@ -158,23 +172,34 @@ class CitiBikePipeline(DataPipeline):
|
|||
ret["end station longitude"] = pd.to_numeric(ret["end station longitude"], downcast="float")
|
||||
ret["bikeid"] = pd.to_numeric(ret["bikeid"], errors="coerce", downcast="integer")
|
||||
ret["gender"] = pd.to_numeric(ret["gender"], errors="coerce", downcast="integer")
|
||||
ret["usertype"] = ret["usertype"].apply(str).apply(
|
||||
lambda x: 0 if x in ["Subscriber", "subscriber"] else 1 if x in ["Customer", "customer"] else 2
|
||||
ret["usertype"] = (
|
||||
ret["usertype"]
|
||||
.apply(str)
|
||||
.apply(
|
||||
lambda x: 0 if x in ["Subscriber", "subscriber"] else 1 if x in ["Customer", "customer"] else 2,
|
||||
)
|
||||
)
|
||||
ret.dropna(
|
||||
subset=[
|
||||
"start station id",
|
||||
"end station id",
|
||||
"start station latitude",
|
||||
"end station latitude",
|
||||
"start station longitude",
|
||||
"end station longitude",
|
||||
],
|
||||
inplace=True,
|
||||
)
|
||||
ret.dropna(subset=[
|
||||
"start station id", "end station id", "start station latitude", "end station latitude",
|
||||
"start station longitude", "end station longitude"
|
||||
], inplace=True)
|
||||
ret.drop(
|
||||
ret[
|
||||
(ret["tripduration"] <= 1) |
|
||||
(ret["start station latitude"] == 0) |
|
||||
(ret["start station longitude"] == 0) |
|
||||
(ret["end station latitude"] == 0) |
|
||||
(ret["end station longitude"] == 0)
|
||||
(ret["tripduration"] <= 1)
|
||||
| (ret["start station latitude"] == 0)
|
||||
| (ret["start station longitude"] == 0)
|
||||
| (ret["end station latitude"] == 0)
|
||||
| (ret["end station longitude"] == 0)
|
||||
].index,
|
||||
axis=0,
|
||||
inplace=True
|
||||
inplace=True,
|
||||
)
|
||||
ret = ret.sort_values(by="starttime", ascending=True)
|
||||
|
||||
|
@ -184,16 +209,16 @@ class CitiBikePipeline(DataPipeline):
|
|||
used_bikes = len(src_data[["bikeid"]].drop_duplicates(subset=["bikeid"]))
|
||||
|
||||
trip_data = src_data[
|
||||
(src_data["start station latitude"] > 40.689960) &
|
||||
(src_data["start station latitude"] < 40.768334) &
|
||||
(src_data["start station longitude"] > -74.019623) &
|
||||
(src_data["start station longitude"] < -73.909760)
|
||||
(src_data["start station latitude"] > 40.689960)
|
||||
& (src_data["start station latitude"] < 40.768334)
|
||||
& (src_data["start station longitude"] > -74.019623)
|
||||
& (src_data["start station longitude"] < -73.909760)
|
||||
]
|
||||
trip_data = trip_data[
|
||||
(trip_data["end station latitude"] > 40.689960) &
|
||||
(trip_data["end station latitude"] < 40.768334) &
|
||||
(trip_data["end station longitude"] > -74.019623) &
|
||||
(trip_data["end station longitude"] < -73.909760)
|
||||
(trip_data["end station latitude"] > 40.689960)
|
||||
& (trip_data["end station latitude"] < 40.768334)
|
||||
& (trip_data["end station longitude"] > -74.019623)
|
||||
& (trip_data["end station longitude"] < -73.909760)
|
||||
]
|
||||
|
||||
trip_data["start_station_id"] = trip_data["start station id"]
|
||||
|
@ -202,25 +227,40 @@ class CitiBikePipeline(DataPipeline):
|
|||
# get new stations
|
||||
used_stations = []
|
||||
used_stations.append(
|
||||
trip_data[["start_station_id", "start station latitude", "start station longitude", ]].drop_duplicates(
|
||||
subset=["start_station_id"]).rename(
|
||||
columns={
|
||||
"start_station_id": "station_id",
|
||||
"start station latitude": "latitude",
|
||||
"start station longitude": "longitude"
|
||||
}))
|
||||
trip_data[["start_station_id", "start station latitude", "start station longitude"]]
|
||||
.drop_duplicates(
|
||||
subset=["start_station_id"],
|
||||
)
|
||||
.rename(
|
||||
columns={
|
||||
"start_station_id": "station_id",
|
||||
"start station latitude": "latitude",
|
||||
"start station longitude": "longitude",
|
||||
},
|
||||
),
|
||||
)
|
||||
used_stations.append(
|
||||
trip_data[["end_station_id", "end station latitude", "end station longitude", ]].drop_duplicates(
|
||||
subset=["end_station_id"]).rename(
|
||||
columns={
|
||||
"end_station_id": "station_id",
|
||||
"end station latitude": "latitude",
|
||||
"end station longitude": "longitude"
|
||||
}))
|
||||
trip_data[["end_station_id", "end station latitude", "end station longitude"]]
|
||||
.drop_duplicates(
|
||||
subset=["end_station_id"],
|
||||
)
|
||||
.rename(
|
||||
columns={
|
||||
"end_station_id": "station_id",
|
||||
"end station latitude": "latitude",
|
||||
"end station longitude": "longitude",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
in_data_station = pd.concat(used_stations, ignore_index=True).drop_duplicates(
|
||||
subset=["station_id"]
|
||||
).sort_values(by=["station_id"]).reset_index(drop=True)
|
||||
in_data_station = (
|
||||
pd.concat(used_stations, ignore_index=True)
|
||||
.drop_duplicates(
|
||||
subset=["station_id"],
|
||||
)
|
||||
.sort_values(by=["station_id"])
|
||||
.reset_index(drop=True)
|
||||
)
|
||||
|
||||
stations_existed = pd.DataFrame(in_data_station[["station_id"]])
|
||||
|
||||
|
@ -229,11 +269,11 @@ class CitiBikePipeline(DataPipeline):
|
|||
# get start station id and end station id
|
||||
trip_data = trip_data.join(
|
||||
stations_existed.set_index("station_id"),
|
||||
on="start_station_id"
|
||||
on="start_station_id",
|
||||
).rename(columns={"station_index": "start_station_index"})
|
||||
trip_data = trip_data.join(
|
||||
stations_existed.set_index("station_id"),
|
||||
on="end_station_id"
|
||||
on="end_station_id",
|
||||
).rename(columns={"station_index": "end_station_index"})
|
||||
trip_data = trip_data.rename(columns={"starttime": "start_time", "tripduration": "duration"})
|
||||
|
||||
|
@ -244,13 +284,17 @@ class CitiBikePipeline(DataPipeline):
|
|||
return trip_data, used_bikes, in_data_station, stations_existed
|
||||
|
||||
def _process_current_topo_station_info(
|
||||
self, stations_existed: pd.DataFrame, used_bikes: int, loc_ref: pd.DataFrame):
|
||||
self,
|
||||
stations_existed: pd.DataFrame,
|
||||
used_bikes: int,
|
||||
loc_ref: pd.DataFrame,
|
||||
):
|
||||
data_station_init = stations_existed.join(
|
||||
self._common_data["full_stations"][["station_id", "capacity"]].set_index("station_id"),
|
||||
on="station_id"
|
||||
on="station_id",
|
||||
).join(
|
||||
loc_ref[["station_id", "latitude", "longitude"]].set_index("station_id"),
|
||||
on="station_id"
|
||||
on="station_id",
|
||||
)
|
||||
# data_station_init.rename(columns={"station_id": "station_index"}, inplace=True)
|
||||
avg_capacity = int(self._common_data["full_dock_num"] / self._common_data["full_station_num"])
|
||||
|
@ -259,22 +303,36 @@ class CitiBikePipeline(DataPipeline):
|
|||
data_station_init.fillna(value=values, inplace=True)
|
||||
data_station_init["init"] = (data_station_init["capacity"] * avalible_bike_rate).round().apply(int)
|
||||
data_station_init["capacity"] = pd.to_numeric(
|
||||
data_station_init["capacity"], errors="coerce", downcast="integer"
|
||||
data_station_init["capacity"],
|
||||
errors="coerce",
|
||||
downcast="integer",
|
||||
)
|
||||
data_station_init["station_id"] = pd.to_numeric(
|
||||
data_station_init["station_id"], errors="coerce", downcast="integer"
|
||||
data_station_init["station_id"],
|
||||
errors="coerce",
|
||||
downcast="integer",
|
||||
)
|
||||
|
||||
return data_station_init
|
||||
|
||||
def _process_distance(self, station_info: pd.DataFrame):
|
||||
distance_adj = pd.DataFrame(0, index=station_info["station_index"],
|
||||
columns=station_info["station_index"], dtype=np.float)
|
||||
distance_adj = pd.DataFrame(
|
||||
0,
|
||||
index=station_info["station_index"],
|
||||
columns=station_info["station_index"],
|
||||
dtype=np.float,
|
||||
)
|
||||
look_up_df = station_info[["latitude", "longitude"]]
|
||||
return distance_adj.apply(lambda x: pd.DataFrame(x).apply(lambda y: geopy.distance.distance(
|
||||
(look_up_df.at[x.name, "latitude"], look_up_df.at[x.name, "longitude"]),
|
||||
(look_up_df.at[y.name, "latitude"], look_up_df.at[y.name, "longitude"])
|
||||
).km, axis=1), axis=1)
|
||||
return distance_adj.apply(
|
||||
lambda x: pd.DataFrame(x).apply(
|
||||
lambda y: geopy.distance.distance(
|
||||
(look_up_df.at[x.name, "latitude"], look_up_df.at[x.name, "longitude"]),
|
||||
(look_up_df.at[y.name, "latitude"], look_up_df.at[y.name, "longitude"]),
|
||||
).km,
|
||||
axis=1,
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
def _preprocess(self, unzipped_file: str):
|
||||
self._read_common_data()
|
||||
|
@ -292,7 +350,9 @@ class CitiBikePipeline(DataPipeline):
|
|||
|
||||
logger.info_green("Processing station info data.")
|
||||
station_info = self._process_current_topo_station_info(
|
||||
stations_existed=stations_existed, used_bikes=used_bikes, loc_ref=in_data_station
|
||||
stations_existed=stations_existed,
|
||||
used_bikes=used_bikes,
|
||||
loc_ref=in_data_station,
|
||||
)
|
||||
with open(self._station_meta_file, mode="w", encoding="utf-8", newline="") as f:
|
||||
station_info.to_csv(f, index=False, header=True)
|
||||
|
@ -415,7 +475,13 @@ class CitiBikeTopology(DataTopology):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, topology: str, trip_source: str, station_info: str, weather_source: str, is_temp: bool = False):
|
||||
self,
|
||||
topology: str,
|
||||
trip_source: str,
|
||||
station_info: str,
|
||||
weather_source: str,
|
||||
is_temp: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self._data_pipeline["trip"] = CitiBikePipeline(topology, trip_source, station_info, is_temp)
|
||||
self._data_pipeline["weather"] = NOAAWeatherPipeline(topology, weather_source, is_temp)
|
||||
|
@ -453,7 +519,14 @@ class CitiBikeToyPipeline(DataPipeline):
|
|||
_meta_file_name = "trips.yml"
|
||||
|
||||
def __init__(
|
||||
self, start_time: str, end_time: str, stations: list, trips: list, topology: str, is_temp: bool = False):
|
||||
self,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
stations: list,
|
||||
trips: list,
|
||||
topology: str,
|
||||
is_temp: bool = False,
|
||||
):
|
||||
super().__init__("citi_bike", topology, "", is_temp)
|
||||
self._start_time = start_time
|
||||
self._end_time = end_time
|
||||
|
@ -465,7 +538,6 @@ class CitiBikeToyPipeline(DataPipeline):
|
|||
|
||||
def download(self, is_force: bool):
|
||||
"""Toy datapipeline not need download process."""
|
||||
pass
|
||||
|
||||
def _station_dict_to_pd(self, station_dict):
|
||||
"""Convert dictionary of station information to pd series."""
|
||||
|
@ -477,7 +549,8 @@ class CitiBikeToyPipeline(DataPipeline):
|
|||
station_dict["lat"],
|
||||
station_dict["lon"],
|
||||
],
|
||||
index=["station_index", "capacity", "init", "latitude", "longitude"])
|
||||
index=["station_index", "capacity", "init", "latitude", "longitude"],
|
||||
)
|
||||
|
||||
def _gen_stations(self):
|
||||
"""Generate station meta csv."""
|
||||
|
@ -523,10 +596,14 @@ class CitiBikeToyPipeline(DataPipeline):
|
|||
trips_df = pd.DataFrame.from_dict(trips)
|
||||
|
||||
trips_df["start_station_index"] = pd.to_numeric(
|
||||
trips_df["start_station_index"], errors="coerce", downcast="integer"
|
||||
trips_df["start_station_index"],
|
||||
errors="coerce",
|
||||
downcast="integer",
|
||||
)
|
||||
trips_df["end_station_index"] = pd.to_numeric(
|
||||
trips_df["end_station_index"], errors="coerce", downcast="integer"
|
||||
trips_df["end_station_index"],
|
||||
errors="coerce",
|
||||
downcast="integer",
|
||||
)
|
||||
self._new_file_list.append(self._clean_file)
|
||||
with open(self._clean_file, "w", encoding="utf-8", newline="") as f:
|
||||
|
@ -540,13 +617,19 @@ class CitiBikeToyPipeline(DataPipeline):
|
|||
0,
|
||||
index=station_init["station_index"],
|
||||
columns=station_init["station_index"],
|
||||
dtype=np.float
|
||||
dtype=np.float,
|
||||
)
|
||||
look_up_df = station_init[["latitude", "longitude"]]
|
||||
distance_df = distance_adj.apply(lambda x: pd.DataFrame(x).apply(lambda y: geopy.distance.distance(
|
||||
(look_up_df.at[x.name, "latitude"], look_up_df.at[x.name, "longitude"]),
|
||||
(look_up_df.at[y.name, "latitude"], look_up_df.at[y.name, "longitude"])
|
||||
).km, axis=1), axis=1)
|
||||
distance_df = distance_adj.apply(
|
||||
lambda x: pd.DataFrame(x).apply(
|
||||
lambda y: geopy.distance.distance(
|
||||
(look_up_df.at[x.name, "latitude"], look_up_df.at[x.name, "longitude"]),
|
||||
(look_up_df.at[y.name, "latitude"], look_up_df.at[y.name, "longitude"]),
|
||||
).km,
|
||||
axis=1,
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
self._new_file_list.append(self._distance_file)
|
||||
with open(self._distance_file, "w", encoding="utf-8", newline="") as f:
|
||||
distance_df.to_csv(f, index=False, header=True)
|
||||
|
@ -589,7 +672,6 @@ class WeatherToyPipeline(WeatherPipeline):
|
|||
|
||||
def download(self, is_force: bool):
|
||||
"""Toy datapipeline not need download process."""
|
||||
pass
|
||||
|
||||
def clean(self):
|
||||
"""Clean the original data file."""
|
||||
|
@ -660,13 +742,13 @@ class CitiBikeToyTopology(DataTopology):
|
|||
stations=cfg["stations"],
|
||||
trips=cfg["trips"],
|
||||
topology=topology,
|
||||
is_temp=is_temp
|
||||
is_temp=is_temp,
|
||||
)
|
||||
self._data_pipeline["weather"] = WeatherToyPipeline(
|
||||
topology=topology,
|
||||
start_time=cfg["start_time"],
|
||||
end_time=cfg["end_time"],
|
||||
is_temp=is_temp
|
||||
is_temp=is_temp,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Config file {config_path} for toy topology {topology} not found.")
|
||||
|
@ -701,7 +783,7 @@ class CitiBikeProcess:
|
|||
self.topologies[topology] = CitiBikeToyTopology(
|
||||
topology=topology,
|
||||
config_path=self._conf["trips"][topology]["toy_meta_path"],
|
||||
is_temp=is_temp
|
||||
is_temp=is_temp,
|
||||
)
|
||||
else:
|
||||
self.topologies[topology] = CitiBikeTopology(
|
||||
|
@ -709,7 +791,7 @@ class CitiBikeProcess:
|
|||
trip_source=self._conf["trips"][topology]["trip_remote_url"],
|
||||
station_info=self._conf["station_info"]["ny_station_info_url"],
|
||||
weather_source=self._conf["weather"][topology]["noaa_weather_url"],
|
||||
is_temp=is_temp
|
||||
is_temp=is_temp,
|
||||
)
|
||||
|
||||
|
||||
|
@ -782,8 +864,8 @@ class NOAAWeatherPipeline(WeatherPipeline):
|
|||
|
||||
def _gen_fall_back_file(self):
|
||||
fall_back_content = [
|
||||
"\"STATION\",\"DATE\",\"AWND\",\"PRCP\",\"SNOW\",\"TMAX\",\"TMIN\"\n",
|
||||
",,,,,,\n"
|
||||
'"STATION","DATE","AWND","PRCP","SNOW","TMAX","TMIN"\n',
|
||||
",,,,,,\n",
|
||||
]
|
||||
with open(self._download_file, mode="w", encoding="utf-8", newline="") as f:
|
||||
f.writelines(fall_back_content)
|
||||
|
|
|
@ -16,7 +16,7 @@ scenario_map["vm_scheduling"] = VmSchedulingProcess
|
|||
def generate(scenario: str, topology: str = "", forced: bool = False, **kwargs):
|
||||
logger.info_green(
|
||||
f"Generating data files for scenario {scenario} topology {topology}"
|
||||
f" {'forced redownload.' if forced else ', not forced redownload.'}"
|
||||
f" {'forced redownload.' if forced else ', not forced redownload.'}",
|
||||
)
|
||||
if scenario in scenario_map:
|
||||
process = scenario_map[scenario]()
|
||||
|
|
|
@ -33,7 +33,7 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
|
||||
_meta_file_name = "vmtable.yml"
|
||||
# VM category includes three types, converting to 0, 1, 2.
|
||||
_category_map = {'Delay-insensitive': 0, 'Interactive': 1, 'Unknown': 2}
|
||||
_category_map = {"Delay-insensitive": 0, "Interactive": 1, "Unknown": 2}
|
||||
|
||||
def __init__(self, topology: str, source: str, sample: int, seed: int, is_temp: bool = False):
|
||||
super().__init__(scenario="vm_scheduling", topology=topology, source=source, is_temp=is_temp)
|
||||
|
@ -56,8 +56,8 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
aria2p.Client(
|
||||
host="http://localhost",
|
||||
port=6800,
|
||||
secret=""
|
||||
)
|
||||
secret="",
|
||||
),
|
||||
)
|
||||
self._download_file_list = []
|
||||
|
||||
|
@ -97,14 +97,14 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
with open(self._download_file, mode="r", encoding="utf-8") as urls:
|
||||
for remote_url in urls.read().splitlines():
|
||||
# Get the file name.
|
||||
file_name = remote_url.split('/')[-1]
|
||||
file_name = remote_url.split("/")[-1]
|
||||
# Two kinds of required files "vmtable" and "vm_cpu_readings-" start with vm.
|
||||
if file_name.startswith("vmtable"):
|
||||
if (not is_force) and os.path.exists(self._vm_table_file):
|
||||
logger.info_green(f"{self._vm_table_file} already exists, skipping download.")
|
||||
else:
|
||||
logger.info_green(f"Downloading vmtable from {remote_url} to {self._vm_table_file}.")
|
||||
self.aria2.add_uris(uris=[remote_url], options={'dir': self._download_folder})
|
||||
self.aria2.add_uris(uris=[remote_url], options={"dir": self._download_folder})
|
||||
|
||||
elif file_name.startswith("vm_cpu_readings") and num_files > 0:
|
||||
num_files -= 1
|
||||
|
@ -115,7 +115,7 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
logger.info_green(f"{cpu_readings_file} already exists, skipping download.")
|
||||
else:
|
||||
logger.info_green(f"Downloading cpu_readings from {remote_url} to {cpu_readings_file}.")
|
||||
self.aria2.add_uris(uris=[remote_url], options={'dir': f"{self._download_folder}"})
|
||||
self.aria2.add_uris(uris=[remote_url], options={"dir": f"{self._download_folder}"})
|
||||
|
||||
self._check_all_download_completed()
|
||||
|
||||
|
@ -149,7 +149,7 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
# Unzip gz file.
|
||||
with gzip.open(original_file, mode="rb") as f_in:
|
||||
logger.info_green(
|
||||
f"Unzip {original_file} to {raw_file}."
|
||||
f"Unzip {original_file} to {raw_file}.",
|
||||
)
|
||||
with open(raw_file, "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
|
@ -184,63 +184,80 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
"""Process vmtable file."""
|
||||
|
||||
headers = [
|
||||
'vmid', 'subscriptionid', 'deploymentid', 'vmcreated', 'vmdeleted', 'maxcpu', 'avgcpu', 'p95maxcpu',
|
||||
'vmcategory', 'vmcorecountbucket', 'vmmemorybucket'
|
||||
"vmid",
|
||||
"subscriptionid",
|
||||
"deploymentid",
|
||||
"vmcreated",
|
||||
"vmdeleted",
|
||||
"maxcpu",
|
||||
"avgcpu",
|
||||
"p95maxcpu",
|
||||
"vmcategory",
|
||||
"vmcorecountbucket",
|
||||
"vmmemorybucket",
|
||||
]
|
||||
|
||||
required_headers = [
|
||||
'vmid', 'subscriptionid', 'deploymentid', 'vmcreated', 'vmdeleted', 'vmcategory',
|
||||
'vmcorecountbucket', 'vmmemorybucket'
|
||||
"vmid",
|
||||
"subscriptionid",
|
||||
"deploymentid",
|
||||
"vmcreated",
|
||||
"vmdeleted",
|
||||
"vmcategory",
|
||||
"vmcorecountbucket",
|
||||
"vmmemorybucket",
|
||||
]
|
||||
|
||||
vm_table = pd.read_csv(raw_vm_table_file, header=None, index_col=False, names=headers)
|
||||
vm_table = vm_table.loc[:, required_headers]
|
||||
# Convert to tick by dividing by 300 (5 minutes).
|
||||
vm_table['vmcreated'] = pd.to_numeric(vm_table['vmcreated'], errors="coerce", downcast="integer") // 300
|
||||
vm_table['vmdeleted'] = pd.to_numeric(vm_table['vmdeleted'], errors="coerce", downcast="integer") // 300
|
||||
vm_table["vmcreated"] = pd.to_numeric(vm_table["vmcreated"], errors="coerce", downcast="integer") // 300
|
||||
vm_table["vmdeleted"] = pd.to_numeric(vm_table["vmdeleted"], errors="coerce", downcast="integer") // 300
|
||||
# The lifetime of the VM is deleted time - created time + 1 (tick).
|
||||
vm_table['lifetime'] = vm_table['vmdeleted'] - vm_table['vmcreated'] + 1
|
||||
vm_table["lifetime"] = vm_table["vmdeleted"] - vm_table["vmcreated"] + 1
|
||||
|
||||
vm_table['vmcategory'] = vm_table['vmcategory'].map(self._category_map)
|
||||
vm_table["vmcategory"] = vm_table["vmcategory"].map(self._category_map)
|
||||
|
||||
# Transform vmcorecount '>24' bucket to 32 and vmmemory '>64' to 128.
|
||||
vm_table = vm_table.replace({'vmcorecountbucket': '>24'}, 32)
|
||||
vm_table = vm_table.replace({'vmmemorybucket': '>64'}, 128)
|
||||
vm_table['vmcorecountbucket'] = pd.to_numeric(
|
||||
vm_table['vmcorecountbucket'], errors="coerce", downcast="integer"
|
||||
vm_table = vm_table.replace({"vmcorecountbucket": ">24"}, 32)
|
||||
vm_table = vm_table.replace({"vmmemorybucket": ">64"}, 128)
|
||||
vm_table["vmcorecountbucket"] = pd.to_numeric(
|
||||
vm_table["vmcorecountbucket"],
|
||||
errors="coerce",
|
||||
downcast="integer",
|
||||
)
|
||||
vm_table['vmmemorybucket'] = pd.to_numeric(vm_table['vmmemorybucket'], errors="coerce", downcast="integer")
|
||||
vm_table["vmmemorybucket"] = pd.to_numeric(vm_table["vmmemorybucket"], errors="coerce", downcast="integer")
|
||||
vm_table.dropna(inplace=True)
|
||||
|
||||
vm_table = vm_table.sort_values(by='vmcreated', ascending=True)
|
||||
vm_table = vm_table.sort_values(by="vmcreated", ascending=True)
|
||||
|
||||
# Generate ID map.
|
||||
vm_id_map = self._generate_id_map(vm_table['vmid'].unique())
|
||||
sub_id_map = self._generate_id_map(vm_table['subscriptionid'].unique())
|
||||
deployment_id_map = self._generate_id_map(vm_table['deploymentid'].unique())
|
||||
vm_id_map = self._generate_id_map(vm_table["vmid"].unique())
|
||||
sub_id_map = self._generate_id_map(vm_table["subscriptionid"].unique())
|
||||
deployment_id_map = self._generate_id_map(vm_table["deploymentid"].unique())
|
||||
|
||||
id_maps = (vm_id_map, sub_id_map, deployment_id_map)
|
||||
|
||||
# Mapping IDs.
|
||||
vm_table['vmid'] = vm_table['vmid'].map(vm_id_map)
|
||||
vm_table['subscriptionid'] = vm_table['subscriptionid'].map(sub_id_map)
|
||||
vm_table['deploymentid'] = vm_table['deploymentid'].map(deployment_id_map)
|
||||
vm_table["vmid"] = vm_table["vmid"].map(vm_id_map)
|
||||
vm_table["subscriptionid"] = vm_table["subscriptionid"].map(sub_id_map)
|
||||
vm_table["deploymentid"] = vm_table["deploymentid"].map(deployment_id_map)
|
||||
|
||||
# Sampling the VM table.
|
||||
# 2695548 is the total number of vms in the original Azure public dataset.
|
||||
if self._sample < 2695548:
|
||||
vm_table = vm_table.sample(n=self._sample, random_state=self._seed)
|
||||
vm_table = vm_table.sort_values(by='vmcreated', ascending=True)
|
||||
vm_table = vm_table.sort_values(by="vmcreated", ascending=True)
|
||||
|
||||
return id_maps, vm_table
|
||||
|
||||
def _convert_cpu_readings_id(self, old_data_path: str, new_data_path: str, vm_id_map: dict):
|
||||
"""Convert vmid in each cpu readings file."""
|
||||
with open(old_data_path, 'r') as f_in:
|
||||
with open(old_data_path, "r") as f_in:
|
||||
csv_reader = reader(f_in)
|
||||
with open(new_data_path, 'w') as f_out:
|
||||
with open(new_data_path, "w") as f_out:
|
||||
csv_writer = writer(f_out)
|
||||
csv_writer.writerow(['timestamp', 'vmid', 'maxcpu'])
|
||||
csv_writer.writerow(["timestamp", "vmid", "maxcpu"])
|
||||
for row in csv_reader:
|
||||
# [timestamp, vmid, mincpu, maxcpu, avgcpu]
|
||||
if row[1] in vm_id_map:
|
||||
|
@ -248,12 +265,12 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
csv_writer.writerow(new_row)
|
||||
|
||||
def _write_id_map_to_csv(self, id_maps):
|
||||
file_name = ['vm_id_map', 'sub_id_map', 'deployment_id_map']
|
||||
file_name = ["vm_id_map", "sub_id_map", "deployment_id_map"]
|
||||
for index in range(len(id_maps)):
|
||||
id_map = id_maps[index]
|
||||
with open(os.path.join(self._raw_folder, file_name[index]) + '.csv', 'w') as f:
|
||||
with open(os.path.join(self._raw_folder, file_name[index]) + ".csv", "w") as f:
|
||||
csv_writer = writer(f)
|
||||
csv_writer.writerow(['original_id', 'new_id'])
|
||||
csv_writer.writerow(["original_id", "new_id"])
|
||||
for key, value in id_map.items():
|
||||
csv_writer.writerow([key, value])
|
||||
|
||||
|
@ -288,7 +305,7 @@ class VmSchedulingPipeline(DataPipeline):
|
|||
self._convert_cpu_readings_id(
|
||||
old_data_path=raw_cpu_readings_file,
|
||||
new_data_path=clean_cpu_readings_file,
|
||||
vm_id_map=filtered_vm_id_map
|
||||
vm_id_map=filtered_vm_id_map,
|
||||
)
|
||||
|
||||
def build(self):
|
||||
|
@ -313,7 +330,7 @@ class VmSchedulingTopology(DataTopology):
|
|||
source=source,
|
||||
sample=sample,
|
||||
seed=seed,
|
||||
is_temp=is_temp
|
||||
is_temp=is_temp,
|
||||
)
|
||||
|
||||
|
||||
|
@ -336,5 +353,5 @@ class VmSchedulingProcess:
|
|||
source=self._conf["vm_data"][topology]["remote_url"],
|
||||
sample=self._conf["vm_data"][topology]["sample"],
|
||||
seed=self._conf["vm_data"][topology]["seed"],
|
||||
is_temp=is_temp
|
||||
is_temp=is_temp,
|
||||
)
|
||||
|
|
|
@ -77,12 +77,12 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
# Simultaneously capture image and init master
|
||||
build_node_image_thread = threading.Thread(
|
||||
target=GrassAzureExecutor._build_node_image,
|
||||
args=(cluster_details,)
|
||||
args=(cluster_details,),
|
||||
)
|
||||
build_node_image_thread.start()
|
||||
create_and_init_master_thread = threading.Thread(
|
||||
target=GrassAzureExecutor._create_and_init_master,
|
||||
args=(cluster_details,)
|
||||
args=(cluster_details,),
|
||||
)
|
||||
create_and_init_master_thread.start()
|
||||
build_node_image_thread.join()
|
||||
|
@ -125,14 +125,14 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
"root['connection']['ssh']": {"port": GlobalParams.DEFAULT_SSH_PORT},
|
||||
"root['connection']['ssh']['port']": GlobalParams.DEFAULT_SSH_PORT,
|
||||
"root['connection']['api_server']": {"port": GrassParams.DEFAULT_API_SERVER_PORT},
|
||||
"root['connection']['api_server']['port']": GrassParams.DEFAULT_API_SERVER_PORT
|
||||
"root['connection']['api_server']['port']": GrassParams.DEFAULT_API_SERVER_PORT,
|
||||
}
|
||||
with open(f"{GrassPaths.ABS_MARO_GRASS_LIB}/deployments/internal/grass_azure_create.yml") as fr:
|
||||
create_deployment_template = yaml.safe_load(fr)
|
||||
DeploymentValidator.validate_and_fill_dict(
|
||||
template_dict=create_deployment_template,
|
||||
actual_dict=create_deployment,
|
||||
optional_key_to_value=optional_key_to_value
|
||||
optional_key_to_value=optional_key_to_value,
|
||||
)
|
||||
|
||||
# Init runtime fields.
|
||||
|
@ -169,7 +169,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
else:
|
||||
AzureController.create_resource_group(
|
||||
resource_group=resource_group,
|
||||
location=cluster_details["cloud"]["location"]
|
||||
location=cluster_details["cloud"]["location"],
|
||||
)
|
||||
logger.info_green(f"Resource group '{resource_group}' is created")
|
||||
|
||||
|
@ -192,13 +192,13 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
)
|
||||
ArmTemplateParameterBuilder.create_vnet(
|
||||
cluster_details=cluster_details,
|
||||
export_path=parameters_file_path
|
||||
export_path=parameters_file_path,
|
||||
)
|
||||
AzureController.start_deployment(
|
||||
resource_group=cluster_details["cloud"]["resource_group"],
|
||||
deployment_name="vnet",
|
||||
template_file_path=template_file_path,
|
||||
parameters_file_path=parameters_file_path
|
||||
parameters_file_path=parameters_file_path,
|
||||
)
|
||||
|
||||
logger.info_green("Vnet is created")
|
||||
|
@ -233,13 +233,13 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
ArmTemplateParameterBuilder.create_build_node_image_vm(
|
||||
cluster_details=cluster_details,
|
||||
node_size=cluster_details["master"]["node_size"],
|
||||
export_path=parameters_file_path
|
||||
export_path=parameters_file_path,
|
||||
)
|
||||
AzureController.start_deployment(
|
||||
resource_group=cluster_details["cloud"]["resource_group"],
|
||||
deployment_name=resource_name,
|
||||
template_file_path=template_file_path,
|
||||
parameters_file_path=parameters_file_path
|
||||
parameters_file_path=parameters_file_path,
|
||||
)
|
||||
# Gracefully wait
|
||||
time.sleep(10)
|
||||
|
@ -247,7 +247,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
# Get public ip address
|
||||
ip_addresses = AzureController.list_ip_addresses(
|
||||
resource_group=cluster_details["cloud"]["resource_group"],
|
||||
vm_name=vm_name
|
||||
vm_name=vm_name,
|
||||
)
|
||||
public_ip_address = ip_addresses[0]["virtualMachine"]["network"]["publicIpAddresses"][0]["ipAddress"]
|
||||
|
||||
|
@ -255,7 +255,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
GrassAzureExecutor.retry_connection(
|
||||
node_username=cluster_details["cloud"]["default_username"],
|
||||
node_hostname=public_ip_address,
|
||||
node_ssh_port=cluster_details["connection"]["ssh"]["port"]
|
||||
node_ssh_port=cluster_details["connection"]["ssh"]["port"],
|
||||
)
|
||||
|
||||
# Run init image script
|
||||
|
@ -264,12 +264,12 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
remote_dir="~/",
|
||||
node_username=cluster_details["cloud"]["default_username"],
|
||||
node_hostname=public_ip_address,
|
||||
node_ssh_port=cluster_details["connection"]["ssh"]["port"]
|
||||
node_ssh_port=cluster_details["connection"]["ssh"]["port"],
|
||||
)
|
||||
GrassAzureExecutor.remote_init_build_node_image_vm(
|
||||
node_username=cluster_details["cloud"]["default_username"],
|
||||
node_hostname=public_ip_address,
|
||||
node_ssh_port=cluster_details["connection"]["ssh"]["port"]
|
||||
node_ssh_port=cluster_details["connection"]["ssh"]["port"],
|
||||
)
|
||||
|
||||
# Extract image
|
||||
|
@ -278,14 +278,14 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
AzureController.create_image_from_vm(
|
||||
resource_group=cluster_details["cloud"]["resource_group"],
|
||||
image_name=image_name,
|
||||
vm_name=vm_name
|
||||
vm_name=vm_name,
|
||||
)
|
||||
|
||||
# Delete resources
|
||||
GrassAzureExecutor._delete_resources(
|
||||
resource_group=cluster_details["cloud"]["resource_group"],
|
||||
resource_name=resource_name,
|
||||
cluster_id=cluster_details["id"]
|
||||
cluster_id=cluster_details["id"],
|
||||
)
|
||||
|
||||
logger.info_green("MARO Node Image is built")
|
||||
|
@ -313,7 +313,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
user_id=cluster_details["user"]["id"],
|
||||
master_to_dev_encryption_private_key=cluster_details["user"]["master_to_dev_encryption_private_key"],
|
||||
dev_to_master_encryption_public_key=cluster_details["user"]["dev_to_master_encryption_public_key"],
|
||||
dev_to_master_signing_private_key=cluster_details["user"]["dev_to_master_signing_private_key"]
|
||||
dev_to_master_signing_private_key=cluster_details["user"]["dev_to_master_signing_private_key"],
|
||||
)
|
||||
master_api_client.create_master(master_details=cluster_details["master"])
|
||||
master_api_client.create_cluster(cluster_details=cluster_details)
|
||||
|
@ -338,25 +338,24 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
# Create ARM parameters and start deployment
|
||||
template_file_path = f"{GrassPaths.ABS_MARO_GRASS_LIB}/modes/azure/create_master/template.json"
|
||||
parameters_file_path = (
|
||||
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}"
|
||||
f"/master/arm_create_master_parameters.json"
|
||||
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}" f"/master/arm_create_master_parameters.json"
|
||||
)
|
||||
ArmTemplateParameterBuilder.create_master(
|
||||
cluster_details=cluster_details,
|
||||
node_size=cluster_details["master"]["node_size"],
|
||||
export_path=parameters_file_path
|
||||
export_path=parameters_file_path,
|
||||
)
|
||||
AzureController.start_deployment(
|
||||
resource_group=cluster_details["cloud"]["resource_group"],
|
||||
deployment_name="master",
|
||||
template_file_path=template_file_path,
|
||||
parameters_file_path=parameters_file_path
|
||||
parameters_file_path=parameters_file_path,
|
||||
)
|
||||
|
||||
# Get master IP addresses
|
||||
ip_addresses = AzureController.list_ip_addresses(
|
||||
resource_group=cluster_details["cloud"]["resource_group"],
|
||||
vm_name=vm_name
|
||||
vm_name=vm_name,
|
||||
)
|
||||
public_ip_address = ip_addresses[0]["virtualMachine"]["network"]["publicIpAddresses"][0]["ipAddress"]
|
||||
private_ip_address = ip_addresses[0]["virtualMachine"]["network"]["privateIpAddresses"][0]
|
||||
|
@ -433,12 +432,12 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
if node_size_to_count[node_size] > replicas:
|
||||
self._delete_nodes(
|
||||
num=node_size_to_count[node_size] - replicas,
|
||||
node_size=node_size
|
||||
node_size=node_size,
|
||||
)
|
||||
elif node_size_to_count[node_size] < replicas:
|
||||
self._create_nodes(
|
||||
num=replicas - node_size_to_count[node_size],
|
||||
node_size=node_size
|
||||
node_size=node_size,
|
||||
)
|
||||
else:
|
||||
logger.warning_yellow("Replica is match, no create or delete")
|
||||
|
@ -459,7 +458,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
with ThreadPool(GlobalParams.PARALLELS) as pool:
|
||||
pool.starmap(
|
||||
self._create_node,
|
||||
[[node_size]] * num
|
||||
[[node_size]] * num,
|
||||
)
|
||||
|
||||
def _create_node(self, node_size: str) -> None:
|
||||
|
@ -478,7 +477,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
# Create node
|
||||
join_cluster_deployment = self._create_vm(
|
||||
node_name=node_name,
|
||||
node_size=node_size
|
||||
node_size=node_size,
|
||||
)
|
||||
|
||||
# Start joining cluster
|
||||
|
@ -512,12 +511,12 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
with ThreadPool(GlobalParams.PARALLELS) as pool:
|
||||
pool.starmap(
|
||||
self._delete_node,
|
||||
params
|
||||
params,
|
||||
)
|
||||
else:
|
||||
logger.warning_yellow(
|
||||
"Unable to scale down.\n"
|
||||
f"Only {len(deletable_nodes)} nodes are deletable, but need to delete {num} to meet the replica"
|
||||
f"Only {len(deletable_nodes)} nodes are deletable, but need to delete {num} to meet the replica",
|
||||
)
|
||||
|
||||
def _create_vm(self, node_name: str, node_size: str) -> dict:
|
||||
|
@ -543,19 +542,19 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
node_name=node_name,
|
||||
cluster_details=self.cluster_details,
|
||||
node_size=node_size,
|
||||
export_path=parameters_file_path
|
||||
export_path=parameters_file_path,
|
||||
)
|
||||
AzureController.start_deployment(
|
||||
resource_group=self.resource_group,
|
||||
deployment_name=node_name,
|
||||
template_file_path=template_file_path,
|
||||
parameters_file_path=parameters_file_path
|
||||
parameters_file_path=parameters_file_path,
|
||||
)
|
||||
|
||||
# Get node IP addresses
|
||||
ip_addresses = AzureController.list_ip_addresses(
|
||||
resource_group=self.resource_group,
|
||||
vm_name=f"{self.cluster_id}-{node_name}-vm"
|
||||
vm_name=f"{self.cluster_id}-{node_name}-vm",
|
||||
)
|
||||
|
||||
logger.info_green(f"VM '{node_name}' is created")
|
||||
|
@ -566,11 +565,11 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
"master": {
|
||||
"private_ip_address": self.master_private_ip_address,
|
||||
"api_server": {
|
||||
"port": self.master_api_server_port
|
||||
"port": self.master_api_server_port,
|
||||
},
|
||||
"redis": {
|
||||
"port": self.master_redis_port
|
||||
}
|
||||
"port": self.master_redis_port,
|
||||
},
|
||||
},
|
||||
"node": {
|
||||
"name": node_name,
|
||||
|
@ -584,23 +583,23 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
"resources": {
|
||||
"cpu": "all",
|
||||
"memory": "all",
|
||||
"gpu": "all"
|
||||
"gpu": "all",
|
||||
},
|
||||
"api_server": {
|
||||
"port": self.api_server_port
|
||||
"port": self.api_server_port,
|
||||
},
|
||||
"ssh": {
|
||||
"port": self.ssh_port
|
||||
}
|
||||
"port": self.ssh_port,
|
||||
},
|
||||
},
|
||||
"configs": {
|
||||
"install_node_runtime": False,
|
||||
"install_node_gpu_support": False
|
||||
}
|
||||
"install_node_gpu_support": False,
|
||||
},
|
||||
}
|
||||
with open(
|
||||
file=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/nodes/{node_name}/join_cluster_deployment.yml",
|
||||
mode="w"
|
||||
mode="w",
|
||||
) as fw:
|
||||
yaml.safe_dump(data=join_cluster_deployment, stream=fw)
|
||||
|
||||
|
@ -624,13 +623,13 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
self._delete_resources(
|
||||
resource_group=self.resource_group,
|
||||
cluster_id=self.cluster_id,
|
||||
resource_name=node_name
|
||||
resource_name=node_name,
|
||||
)
|
||||
|
||||
# Delete azure deployment
|
||||
AzureController.delete_deployment(
|
||||
resource_group=self.resource_group,
|
||||
deployment_name=node_name
|
||||
deployment_name=node_name,
|
||||
)
|
||||
|
||||
# Delete node related files
|
||||
|
@ -655,21 +654,22 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
self.retry_connection(
|
||||
node_username=node_details["username"],
|
||||
node_hostname=node_details["public_ip_address"],
|
||||
node_ssh_port=node_details["ssh"]["port"]
|
||||
node_ssh_port=node_details["ssh"]["port"],
|
||||
)
|
||||
|
||||
# Copy required files
|
||||
local_path_to_remote_dir = {
|
||||
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/nodes/{node_name}/join_cluster_deployment.yml":
|
||||
f"{GlobalPaths.MARO_LOCAL}/clusters/{self.cluster_name}/nodes/{node_name}"
|
||||
}
|
||||
local_path = (
|
||||
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/nodes/{node_name}/" + "join_cluster_deployment.yml"
|
||||
)
|
||||
remote_dir = f"{GlobalPaths.MARO_LOCAL}/clusters/{self.cluster_name}/nodes/{node_name}"
|
||||
local_path_to_remote_dir = {local_path: remote_dir}
|
||||
for local_path, remote_dir in local_path_to_remote_dir.items():
|
||||
FileSynchronizer.copy_files_to_node(
|
||||
local_path=local_path,
|
||||
remote_dir=remote_dir,
|
||||
node_username=node_details["username"],
|
||||
node_hostname=node_details["public_ip_address"],
|
||||
node_ssh_port=node_details["ssh"]["port"]
|
||||
node_ssh_port=node_details["ssh"]["port"],
|
||||
)
|
||||
|
||||
# Remote join cluster
|
||||
|
@ -682,7 +682,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
deployment_path=(
|
||||
f"{GlobalPaths.MARO_LOCAL}/clusters/{self.cluster_name}/nodes/{node_name}"
|
||||
f"/join_cluster_deployment.yml"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info_green(f"Node '{node_name}' is joined")
|
||||
|
@ -710,7 +710,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
# Check replicas
|
||||
if len(startable_nodes) < replicas:
|
||||
raise BadRequestError(
|
||||
f"No enough '{node_size}' nodes can be started, only {len(startable_nodes)} is able to start"
|
||||
f"No enough '{node_size}' nodes can be started, only {len(startable_nodes)} is able to start",
|
||||
)
|
||||
|
||||
# Parallel start
|
||||
|
@ -718,7 +718,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
with ThreadPool(GlobalParams.PARALLELS) as pool:
|
||||
pool.starmap(
|
||||
self._start_node,
|
||||
params
|
||||
params,
|
||||
)
|
||||
|
||||
def _start_node(self, node_name: str):
|
||||
|
@ -735,7 +735,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
# Start node vm
|
||||
AzureController.start_vm(
|
||||
resource_group=self.resource_group,
|
||||
vm_name=f"{self.cluster_id}-{node_name}-vm"
|
||||
vm_name=f"{self.cluster_id}-{node_name}-vm",
|
||||
)
|
||||
|
||||
# Start node
|
||||
|
@ -761,16 +761,16 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
stoppable_nodes_details = []
|
||||
for node_details in nodes_details:
|
||||
if (
|
||||
node_details["node_size"] == node_size and
|
||||
node_details["state"]["status"] == NodeStatus.RUNNING and
|
||||
self._count_running_containers(node_details) == 0
|
||||
node_details["node_size"] == node_size
|
||||
and node_details["state"]["status"] == NodeStatus.RUNNING
|
||||
and self._count_running_containers(node_details) == 0
|
||||
):
|
||||
stoppable_nodes_details.append(node_details)
|
||||
|
||||
# Check replicas
|
||||
if len(stoppable_nodes_details) < replicas:
|
||||
raise BadRequestError(
|
||||
f"No more '{node_size}' nodes can be stopped, only {len(stoppable_nodes_details)} are stoppable"
|
||||
f"No more '{node_size}' nodes can be stopped, only {len(stoppable_nodes_details)} are stoppable",
|
||||
)
|
||||
|
||||
# Parallel stop
|
||||
|
@ -778,7 +778,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
with ThreadPool(GlobalParams.PARALLELS) as pool:
|
||||
pool.starmap(
|
||||
self._stop_node,
|
||||
params
|
||||
params,
|
||||
)
|
||||
|
||||
def _stop_node(self, node_details: dict):
|
||||
|
@ -800,7 +800,7 @@ class GrassAzureExecutor(GrassExecutor):
|
|||
# Stop node vm
|
||||
AzureController.stop_vm(
|
||||
resource_group=self.resource_group,
|
||||
vm_name=f"{self.cluster_id}-{node_name}-vm"
|
||||
vm_name=f"{self.cluster_id}-{node_name}-vm",
|
||||
)
|
||||
|
||||
logger.info_green(f"Node '{node_name}' is stopped")
|
||||
|
@ -964,7 +964,7 @@ class ArmTemplateParameterBuilder:
|
|||
# Load and update parameters
|
||||
with open(
|
||||
file=f"{GrassPaths.ABS_MARO_GRASS_LIB}/modes/azure/create_build_node_image_vm/parameters.json",
|
||||
mode="r"
|
||||
mode="r",
|
||||
) as fr:
|
||||
base_parameters = json.load(fr)
|
||||
parameters = base_parameters["parameters"]
|
||||
|
@ -1008,7 +1008,7 @@ class ArmTemplateParameterBuilder:
|
|||
parameters["adminUsername"]["value"] = cluster_details["cloud"]["default_username"]
|
||||
parameters["imageResourceId"]["value"] = AzureController.get_image_resource_id(
|
||||
resource_group=cluster_details["cloud"]["resource_group"],
|
||||
image_name=f"{cluster_details['id']}-node-image"
|
||||
image_name=f"{cluster_details['id']}-node-image",
|
||||
)
|
||||
parameters["location"]["value"] = cluster_details["cloud"]["location"]
|
||||
parameters["networkInterfaceName"]["value"] = f"{cluster_details['id']}-{node_name}-nic"
|
||||
|
|
|
@ -21,7 +21,11 @@ from maro.cli.utils.name_creator import NameCreator
|
|||
from maro.cli.utils.params import GlobalPaths
|
||||
from maro.cli.utils.subprocess import Subprocess
|
||||
from maro.utils.exception.cli_exception import (
|
||||
BadRequestError, CliError, ClusterInternalError, CommandExecutionError, FileOperationError
|
||||
BadRequestError,
|
||||
CliError,
|
||||
ClusterInternalError,
|
||||
CommandExecutionError,
|
||||
FileOperationError,
|
||||
)
|
||||
from maro.utils.logger import CliLogger
|
||||
|
||||
|
@ -56,7 +60,7 @@ class GrassExecutor:
|
|||
user_id=self.user_details["id"],
|
||||
master_to_dev_encryption_private_key=self.user_details["master_to_dev_encryption_private_key"],
|
||||
dev_to_master_encryption_public_key=self.user_details["dev_to_master_encryption_public_key"],
|
||||
dev_to_master_signing_private_key=self.user_details["dev_to_master_signing_private_key"]
|
||||
dev_to_master_signing_private_key=self.user_details["dev_to_master_signing_private_key"],
|
||||
)
|
||||
self.master_ssh_port = self.cluster_details["master"]["ssh"]["port"]
|
||||
self.master_api_server_port = self.cluster_details["master"]["api_server"]["port"]
|
||||
|
@ -79,18 +83,18 @@ class GrassExecutor:
|
|||
GrassExecutor.retry_connection(
|
||||
node_username=cluster_details["master"]["username"],
|
||||
node_hostname=cluster_details["master"]["public_ip_address"],
|
||||
node_ssh_port=cluster_details["master"]["ssh"]["port"]
|
||||
node_ssh_port=cluster_details["master"]["ssh"]["port"],
|
||||
)
|
||||
|
||||
DetailsWriter.save_cluster_details(
|
||||
cluster_name=cluster_details["name"],
|
||||
cluster_details=cluster_details
|
||||
cluster_details=cluster_details,
|
||||
)
|
||||
|
||||
# Copy required files
|
||||
local_path_to_remote_dir = {
|
||||
GrassPaths.ABS_MARO_GRASS_LIB: f"{GlobalPaths.MARO_SHARED}/lib",
|
||||
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}": f"{GlobalPaths.MARO_SHARED}/clusters"
|
||||
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}": f"{GlobalPaths.MARO_SHARED}/clusters",
|
||||
}
|
||||
for local_path, remote_dir in local_path_to_remote_dir.items():
|
||||
FileSynchronizer.copy_files_to_node(
|
||||
|
@ -98,7 +102,7 @@ class GrassExecutor:
|
|||
remote_dir=remote_dir,
|
||||
node_username=cluster_details["master"]["username"],
|
||||
node_hostname=cluster_details["master"]["public_ip_address"],
|
||||
node_ssh_port=cluster_details["master"]["ssh"]["port"]
|
||||
node_ssh_port=cluster_details["master"]["ssh"]["port"],
|
||||
)
|
||||
|
||||
# Remote init master
|
||||
|
@ -106,7 +110,7 @@ class GrassExecutor:
|
|||
master_username=cluster_details["master"]["username"],
|
||||
master_hostname=cluster_details["master"]["public_ip_address"],
|
||||
master_ssh_port=cluster_details["master"]["ssh"]["port"],
|
||||
cluster_name=cluster_details["name"]
|
||||
cluster_name=cluster_details["name"],
|
||||
)
|
||||
# Gracefully wait
|
||||
time.sleep(10)
|
||||
|
@ -129,7 +133,7 @@ class GrassExecutor:
|
|||
master_hostname=cluster_details["master"]["public_ip_address"],
|
||||
master_ssh_port=cluster_details["master"]["ssh"]["port"],
|
||||
user_id=cluster_details["user"]["admin_id"],
|
||||
user_role=UserRole.ADMIN
|
||||
user_role=UserRole.ADMIN,
|
||||
)
|
||||
|
||||
# Update user_details, "admin_id" change to "id"
|
||||
|
@ -138,15 +142,15 @@ class GrassExecutor:
|
|||
# Save dev_to_master private key
|
||||
os.makedirs(
|
||||
name=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}/users/{user_details['id']}",
|
||||
exist_ok=True
|
||||
exist_ok=True,
|
||||
)
|
||||
with open(
|
||||
file=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}/users/{user_details['id']}/user_details",
|
||||
mode="w"
|
||||
mode="w",
|
||||
) as fw:
|
||||
yaml.safe_dump(
|
||||
data=user_details,
|
||||
stream=fw
|
||||
stream=fw,
|
||||
)
|
||||
|
||||
# Save default user
|
||||
|
@ -168,8 +172,9 @@ class GrassExecutor:
|
|||
logger.info(
|
||||
json.dumps(
|
||||
nodes_details,
|
||||
indent=4, sort_keys=True
|
||||
)
|
||||
indent=4,
|
||||
sort_keys=True,
|
||||
),
|
||||
)
|
||||
|
||||
# maro grass image
|
||||
|
@ -194,7 +199,7 @@ class GrassExecutor:
|
|||
abs_image_path = f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/image_files/{new_file_name}"
|
||||
DockerController.save_image(
|
||||
image_name=image_name,
|
||||
abs_export_path=abs_image_path
|
||||
abs_export_path=abs_image_path,
|
||||
)
|
||||
else:
|
||||
# Push image from local image file.
|
||||
|
@ -204,7 +209,7 @@ class GrassExecutor:
|
|||
FileSynchronizer.copy_and_rename(
|
||||
source_path=image_path,
|
||||
target_dir=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/image_files",
|
||||
new_name=new_file_name
|
||||
new_name=new_file_name,
|
||||
)
|
||||
# Use md5_checksum to skip existed image file.
|
||||
remote_image_file_details = self.master_api_client.get_image_file(image_file_name=new_file_name)
|
||||
|
@ -220,13 +225,13 @@ class GrassExecutor:
|
|||
remote_dir=f"{GlobalPaths.MARO_SHARED}/clusters/{self.cluster_name}/image_files",
|
||||
node_username=self.master_username,
|
||||
node_hostname=self.master_public_ip_address,
|
||||
node_ssh_port=self.master_ssh_port
|
||||
node_ssh_port=self.master_ssh_port,
|
||||
)
|
||||
self.master_api_client.create_image_file(
|
||||
image_file_details={
|
||||
"name": new_file_name,
|
||||
"md5_checksum": local_md5_checksum
|
||||
}
|
||||
"md5_checksum": local_md5_checksum,
|
||||
},
|
||||
)
|
||||
logger.info_green(f"Image {image_name} is loaded")
|
||||
else:
|
||||
|
@ -251,7 +256,7 @@ class GrassExecutor:
|
|||
remote_dir=f"{GlobalPaths.MARO_SHARED}/clusters/{self.cluster_name}/data{remote_path}",
|
||||
node_username=self.master_username,
|
||||
node_hostname=self.master_public_ip_address,
|
||||
node_ssh_port=self.master_ssh_port
|
||||
node_ssh_port=self.master_ssh_port,
|
||||
)
|
||||
|
||||
def pull_data(self, local_path: str, remote_path: str) -> None:
|
||||
|
@ -271,7 +276,7 @@ class GrassExecutor:
|
|||
remote_path=f"{GlobalPaths.MARO_SHARED}/clusters/{self.cluster_name}/data{remote_path}",
|
||||
node_username=self.master_username,
|
||||
node_hostname=self.master_public_ip_address,
|
||||
node_ssh_port=self.master_ssh_port
|
||||
node_ssh_port=self.master_ssh_port,
|
||||
)
|
||||
|
||||
# maro grass job
|
||||
|
@ -333,8 +338,9 @@ class GrassExecutor:
|
|||
logger.info(
|
||||
json.dumps(
|
||||
jobs_details,
|
||||
indent=4, sort_keys=True
|
||||
)
|
||||
indent=4,
|
||||
sort_keys=True,
|
||||
),
|
||||
)
|
||||
|
||||
def get_job_logs(self, job_name: str, export_dir: str = "./") -> None:
|
||||
|
@ -357,7 +363,7 @@ class GrassExecutor:
|
|||
remote_path=f"{GlobalPaths.MARO_SHARED}/clusters/{self.cluster_name}/logs/{job_details['id']}",
|
||||
node_username=self.master_username,
|
||||
node_hostname=self.master_public_ip_address,
|
||||
node_ssh_port=self.master_ssh_port
|
||||
node_ssh_port=self.master_ssh_port,
|
||||
)
|
||||
except CommandExecutionError:
|
||||
logger.error_red("No logs have been created at this time")
|
||||
|
@ -375,14 +381,14 @@ class GrassExecutor:
|
|||
"""
|
||||
# Validate grass_azure_start_job
|
||||
optional_key_to_value = {
|
||||
"root['tags']": {}
|
||||
"root['tags']": {},
|
||||
}
|
||||
with open(f"{GrassPaths.ABS_MARO_GRASS_LIB}/deployments/internal/grass_azure_start_job.yml") as fr:
|
||||
start_job_template = yaml.safe_load(fr)
|
||||
DeploymentValidator.validate_and_fill_dict(
|
||||
template_dict=start_job_template,
|
||||
actual_dict=start_job_deployment,
|
||||
optional_key_to_value=optional_key_to_value
|
||||
optional_key_to_value=optional_key_to_value,
|
||||
)
|
||||
|
||||
# Validate component
|
||||
|
@ -393,7 +399,7 @@ class GrassExecutor:
|
|||
DeploymentValidator.validate_and_fill_dict(
|
||||
template_dict=start_job_component_template,
|
||||
actual_dict=component_details,
|
||||
optional_key_to_value={}
|
||||
optional_key_to_value={},
|
||||
)
|
||||
|
||||
# Init runtime fields
|
||||
|
@ -457,7 +463,7 @@ class GrassExecutor:
|
|||
DeploymentValidator.validate_and_fill_dict(
|
||||
template_dict=start_job_template,
|
||||
actual_dict=start_schedule_deployment,
|
||||
optional_key_to_value={}
|
||||
optional_key_to_value={},
|
||||
)
|
||||
|
||||
# Validate component
|
||||
|
@ -468,7 +474,7 @@ class GrassExecutor:
|
|||
DeploymentValidator.validate_and_fill_dict(
|
||||
template_dict=start_job_component_template,
|
||||
actual_dict=component_details,
|
||||
optional_key_to_value={}
|
||||
optional_key_to_value={},
|
||||
)
|
||||
|
||||
return start_schedule_deployment
|
||||
|
@ -497,8 +503,9 @@ class GrassExecutor:
|
|||
logger.info(
|
||||
json.dumps(
|
||||
return_status,
|
||||
indent=4, sort_keys=True
|
||||
)
|
||||
indent=4,
|
||||
sort_keys=True,
|
||||
),
|
||||
)
|
||||
|
||||
# maro grass template
|
||||
|
@ -582,8 +589,11 @@ class GrassExecutor:
|
|||
|
||||
@staticmethod
|
||||
def remote_create_user(
|
||||
master_username: str, master_hostname: str, master_ssh_port: int,
|
||||
user_id: str, user_role: str
|
||||
master_username: str,
|
||||
master_hostname: str,
|
||||
master_ssh_port: int,
|
||||
user_id: str,
|
||||
user_role: str,
|
||||
) -> dict:
|
||||
"""Remote create MARO User.
|
||||
|
||||
|
@ -609,8 +619,12 @@ class GrassExecutor:
|
|||
|
||||
@staticmethod
|
||||
def remote_join_cluster(
|
||||
node_username: str, node_hostname: str, node_ssh_port: int,
|
||||
master_private_ip_address: str, master_api_server_port: int, deployment_path: str
|
||||
node_username: str,
|
||||
node_hostname: str,
|
||||
node_ssh_port: int,
|
||||
master_private_ip_address: str,
|
||||
master_api_server_port: int,
|
||||
deployment_path: str,
|
||||
) -> None:
|
||||
"""Remote join cluster.
|
||||
|
||||
|
@ -735,14 +749,14 @@ class GrassExecutor:
|
|||
GrassExecutor.test_ssh_default_port_connection(
|
||||
node_ssh_port=node_ssh_port,
|
||||
node_username=node_username,
|
||||
node_hostname=node_hostname
|
||||
node_hostname=node_hostname,
|
||||
)
|
||||
return
|
||||
except (CliError, TimeoutExpired):
|
||||
remain_retries -= 1
|
||||
logger.debug(
|
||||
f"Unable to connect to {node_hostname} with port {node_ssh_port}, "
|
||||
f"remains {remain_retries} retries"
|
||||
f"remains {remain_retries} retries",
|
||||
)
|
||||
time.sleep(5)
|
||||
raise ClusterInternalError(f"Unable to connect to {node_hostname} with port {node_ssh_port}")
|
||||
|
@ -751,7 +765,7 @@ class GrassExecutor:
|
|||
|
||||
@staticmethod
|
||||
def _get_md5_checksum(path: str, block_size=128) -> str:
|
||||
""" Get md5 checksum of a local file.
|
||||
"""Get md5 checksum of a local file.
|
||||
|
||||
Args:
|
||||
path (str): path of the local file.
|
||||
|
|
|
@ -27,8 +27,9 @@ logger = CliLogger(name=__name__)
|
|||
class GrassLocalExecutor(AbsVisibleExecutor):
|
||||
def __init__(self, cluster_name: str, cluster_details: dict = None):
|
||||
self.cluster_name = cluster_name
|
||||
self.cluster_details = DetailsReader.load_cluster_details(cluster_name=cluster_name) \
|
||||
if not cluster_details else cluster_details
|
||||
self.cluster_details = (
|
||||
DetailsReader.load_cluster_details(cluster_name=cluster_name) if not cluster_details else cluster_details
|
||||
)
|
||||
|
||||
# Connection with Redis
|
||||
redis_port = self.cluster_details["master"]["redis"]["port"]
|
||||
|
@ -37,7 +38,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
self._redis_connection.ping()
|
||||
except Exception:
|
||||
redis_process = subprocess.Popen(
|
||||
["redis-server", "--port", str(redis_port), "--daemonize yes"]
|
||||
["redis-server", "--port", str(redis_port), "--daemonize yes"],
|
||||
)
|
||||
redis_process.wait(timeout=2)
|
||||
|
||||
|
@ -53,7 +54,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
deployment["total_request_resource"] = {
|
||||
"cpu": total_cpu,
|
||||
"memory": total_memory,
|
||||
"gpu": total_gpu
|
||||
"gpu": total_gpu,
|
||||
}
|
||||
|
||||
deployment["status"] = JobStatus.PENDING
|
||||
|
@ -76,7 +77,9 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
|
||||
# Update resource
|
||||
is_satisfied, updated_resource = resource_op(
|
||||
available_resource, cluster_resource, ResourceOperation.ALLOCATION
|
||||
available_resource,
|
||||
cluster_resource,
|
||||
ResourceOperation.ALLOCATION,
|
||||
)
|
||||
if not is_satisfied:
|
||||
self._resource_redis.sub_cluster()
|
||||
|
@ -91,13 +94,13 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
self._redis_connection.hset(
|
||||
f"{self.cluster_name}:runtime_detail",
|
||||
"available_resource",
|
||||
json.dumps(cluster_resource)
|
||||
json.dumps(cluster_resource),
|
||||
)
|
||||
|
||||
# Save cluster config locally.
|
||||
DetailsWriter.save_cluster_details(
|
||||
cluster_name=self.cluster_name,
|
||||
cluster_details=self.cluster_details
|
||||
cluster_details=self.cluster_details,
|
||||
)
|
||||
|
||||
logger.info(f"{self.cluster_name} is created.")
|
||||
|
@ -117,7 +120,9 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
# Update resource
|
||||
cluster_resource = self.cluster_details["master"]["resource"]
|
||||
_, updated_resource = resource_op(
|
||||
available_resource, cluster_resource, ResourceOperation.RELEASE
|
||||
available_resource,
|
||||
cluster_resource,
|
||||
ResourceOperation.RELEASE,
|
||||
)
|
||||
self._resource_redis.set_available_resource(updated_resource)
|
||||
|
||||
|
@ -156,7 +161,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
is_satisfied, _ = resource_op(
|
||||
self.cluster_details["master"]["resource"],
|
||||
start_job_deployment["total_request_resource"],
|
||||
ResourceOperation.ALLOCATION
|
||||
ResourceOperation.ALLOCATION,
|
||||
)
|
||||
if not is_satisfied:
|
||||
raise BadRequestError(f"No enough resource to start job {start_job_deployment['name']}.")
|
||||
|
@ -171,13 +176,13 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
self._redis_connection.hset(
|
||||
f"{self.cluster_name}:job_details",
|
||||
job_name,
|
||||
json.dumps(job_details)
|
||||
json.dumps(job_details),
|
||||
)
|
||||
|
||||
# Push job name to pending_job_tickets
|
||||
self._redis_connection.lpush(
|
||||
f"{self.cluster_name}:pending_job_tickets",
|
||||
job_name
|
||||
job_name,
|
||||
)
|
||||
logger.info(f"Sending {job_name} into pending job tickets.")
|
||||
|
||||
|
@ -189,7 +194,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
# push job_name into killed_job_tickets
|
||||
self._redis_connection.lpush(
|
||||
f"{self.cluster_name}:killed_job_tickets",
|
||||
job_name
|
||||
job_name,
|
||||
)
|
||||
logger.info(f"Sending {job_name} into killed job tickets.")
|
||||
|
||||
|
@ -231,7 +236,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
is_satisfied, _ = resource_op(
|
||||
self.cluster_details["master"]["resource"],
|
||||
start_schedule_deployment["total_request_resource"],
|
||||
ResourceOperation.ALLOCATION
|
||||
ResourceOperation.ALLOCATION,
|
||||
)
|
||||
if not is_satisfied:
|
||||
raise BadRequestError(f"No enough resource to start schedule {schedule_name} in {self.cluster_name}.")
|
||||
|
@ -240,7 +245,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
self._redis_connection.hset(
|
||||
f"{self.cluster_name}:job_details",
|
||||
schedule_name,
|
||||
json.dumps(start_schedule_deployment)
|
||||
json.dumps(start_schedule_deployment),
|
||||
)
|
||||
|
||||
job_list = start_schedule_deployment["job_names"]
|
||||
|
@ -256,7 +261,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
def stop_schedule(self, schedule_name: str):
|
||||
try:
|
||||
schedule_details = json.loads(
|
||||
self._redis_connection.hget(f"{self.cluster_name}:job_details", schedule_name)
|
||||
self._redis_connection.hget(f"{self.cluster_name}:job_details", schedule_name),
|
||||
)
|
||||
except Exception:
|
||||
logger.error(f"No such schedule '{schedule_name}' in Redis.")
|
||||
|
@ -281,15 +286,17 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
def get_job_queue(self):
|
||||
pending_job_queue = self._redis_connection.lrange(
|
||||
f"{self.cluster_name}:pending_job_tickets",
|
||||
0, -1
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
killed_job_queue = self._redis_connection.lrange(
|
||||
f"{self.cluster_name}:killed_job_tickets",
|
||||
0, -1
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
return {
|
||||
"pending_jobs": pending_job_queue,
|
||||
"killed_jobs": killed_job_queue
|
||||
"killed_jobs": killed_job_queue,
|
||||
}
|
||||
|
||||
def get_resource(self):
|
||||
|
@ -298,6 +305,6 @@ class GrassLocalExecutor(AbsVisibleExecutor):
|
|||
def get_resource_usage(self, previous_length: int = 0):
|
||||
available_resource = self._redis_connection.hget(
|
||||
f"{self.cluster_name}:runtime_detail",
|
||||
"available_resource"
|
||||
"available_resource",
|
||||
)
|
||||
return json.loads(available_resource)
|
||||
|
|
|
@ -63,7 +63,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
user_id=cluster_details["user"]["id"],
|
||||
master_to_dev_encryption_private_key=cluster_details["user"]["master_to_dev_encryption_private_key"],
|
||||
dev_to_master_encryption_public_key=cluster_details["user"]["dev_to_master_encryption_public_key"],
|
||||
dev_to_master_signing_private_key=cluster_details["user"]["dev_to_master_signing_private_key"]
|
||||
dev_to_master_signing_private_key=cluster_details["user"]["dev_to_master_signing_private_key"],
|
||||
)
|
||||
master_api_client.create_master(master_details=cluster_details["master"])
|
||||
master_api_client.create_cluster(cluster_details=cluster_details)
|
||||
|
@ -95,20 +95,20 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
"root['master']['fluentd']": {"port": GlobalParams.DEFAULT_FLUENTD_PORT},
|
||||
"root['master']['fluentd']['port']": GlobalParams.DEFAULT_FLUENTD_PORT,
|
||||
"root['master']['samba']": {
|
||||
"password": samba_password
|
||||
"password": samba_password,
|
||||
},
|
||||
"root['master']['samba']['password']": samba_password,
|
||||
"root['master']['ssh']": {"port": GlobalParams.DEFAULT_SSH_PORT},
|
||||
"root['master']['ssh']['port']": GlobalParams.DEFAULT_SSH_PORT,
|
||||
"root['master']['api_server']": {"port": GrassParams.DEFAULT_API_SERVER_PORT},
|
||||
"root['master']['api_server']['port']": GrassParams.DEFAULT_API_SERVER_PORT
|
||||
"root['master']['api_server']['port']": GrassParams.DEFAULT_API_SERVER_PORT,
|
||||
}
|
||||
with open(f"{GrassPaths.ABS_MARO_GRASS_LIB}/deployments/internal/grass_on_premises_create.yml") as fr:
|
||||
create_deployment_template = yaml.safe_load(fr)
|
||||
DeploymentValidator.validate_and_fill_dict(
|
||||
template_dict=create_deployment_template,
|
||||
actual_dict=create_deployment,
|
||||
optional_key_to_value=optional_key_to_value
|
||||
optional_key_to_value=optional_key_to_value,
|
||||
)
|
||||
|
||||
# Init runtime fields.
|
||||
|
@ -134,13 +134,13 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
self.remote_leave_cluster(
|
||||
node_username=node_details["username"],
|
||||
node_hostname=node_details["public_ip_address"],
|
||||
node_ssh_port=node_details["ssh"]["port"]
|
||||
node_ssh_port=node_details["ssh"]["port"],
|
||||
)
|
||||
|
||||
self.remote_delete_master(
|
||||
master_username=self.master_username,
|
||||
master_hostname=self.master_public_ip_address,
|
||||
master_ssh_port=self.master_ssh_port
|
||||
master_ssh_port=self.master_ssh_port,
|
||||
)
|
||||
|
||||
shutil.rmtree(path=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}")
|
||||
|
@ -177,7 +177,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
|
||||
# Get standardized join_cluster_deployment
|
||||
join_cluster_deployment = GrassOnPremisesExecutor._standardize_join_cluster_deployment(
|
||||
join_cluster_deployment=join_cluster_deployment
|
||||
join_cluster_deployment=join_cluster_deployment,
|
||||
)
|
||||
|
||||
# Save join_cluster_deployment TODO: do checking, already join another node
|
||||
|
@ -186,7 +186,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
|
||||
# Copy required files
|
||||
local_path_to_remote_dir = {
|
||||
f"{GlobalPaths.ABS_MARO_LOCAL_TMP}/join_cluster_deployment.yml": GlobalPaths.MARO_LOCAL_TMP
|
||||
f"{GlobalPaths.ABS_MARO_LOCAL_TMP}/join_cluster_deployment.yml": GlobalPaths.MARO_LOCAL_TMP,
|
||||
}
|
||||
for local_path, remote_dir in local_path_to_remote_dir.items():
|
||||
FileSynchronizer.copy_files_to_node(
|
||||
|
@ -194,7 +194,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
remote_dir=remote_dir,
|
||||
node_username=join_cluster_deployment["node"]["username"],
|
||||
node_hostname=join_cluster_deployment["node"]["public_ip_address"],
|
||||
node_ssh_port=join_cluster_deployment["node"]["ssh"]["port"]
|
||||
node_ssh_port=join_cluster_deployment["node"]["ssh"]["port"],
|
||||
)
|
||||
|
||||
# Remote join node
|
||||
|
@ -204,7 +204,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
node_ssh_port=join_cluster_deployment["node"]["ssh"]["port"],
|
||||
master_private_ip_address=join_cluster_deployment["master"]["private_ip_address"],
|
||||
master_api_server_port=join_cluster_deployment["master"]["api_server"]["port"],
|
||||
deployment_path=f"{GlobalPaths.MARO_LOCAL_TMP}/join_cluster_deployment.yml"
|
||||
deployment_path=f"{GlobalPaths.MARO_LOCAL_TMP}/join_cluster_deployment.yml",
|
||||
)
|
||||
|
||||
os.remove(f"{GlobalPaths.ABS_MARO_LOCAL_TMP}/join_cluster_deployment.yml")
|
||||
|
@ -230,7 +230,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
"root['node']['resources']": {
|
||||
"cpu": "all",
|
||||
"memory": "all",
|
||||
"gpu": "all"
|
||||
"gpu": "all",
|
||||
},
|
||||
"root['node']['resources']['cpu']": "all",
|
||||
"root['node']['resources']['memory']": "all",
|
||||
|
@ -241,21 +241,21 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
"root['node']['ssh']['port']": GlobalParams.DEFAULT_SSH_PORT,
|
||||
"root['configs']": {
|
||||
"install_node_runtime": False,
|
||||
"install_node_gpu_support": False
|
||||
"install_node_gpu_support": False,
|
||||
},
|
||||
"root['configs']['install_node_runtime']": False,
|
||||
"root['configs']['install_node_gpu_support']": False
|
||||
"root['configs']['install_node_gpu_support']": False,
|
||||
}
|
||||
with open(
|
||||
file=f"{GrassPaths.ABS_MARO_GRASS_LIB}/deployments/internal/grass_on_premises_join_cluster.yml",
|
||||
mode="r"
|
||||
mode="r",
|
||||
) as fr:
|
||||
create_deployment_template = yaml.safe_load(stream=fr)
|
||||
|
||||
DeploymentValidator.validate_and_fill_dict(
|
||||
template_dict=create_deployment_template,
|
||||
actual_dict=join_cluster_deployment,
|
||||
optional_key_to_value=optional_key_to_value
|
||||
optional_key_to_value=optional_key_to_value,
|
||||
)
|
||||
|
||||
return join_cluster_deployment
|
||||
|
@ -283,7 +283,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
|
|||
GrassOnPremisesExecutor.remote_leave_cluster(
|
||||
node_username=leave_cluster_deployment["node"]["username"],
|
||||
node_hostname=leave_cluster_deployment["node"]["hostname"],
|
||||
node_ssh_port=leave_cluster_deployment["node"]["ssh"]["port"]
|
||||
node_ssh_port=leave_cluster_deployment["node"]["ssh"]["port"],
|
||||
)
|
||||
|
||||
logger.info_green("Node is left")
|
||||
|
|
|
@ -9,8 +9,7 @@ from maro.cli.utils.operation_lock_wrapper import operation_lock
|
|||
@check_details_validity
|
||||
@operation_lock
|
||||
def push_image(
|
||||
cluster_name: str, image_name: str, image_path: str, remote_context_path: str, remote_image_name: str,
|
||||
**kwargs
|
||||
cluster_name: str, image_name: str, image_path: str, remote_context_path: str, remote_image_name: str, **kwargs
|
||||
):
|
||||
# Late imports.
|
||||
from maro.cli.grass.executors.grass_azure_executor import GrassAzureExecutor
|
||||
|
@ -26,7 +25,7 @@ def push_image(
|
|||
image_name=image_name,
|
||||
image_path=image_path,
|
||||
remote_context_path=remote_context_path,
|
||||
remote_image_name=remote_image_name
|
||||
remote_image_name=remote_image_name,
|
||||
)
|
||||
elif cluster_details["mode"] == "grass/on-premises":
|
||||
executor = GrassOnPremisesExecutor(cluster_name=cluster_name)
|
||||
|
@ -34,7 +33,7 @@ def push_image(
|
|||
image_name=image_name,
|
||||
image_path=image_path,
|
||||
remote_context_path=remote_context_path,
|
||||
remote_image_name=remote_image_name
|
||||
remote_image_name=remote_image_name,
|
||||
)
|
||||
else:
|
||||
raise BadRequestError(f"Unsupported operation in mode '{cluster_details['mode']}'.")
|
||||
|
|
|
@ -72,7 +72,7 @@ if __name__ == "__main__":
|
|||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True
|
||||
universal_newlines=True,
|
||||
)
|
||||
while True:
|
||||
next_line = process.stdout.readline()
|
||||
|
|
|
@ -31,7 +31,7 @@ class UserCreator:
|
|||
self._local_cluster_details = local_cluster_details
|
||||
self._redis_controller = RedisController(
|
||||
host="localhost",
|
||||
port=self._local_cluster_details["master"]["redis"]["port"]
|
||||
port=self._local_cluster_details["master"]["redis"]["port"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -39,22 +39,22 @@ class UserCreator:
|
|||
rsa_key = rsa.generate_private_key(
|
||||
backend=default_backend(),
|
||||
public_exponent=65537,
|
||||
key_size=2048
|
||||
key_size=2048,
|
||||
)
|
||||
|
||||
# Format and encoding are diff from OpenSSH
|
||||
private_key = rsa_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_key = rsa_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return {
|
||||
"public_key": public_key.decode("utf-8"),
|
||||
"private_key": private_key.decode("utf-8")
|
||||
"private_key": private_key.decode("utf-8"),
|
||||
}
|
||||
|
||||
def create_user(self, user_id: str, user_role: str) -> None:
|
||||
|
@ -72,8 +72,8 @@ class UserCreator:
|
|||
"master_to_dev_encryption_public_key": master_to_dev_encryption_key_pair["public_key"],
|
||||
"dev_to_master_encryption_private_key": dev_to_master_encryption_key_pair["private_key"],
|
||||
"dev_to_master_signing_public_key": dev_to_master_signing_key_pair["public_key"],
|
||||
"master_to_dev_signing_private_key": master_to_dev_signing_key_pair["private_key"]
|
||||
}
|
||||
"master_to_dev_signing_private_key": master_to_dev_signing_key_pair["private_key"],
|
||||
},
|
||||
)
|
||||
|
||||
# Write private key to console.
|
||||
|
@ -85,9 +85,9 @@ class UserCreator:
|
|||
"master_to_dev_encryption_private_key": master_to_dev_encryption_key_pair["private_key"],
|
||||
"dev_to_master_encryption_public_key": dev_to_master_encryption_key_pair["public_key"],
|
||||
"dev_to_master_signing_private_key": dev_to_master_signing_key_pair["private_key"],
|
||||
"master_to_dev_signing_public_key": master_to_dev_signing_key_pair["public_key"]
|
||||
}
|
||||
)
|
||||
"master_to_dev_signing_public_key": master_to_dev_signing_key_pair["public_key"],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ REMOVE_CONTAINERS = """sudo docker rm -f maro-fluentd-{cluster_id} maro-redis-{c
|
|||
|
||||
# Master Deleter.
|
||||
|
||||
|
||||
class MasterDeleter:
|
||||
def __init__(self, local_cluster_details: dict):
|
||||
self._local_cluster_details = local_cluster_details
|
||||
|
@ -83,7 +84,7 @@ class Subprocess:
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
if completed_process.returncode != 0:
|
||||
raise Exception(completed_process.stderr)
|
||||
|
|
|
@ -106,7 +106,7 @@ class MasterInitializer:
|
|||
master_redis_port=self.master_details["redis"]["port"],
|
||||
cluster_name=cluster_details["name"],
|
||||
master_fluentd_port=self.master_details["fluentd"]["port"],
|
||||
steps=5
|
||||
steps=5,
|
||||
)
|
||||
Subprocess.interactive_run(command=command)
|
||||
|
||||
|
@ -115,7 +115,7 @@ class MasterInitializer:
|
|||
# Rewrite data in .service and write it to systemd folder
|
||||
with open(
|
||||
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/master_agent/maro-master-agent.service",
|
||||
mode="r"
|
||||
mode="r",
|
||||
) as fr:
|
||||
service_file = fr.read()
|
||||
service_file = service_file.format(maro_shared_path=Paths.ABS_MARO_SHARED)
|
||||
|
@ -131,13 +131,13 @@ class MasterInitializer:
|
|||
# Rewrite data in .service and write it to systemd folder
|
||||
with open(
|
||||
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/master_api_server/maro-master-api-server.service",
|
||||
mode="r"
|
||||
mode="r",
|
||||
) as fr:
|
||||
service_file = fr.read()
|
||||
service_file = service_file.format(
|
||||
home_path=str(pathlib.Path.home()),
|
||||
maro_shared_path=Paths.ABS_MARO_SHARED,
|
||||
master_api_server_port=self.master_details["api_server"]["port"]
|
||||
master_api_server_port=self.master_details["api_server"]["port"],
|
||||
)
|
||||
os.makedirs(name=os.path.expanduser("~/.config/systemd/user/"), exist_ok=True)
|
||||
with open(file=os.path.expanduser("~/.config/systemd/user/maro-master-api-server.service"), mode="w") as fw:
|
||||
|
@ -152,7 +152,7 @@ class MasterInitializer:
|
|||
os.makedirs(name=f"{Paths.ABS_MARO_LOCAL}/scripts", exist_ok=True)
|
||||
shutil.copy2(
|
||||
src=f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/master/delete_master.py",
|
||||
dst=f"{Paths.ABS_MARO_LOCAL}/scripts"
|
||||
dst=f"{Paths.ABS_MARO_LOCAL}/scripts",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class RedisController:
|
|||
def delete_node_details(self, node_name: str) -> None:
|
||||
self._redis.hdel(
|
||||
"name_to_node_details",
|
||||
node_name
|
||||
node_name,
|
||||
)
|
||||
|
||||
|
||||
|
@ -71,14 +71,14 @@ class Subprocess:
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
if completed_process.returncode != 0:
|
||||
raise Exception(completed_process.stderr)
|
||||
sys.stderr.write(completed_process.stderr)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
local_cluster_details = DetailsReader.load_local_cluster_details()
|
||||
local_node_details = DetailsReader.load_local_node_details()
|
||||
|
||||
|
@ -87,6 +87,6 @@ if __name__ == '__main__':
|
|||
|
||||
redis_controller = RedisController(
|
||||
host=local_cluster_details["master"]["private_ip_address"],
|
||||
port=local_cluster_details["master"]["redis"]["port"]
|
||||
port=local_cluster_details["master"]["redis"]["port"],
|
||||
)
|
||||
redis_controller.delete_node_details(node_name=local_node_details["name"])
|
||||
|
|
|
@ -111,6 +111,7 @@ echo "{public_key}" >> ~/.ssh/authorized_keys
|
|||
|
||||
# Node Joiner.
|
||||
|
||||
|
||||
class NodeJoiner:
|
||||
def __init__(self, join_cluster_deployment: dict):
|
||||
self.join_cluster_deployment = join_cluster_deployment
|
||||
|
@ -118,7 +119,7 @@ class NodeJoiner:
|
|||
|
||||
redis_controller = RedisController(
|
||||
host=join_cluster_deployment["master"]["private_ip_address"],
|
||||
port=join_cluster_deployment["master"]["redis"]["port"]
|
||||
port=join_cluster_deployment["master"]["redis"]["port"],
|
||||
)
|
||||
self.cluster_details = redis_controller.get_cluster_details()
|
||||
|
||||
|
@ -138,7 +139,7 @@ class NodeJoiner:
|
|||
node_details["image_files"] = {}
|
||||
node_details["containers"] = {}
|
||||
node_details["state"] = {
|
||||
"status": NodeStatus.PENDING
|
||||
"status": NodeStatus.PENDING,
|
||||
}
|
||||
|
||||
return node_details
|
||||
|
@ -163,7 +164,7 @@ class NodeJoiner:
|
|||
master_username=self.master_details["username"],
|
||||
master_hostname=self.master_details["private_ip_address"],
|
||||
master_samba_password=self.master_details["samba"]["password"],
|
||||
maro_shared_path=Paths.ABS_MARO_SHARED
|
||||
maro_shared_path=Paths.ABS_MARO_SHARED,
|
||||
)
|
||||
Subprocess.run(command=command)
|
||||
|
||||
|
@ -171,7 +172,7 @@ class NodeJoiner:
|
|||
# Rewrite data in .service and write it to systemd folder.
|
||||
with open(
|
||||
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/node_agent/maro-node-agent.service",
|
||||
mode="r"
|
||||
mode="r",
|
||||
) as fr:
|
||||
service_file = fr.read()
|
||||
service_file = service_file.format(maro_shared_path=Paths.ABS_MARO_SHARED)
|
||||
|
@ -186,13 +187,13 @@ class NodeJoiner:
|
|||
# Rewrite data in .service and write it to systemd folder.
|
||||
with open(
|
||||
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/node_api_server/maro-node-api-server.service",
|
||||
mode="r"
|
||||
mode="r",
|
||||
) as fr:
|
||||
service_file = fr.read()
|
||||
service_file = service_file.format(
|
||||
home_path=str(pathlib.Path.home()),
|
||||
maro_shared_path=Paths.ABS_MARO_SHARED,
|
||||
node_api_server_port=self.node_details["api_server"]["port"]
|
||||
node_api_server_port=self.node_details["api_server"]["port"],
|
||||
)
|
||||
os.makedirs(os.path.expanduser("~/.config/systemd/user/"), exist_ok=True)
|
||||
with open(file=os.path.expanduser("~/.config/systemd/user/maro-node-api-server.service"), mode="w") as fw:
|
||||
|
@ -205,13 +206,13 @@ class NodeJoiner:
|
|||
def copy_leave_script():
|
||||
src_files = [
|
||||
f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/node/leave_cluster.py",
|
||||
f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/node/activate_leave_cluster.py"
|
||||
f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/node/activate_leave_cluster.py",
|
||||
]
|
||||
os.makedirs(name=f"{Paths.ABS_MARO_LOCAL}/scripts", exist_ok=True)
|
||||
for src_file in src_files:
|
||||
shutil.copy2(
|
||||
src=src_file,
|
||||
dst=f"{Paths.ABS_MARO_LOCAL}/scripts"
|
||||
dst=f"{Paths.ABS_MARO_LOCAL}/scripts",
|
||||
)
|
||||
|
||||
def load_master_public_key(self):
|
||||
|
@ -227,11 +228,11 @@ class NodeJoiner:
|
|||
"master": {
|
||||
"private_ip_address": "",
|
||||
"api_server": {
|
||||
"port": ""
|
||||
"port": "",
|
||||
},
|
||||
"redis": {
|
||||
"port": ""
|
||||
}
|
||||
"port": "",
|
||||
},
|
||||
},
|
||||
"node": {
|
||||
"hostname": "",
|
||||
|
@ -241,19 +242,19 @@ class NodeJoiner:
|
|||
"resources": {
|
||||
"cpu": "",
|
||||
"memory": "",
|
||||
"gpu": ""
|
||||
"gpu": "",
|
||||
},
|
||||
"ssh": {
|
||||
"port": ""
|
||||
"port": "",
|
||||
},
|
||||
"api_server": {
|
||||
"port": ""
|
||||
}
|
||||
"port": "",
|
||||
},
|
||||
},
|
||||
"configs": {
|
||||
"install_node_runtime": "",
|
||||
"install_node_gpu_support": ""
|
||||
}
|
||||
"install_node_gpu_support": "",
|
||||
},
|
||||
}
|
||||
DeploymentValidator.validate_and_fill_dict(
|
||||
template_dict=join_cluster_deployment_template,
|
||||
|
@ -266,7 +267,7 @@ class NodeJoiner:
|
|||
"root['node']['resources']": {
|
||||
"cpu": "all",
|
||||
"memory": "all",
|
||||
"gpu": "all"
|
||||
"gpu": "all",
|
||||
},
|
||||
"root['node']['resources']['cpu']": "all",
|
||||
"root['node']['resources']['memory']": "all",
|
||||
|
@ -277,11 +278,11 @@ class NodeJoiner:
|
|||
"root['node']['ssh']['port']": Params.DEFAULT_SSH_PORT,
|
||||
"root['configs']": {
|
||||
"install_node_runtime": False,
|
||||
"install_node_gpu_support": False
|
||||
"install_node_gpu_support": False,
|
||||
},
|
||||
"root['configs']['install_node_runtime']": False,
|
||||
"root['configs']['install_node_gpu_support']": False
|
||||
}
|
||||
"root['configs']['install_node_gpu_support']": False,
|
||||
},
|
||||
)
|
||||
|
||||
return join_cluster_deployment
|
||||
|
@ -289,6 +290,7 @@ class NodeJoiner:
|
|||
|
||||
# Utils Classes.
|
||||
|
||||
|
||||
class Params:
|
||||
DEFAULT_SSH_PORT = 22
|
||||
DEFAULT_REDIS_PORT = 6379
|
||||
|
@ -340,13 +342,13 @@ class RedisController:
|
|||
self._redis.hset(
|
||||
"name_to_node_details",
|
||||
node_details["name"],
|
||||
json.dumps(node_details)
|
||||
json.dumps(node_details),
|
||||
)
|
||||
|
||||
"""Utils."""
|
||||
|
||||
def lock(self, name: str) -> Lock:
|
||||
""" Get a new lock with redis.
|
||||
"""Get a new lock with redis.
|
||||
|
||||
Use 'with lock(name):' paradigm to do the locking.
|
||||
|
||||
|
@ -386,7 +388,7 @@ class DeploymentValidator:
|
|||
DeploymentValidator._set_value(
|
||||
original_dict=actual_dict,
|
||||
key_list=DeploymentValidator._get_parent_to_child_key_list(deep_diff_str=missing_key_str),
|
||||
value=optional_key_to_value[missing_key_str]
|
||||
value=optional_key_to_value[missing_key_str],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -454,7 +456,7 @@ class Subprocess:
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
if completed_process.returncode != 0:
|
||||
raise Exception(completed_process.stderr)
|
||||
|
@ -477,7 +479,7 @@ class Subprocess:
|
|||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True
|
||||
universal_newlines=True,
|
||||
)
|
||||
while True:
|
||||
next_line = process.stdout.readline()
|
||||
|
@ -523,7 +525,7 @@ if __name__ == "__main__":
|
|||
join_cluster_deployment = yaml.safe_load(stream=fr)
|
||||
|
||||
join_cluster_deployment = NodeJoiner.standardize_join_cluster_deployment(
|
||||
join_cluster_deployment=join_cluster_deployment
|
||||
join_cluster_deployment=join_cluster_deployment,
|
||||
)
|
||||
node_joiner = NodeJoiner(join_cluster_deployment=join_cluster_deployment)
|
||||
node_joiner.init_node_runtime_env()
|
||||
|
|
|
@ -92,7 +92,7 @@ class Subprocess:
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
if completed_process.returncode != 0:
|
||||
raise Exception(completed_process.stderr)
|
||||
|
|
|
@ -23,7 +23,7 @@ if __name__ == "__main__":
|
|||
# Rewrite data in .service and write it to systemd folder
|
||||
with open(
|
||||
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/node_agent/maro-node-agent.service",
|
||||
mode="r"
|
||||
mode="r",
|
||||
) as fr:
|
||||
service_file = fr.read()
|
||||
service_file = service_file.format(maro_shared_path=Paths.ABS_MARO_SHARED)
|
||||
|
|
|
@ -28,13 +28,13 @@ if __name__ == "__main__":
|
|||
# Rewrite data in .service and write it to systemd folder
|
||||
with open(
|
||||
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/node_api_server/maro-node-api-server.service",
|
||||
mode="r"
|
||||
mode="r",
|
||||
) as fr:
|
||||
service_file = fr.read()
|
||||
service_file = service_file.format(
|
||||
home_path=str(pathlib.Path.home()),
|
||||
maro_shared_path=Paths.ABS_MARO_SHARED,
|
||||
node_api_server_port=local_node_details["api_server"]["port"]
|
||||
node_api_server_port=local_node_details["api_server"]["port"],
|
||||
)
|
||||
os.makedirs(os.path.expanduser("~/.config/systemd/user/"), exist_ok=True)
|
||||
with open(file=os.path.expanduser("~/.config/systemd/user/maro-node-api-server.service"), mode="w") as fw:
|
||||
|
|
|
@ -8,8 +8,7 @@ from .params import Paths
|
|||
|
||||
|
||||
class DetailsReader:
|
||||
"""Reader class for details.
|
||||
"""
|
||||
"""Reader class for details."""
|
||||
|
||||
@staticmethod
|
||||
def load_cluster_details(cluster_name: str) -> dict:
|
||||
|
|
|
@ -10,8 +10,7 @@ from .params import Paths
|
|||
|
||||
|
||||
class DetailsWriter:
|
||||
"""Writer class for details.
|
||||
"""
|
||||
"""Writer class for details."""
|
||||
|
||||
@staticmethod
|
||||
def save_local_cluster_details(cluster_details: dict) -> None:
|
||||
|
|
|
@ -7,8 +7,7 @@ import redis
|
|||
|
||||
|
||||
class RedisController:
|
||||
"""Controller class for Redis.
|
||||
"""
|
||||
"""Controller class for Redis."""
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
self._redis = redis.Redis(host=host, port=port, encoding="utf-8", decode_responses=True)
|
||||
|
@ -19,5 +18,5 @@ class RedisController:
|
|||
return self._redis.hset(
|
||||
"id_to_user_details",
|
||||
user_id,
|
||||
json.dumps(obj=user_details)
|
||||
json.dumps(obj=user_details),
|
||||
)
|
||||
|
|
|
@ -7,8 +7,7 @@ import sys
|
|||
|
||||
|
||||
class Subprocess:
|
||||
"""Wrapper class of subprocess, with CliException integrated.
|
||||
"""
|
||||
"""Wrapper class of subprocess, with CliException integrated."""
|
||||
|
||||
@staticmethod
|
||||
def run(command: str, timeout: int = None) -> str:
|
||||
|
@ -29,7 +28,7 @@ class Subprocess:
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
if completed_process.returncode != 0:
|
||||
raise Exception(completed_process.stderr)
|
||||
|
@ -52,7 +51,7 @@ class Subprocess:
|
|||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True
|
||||
universal_newlines=True,
|
||||
)
|
||||
while True:
|
||||
next_line = process.stdout.readline()
|
||||
|
|
|
@ -21,13 +21,15 @@ logger = logging.getLogger(__name__)
|
|||
AVAILABLE_METRICS = {
|
||||
"cpu",
|
||||
"memory",
|
||||
"gpu"
|
||||
"gpu",
|
||||
}
|
||||
|
||||
ERROR_CODE_FOR_NOT_RESTART = 64
|
||||
ERROR_CODE_FOR_STOP_JOB = 65
|
||||
ERROR_CODES_FOR_NOT_RESTART_CONTAINER = {
|
||||
0, ERROR_CODE_FOR_NOT_RESTART, ERROR_CODE_FOR_STOP_JOB
|
||||
0,
|
||||
ERROR_CODE_FOR_NOT_RESTART,
|
||||
ERROR_CODE_FOR_STOP_JOB,
|
||||
}
|
||||
|
||||
|
||||
|
@ -44,27 +46,27 @@ class MasterAgent:
|
|||
"""
|
||||
job_tracking_agent = JobTrackingAgent(
|
||||
local_cluster_details=self._local_cluster_details,
|
||||
local_master_details=self._local_master_details
|
||||
local_master_details=self._local_master_details,
|
||||
)
|
||||
job_tracking_agent.start()
|
||||
container_tracking_agent = ContainerTrackingAgent(
|
||||
local_cluster_details=self._local_cluster_details,
|
||||
local_master_details=self._local_master_details
|
||||
local_master_details=self._local_master_details,
|
||||
)
|
||||
container_tracking_agent.start()
|
||||
pending_job_agent = PendingJobAgent(
|
||||
local_cluster_details=self._local_cluster_details,
|
||||
local_master_details=self._local_master_details
|
||||
local_master_details=self._local_master_details,
|
||||
)
|
||||
pending_job_agent.start()
|
||||
container_runtime_agent = ContainerRuntimeAgent(
|
||||
local_cluster_details=self._local_cluster_details,
|
||||
local_master_details=self._local_master_details
|
||||
local_master_details=self._local_master_details,
|
||||
)
|
||||
container_runtime_agent.start()
|
||||
killed_job_agent = KilledJobAgent(
|
||||
local_cluster_details=self._local_cluster_details,
|
||||
local_master_details=self._local_master_details
|
||||
local_master_details=self._local_master_details,
|
||||
)
|
||||
killed_job_agent.start()
|
||||
|
||||
|
@ -81,7 +83,7 @@ class JobTrackingAgent(multiprocessing.Process):
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host="localhost",
|
||||
port=local_master_details["redis"]["port"]
|
||||
port=local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
self._check_interval = check_interval
|
||||
|
@ -138,7 +140,7 @@ class JobTrackingAgent(multiprocessing.Process):
|
|||
job_details["status"] = job_state
|
||||
self._redis_controller.set_job_details(
|
||||
job_name=job_name,
|
||||
job_details=job_details
|
||||
job_details=job_details,
|
||||
)
|
||||
|
||||
# Utils.
|
||||
|
@ -170,7 +172,7 @@ class ContainerTrackingAgent(multiprocessing.Process):
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host="localhost",
|
||||
port=local_master_details["redis"]["port"]
|
||||
port=local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
self._check_interval = check_interval
|
||||
|
@ -221,7 +223,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host="localhost",
|
||||
port=local_master_details["redis"]["port"]
|
||||
port=local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
self._check_interval = check_interval
|
||||
|
@ -258,7 +260,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
# Remove container.
|
||||
is_remove_container = self._is_remove_container(
|
||||
container_details=container_details,
|
||||
job_runtime_details=job_runtime_details
|
||||
job_runtime_details=job_runtime_details,
|
||||
)
|
||||
if is_remove_container:
|
||||
node_name = container_details["node_name"]
|
||||
|
@ -272,7 +274,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
# Restart container.
|
||||
if self._is_restart_container(
|
||||
container_details=container_details,
|
||||
job_runtime_details=job_runtime_details
|
||||
job_runtime_details=job_runtime_details,
|
||||
):
|
||||
self._restart_container(container_name=container_name, container_details=container_details)
|
||||
|
||||
|
@ -309,7 +311,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
"""
|
||||
exceed_maximum_restart_times = self._redis_controller.get_rejoin_component_restart_times(
|
||||
job_id=container_details["job_id"],
|
||||
component_id=container_details["component_id"]
|
||||
component_id=container_details["component_id"],
|
||||
) >= int(job_runtime_details.get("rejoin:max_restart_times", sys.maxsize))
|
||||
return (
|
||||
container_details["state"]["Status"] == "exited"
|
||||
|
@ -346,13 +348,13 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
"""
|
||||
# Get component_name_to_container_name.
|
||||
rejoin_container_name_to_component_name = self._redis_controller.get_rejoin_container_name_to_component_name(
|
||||
job_id=container_details["job_id"]
|
||||
job_id=container_details["job_id"],
|
||||
)
|
||||
|
||||
# If the mapping not exists, or the container is not in the mapping, skip the restart operation.
|
||||
if (
|
||||
rejoin_container_name_to_component_name is None or
|
||||
container_name not in rejoin_container_name_to_component_name
|
||||
rejoin_container_name_to_component_name is None
|
||||
or container_name not in rejoin_container_name_to_component_name
|
||||
):
|
||||
logger.warning(f"Container {container_name} is not found in container_name_to_component_name mapping")
|
||||
return
|
||||
|
@ -364,24 +366,24 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
# Get resources and allocation plan.
|
||||
free_resources = ResourceController.get_free_resources(
|
||||
redis_controller=self._redis_controller,
|
||||
cluster_name=self._cluster_name
|
||||
cluster_name=self._cluster_name,
|
||||
)
|
||||
required_resources = [
|
||||
ContainerResource(
|
||||
container_name=ContainerController.build_container_name(
|
||||
job_id=container_details["job_id"],
|
||||
component_id=container_details["component_id"],
|
||||
component_index=container_details["component_index"]
|
||||
component_index=container_details["component_index"],
|
||||
),
|
||||
cpu=float(container_details["cpu"]),
|
||||
memory=float(container_details["memory"].replace("m", "")),
|
||||
gpu=float(container_details["gpu"])
|
||||
)
|
||||
gpu=float(container_details["gpu"]),
|
||||
),
|
||||
]
|
||||
allocation_plan = ResourceController._get_single_metric_balanced_allocation_plan(
|
||||
allocation_details={"metric": "cpu"},
|
||||
required_resources=required_resources,
|
||||
free_resources=free_resources
|
||||
free_resources=free_resources,
|
||||
)
|
||||
|
||||
# Start a new container.
|
||||
|
@ -392,11 +394,11 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
container_name=container_name,
|
||||
node_details=node_details,
|
||||
job_details=job_details,
|
||||
component_name=component_name
|
||||
component_name=component_name,
|
||||
)
|
||||
self._redis_controller.incr_rejoin_component_restart_times(
|
||||
job_id=container_details["job_id"],
|
||||
component_id=container_details["component_id"]
|
||||
component_id=container_details["component_id"],
|
||||
)
|
||||
except ResourceAllocationFailed as e:
|
||||
logger.warning(f"{e}")
|
||||
|
@ -436,13 +438,13 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
NodeApiClientV1.remove_container(
|
||||
node_hostname=node_details["hostname"],
|
||||
node_api_server_port=node_details["api_server"]["port"],
|
||||
container_name=container_name
|
||||
container_name=container_name,
|
||||
)
|
||||
else:
|
||||
NodeApiClientV1.stop_container(
|
||||
node_hostname=node_details["hostname"],
|
||||
node_api_server_port=node_details["api_server"]["port"],
|
||||
container_name=container_name
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
def _start_container(self, container_name: str, node_details: dict, job_details: dict, component_name: str) -> None:
|
||||
|
@ -486,7 +488,6 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
"command": job_details["components"][component_type]["command"],
|
||||
"image_name": job_details["components"][component_type]["image"],
|
||||
"volumes": [f"{maro_mount_source}:{mount_target}"],
|
||||
|
||||
# System related.
|
||||
"container_name": container_name,
|
||||
"fluentd_address": f"{self._master_hostname}:{self._master_fluentd_port}",
|
||||
|
@ -503,7 +504,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
"COMPONENT_INDEX": component_index,
|
||||
"CONTAINER_NAME": container_name,
|
||||
"COMPONENT_NAME": component_name,
|
||||
"PYTHONUNBUFFERED": 0
|
||||
"PYTHONUNBUFFERED": 0,
|
||||
},
|
||||
"labels": {
|
||||
"cluster_id": cluster_id,
|
||||
|
@ -518,8 +519,8 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
"container_name": container_name,
|
||||
"component_name": component_name,
|
||||
"cpu": cpu,
|
||||
"memory": memory
|
||||
}
|
||||
"memory": memory,
|
||||
},
|
||||
}
|
||||
|
||||
if gpu != 0:
|
||||
|
@ -529,7 +530,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
|
|||
NodeApiClientV1.create_container(
|
||||
node_hostname=node_details["hostname"],
|
||||
node_api_server_port=node_details["api_server"]["port"],
|
||||
create_config=create_config
|
||||
create_config=create_config,
|
||||
)
|
||||
|
||||
|
||||
|
@ -548,7 +549,7 @@ class PendingJobAgent(multiprocessing.Process):
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host="localhost",
|
||||
port=local_master_details["redis"]["port"]
|
||||
port=local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
self._check_interval = check_interval
|
||||
|
@ -580,7 +581,7 @@ class PendingJobAgent(multiprocessing.Process):
|
|||
# Get free resources at the very beginning.
|
||||
free_resources = ResourceController.get_free_resources(
|
||||
redis_controller=self._redis_controller,
|
||||
cluster_name=self._cluster_name
|
||||
cluster_name=self._cluster_name,
|
||||
)
|
||||
|
||||
# Iterate tickets.
|
||||
|
@ -596,14 +597,14 @@ class PendingJobAgent(multiprocessing.Process):
|
|||
allocation_plan = ResourceController.get_allocation_plan(
|
||||
allocation_details=job_details["allocation"],
|
||||
required_resources=required_resources,
|
||||
free_resources=free_resources
|
||||
free_resources=free_resources,
|
||||
)
|
||||
for container_name, node_name in allocation_plan.items():
|
||||
node_details = self._redis_controller.get_node_details(node_name=node_name)
|
||||
self._start_container(
|
||||
container_name=container_name,
|
||||
node_details=node_details,
|
||||
job_details=job_details
|
||||
job_details=job_details,
|
||||
)
|
||||
self._redis_controller.remove_pending_job_ticket(job_name=pending_job_name)
|
||||
job_details["status"] = JobStatus.RUNNING
|
||||
|
@ -654,7 +655,6 @@ class PendingJobAgent(multiprocessing.Process):
|
|||
"command": job_details["components"][component_type]["command"],
|
||||
"image_name": job_details["components"][component_type]["image"],
|
||||
"volumes": [f"{maro_mount_source}:{mount_target}"],
|
||||
|
||||
# System related.
|
||||
"container_name": container_name,
|
||||
"fluentd_address": f"{self._master_hostname}:{self._master_fluentd_port}",
|
||||
|
@ -670,7 +670,7 @@ class PendingJobAgent(multiprocessing.Process):
|
|||
"COMPONENT_TYPE": component_type,
|
||||
"COMPONENT_INDEX": component_index,
|
||||
"CONTAINER_NAME": container_name,
|
||||
"PYTHONUNBUFFERED": 0
|
||||
"PYTHONUNBUFFERED": 0,
|
||||
},
|
||||
"labels": {
|
||||
"cluster_id": cluster_id,
|
||||
|
@ -684,8 +684,8 @@ class PendingJobAgent(multiprocessing.Process):
|
|||
"component_index": component_index,
|
||||
"container_name": container_name,
|
||||
"cpu": cpu,
|
||||
"memory": memory
|
||||
}
|
||||
"memory": memory,
|
||||
},
|
||||
}
|
||||
|
||||
if gpu != 0:
|
||||
|
@ -695,7 +695,7 @@ class PendingJobAgent(multiprocessing.Process):
|
|||
NodeApiClientV1.create_container(
|
||||
node_hostname=node_details["hostname"],
|
||||
node_api_server_port=node_details["api_server"]["port"],
|
||||
create_config=create_config
|
||||
create_config=create_config,
|
||||
)
|
||||
|
||||
|
||||
|
@ -712,7 +712,7 @@ class KilledJobAgent(multiprocessing.Process):
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host="localhost",
|
||||
port=local_master_details["redis"]["port"]
|
||||
port=local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
self._check_interval = check_interval
|
||||
|
@ -791,13 +791,12 @@ class KilledJobAgent(multiprocessing.Process):
|
|||
NodeApiClientV1.remove_container(
|
||||
node_hostname=node_details["hostname"],
|
||||
node_api_server_port=node_details["api_server"]["port"],
|
||||
container_name=container_name
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
|
||||
class ResourceController:
|
||||
"""Controller class for computing resources in MARO Nodes.
|
||||
"""
|
||||
"""Controller class for computing resources in MARO Nodes."""
|
||||
|
||||
@staticmethod
|
||||
def get_allocation_plan(allocation_details: dict, required_resources: list, free_resources: list) -> dict:
|
||||
|
@ -815,13 +814,13 @@ class ResourceController:
|
|||
return ResourceController._get_single_metric_balanced_allocation_plan(
|
||||
allocation_details=allocation_details,
|
||||
required_resources=required_resources,
|
||||
free_resources=free_resources
|
||||
free_resources=free_resources,
|
||||
)
|
||||
elif allocation_details["mode"] == "single-metric-compacted":
|
||||
return ResourceController._get_single_metric_compacted_allocation_plan(
|
||||
allocation_details=allocation_details,
|
||||
required_resources=required_resources,
|
||||
free_resources=free_resources
|
||||
free_resources=free_resources,
|
||||
)
|
||||
else:
|
||||
raise ResourceAllocationFailed("Invalid allocation mode.")
|
||||
|
@ -829,7 +828,8 @@ class ResourceController:
|
|||
@staticmethod
|
||||
def _get_single_metric_compacted_allocation_plan(
|
||||
allocation_details: dict,
|
||||
required_resources: list, free_resources: list
|
||||
required_resources: list,
|
||||
free_resources: list,
|
||||
) -> dict:
|
||||
"""Get single_metric_compacted allocation plan.
|
||||
|
||||
|
@ -856,13 +856,13 @@ class ResourceController:
|
|||
for required_resource in required_resources:
|
||||
heapq.heappush(
|
||||
required_resources_pq,
|
||||
(-getattr(required_resource, metric), required_resource)
|
||||
(-getattr(required_resource, metric), required_resource),
|
||||
)
|
||||
free_resources_pq = []
|
||||
for free_resource in free_resources:
|
||||
heapq.heappush(
|
||||
free_resources_pq,
|
||||
(getattr(free_resource, metric), free_resource)
|
||||
(getattr(free_resource, metric), free_resource),
|
||||
)
|
||||
|
||||
# Get allocation.
|
||||
|
@ -890,23 +890,23 @@ class ResourceController:
|
|||
free_resource.gpu -= required_resource.gpu
|
||||
heapq.heappush(
|
||||
free_resources_pq,
|
||||
(getattr(free_resource, metric), free_resource)
|
||||
(getattr(free_resource, metric), free_resource),
|
||||
)
|
||||
for not_usable_free_resource in not_usable_free_resources:
|
||||
heapq.heappush(
|
||||
free_resources_pq,
|
||||
(getattr(not_usable_free_resource, metric), not_usable_free_resource)
|
||||
(getattr(not_usable_free_resource, metric), not_usable_free_resource),
|
||||
)
|
||||
else:
|
||||
# add previous resources back, to do printing.
|
||||
for not_usable_free_resource in not_usable_free_resources:
|
||||
heapq.heappush(
|
||||
free_resources_pq,
|
||||
(getattr(not_usable_free_resource, metric), not_usable_free_resource)
|
||||
(getattr(not_usable_free_resource, metric), not_usable_free_resource),
|
||||
)
|
||||
heapq.heappush(
|
||||
required_resources_pq,
|
||||
(-getattr(required_resource, metric), required_resource)
|
||||
(-getattr(required_resource, metric), required_resource),
|
||||
)
|
||||
|
||||
logger.warning(allocation_plan)
|
||||
|
@ -921,7 +921,8 @@ class ResourceController:
|
|||
@staticmethod
|
||||
def _get_single_metric_balanced_allocation_plan(
|
||||
allocation_details: dict,
|
||||
required_resources: list, free_resources: list
|
||||
required_resources: list,
|
||||
free_resources: list,
|
||||
) -> dict:
|
||||
"""Get single_metric_balanced allocation plan.
|
||||
|
||||
|
@ -948,13 +949,13 @@ class ResourceController:
|
|||
for required_resource in required_resources:
|
||||
heapq.heappush(
|
||||
required_resources_pq,
|
||||
(-getattr(required_resource, metric), required_resource)
|
||||
(-getattr(required_resource, metric), required_resource),
|
||||
)
|
||||
free_resources_pq = []
|
||||
for free_resource in free_resources:
|
||||
heapq.heappush(
|
||||
free_resources_pq,
|
||||
(-getattr(free_resource, metric), free_resource)
|
||||
(-getattr(free_resource, metric), free_resource),
|
||||
)
|
||||
|
||||
# Get allocation.
|
||||
|
@ -982,23 +983,23 @@ class ResourceController:
|
|||
free_resource.gpu -= required_resource.gpu
|
||||
heapq.heappush(
|
||||
free_resources_pq,
|
||||
(-getattr(free_resource, metric), free_resource)
|
||||
(-getattr(free_resource, metric), free_resource),
|
||||
)
|
||||
for not_usable_free_resource in not_usable_free_resources:
|
||||
heapq.heappush(
|
||||
free_resources_pq,
|
||||
(-getattr(not_usable_free_resource, metric), not_usable_free_resource)
|
||||
(-getattr(not_usable_free_resource, metric), not_usable_free_resource),
|
||||
)
|
||||
else:
|
||||
# add previous resources back, to do printing.
|
||||
for not_usable_free_resource in not_usable_free_resources:
|
||||
heapq.heappush(
|
||||
free_resources_pq,
|
||||
(-getattr(not_usable_free_resource, metric), not_usable_free_resource)
|
||||
(-getattr(not_usable_free_resource, metric), not_usable_free_resource),
|
||||
)
|
||||
heapq.heappush(
|
||||
required_resources_pq,
|
||||
(-getattr(required_resource, metric), required_resource)
|
||||
(-getattr(required_resource, metric), required_resource),
|
||||
)
|
||||
|
||||
logger.warning(allocation_plan)
|
||||
|
@ -1038,8 +1039,8 @@ class ResourceController:
|
|||
node_name=node_name,
|
||||
cpu=target_free_cpu,
|
||||
memory=target_free_memory,
|
||||
gpu=target_free_gpu
|
||||
)
|
||||
gpu=target_free_gpu,
|
||||
),
|
||||
)
|
||||
except KeyError:
|
||||
# node_details is not in stable state.
|
||||
|
@ -1076,14 +1077,13 @@ class ResourceController:
|
|||
cpu=required_cpu,
|
||||
memory=required_memory,
|
||||
gpu=required_gpu,
|
||||
)
|
||||
),
|
||||
)
|
||||
return resources_list
|
||||
|
||||
|
||||
class ContainerController:
|
||||
"""Controller class for container.
|
||||
"""
|
||||
"""Controller class for container."""
|
||||
|
||||
@staticmethod
|
||||
def build_container_name(job_id: str, component_id: str, component_index: int) -> str:
|
||||
|
@ -1103,8 +1103,7 @@ class ContainerController:
|
|||
|
||||
|
||||
class JobController:
|
||||
"""Controller class for MARO Job.
|
||||
"""
|
||||
"""Controller class for MARO Job."""
|
||||
|
||||
@staticmethod
|
||||
def get_component_id_to_component_type(job_details: dict) -> dict:
|
||||
|
@ -1130,11 +1129,11 @@ if __name__ == "__main__":
|
|||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)-7s | %(threadName)-10s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
master_agent = MasterAgent(
|
||||
local_cluster_details=DetailsReader.load_local_cluster_details(),
|
||||
local_master_details=DetailsReader.load_local_master_details()
|
||||
local_master_details=DetailsReader.load_local_master_details(),
|
||||
)
|
||||
master_agent.start()
|
||||
|
|
|
@ -67,12 +67,12 @@ class PendingJobAgent(mp.Process):
|
|||
|
||||
# Allocation
|
||||
cluster_resource = json.loads(
|
||||
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource")
|
||||
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource"),
|
||||
)
|
||||
is_satisfied, updated_resource = resource_op(
|
||||
cluster_resource,
|
||||
job_detail["total_request_resource"],
|
||||
ResourceOperation.ALLOCATION
|
||||
ResourceOperation.ALLOCATION,
|
||||
)
|
||||
if not is_satisfied:
|
||||
continue
|
||||
|
@ -82,17 +82,17 @@ class PendingJobAgent(mp.Process):
|
|||
self.redis_connection.lrem(f"{self.cluster_name}:pending_job_tickets", 0, job_name)
|
||||
# Update resource
|
||||
cluster_resource = json.loads(
|
||||
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource")
|
||||
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource"),
|
||||
)
|
||||
is_satisfied, updated_resource = resource_op(
|
||||
cluster_resource,
|
||||
job_detail["total_request_resource"],
|
||||
ResourceOperation.ALLOCATION
|
||||
ResourceOperation.ALLOCATION,
|
||||
)
|
||||
self.redis_connection.hset(
|
||||
f"{self.cluster_name}:runtime_detail",
|
||||
"available_resource",
|
||||
json.dumps(updated_resource)
|
||||
json.dumps(updated_resource),
|
||||
)
|
||||
|
||||
def _start_job(self, job_detail: dict):
|
||||
|
@ -100,14 +100,8 @@ class PendingJobAgent(mp.Process):
|
|||
for component_type, command_info in job_detail["components"].items():
|
||||
for number in range(command_info["num"]):
|
||||
container_name = NameCreator.create_name_with_uuid(prefix=component_type)
|
||||
environment_parameters = (
|
||||
f"-e CONTAINER_NAME={container_name} "
|
||||
f"-e JOB_NAME={job_detail['name']} "
|
||||
)
|
||||
labels = (
|
||||
f"-l CONTAINER_NAME={container_name} "
|
||||
f"-l JOB_NAME={job_detail['name']} "
|
||||
)
|
||||
environment_parameters = f"-e CONTAINER_NAME={container_name} " f"-e JOB_NAME={job_detail['name']} "
|
||||
labels = f"-l CONTAINER_NAME={container_name} " f"-l JOB_NAME={job_detail['name']} "
|
||||
if int(command_info["resources"]["gpu"]) == 0:
|
||||
component_command = START_CONTAINER_COMMAND.format(
|
||||
cpu=command_info["resources"]["cpu"],
|
||||
|
@ -117,7 +111,7 @@ class PendingJobAgent(mp.Process):
|
|||
environment_parameters=environment_parameters,
|
||||
labels=labels,
|
||||
image_name=command_info["image"],
|
||||
command=command_info["command"]
|
||||
command=command_info["command"],
|
||||
)
|
||||
else:
|
||||
component_command = START_CONTAINER_WITH_GPU_COMMAND.format(
|
||||
|
@ -129,12 +123,15 @@ class PendingJobAgent(mp.Process):
|
|||
environment_parameters=environment_parameters,
|
||||
labels=labels,
|
||||
image_name=command_info["image"],
|
||||
command=command_info["command"]
|
||||
command=command_info["command"],
|
||||
)
|
||||
|
||||
completed_process = subprocess.run(
|
||||
component_command,
|
||||
shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8"
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf8",
|
||||
)
|
||||
if completed_process.returncode != 0:
|
||||
raise ResourceAllocationFailed(completed_process.stderr)
|
||||
|
@ -145,7 +142,7 @@ class PendingJobAgent(mp.Process):
|
|||
self.redis_connection.hset(
|
||||
f"{self.cluster_name}:job_details",
|
||||
job_detail["name"],
|
||||
json.dumps(job_detail)
|
||||
json.dumps(job_detail),
|
||||
)
|
||||
|
||||
|
||||
|
@ -163,7 +160,7 @@ class ContainerTrackingAgent(mp.Process):
|
|||
|
||||
def _check_container_status(self):
|
||||
running_jobs = ContainerTrackingAgent.get_running_jobs(
|
||||
self.redis_connection.hgetall(f"{self.cluster_name}:job_details")
|
||||
self.redis_connection.hgetall(f"{self.cluster_name}:job_details"),
|
||||
)
|
||||
|
||||
for job_name, job_detail in running_jobs.items():
|
||||
|
@ -174,7 +171,10 @@ class ContainerTrackingAgent(mp.Process):
|
|||
command = f"docker inspect {container_name}"
|
||||
completed_process = subprocess.run(
|
||||
command,
|
||||
shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8"
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf8",
|
||||
)
|
||||
return_str = completed_process.stdout
|
||||
inspect_details_list = json.loads(return_str)
|
||||
|
@ -229,7 +229,7 @@ class JobTrackingAgent(mp.Process):
|
|||
|
||||
def _check_job_state(self):
|
||||
finished_jobs = self._get_finished_jobs(
|
||||
self.redis_connection.hgetall(f"{self.cluster_name}:job_details")
|
||||
self.redis_connection.hgetall(f"{self.cluster_name}:job_details"),
|
||||
)
|
||||
|
||||
for job_name, job_detail in finished_jobs.items():
|
||||
|
@ -254,25 +254,31 @@ class JobTrackingAgent(mp.Process):
|
|||
for container_name in container_list:
|
||||
command = f"docker stop {container_name}"
|
||||
completed_process = subprocess.run(
|
||||
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8"
|
||||
command,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf8",
|
||||
)
|
||||
if completed_process.returncode != 0:
|
||||
raise ResourceAllocationFailed(completed_process.stderr)
|
||||
|
||||
def _job_clear(self, job_name: str, release_resource: dict):
|
||||
cluster_resource = json.loads(
|
||||
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource")
|
||||
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource"),
|
||||
)
|
||||
|
||||
# resource release
|
||||
_, updated_resource = resource_op(
|
||||
cluster_resource, release_resource, ResourceOperation.RELEASE
|
||||
cluster_resource,
|
||||
release_resource,
|
||||
ResourceOperation.RELEASE,
|
||||
)
|
||||
|
||||
self.redis_connection.hset(
|
||||
f"{self.cluster_name}:runtime_detail",
|
||||
"available_resource",
|
||||
json.dumps(updated_resource)
|
||||
json.dumps(updated_resource),
|
||||
)
|
||||
|
||||
|
||||
|
@ -312,7 +318,7 @@ class MasterAgent:
|
|||
self.check_interval = self.cluster_detail["master"]["agents"]["check_interval"]
|
||||
self.redis_connection = redis.Redis(
|
||||
host="localhost",
|
||||
port=self.cluster_detail["master"]["redis"]["port"]
|
||||
port=self.cluster_detail["master"]["redis"]["port"],
|
||||
)
|
||||
self.redis_connection.hset(f"{self.cluster_name}:runtime_detail", "agent_id", os.getpid())
|
||||
|
||||
|
@ -321,28 +327,28 @@ class MasterAgent:
|
|||
pending_job_agent = PendingJobAgent(
|
||||
cluster_name=self.cluster_name,
|
||||
redis_connection=self.redis_connection,
|
||||
check_interval=self.check_interval
|
||||
check_interval=self.check_interval,
|
||||
)
|
||||
pending_job_agent.start()
|
||||
|
||||
killed_job_agent = KilledJobAgent(
|
||||
cluster_name=self.cluster_name,
|
||||
redis_connection=self.redis_connection,
|
||||
check_interval=self.check_interval
|
||||
check_interval=self.check_interval,
|
||||
)
|
||||
killed_job_agent.start()
|
||||
|
||||
job_tracking_agent = JobTrackingAgent(
|
||||
cluster_name=self.cluster_name,
|
||||
redis_connection=self.redis_connection,
|
||||
check_interval=self.check_interval
|
||||
check_interval=self.check_interval,
|
||||
)
|
||||
job_tracking_agent.start()
|
||||
|
||||
container_tracking_agent = ContainerTrackingAgent(
|
||||
cluster_name=self.cluster_name,
|
||||
redis_connection=self.redis_connection,
|
||||
check_interval=self.check_interval
|
||||
check_interval=self.check_interval,
|
||||
)
|
||||
container_tracking_agent.start()
|
||||
|
||||
|
|
|
@ -6,27 +6,26 @@ import requests
|
|||
|
||||
|
||||
class NodeApiClientV1:
|
||||
"""Client class for Node API Server.
|
||||
"""
|
||||
"""Client class for Node API Server."""
|
||||
|
||||
@staticmethod
|
||||
def create_container(node_hostname: str, node_api_server_port: int, create_config: dict) -> dict:
|
||||
response = requests.post(
|
||||
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers",
|
||||
json=create_config
|
||||
json=create_config,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def stop_container(node_hostname: str, node_api_server_port: int, container_name: str) -> dict:
|
||||
response = requests.post(
|
||||
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers/{container_name}:stop"
|
||||
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers/{container_name}:stop",
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def remove_container(node_hostname: str, node_api_server_port: int, container_name: str) -> dict:
|
||||
response = requests.delete(
|
||||
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers/{container_name}"
|
||||
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers/{container_name}",
|
||||
)
|
||||
return response.json()
|
||||
|
|
|
@ -15,6 +15,7 @@ URL_PREFIX = "/v1/cluster"
|
|||
|
||||
# Api functions.
|
||||
|
||||
|
||||
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
|
||||
@check_jwt_validity
|
||||
def get_cluster():
|
||||
|
|
|
@ -15,6 +15,7 @@ URL_PREFIX = "/v1/containers"
|
|||
|
||||
# Api functions.
|
||||
|
||||
|
||||
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
|
||||
@check_jwt_validity
|
||||
def list_containers():
|
||||
|
|
|
@ -24,7 +24,7 @@ def get_job_queue():
|
|||
killed_job_queue = redis_controller.get_killed_job_ticket()
|
||||
return {
|
||||
"pending_jobs": pending_job_queue,
|
||||
"killed_jobs": killed_job_queue
|
||||
"killed_jobs": killed_job_queue,
|
||||
}
|
||||
|
||||
|
||||
|
@ -66,10 +66,10 @@ def create_job(**kwargs):
|
|||
job_details = kwargs["json_dict"]
|
||||
redis_controller.set_job_details(
|
||||
job_name=job_details["name"],
|
||||
job_details=job_details
|
||||
job_details=job_details,
|
||||
)
|
||||
redis_controller.push_pending_job_ticket(
|
||||
job_name=job_details["name"]
|
||||
job_name=job_details["name"],
|
||||
)
|
||||
return {}
|
||||
|
||||
|
@ -118,6 +118,6 @@ def clean_jobs():
|
|||
node_hostname = node_details["hostname"]
|
||||
for container_name in node_details["containers"]:
|
||||
requests.delete(
|
||||
url=f"http://{node_hostname}:{node_details['api_server']['port']}/containers/{container_name}"
|
||||
url=f"http://{node_hostname}:{node_details['api_server']['port']}/containers/{container_name}",
|
||||
)
|
||||
return {}
|
||||
|
|
|
@ -18,7 +18,7 @@ def get_init_node_script():
|
|||
return send_from_directory(
|
||||
directory=f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/node",
|
||||
filename="join_cluster.py",
|
||||
as_attachment=True
|
||||
as_attachment=True,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
abort(404)
|
||||
|
|
|
@ -22,6 +22,7 @@ URL_PREFIX = "/v1/master"
|
|||
|
||||
# Api functions.
|
||||
|
||||
|
||||
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
|
||||
@check_jwt_validity
|
||||
def get_master():
|
||||
|
@ -79,7 +80,7 @@ def save_master_key(private_key: str) -> None:
|
|||
fw.write(private_key)
|
||||
os.chmod(
|
||||
path=f"{Paths.ABS_MARO_LOCAL}/cluster/{cluster_name}/master_to_node_openssh_private_key",
|
||||
mode=stat.S_IRWXU
|
||||
mode=stat.S_IRWXU,
|
||||
)
|
||||
|
||||
|
||||
|
@ -87,20 +88,20 @@ def generate_rsa_openssh_key_pair() -> dict:
|
|||
rsa_key = rsa.generate_private_key(
|
||||
backend=default_backend(),
|
||||
public_exponent=65537,
|
||||
key_size=2048
|
||||
key_size=2048,
|
||||
)
|
||||
|
||||
# Format and encoding are diff from OpenSSH
|
||||
private_key = rsa_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_key = rsa_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
format=serialization.PublicFormat.OpenSSH
|
||||
format=serialization.PublicFormat.OpenSSH,
|
||||
)
|
||||
return {
|
||||
"public_key": public_key.decode("utf-8"),
|
||||
"private_key": private_key.decode("utf-8")
|
||||
"private_key": private_key.decode("utf-8"),
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ URL_PREFIX = "/v1/nodes"
|
|||
|
||||
# Api functions.
|
||||
|
||||
|
||||
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
|
||||
@check_jwt_validity
|
||||
def list_nodes():
|
||||
|
@ -65,14 +66,14 @@ def create_node(**kwargs):
|
|||
node_details["image_files"] = {}
|
||||
node_details["containers"] = {}
|
||||
node_details["state"] = {
|
||||
"status": NodeStatus.PENDING
|
||||
"status": NodeStatus.PENDING,
|
||||
}
|
||||
|
||||
node_name = node_details["name"]
|
||||
with redis_controller.lock(f"lock:name_to_node_details:{node_name}"):
|
||||
redis_controller.set_node_details(
|
||||
node_name=node_name,
|
||||
node_details=node_details
|
||||
node_details=node_details,
|
||||
)
|
||||
return node_details
|
||||
|
||||
|
@ -121,7 +122,7 @@ def start_node(node_name: str):
|
|||
node_username=node_details["username"],
|
||||
node_hostname=node_details["hostname"],
|
||||
node_ssh_port=node_details["ssh"]["port"],
|
||||
cluster_name=local_cluster_details["name"]
|
||||
cluster_name=local_cluster_details["name"],
|
||||
)
|
||||
except ConnectionFailed:
|
||||
abort(400)
|
||||
|
@ -162,7 +163,7 @@ def stop_node(node_name: str):
|
|||
node_username=node_details["username"],
|
||||
node_hostname=node_details["hostname"],
|
||||
node_ssh_port=node_details["ssh"]["port"],
|
||||
cluster_name=local_cluster_details["name"]
|
||||
cluster_name=local_cluster_details["name"],
|
||||
)
|
||||
except ConnectionFailed:
|
||||
abort(400)
|
||||
|
|
|
@ -18,6 +18,7 @@ URL_PREFIX = "/v1/schedules"
|
|||
|
||||
# Api functions.
|
||||
|
||||
|
||||
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
|
||||
@check_jwt_validity
|
||||
def list_schedules():
|
||||
|
@ -57,14 +58,14 @@ def create_schedule(**kwargs):
|
|||
|
||||
redis_controller.set_schedule_details(
|
||||
schedule_name=schedule_details["name"],
|
||||
schedule_details=schedule_details
|
||||
schedule_details=schedule_details,
|
||||
)
|
||||
|
||||
# Build individual jobs
|
||||
for job_name in schedule_details["job_names"]:
|
||||
redis_controller.set_job_details(
|
||||
job_name=job_name,
|
||||
job_details=_build_job_details(schedule_details=schedule_details, job_name=job_name)
|
||||
job_details=_build_job_details(schedule_details=schedule_details, job_name=job_name),
|
||||
)
|
||||
redis_controller.push_pending_job_ticket(job_name=job_name)
|
||||
return {}
|
||||
|
@ -111,7 +112,7 @@ def _build_job_details(schedule_details: dict, job_name: str) -> dict:
|
|||
job_details["name"] = job_name
|
||||
job_details["tags"] = {
|
||||
"schedule_name": schedule_details["name"],
|
||||
"schedule_id": schedule_details["id"]
|
||||
"schedule_id": schedule_details["id"],
|
||||
}
|
||||
job_details.pop("job_names")
|
||||
|
||||
|
|
|
@ -19,5 +19,5 @@ URL_PREFIX = "/v1/status"
|
|||
def status():
|
||||
return {
|
||||
"status": "OK",
|
||||
"time": time.time()
|
||||
"time": time.time(),
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ URL_PREFIX = "/v1/visible"
|
|||
|
||||
# Api functions.
|
||||
|
||||
|
||||
@blueprint.route(f"{URL_PREFIX}/static", methods=["GET"])
|
||||
@check_jwt_validity
|
||||
def get_static_resource():
|
||||
|
@ -34,6 +35,6 @@ def get_dynamic_resource(previous_length: str):
|
|||
"""
|
||||
|
||||
name_to_node_usage = redis_controller.get_resource_usage(
|
||||
previous_length=int(previous_length)
|
||||
previous_length=int(previous_length),
|
||||
)
|
||||
return name_to_node_usage
|
||||
|
|
|
@ -42,11 +42,11 @@ def check_jwt_validity(func):
|
|||
user_details = redis_controller.get_user_details(user_id=payload["user_id"])
|
||||
|
||||
# Get decrypted_bytes
|
||||
if request.data != b'':
|
||||
if request.data != b"":
|
||||
decrypted_bytes = _get_decrypted_bytes(
|
||||
payload=payload,
|
||||
encrypted_bytes=request.data,
|
||||
user_details=user_details
|
||||
user_details=user_details,
|
||||
)
|
||||
kwargs["json_dict"] = json.loads(decrypted_bytes.decode("utf-8"))
|
||||
|
||||
|
@ -69,7 +69,7 @@ def check_jwt_validity(func):
|
|||
def _get_encrypted_bytes(json_dict: dict, aes_key: bytes, aes_ctr_nonce: bytes) -> bytes:
|
||||
cipher = Cipher(
|
||||
algorithm=algorithms.AES(key=aes_key),
|
||||
mode=modes.CTR(nonce=aes_ctr_nonce)
|
||||
mode=modes.CTR(nonce=aes_ctr_nonce),
|
||||
)
|
||||
encryptor = cipher.encryptor()
|
||||
return_bytes = encryptor.update(json.dumps(json_dict).encode("utf-8")) + encryptor.finalize()
|
||||
|
@ -80,21 +80,21 @@ def _get_decrypted_bytes(payload: dict, encrypted_bytes: bytes, user_details: di
|
|||
# Decrypted aes_key and aes_ctr_nonce
|
||||
dev_to_master_encryption_private_key_obj = serialization.load_pem_private_key(
|
||||
data=user_details["dev_to_master_encryption_private_key"].encode("utf-8"),
|
||||
password=None
|
||||
password=None,
|
||||
)
|
||||
aes_key = dev_to_master_encryption_private_key_obj.decrypt(
|
||||
ciphertext=base64.b64decode(payload["aes_key"].encode("ascii")),
|
||||
padding=_get_asymmetric_padding()
|
||||
padding=_get_asymmetric_padding(),
|
||||
)
|
||||
aes_ctr_nonce = dev_to_master_encryption_private_key_obj.decrypt(
|
||||
ciphertext=base64.b64decode(payload["aes_ctr_nonce"].encode("ascii")),
|
||||
padding=_get_asymmetric_padding()
|
||||
padding=_get_asymmetric_padding(),
|
||||
)
|
||||
|
||||
# Return decrypted_bytes
|
||||
cipher = Cipher(
|
||||
algorithm=algorithms.AES(key=aes_key),
|
||||
mode=modes.CTR(nonce=aes_ctr_nonce)
|
||||
mode=modes.CTR(nonce=aes_ctr_nonce),
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
return decryptor.update(encrypted_bytes) + decryptor.finalize()
|
||||
|
@ -109,12 +109,12 @@ def build_response(return_json: dict, user_details: dict) -> Response:
|
|||
encrypted_bytes = _get_encrypted_bytes(
|
||||
json_dict=return_json,
|
||||
aes_key=aes_key,
|
||||
aes_ctr_nonce=aes_ctr_nonce
|
||||
aes_ctr_nonce=aes_ctr_nonce,
|
||||
)
|
||||
|
||||
# Encrypt aes_key and aes_ctr_nonce with rsa_key_pair
|
||||
master_to_dev_encryption_public_key_obj = serialization.load_pem_public_key(
|
||||
data=user_details["master_to_dev_encryption_public_key"].encode("utf-8")
|
||||
data=user_details["master_to_dev_encryption_public_key"].encode("utf-8"),
|
||||
)
|
||||
|
||||
# Build jwt_token
|
||||
|
@ -123,18 +123,18 @@ def build_response(return_json: dict, user_details: dict) -> Response:
|
|||
"aes_key": base64.b64encode(
|
||||
master_to_dev_encryption_public_key_obj.encrypt(
|
||||
plaintext=aes_key,
|
||||
padding=_get_asymmetric_padding()
|
||||
)
|
||||
padding=_get_asymmetric_padding(),
|
||||
),
|
||||
).decode("ascii"),
|
||||
"aes_ctr_nonce": base64.b64encode(
|
||||
master_to_dev_encryption_public_key_obj.encrypt(
|
||||
plaintext=aes_ctr_nonce,
|
||||
padding=_get_asymmetric_padding()
|
||||
)
|
||||
).decode("ascii")
|
||||
padding=_get_asymmetric_padding(),
|
||||
),
|
||||
).decode("ascii"),
|
||||
},
|
||||
key=user_details["master_to_dev_signing_private_key"],
|
||||
algorithm="RS512"
|
||||
algorithm="RS512",
|
||||
)
|
||||
|
||||
# Build response
|
||||
|
@ -147,5 +147,5 @@ def _get_asymmetric_padding():
|
|||
return padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
label=None,
|
||||
)
|
||||
|
|
|
@ -33,24 +33,24 @@ class NodeAgent:
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host=self._local_master_details["hostname"],
|
||||
port=self._local_master_details["redis"]["port"]
|
||||
port=self._local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
# Init agents.
|
||||
self.load_image_agent = LoadImageAgent(
|
||||
local_cluster_details=local_cluster_details,
|
||||
local_master_details=local_master_details,
|
||||
local_node_details=local_node_details
|
||||
local_node_details=local_node_details,
|
||||
)
|
||||
self.node_tracking_agent = NodeTrackingAgent(
|
||||
local_cluster_details=local_cluster_details,
|
||||
local_master_details=local_master_details,
|
||||
local_node_details=local_node_details
|
||||
local_node_details=local_node_details,
|
||||
)
|
||||
self.resource_tracking_agent = ResourceTrackingAgent(
|
||||
local_cluster_details=local_cluster_details,
|
||||
local_master_details=local_master_details,
|
||||
local_node_details=local_node_details
|
||||
local_node_details=local_node_details,
|
||||
)
|
||||
|
||||
# When SIGTERM, gracefully exit.
|
||||
|
@ -70,7 +70,7 @@ class NodeAgent:
|
|||
self.resource_tracking_agent.join()
|
||||
|
||||
def gracefully_exit(self, signum, frame) -> None:
|
||||
""" Gracefully exit when SIGTERM.
|
||||
"""Gracefully exit when SIGTERM.
|
||||
|
||||
If we get SIGKILL here, it means that the node is not stopped properly,
|
||||
the status of the node remains 'RUNNING'.
|
||||
|
@ -87,7 +87,7 @@ class NodeAgent:
|
|||
# Set STOPPED state
|
||||
state_details = {
|
||||
"status": NodeStatus.STOPPED,
|
||||
"check_time": self._redis_controller.get_time()
|
||||
"check_time": self._redis_controller.get_time(),
|
||||
}
|
||||
with self._redis_controller.lock(f"lock:name_to_node_details:{self._local_node_details['name']}"):
|
||||
node_details = self._redis_controller.get_node_details(node_name=self._local_node_details["name"])
|
||||
|
@ -96,7 +96,7 @@ class NodeAgent:
|
|||
node_details["state"] = state_details
|
||||
self._redis_controller.set_node_details(
|
||||
node_name=self._local_node_details["name"],
|
||||
node_details=node_details
|
||||
node_details=node_details,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
@ -116,7 +116,7 @@ class NodeAgent:
|
|||
if not isinstance(resources_details["memory"], (float, int)):
|
||||
if resources_details["memory"] != "all":
|
||||
logger.warning("Invalid memory assignment, will use all memories in this node")
|
||||
resources_details["memory"] = psutil.virtual_memory().total / (1024 ** 2) # (float) in MByte
|
||||
resources_details["memory"] = psutil.virtual_memory().total / (1024**2) # (float) in MByte
|
||||
|
||||
if not isinstance(resources_details["gpu"], (float, int)):
|
||||
if resources_details["gpu"] != "all":
|
||||
|
@ -140,7 +140,7 @@ class NodeAgent:
|
|||
node_details["resources"] = resources_details
|
||||
self._redis_controller.set_node_details(
|
||||
node_name=self._local_node_details["name"],
|
||||
node_details=node_details
|
||||
node_details=node_details,
|
||||
)
|
||||
|
||||
|
||||
|
@ -152,8 +152,10 @@ class NodeTrackingAgent(multiprocessing.Process):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
local_cluster_details: dict, local_master_details: dict, local_node_details: dict,
|
||||
check_interval: int = 5
|
||||
local_cluster_details: dict,
|
||||
local_master_details: dict,
|
||||
local_node_details: dict,
|
||||
check_interval: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
self._local_cluster_details = local_cluster_details
|
||||
|
@ -162,7 +164,7 @@ class NodeTrackingAgent(multiprocessing.Process):
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host=self._local_master_details["hostname"],
|
||||
port=self._local_master_details["redis"]["port"]
|
||||
port=self._local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
self._check_interval = check_interval
|
||||
|
@ -201,20 +203,20 @@ class NodeTrackingAgent(multiprocessing.Process):
|
|||
container_name_to_inspect_details = self._get_container_name_to_inspect_details()
|
||||
self._update_name_to_container_details(
|
||||
container_name_to_inspect_details=container_name_to_inspect_details,
|
||||
name_to_container_details=name_to_container_details
|
||||
name_to_container_details=name_to_container_details,
|
||||
)
|
||||
|
||||
# Resources related
|
||||
resources_details = node_details["resources"]
|
||||
self._update_occupied_resources(
|
||||
container_name_to_inspect_details=container_name_to_inspect_details,
|
||||
resources_details=resources_details
|
||||
resources_details=resources_details,
|
||||
)
|
||||
|
||||
# State related.
|
||||
state_details = {
|
||||
"status": NodeStatus.RUNNING,
|
||||
"check_time": self._redis_controller.get_time()
|
||||
"check_time": self._redis_controller.get_time(),
|
||||
}
|
||||
|
||||
# Save details.
|
||||
|
@ -225,13 +227,13 @@ class NodeTrackingAgent(multiprocessing.Process):
|
|||
node_details["state"] = state_details
|
||||
self._redis_controller.set_node_details(
|
||||
node_name=self._local_node_details["name"],
|
||||
node_details=node_details
|
||||
node_details=node_details,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _update_name_to_container_details(
|
||||
container_name_to_inspect_details: dict,
|
||||
name_to_container_details: dict
|
||||
name_to_container_details: dict,
|
||||
) -> None:
|
||||
"""Update name_to_container_details from container_name_to_inspect_details.
|
||||
|
||||
|
@ -246,10 +248,10 @@ class NodeTrackingAgent(multiprocessing.Process):
|
|||
for container_name, inspect_details in container_name_to_inspect_details.items():
|
||||
# Extract container state and labels.
|
||||
name_to_container_details[container_name] = NodeTrackingAgent._extract_labels(
|
||||
inspect_details=inspect_details
|
||||
inspect_details=inspect_details,
|
||||
)
|
||||
name_to_container_details[container_name]["state"] = NodeTrackingAgent._extract_state(
|
||||
inspect_details=inspect_details
|
||||
inspect_details=inspect_details,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -333,8 +335,10 @@ class LoadImageAgent(multiprocessing.Process):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
local_cluster_details: dict, local_master_details: dict, local_node_details: dict,
|
||||
check_interval: int = 10
|
||||
local_cluster_details: dict,
|
||||
local_master_details: dict,
|
||||
local_node_details: dict,
|
||||
check_interval: int = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self._local_cluster_details = local_cluster_details
|
||||
|
@ -343,7 +347,7 @@ class LoadImageAgent(multiprocessing.Process):
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host=self._local_master_details["hostname"],
|
||||
port=self._local_master_details["redis"]["port"]
|
||||
port=self._local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
self._check_interval = check_interval
|
||||
|
@ -382,21 +386,25 @@ class LoadImageAgent(multiprocessing.Process):
|
|||
for image_file_name, image_file_details in name_to_image_file_details_in_master.items():
|
||||
if (
|
||||
image_file_name not in name_to_image_file_details_in_node
|
||||
or name_to_image_file_details_in_node[image_file_name]["md5_checksum"] !=
|
||||
image_file_details["md5_checksum"]
|
||||
or name_to_image_file_details_in_node[image_file_name]["md5_checksum"]
|
||||
!= image_file_details["md5_checksum"]
|
||||
):
|
||||
unloaded_image_names.append(image_file_name)
|
||||
|
||||
# Parallel load
|
||||
with ThreadPool(5) as pool:
|
||||
params = [
|
||||
[os.path.expanduser(
|
||||
f"~/.maro-shared/clusters/{self._local_cluster_details['name']}/image_files/{unloaded_image_name}")]
|
||||
[
|
||||
os.path.expanduser(
|
||||
f"~/.maro-shared/clusters/{self._local_cluster_details['name']}/"
|
||||
+ "image_files/{unloaded_image_name}",
|
||||
),
|
||||
]
|
||||
for unloaded_image_name in unloaded_image_names
|
||||
]
|
||||
pool.starmap(
|
||||
self._load_image,
|
||||
params
|
||||
params,
|
||||
)
|
||||
|
||||
with self._redis_controller.lock(f"lock:name_to_node_details:{self._local_node_details['name']}"):
|
||||
|
@ -405,7 +413,7 @@ class LoadImageAgent(multiprocessing.Process):
|
|||
node_details["image_files"] = name_to_image_file_details_in_master
|
||||
self._redis_controller.set_node_details(
|
||||
node_name=self._local_node_details["name"],
|
||||
node_details=node_details
|
||||
node_details=node_details,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -421,7 +429,7 @@ class ResourceTrackingAgent(multiprocessing.Process):
|
|||
local_cluster_details: dict,
|
||||
local_master_details: dict,
|
||||
local_node_details: dict,
|
||||
check_interval: int = 30
|
||||
check_interval: int = 30,
|
||||
):
|
||||
super().__init__()
|
||||
self._local_cluster_details = local_cluster_details
|
||||
|
@ -430,7 +438,7 @@ class ResourceTrackingAgent(multiprocessing.Process):
|
|||
|
||||
self._redis_controller = RedisController(
|
||||
host=self._local_master_details["hostname"],
|
||||
port=self._local_master_details["redis"]["port"]
|
||||
port=self._local_master_details["redis"]["port"],
|
||||
)
|
||||
|
||||
self._check_interval = check_interval
|
||||
|
@ -468,7 +476,7 @@ class ResourceTrackingAgent(multiprocessing.Process):
|
|||
node_name=self._local_node_details["name"],
|
||||
cpu_usage=cpu_usage_per_core,
|
||||
memory_usage=memory_usage,
|
||||
gpu_memory_usage=gpu_memory_usage
|
||||
gpu_memory_usage=gpu_memory_usage,
|
||||
)
|
||||
|
||||
|
||||
|
@ -476,12 +484,12 @@ if __name__ == "__main__":
|
|||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)-7s | %(threadName)-10s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
node_agent = NodeAgent(
|
||||
local_cluster_details=DetailsReader.load_local_cluster_details(),
|
||||
local_master_details=DetailsReader.load_local_master_details(),
|
||||
local_node_details=DetailsReader.load_local_node_details()
|
||||
local_node_details=DetailsReader.load_local_node_details(),
|
||||
)
|
||||
node_agent.start()
|
||||
|
|
|
@ -15,6 +15,7 @@ URL_PREFIX = "/v1/containers"
|
|||
|
||||
# Api functions.
|
||||
|
||||
|
||||
@blueprint.route(f"{URL_PREFIX}", methods=["POST"])
|
||||
def create_container():
|
||||
"""Create a container, aka 'docker run'.
|
||||
|
|
|
@ -16,5 +16,5 @@ URL_PREFIX = "/v1/status"
|
|||
def status():
|
||||
return {
|
||||
"status": "OK",
|
||||
"time": time.time()
|
||||
"time": time.time(),
|
||||
}
|
||||
|
|
|
@ -14,8 +14,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ConnectionTester:
|
||||
"""Tester class for connection.
|
||||
"""
|
||||
"""Tester class for connection."""
|
||||
|
||||
@staticmethod
|
||||
def test_ssh_default_port_connection(node_username: str, node_hostname: str, node_ssh_port: int, cluster_name: str):
|
||||
|
@ -36,14 +35,14 @@ class ConnectionTester:
|
|||
node_ssh_port=node_ssh_port,
|
||||
node_username=node_username,
|
||||
node_hostname=node_hostname,
|
||||
cluster_name=cluster_name
|
||||
cluster_name=cluster_name,
|
||||
)
|
||||
return True
|
||||
except (CommandExecutionError, TimeoutExpired):
|
||||
remain_retries -= 1
|
||||
logger.debug(
|
||||
f"Unable to connect to {node_hostname} with port {node_ssh_port}, "
|
||||
f"remains {remain_retries} retries"
|
||||
f"remains {remain_retries} retries",
|
||||
)
|
||||
time.sleep(5)
|
||||
raise ConnectionFailed(f"Unable to connect to {node_hostname} with port {node_ssh_port}")
|
||||
|
|
|
@ -8,8 +8,7 @@ from .params import Paths
|
|||
|
||||
|
||||
class DetailsReader:
|
||||
"""Reader class for details.
|
||||
"""
|
||||
"""Reader class for details."""
|
||||
|
||||
@staticmethod
|
||||
def load_cluster_details(cluster_name: str) -> dict:
|
||||
|
|
|
@ -8,8 +8,7 @@ from .subprocess import Subprocess
|
|||
|
||||
|
||||
class DockerController:
|
||||
"""Controller class for docker.
|
||||
"""
|
||||
"""Controller class for docker."""
|
||||
|
||||
@staticmethod
|
||||
def remove_container(container_name: str) -> None:
|
||||
|
@ -69,13 +68,12 @@ class DockerController:
|
|||
command=create_config["command"],
|
||||
image_name=create_config["image_name"],
|
||||
volumes=DockerController._build_list_params_str(params=create_config["volumes"], option="-v"),
|
||||
|
||||
# System related.
|
||||
container_name=create_config["container_name"],
|
||||
fluentd_address=create_config["fluentd_address"],
|
||||
fluentd_tag=create_config["fluentd_tag"],
|
||||
environments=DockerController._build_dict_params_str(params=create_config["environments"], option="-e"),
|
||||
labels=DockerController._build_dict_params_str(params=create_config["labels"], option="-l")
|
||||
labels=DockerController._build_dict_params_str(params=create_config["labels"], option="-l"),
|
||||
)
|
||||
|
||||
# Start creating
|
||||
|
@ -86,7 +84,7 @@ class DockerController:
|
|||
|
||||
@staticmethod
|
||||
def list_container_names() -> list:
|
||||
command = "sudo docker ps -a --format \"{{.Names}}\""
|
||||
command = 'sudo docker ps -a --format "{{.Names}}"'
|
||||
return_str = Subprocess.run(command=command)
|
||||
if return_str == "":
|
||||
return []
|
||||
|
|
|
@ -4,41 +4,36 @@
|
|||
|
||||
# First Layer.
|
||||
|
||||
|
||||
class AgentError(Exception):
|
||||
""" Base error class for all MARO Grass Agents."""
|
||||
pass
|
||||
"""Base error class for all MARO Grass Agents."""
|
||||
|
||||
|
||||
# Second Layer.
|
||||
|
||||
|
||||
class UserFault(AgentError):
|
||||
""" Users should be responsible for the errors."""
|
||||
pass
|
||||
"""Users should be responsible for the errors."""
|
||||
|
||||
|
||||
class ServiceError(AgentError):
|
||||
""" MARO Services should be responsible for the errors."""
|
||||
pass
|
||||
"""MARO Services should be responsible for the errors."""
|
||||
|
||||
|
||||
# Third Layer.
|
||||
|
||||
|
||||
class ResourceAllocationFailed(UserFault):
|
||||
""" Resources are insufficient, unable to allocate."""
|
||||
pass
|
||||
"""Resources are insufficient, unable to allocate."""
|
||||
|
||||
|
||||
class StartContainerError(ServiceError):
|
||||
""" Error when starting containers."""
|
||||
pass
|
||||
"""Error when starting containers."""
|
||||
|
||||
|
||||
class CommandExecutionError(ServiceError):
|
||||
""" Failed to execute shell commands."""
|
||||
pass
|
||||
"""Failed to execute shell commands."""
|
||||
|
||||
|
||||
class ConnectionFailed(ServiceError):
|
||||
""" Failed to connect to other nodes."""
|
||||
pass
|
||||
"""Failed to connect to other nodes."""
|
||||
|
|
|
@ -6,8 +6,7 @@ import uuid
|
|||
|
||||
|
||||
class NameCreator:
|
||||
"""Creator class for MARO Resource namings.
|
||||
"""
|
||||
"""Creator class for MARO Resource namings."""
|
||||
|
||||
@staticmethod
|
||||
def create_name_with_uuid(prefix: str, uuid_len: int = 16) -> str:
|
||||
|
|
|
@ -9,8 +9,7 @@ from redis.lock import Lock
|
|||
|
||||
|
||||
class RedisController:
|
||||
"""Controller class for Redis.
|
||||
"""
|
||||
"""Controller class for Redis."""
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
self._redis = redis.Redis(host=host, port=port, encoding="utf-8", decode_responses=True)
|
||||
|
@ -23,20 +22,20 @@ class RedisController:
|
|||
def set_cluster_details(self, cluster_details: dict):
|
||||
self._redis.set(
|
||||
"cluster_details",
|
||||
json.dumps(cluster_details)
|
||||
json.dumps(cluster_details),
|
||||
)
|
||||
|
||||
"""Master Details Related."""
|
||||
|
||||
def get_master_details(self) -> dict:
|
||||
return json.loads(
|
||||
self._redis.get("master_details")
|
||||
self._redis.get("master_details"),
|
||||
)
|
||||
|
||||
def set_master_details(self, master_details: dict) -> None:
|
||||
self._redis.set(
|
||||
"master_details",
|
||||
json.dumps(master_details)
|
||||
json.dumps(master_details),
|
||||
)
|
||||
|
||||
def delete_master_details(self) -> None:
|
||||
|
@ -46,7 +45,7 @@ class RedisController:
|
|||
|
||||
def get_name_to_node_details(self) -> dict:
|
||||
name_to_node_details = self._redis.hgetall(
|
||||
"name_to_node_details"
|
||||
"name_to_node_details",
|
||||
)
|
||||
for node_name, node_details_str in name_to_node_details.items():
|
||||
name_to_node_details[node_name] = json.loads(node_details_str)
|
||||
|
@ -54,7 +53,7 @@ class RedisController:
|
|||
|
||||
def get_name_to_node_resources(self) -> dict:
|
||||
name_to_node_details = self._redis.hgetall(
|
||||
"name_to_node_details"
|
||||
"name_to_node_details",
|
||||
)
|
||||
for node_name, node_details_str in name_to_node_details.items():
|
||||
node_details = json.loads(node_details_str)
|
||||
|
@ -64,7 +63,7 @@ class RedisController:
|
|||
def get_node_details(self, node_name: str) -> dict:
|
||||
node_details = self._redis.hget(
|
||||
"name_to_node_details",
|
||||
node_name
|
||||
node_name,
|
||||
)
|
||||
if node_details is None:
|
||||
return {}
|
||||
|
@ -75,13 +74,13 @@ class RedisController:
|
|||
self._redis.hset(
|
||||
"name_to_node_details",
|
||||
node_name,
|
||||
json.dumps(node_details)
|
||||
json.dumps(node_details),
|
||||
)
|
||||
|
||||
def delete_node_details(self, node_name: str) -> None:
|
||||
self._redis.hdel(
|
||||
"name_to_node_details",
|
||||
node_name
|
||||
node_name,
|
||||
)
|
||||
|
||||
def push_resource_usage(
|
||||
|
@ -89,29 +88,29 @@ class RedisController:
|
|||
node_name: str,
|
||||
cpu_usage: list,
|
||||
memory_usage: float,
|
||||
gpu_memory_usage: list
|
||||
gpu_memory_usage: list,
|
||||
):
|
||||
# Push cpu usage to redis
|
||||
self._redis.rpush(
|
||||
f"{node_name}:cpu_usage_per_core",
|
||||
json.dumps(cpu_usage)
|
||||
json.dumps(cpu_usage),
|
||||
)
|
||||
|
||||
# Push memory usage to redis
|
||||
self._redis.rpush(
|
||||
f"{node_name}:memory_usage",
|
||||
json.dumps(memory_usage)
|
||||
json.dumps(memory_usage),
|
||||
)
|
||||
|
||||
# Push gpu memory usage to redis
|
||||
self._redis.rpush(
|
||||
f"{node_name}:gpu_memory_usage",
|
||||
json.dumps(gpu_memory_usage)
|
||||
json.dumps(gpu_memory_usage),
|
||||
)
|
||||
|
||||
def get_resource_usage(self, previous_length: int):
|
||||
name_to_node_details = self._redis.hgetall(
|
||||
"name_to_node_details"
|
||||
"name_to_node_details",
|
||||
)
|
||||
node_name_list = name_to_node_details.keys()
|
||||
name_to_node_usage = {}
|
||||
|
@ -120,15 +119,18 @@ class RedisController:
|
|||
usage_dict = {}
|
||||
usage_dict["cpu"] = self._redis.lrange(
|
||||
f"{node_name}:cpu_usage_per_core",
|
||||
previous_length, -1
|
||||
previous_length,
|
||||
-1,
|
||||
)
|
||||
usage_dict["memory"] = self._redis.lrange(
|
||||
f"{node_name}:memory_usage",
|
||||
previous_length, -1
|
||||
previous_length,
|
||||
-1,
|
||||
)
|
||||
usage_dict["gpu"] = self._redis.lrange(
|
||||
f"{node_name}:gpu_memory_usage",
|
||||
previous_length, -1
|
||||
previous_length,
|
||||
-1,
|
||||
)
|
||||
name_to_node_usage[node_name] = usage_dict
|
||||
|
||||
|
@ -147,7 +149,7 @@ class RedisController:
|
|||
def get_job_details(self, job_name: str) -> dict:
|
||||
return_str = self._redis.hget(
|
||||
"name_to_job_details",
|
||||
job_name
|
||||
job_name,
|
||||
)
|
||||
return json.loads(return_str) if return_str is not None else {}
|
||||
|
||||
|
@ -155,13 +157,13 @@ class RedisController:
|
|||
self._redis.hset(
|
||||
"name_to_job_details",
|
||||
job_name,
|
||||
json.dumps(job_details)
|
||||
json.dumps(job_details),
|
||||
)
|
||||
|
||||
def delete_job_details(self, job_name: str) -> None:
|
||||
self._redis.hdel(
|
||||
"name_to_job_details",
|
||||
job_name
|
||||
job_name,
|
||||
)
|
||||
|
||||
"""Schedule Details Related."""
|
||||
|
@ -177,7 +179,7 @@ class RedisController:
|
|||
def get_schedule_details(self, schedule_name: str) -> dict:
|
||||
return_str = self._redis.hget(
|
||||
"name_to_schedule_details",
|
||||
schedule_name
|
||||
schedule_name,
|
||||
)
|
||||
return json.loads(return_str) if return_str is not None else {}
|
||||
|
||||
|
@ -185,13 +187,13 @@ class RedisController:
|
|||
self._redis.hset(
|
||||
"name_to_schedule_details",
|
||||
schedule_name,
|
||||
json.dumps(schedule_details)
|
||||
json.dumps(schedule_details),
|
||||
)
|
||||
|
||||
def delete_schedule_details(self, schedule_name: str) -> None:
|
||||
self._redis.hdel(
|
||||
"name_to_schedule_details",
|
||||
schedule_name
|
||||
schedule_name,
|
||||
)
|
||||
|
||||
"""Container Details Related."""
|
||||
|
@ -213,14 +215,14 @@ class RedisController:
|
|||
name_to_container_details[container_name] = json.dumps(container_details)
|
||||
self._redis.hmset(
|
||||
"name_to_container_details",
|
||||
name_to_container_details
|
||||
name_to_container_details,
|
||||
)
|
||||
|
||||
def set_container_details(self, container_name: str, container_details: dict) -> None:
|
||||
self._redis.hset(
|
||||
"name_to_container_details",
|
||||
container_name,
|
||||
container_details
|
||||
container_details,
|
||||
)
|
||||
|
||||
"""Pending Job Tickets Related."""
|
||||
|
@ -229,20 +231,20 @@ class RedisController:
|
|||
return self._redis.lrange(
|
||||
"pending_job_tickets",
|
||||
0,
|
||||
-1
|
||||
-1,
|
||||
)
|
||||
|
||||
def push_pending_job_ticket(self, job_name: str):
|
||||
self._redis.rpush(
|
||||
"pending_job_tickets",
|
||||
job_name
|
||||
job_name,
|
||||
)
|
||||
|
||||
def remove_pending_job_ticket(self, job_name: str):
|
||||
self._redis.lrem(
|
||||
"pending_job_tickets",
|
||||
0,
|
||||
job_name
|
||||
job_name,
|
||||
)
|
||||
|
||||
def delete_pending_jobs_queue(self):
|
||||
|
@ -254,20 +256,20 @@ class RedisController:
|
|||
return self._redis.lrange(
|
||||
"killed_job_tickets",
|
||||
0,
|
||||
-1
|
||||
-1,
|
||||
)
|
||||
|
||||
def push_killed_job_ticket(self, job_name: str):
|
||||
self._redis.rpush(
|
||||
"killed_job_tickets",
|
||||
job_name
|
||||
job_name,
|
||||
)
|
||||
|
||||
def remove_killed_job_ticket(self, job_name: str):
|
||||
self._redis.lrem(
|
||||
"killed_job_tickets",
|
||||
0,
|
||||
job_name
|
||||
job_name,
|
||||
)
|
||||
|
||||
def delete_killed_jobs_queue(self):
|
||||
|
@ -277,7 +279,7 @@ class RedisController:
|
|||
|
||||
def get_rejoin_component_name_to_container_name(self, job_id: str) -> dict:
|
||||
return self._redis.hgetall(
|
||||
f"job:{job_id}:rejoin_component_name_to_container_name"
|
||||
f"job:{job_id}:rejoin_component_name_to_container_name",
|
||||
)
|
||||
|
||||
def get_rejoin_container_name_to_component_name(self, job_id: str) -> dict:
|
||||
|
@ -286,18 +288,18 @@ class RedisController:
|
|||
|
||||
def delete_rejoin_container_name_to_component_name(self, job_id: str) -> None:
|
||||
self._redis.delete(
|
||||
f"job:{job_id}:rejoin_component_name_to_container_name"
|
||||
f"job:{job_id}:rejoin_component_name_to_container_name",
|
||||
)
|
||||
|
||||
def get_job_runtime_details(self, job_id: str) -> dict:
|
||||
return self._redis.hgetall(
|
||||
f"job:{job_id}:runtime_details"
|
||||
f"job:{job_id}:runtime_details",
|
||||
)
|
||||
|
||||
def get_rejoin_component_restart_times(self, job_id: str, component_id: str) -> int:
|
||||
restart_times = self._redis.hget(
|
||||
f"job:{job_id}:component_id_to_restart_times",
|
||||
component_id
|
||||
component_id,
|
||||
)
|
||||
return 0 if restart_times is None else int(restart_times)
|
||||
|
||||
|
@ -305,7 +307,7 @@ class RedisController:
|
|||
self._redis.hincrby(
|
||||
f"job:{job_id}:component_id_to_restart_times",
|
||||
component_id,
|
||||
1
|
||||
1,
|
||||
)
|
||||
|
||||
"""User Related."""
|
||||
|
@ -313,14 +315,14 @@ class RedisController:
|
|||
def get_user_details(self, user_id: str) -> dict:
|
||||
return_str = self._redis.hget(
|
||||
"id_to_user_details",
|
||||
user_id
|
||||
user_id,
|
||||
)
|
||||
return json.loads(return_str) if return_str is not None else {}
|
||||
|
||||
"""Utils."""
|
||||
|
||||
def get_time(self) -> int:
|
||||
""" Get current unix timestamp (seconds) from Redis server.
|
||||
"""Get current unix timestamp (seconds) from Redis server.
|
||||
|
||||
Returns:
|
||||
int: current timestamp.
|
||||
|
@ -328,7 +330,7 @@ class RedisController:
|
|||
return self._redis.time()[0]
|
||||
|
||||
def lock(self, name: str) -> Lock:
|
||||
""" Get a new lock with redis.
|
||||
"""Get a new lock with redis.
|
||||
|
||||
Use 'with lock(name):' paradigm to do the locking.
|
||||
|
||||
|
|
|
@ -3,8 +3,7 @@
|
|||
|
||||
|
||||
class BasicResource:
|
||||
"""An abstraction class for computing resources.
|
||||
"""
|
||||
"""An abstraction class for computing resources."""
|
||||
|
||||
def __init__(self, cpu: float, memory: float, gpu: float):
|
||||
self.cpu = cpu
|
||||
|
|
|
@ -8,8 +8,7 @@ from .exception import CommandExecutionError
|
|||
|
||||
|
||||
class Subprocess:
|
||||
"""Wrapper class of subprocess
|
||||
"""
|
||||
"""Wrapper class of subprocess"""
|
||||
|
||||
@staticmethod
|
||||
def run(command: str, timeout: int = None) -> str:
|
||||
|
@ -30,7 +29,7 @@ class Subprocess:
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
if completed_process.returncode != 0:
|
||||
raise CommandExecutionError(completed_process.stderr)
|
||||
|
|
|
@ -7,5 +7,5 @@ def template(export_path: str, **kwargs):
|
|||
from maro.cli.grass.executors.grass_executor import GrassExecutor
|
||||
|
||||
GrassExecutor.template(
|
||||
export_path=export_path
|
||||
export_path=export_path,
|
||||
)
|
||||
|
|
|
@ -8,8 +8,7 @@ from maro.cli.utils.subprocess import Subprocess
|
|||
|
||||
|
||||
class DockerController:
|
||||
"""Controller class for docker.
|
||||
"""
|
||||
"""Controller class for docker."""
|
||||
|
||||
@staticmethod
|
||||
def save_image(image_name: str, abs_export_path: str):
|
||||
|
|
|
@ -15,24 +15,23 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|||
|
||||
|
||||
class EncryptedRequests:
|
||||
"""Wrapper class for requests with encryption/decryption integrated.
|
||||
"""
|
||||
"""Wrapper class for requests with encryption/decryption integrated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
master_to_dev_encryption_private_key: str,
|
||||
dev_to_master_encryption_public_key: str,
|
||||
dev_to_master_signing_private_key: str
|
||||
dev_to_master_signing_private_key: str,
|
||||
):
|
||||
self._user_id = user_id
|
||||
self._dev_to_master_signing_private_key = dev_to_master_signing_private_key
|
||||
self._master_to_dev_encryption_private_key_obj = serialization.load_pem_private_key(
|
||||
master_to_dev_encryption_private_key.encode("utf-8"),
|
||||
password=None
|
||||
password=None,
|
||||
)
|
||||
self._dev_to_master_encryption_public_key_obj = serialization.load_pem_public_key(
|
||||
dev_to_master_encryption_public_key.encode("utf-8")
|
||||
dev_to_master_encryption_public_key.encode("utf-8"),
|
||||
)
|
||||
|
||||
def get(self, url: str):
|
||||
|
@ -45,8 +44,8 @@ class EncryptedRequests:
|
|||
url=url,
|
||||
headers=self._get_new_headers(
|
||||
aes_key=aes_key,
|
||||
aes_ctr_nonce=aes_ctr_nonce
|
||||
)
|
||||
aes_ctr_nonce=aes_ctr_nonce,
|
||||
),
|
||||
)
|
||||
|
||||
# Parse response
|
||||
|
@ -63,13 +62,15 @@ class EncryptedRequests:
|
|||
url=url,
|
||||
headers=self._get_new_headers(
|
||||
aes_key=aes_key,
|
||||
aes_ctr_nonce=aes_ctr_nonce
|
||||
aes_ctr_nonce=aes_ctr_nonce,
|
||||
),
|
||||
data=None if json_dict is None else self._get_encrypted_bytes(
|
||||
data=None
|
||||
if json_dict is None
|
||||
else self._get_encrypted_bytes(
|
||||
json_dict=json_dict,
|
||||
aes_key=aes_key,
|
||||
aes_ctr_nonce=aes_ctr_nonce
|
||||
)
|
||||
aes_ctr_nonce=aes_ctr_nonce,
|
||||
),
|
||||
)
|
||||
|
||||
# Parse response
|
||||
|
@ -86,13 +87,15 @@ class EncryptedRequests:
|
|||
url=url,
|
||||
headers=self._get_new_headers(
|
||||
aes_key=aes_key,
|
||||
aes_ctr_nonce=aes_ctr_nonce
|
||||
aes_ctr_nonce=aes_ctr_nonce,
|
||||
),
|
||||
data=None if json_dict is None else self._get_encrypted_bytes(
|
||||
data=None
|
||||
if json_dict is None
|
||||
else self._get_encrypted_bytes(
|
||||
json_dict=json_dict,
|
||||
aes_key=aes_key,
|
||||
aes_ctr_nonce=aes_ctr_nonce
|
||||
)
|
||||
aes_ctr_nonce=aes_ctr_nonce,
|
||||
),
|
||||
)
|
||||
|
||||
# Parse response
|
||||
|
@ -105,7 +108,7 @@ class EncryptedRequests:
|
|||
def _get_encrypted_bytes(json_dict: dict, aes_key: bytes, aes_ctr_nonce: bytes) -> bytes:
|
||||
cipher = Cipher(
|
||||
algorithm=algorithms.AES(key=aes_key),
|
||||
mode=modes.CTR(nonce=aes_ctr_nonce)
|
||||
mode=modes.CTR(nonce=aes_ctr_nonce),
|
||||
)
|
||||
encryptor = cipher.encryptor()
|
||||
return_bytes = encryptor.update(json.dumps(json_dict).encode("utf-8")) + encryptor.finalize()
|
||||
|
@ -121,17 +124,17 @@ class EncryptedRequests:
|
|||
payload = jwt.decode(jwt=jwt_token, options={"verify_signature": False})
|
||||
aes_key = self._master_to_dev_encryption_private_key_obj.decrypt(
|
||||
ciphertext=base64.b64decode(payload["aes_key"].encode("ascii")),
|
||||
padding=self._get_asymmetric_padding()
|
||||
padding=self._get_asymmetric_padding(),
|
||||
)
|
||||
aes_ctr_nonce = self._master_to_dev_encryption_private_key_obj.decrypt(
|
||||
ciphertext=base64.b64decode(payload["aes_ctr_nonce"].encode("ascii")),
|
||||
padding=self._get_asymmetric_padding()
|
||||
padding=self._get_asymmetric_padding(),
|
||||
)
|
||||
|
||||
# Return decrypted_bytes
|
||||
cipher = Cipher(
|
||||
algorithm=algorithms.AES(key=aes_key),
|
||||
mode=modes.CTR(nonce=aes_ctr_nonce)
|
||||
mode=modes.CTR(nonce=aes_ctr_nonce),
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
return decryptor.update(encrypted_bytes) + decryptor.finalize()
|
||||
|
@ -145,22 +148,22 @@ class EncryptedRequests:
|
|||
"aes_key": base64.b64encode(
|
||||
self._dev_to_master_encryption_public_key_obj.encrypt(
|
||||
plaintext=aes_key,
|
||||
padding=self._get_asymmetric_padding()
|
||||
)
|
||||
padding=self._get_asymmetric_padding(),
|
||||
),
|
||||
).decode("ascii"),
|
||||
"aes_ctr_nonce": base64.b64encode(
|
||||
self._dev_to_master_encryption_public_key_obj.encrypt(
|
||||
plaintext=aes_ctr_nonce,
|
||||
padding=self._get_asymmetric_padding()
|
||||
)
|
||||
).decode("ascii")
|
||||
padding=self._get_asymmetric_padding(),
|
||||
),
|
||||
).decode("ascii"),
|
||||
},
|
||||
key=self._dev_to_master_signing_private_key,
|
||||
algorithm="RS512"
|
||||
algorithm="RS512",
|
||||
)
|
||||
|
||||
return {
|
||||
"Authorization": f"Bearer {jwt_token}"
|
||||
"Authorization": f"Bearer {jwt_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -168,5 +171,5 @@ class EncryptedRequests:
|
|||
return padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
label=None,
|
||||
)
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче