Format code on `master` (black + isort) (#538)

* Config files

* Add autoflake

* Update isort exclude; add pre-commit to requirements

* Manually fix a few bad cases
This commit is contained in:
Huoran Li 2022-06-14 07:37:59 +08:00 коммит произвёл GitHub
Родитель 17ad48928c
Коммит 7e3c1d5893
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
324 изменённых файлов: 5620 добавлений и 3595 удалений

7
.github/linters/pyproject.toml поставляемый Normal file
Просмотреть файл

@ -0,0 +1,7 @@
[tool.black]
line-length = 120
[tool.isort]
profile = "black"
line_length = 120
known_first_party = "maro"

15
.github/linters/tox.ini поставляемый
Просмотреть файл

@ -5,7 +5,9 @@ ignore =
# line break after binary operator
W504,
# line break before binary operator
W503
W503,
# whitespace before ':'
E203
exclude =
.git,
@ -27,14 +29,5 @@ max-line-length = 120
per-file-ignores =
# import not used: ignore in __init__.py files
__init__.py:F401
# igore invalid escape sequence in cli main script to show banner
# ignore invalid escape sequence in cli main script to show banner
maro.py:W605
[isort]
indent = " "
line_length = 120
use_parentheses = True
multi_line_output = 6
known_first_party = maro
filter_files = True
skip_glob = maro/__init__.py, tests/*, examples/*, setup.py

6
.github/workflows/lint.yml поставляемый
Просмотреть файл

@ -45,12 +45,12 @@ jobs:
uses: github/super-linter@latest
env:
VALIDATE_ALL_CODEBASE: false
VALIDATE_PYTHON_PYLINT: false # disable pylint, as we have not configure it
VALIDATE_PYTHON_BLACK: false # same as above
VALIDATE_PYTHON_PYLINT: false # disable pylint, as we have not configured it
VALIDATE_PYTHON_MYPY: false # same as above
VALIDATE_JSCPD: false # Can not exclude specific file: https://github.com/kucherenko/jscpd/issues/215
PYTHON_FLAKE8_CONFIG_FILE: tox.ini
PYTHON_ISORT_CONFIG_FILE: tox.ini
PYTHON_BLACK_CONFIG_FILE: pyproject.toml
PYTHON_ISORT_CONFIG_FILE: pyproject.toml
EDITORCONFIG_FILE_NAME: ../../.editorconfig
FILTER_REGEX_INCLUDE: maro/.*
FILTER_REGEX_EXCLUDE: tests/.*

51
.pre-commit-config.yaml Normal file
Просмотреть файл

@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
repos:
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
args:
- --in-place
- --remove-unused-variables
- --remove-all-unused-imports
exclude: .*/__init__\.py|setup\.py
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
args:
- --settings-path=.github/linters/pyproject.toml
- --check
- repo: https://github.com/asottile/add-trailing-comma
rev: v2.2.3
hooks:
- id: add-trailing-comma
name: add-trailing-comma (1st round)
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
name: black (1st round)
args:
- --config=.github/linters/pyproject.toml
- repo: https://github.com/asottile/add-trailing-comma
rev: v2.2.3
hooks:
- id: add-trailing-comma
name: add-trailing-comma (2nd round)
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
name: black (2nd round)
args:
- --config=.github/linters/pyproject.toml
- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
hooks:
- id: flake8
args:
- --config=.github/linters/tox.ini
exclude: \.git|__pycache__|docs|build|dist|.*\.egg-info|docker_files|\.vscode|\.github|scripts|tests|maro\/backends\/.*.cp|setup.py

Просмотреть файл

@ -65,7 +65,7 @@ else:
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False

Просмотреть файл

@ -1,14 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict
import torch
from torch.optim import Adam, RMSprop
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
actor_net_conf = {
"hidden_dims": [256, 128, 64],
@ -58,10 +56,10 @@ def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
name=name,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=.0,
reward_discount=0.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
lam=.0,
lam=0.0,
),
)

Просмотреть файл

@ -1,15 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict
import torch
from torch.optim import RMSprop
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNTrainer, DQNParams
from maro.rl.training.algorithms import DQNParams, DQNTrainer
q_net_conf = {
"hidden_dims": [256, 128, 64, 32],
@ -38,14 +36,18 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli
name=name,
q_net=MyQNet(state_dim, action_num),
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
exploration_scheduling_options=[(
"epsilon", MultiLinearExplorationScheduler, {
"splits": [(2, 0.32)],
"initial_value": 0.4,
"last_ep": 5,
"final_value": 0.0,
}
)],
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(2, 0.32)],
"initial_value": 0.4,
"last_ep": 5,
"final_value": 0.0,
},
),
],
warmup=100,
)
@ -54,7 +56,7 @@ def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
params=DQNParams(
reward_discount=.0,
reward_discount=0.0,
update_target_every=5,
num_epochs=10,
soft_update_coef=0.1,

Просмотреть файл

@ -2,22 +2,21 @@
# Licensed under the MIT license.
from functools import partial
from typing import Dict, List
from typing import List
import torch
from torch.optim import Adam, RMSprop
from maro.rl.model import DiscreteACBasedNet, FullyConnected, MultiQNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteMADDPGTrainer, DiscreteMADDPGParams
from maro.rl.training.algorithms import DiscreteMADDPGParams, DiscreteMADDPGTrainer
actor_net_conf = {
"hidden_dims": [256, 128, 64],
"activation": torch.nn.Tanh,
"softmax": True,
"batch_norm": False,
"head": True
"head": True,
}
critic_net_conf = {
"hidden_dims": [256, 128, 64],
@ -25,7 +24,7 @@ critic_net_conf = {
"activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"head": True
"head": True,
}
actor_learning_rate = 0.001
critic_learning_rate = 0.001
@ -64,9 +63,9 @@ def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMAD
return DiscreteMADDPGTrainer(
name=name,
params=DiscreteMADDPGParams(
reward_discount=.0,
reward_discount=0.0,
num_epoch=10,
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
shared_critic=False
)
shared_critic=False,
),
)

Просмотреть файл

@ -15,11 +15,11 @@ def get_ppo(state_dim: int, name: str) -> PPOTrainer:
name=name,
params=PPOParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=.0,
reward_discount=0.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
lam=.0,
lam=0.0,
clip_ratio=0.1,
),
)

Просмотреть файл

@ -4,7 +4,7 @@
env_conf = {
"scenario": "cim",
"topology": "toy.4p_ssdd_l0.0",
"durations": 560
"durations": 560,
}
if env_conf["topology"].startswith("toy"):
@ -17,27 +17,26 @@ vessel_attributes = ["empty", "full", "remaining_space"]
state_shaping_conf = {
"look_back": 7,
"max_ports_downstream": 2
"max_ports_downstream": 2,
}
action_shaping_conf = {
"action_space": [(i - 10) / 10 for i in range(21)],
"finite_vessel_space": True,
"has_early_discharge": True
"has_early_discharge": True,
}
reward_shaping_conf = {
"time_window": 99,
"fulfillment_factor": 1.0,
"shortage_factor": 1.0,
"time_decay": 0.97
"time_decay": 0.97,
}
# obtain state dimension from a temporary env_wrapper instance
state_dim = (
(state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
+ len(vessel_attributes)
)
state_dim = (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(
port_attributes,
) + len(vessel_attributes)
action_num = len(action_shaping_conf["action_space"])

Просмотреть файл

@ -8,29 +8,32 @@ import numpy as np
from maro.rl.rollout import AbsEnvSampler, CacheElement
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent
from .config import (
action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf,
vessel_attributes,
)
from .config import action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf, vessel_attributes
class CIMEnvSampler(AbsEnvSampler):
def _get_global_and_agent_state_impl(
self, event: DecisionEvent, tick: int = None,
self,
event: DecisionEvent,
tick: int = None,
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
tick = self._env.tick
vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
port_idx, vessel_idx = event.port_idx, event.vessel_idx
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
state = np.concatenate([
port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes],
vessel_snapshots[tick: vessel_idx: vessel_attributes]
])
future_port_list = vessel_snapshots[tick:vessel_idx:"future_stop_list"].astype("int")
state = np.concatenate(
[
port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
vessel_snapshots[tick:vessel_idx:vessel_attributes],
],
)
return state, {port_idx: state}
def _translate_to_env_action(
self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionEvent,
self,
action_dict: Dict[Any, Union[np.ndarray, List[object]]],
event: DecisionEvent,
) -> Dict[Any, object]:
action_space = action_shaping_conf["action_space"]
finite_vsl_space = action_shaping_conf["finite_vessel_space"]
@ -40,7 +43,7 @@ class CIMEnvSampler(AbsEnvSampler):
vsl_idx, action_scope = event.vessel_idx, event.action_scope
vsl_snapshots = self._env.snapshot_list["vessels"]
vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")
vsl_space = vsl_snapshots[self._env.tick : vsl_idx : vessel_attributes][2] if finite_vsl_space else float("inf")
percent = abs(action_space[model_action[0]])
zero_action_idx = len(action_space) / 2 # index corresponding to value zero.
@ -49,7 +52,9 @@ class CIMEnvSampler(AbsEnvSampler):
actual_action = min(round(percent * action_scope.load), vsl_space)
elif model_action > zero_action_idx:
action_type = ActionType.DISCHARGE
early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
early_discharge = (
vsl_snapshots[self._env.tick : vsl_idx : "early_discharge"][0] if has_early_discharge else 0
)
plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
else:
@ -70,7 +75,7 @@ class CIMEnvSampler(AbsEnvSampler):
decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
rewards = np.float32(
reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list),
)
return {agent_id: reward for agent_id, reward in zip(ports, rewards)}

Просмотреть файл

@ -3,19 +3,16 @@ from typing import Any, Callable, Dict, Optional
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
from examples.cim.rl.env_sampler import CIMEnvSampler
from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
from .algorithms.ac import get_ac_policy
from .algorithms.dqn import get_dqn_policy
from .algorithms.maddpg import get_maddpg_policy
from .algorithms.ppo import get_ppo_policy
from .algorithms.ac import get_ac
from .algorithms.ppo import get_ppo
from .algorithms.dqn import get_dqn
from .algorithms.maddpg import get_maddpg
from .algorithms.ac import get_ac, get_ac_policy
from .algorithms.dqn import get_dqn, get_dqn_policy
from .algorithms.maddpg import get_maddpg, get_maddpg_policy
from .algorithms.ppo import get_ppo, get_ppo_policy
class CIMBundle(RLComponentBundle):
@ -29,7 +26,7 @@ class CIMBundle(RLComponentBundle):
return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"])
def get_agent2policy(self) -> Dict[Any, str]:
return {agent: f"{algorithm}_{agent}.policy"for agent in self.env.agent_idx_list}
return {agent: f"{algorithm}_{agent}.policy" for agent in self.env.agent_idx_list}
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
if algorithm == "ac":
@ -60,23 +57,17 @@ class CIMBundle(RLComponentBundle):
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
if algorithm == "ac":
trainer_creator = {
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}")
for i in range(num_agents)
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
}
elif algorithm == "ppo":
trainer_creator = {
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}")
for i in range(num_agents)
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
}
elif algorithm == "dqn":
trainer_creator = {
f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}")
for i in range(num_agents)
}
trainer_creator = {f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}") for i in range(num_agents)}
elif algorithm == "discrete_maddpg":
trainer_creator = {
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}")
for i in range(num_agents)
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")

Просмотреть файл

@ -63,8 +63,13 @@ class GreedyPolicy:
if __name__ == "__main__":
env = Env(scenario=config.env.scenario, topology=config.env.topology, start_tick=config.env.start_tick,
durations=config.env.durations, snapshot_resolution=config.env.resolution)
env = Env(
scenario=config.env.scenario,
topology=config.env.topology,
start_tick=config.env.start_tick,
durations=config.env.durations,
snapshot_resolution=config.env.resolution,
)
if config.env.seed is not None:
env.set_seed(config.env.seed)

Просмотреть файл

@ -7,11 +7,15 @@ from pulp import PULP_CBC_CMD, LpInteger, LpMaximize, LpProblem, LpVariable, lpS
from maro.utils import DottableDict
class CitiBikeILP():
class CitiBikeILP:
def __init__(
self, num_station: int, num_neighbor: int,
station_capacity: List[int], station_neighbor_list: List[List[int]],
decision_interval: int, config: DottableDict
self,
num_station: int,
num_neighbor: int,
station_capacity: List[int],
station_neighbor_list: List[List[int]],
decision_interval: int,
config: DottableDict,
):
"""A simple Linear Programming formulation for solving the bike repositioning problem.
@ -68,37 +72,37 @@ class CitiBikeILP():
name=f"T{decision_point}_S{station}_Inv",
lowBound=0,
upBound=self._station_capacity[station],
cat=LpInteger
cat=LpInteger,
)
self._safety_inventory[decision_point][station] = LpVariable(
name=f"T{decision_point}_S{station}_SafetyInv",
lowBound=0,
upBound=round(self._safety_inventory_limit * self._station_capacity[station]),
cat=LpInteger
cat=LpInteger,
)
self._fulfillment[decision_point][station] = LpVariable(
name=f"T{decision_point}_S{station}_Fulfillment",
lowBound=0,
cat=LpInteger
cat=LpInteger,
)
# For intermediate variables.
self._transfer_from[decision_point][station] = LpVariable(
name=f"T{decision_point}_TransferFrom{station}",
lowBound=0,
cat=LpInteger
cat=LpInteger,
)
self._transfer_to[decision_point][station] = LpVariable(
name=f"T{decision_point}_TransferTo{station}",
lowBound=0,
cat=LpInteger
cat=LpInteger,
)
for neighbor_idx in range(self._num_neighbor):
self._transfer[decision_point][station][neighbor_idx] = LpVariable(
name=f"T{decision_point}_Transfer_from{station}_to{neighbor_idx}th",
lowBound=0,
cat=LpInteger
cat=LpInteger,
)
# Initialize inventory of the first decision point with the environment's current inventory.
@ -113,23 +117,26 @@ class CitiBikeILP():
), f"Fulfillment_Limit_T{decision_point}_S{station}"
# For intermediate variables.
problem += (
self._transfer_from[decision_point][station] == lpSum(
self._transfer_from[decision_point][station]
== lpSum(
self._transfer[decision_point][station][neighbor_idx]
for neighbor_idx in range(self._num_neighbor)
)
), f"TotalTransferFrom_T{decision_point}_S{station}"
problem += (
self._transfer_to[decision_point][station] == lpSum(
self._transfer_to[decision_point][station]
== lpSum(
self._transfer[decision_point][neighbor][self._station_neighbor_list[neighbor].index(station)]
for neighbor in range(self._num_station)
if station in self._station_neighbor_list[neighbor][:self._num_neighbor]
if station in self._station_neighbor_list[neighbor][: self._num_neighbor]
)
), f"TotalTransferTo_T{decision_point}_S{station}"
for decision_point in range(1, self._num_decision_point):
for station in range(self._num_station):
problem += (
self._inventory[decision_point][station] == (
self._inventory[decision_point][station]
== (
self._inventory[decision_point - 1][station]
+ supply[decision_point - 1, station]
- self._fulfillment[decision_point - 1][station]
@ -138,7 +145,8 @@ class CitiBikeILP():
)
), f"Inventory_T{decision_point}_S{station}"
problem += (
self._safety_inventory[decision_point][station] <= (
self._safety_inventory[decision_point][station]
<= (
self._inventory[decision_point - 1][station]
+ supply[decision_point - 1, station]
- self._fulfillment[decision_point - 1][station]
@ -148,26 +156,31 @@ class CitiBikeILP():
def _set_objective(self, problem: LpProblem):
fulfillment_gain = lpSum(
math.pow(self._fulfillment_time_decay_factor, decision_point) * lpSum(
self._fulfillment[decision_point][station] for station in range(self._num_station)
) for decision_point in range(self._num_decision_point)
)
safety_inventory_reward = self._safety_inventory_reward_factor * lpSum(
math.pow(self._safety_inventory_reward_time_decay_factor, decision_point) * lpSum(
self._safety_inventory[decision_point][station] for station in range(self._num_station)
) for decision_point in range(self._num_decision_point)
)
transfer_cost = self._transfer_cost_factor * lpSum(
self._transfer_to[decision_point][station] for station in range(self._num_station)
math.pow(self._fulfillment_time_decay_factor, decision_point)
* lpSum(self._fulfillment[decision_point][station] for station in range(self._num_station))
for decision_point in range(self._num_decision_point)
)
problem += (fulfillment_gain + safety_inventory_reward - transfer_cost)
safety_inventory_reward = self._safety_inventory_reward_factor * lpSum(
math.pow(self._safety_inventory_reward_time_decay_factor, decision_point)
* lpSum(self._safety_inventory[decision_point][station] for station in range(self._num_station))
for decision_point in range(self._num_decision_point)
)
transfer_cost = self._transfer_cost_factor * lpSum(
self._transfer_to[decision_point][station]
for station in range(self._num_station)
for decision_point in range(self._num_decision_point)
)
problem += fulfillment_gain + safety_inventory_reward - transfer_cost
def _formulate_and_solve(
self, env_tick: int, init_inventory: np.ndarray, demand: np.ndarray, supply: np.ndarray
self,
env_tick: int,
init_inventory: np.ndarray,
demand: np.ndarray,
supply: np.ndarray,
):
problem = LpProblem(
name=f"Citi_Bike_Repositioning_from_tick_{env_tick}",
@ -181,7 +194,11 @@ class CitiBikeILP():
# ============================= private end =============================
def get_transfer_list(
self, env_tick: int, init_inventory: np.ndarray, demand: np.ndarray, supply: np.ndarray
self,
env_tick: int,
init_inventory: np.ndarray,
demand: np.ndarray,
supply: np.ndarray,
) -> List[Tuple[int, int, int]]:
"""Get the transfer list for the given env_tick.
@ -201,7 +218,10 @@ class CitiBikeILP():
if env_tick >= self._last_start_tick + self._apply_buffer_size:
self._last_start_tick = env_tick
self._formulate_and_solve(
env_tick=env_tick, init_inventory=init_inventory, demand=demand, supply=supply
env_tick=env_tick,
init_inventory=init_inventory,
demand=demand,
supply=supply,
)
decision_point = (env_tick - self._last_start_tick) // self._decision_interval

Просмотреть файл

@ -5,8 +5,8 @@ from typing import List, Tuple
import numpy as np
import yaml
from citi_bike_ilp import CitiBikeILP
from maro.data_lib import BinaryReader, ItemTickPicker
from maro.event_buffer import AbsEvent
from maro.forecasting import OneStepFixWindowMA as Forecaster
@ -22,9 +22,14 @@ ENV: Env = None
TRIP_PICKER: ItemTickPicker = None
class MaIlpAgent():
class MaIlpAgent:
def __init__(
self, ilp: CitiBikeILP, num_station: int, num_time_interval: int, ticks_per_interval: int, ma_window_size: int
self,
ilp: CitiBikeILP,
num_station: int,
num_time_interval: int,
ticks_per_interval: int,
ma_window_size: int,
):
"""An agent that make decisions by ILP in Citi Bike scenario.
@ -100,15 +105,23 @@ class MaIlpAgent():
The second item indicates the forecasting supply for each station in each time interval,
with shape: (num_time_interval, num_station).
"""
demand = np.array(
[round(self._demand_forecaster[i].forecast()) for i in range(self._num_station)],
dtype=np.int16
).reshape((1, -1)).repeat(self._num_time_interval, axis=0)
demand = (
np.array(
[round(self._demand_forecaster[i].forecast()) for i in range(self._num_station)],
dtype=np.int16,
)
.reshape((1, -1))
.repeat(self._num_time_interval, axis=0)
)
supply = np.array(
[round(self._supply_forecaster[i].forecast()) for i in range(self._num_station)],
dtype=np.int16
).reshape((1, -1)).repeat(self._num_time_interval, axis=0)
supply = (
np.array(
[round(self._supply_forecaster[i].forecast()) for i in range(self._num_station)],
dtype=np.int16,
)
.reshape((1, -1))
.repeat(self._num_time_interval, axis=0)
)
return demand, supply
@ -147,7 +160,7 @@ class MaIlpAgent():
env_tick=env_tick,
init_inventory=init_inventory,
demand=demand,
supply=supply
supply=supply,
)
action_list = [
@ -160,20 +173,28 @@ class MaIlpAgent():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--peep", action='store_true',
help="If set, peep the future demand and supply of bikes for each station directly from the log data."
"--peep",
action="store_true",
help="If set, peep the future demand and supply of bikes for each station directly from the log data.",
)
parser.add_argument(
"-c", "--config", type=str, default="examples/citi_bike/online_lp/config.yml",
help="The path of the config file."
"-c",
"--config",
type=str,
default="examples/citi_bike/online_lp/config.yml",
help="The path of the config file.",
)
parser.add_argument(
"-t", "--topology", type=str,
help="Which topology to use. If set, it will over-write the topology set in the config file."
"-t",
"--topology",
type=str,
help="Which topology to use. If set, it will over-write the topology set in the config file.",
)
parser.add_argument(
"-r", "--seed", type=int,
help="The random seed for the environment. If set, it will over-write the seed set in the config file."
"-r",
"--seed",
type=int,
help="The random seed for the environment. If set, it will over-write the seed set in the config file.",
)
args = parser.parse_args()
@ -205,7 +226,7 @@ if __name__ == "__main__":
TRIP_PICKER = BinaryReader(env.configs["trip_data"]).items_tick_picker(
start_time_offset=config.env.start_tick,
end_time_offset=(config.env.start_tick + config.env.durations),
time_unit="m"
time_unit="m",
)
if config.env.seed is not None:
@ -220,29 +241,26 @@ if __name__ == "__main__":
# TODO: Update the Env interface.
num_station = len(env.agent_idx_list)
station_distance_adj = np.array(
load_adj_from_csv(env.configs["distance_adj_data"], skiprows=1)
load_adj_from_csv(env.configs["distance_adj_data"], skiprows=1),
).reshape(num_station, num_station)
station_neighbor_list = [
neighbor_list[1:]
for neighbor_list in np.argsort(station_distance_adj, axis=1).tolist()
]
station_neighbor_list = [neighbor_list[1:] for neighbor_list in np.argsort(station_distance_adj, axis=1).tolist()]
# Init a Moving-Average based ILP agent.
decision_interval = env.configs["decision"]["resolution"]
ilp = CitiBikeILP(
num_station=num_station,
num_neighbor=min(config.ilp.num_neighbor, num_station - 1),
station_capacity=env.snapshot_list["stations"][env.frame_index:env.agent_idx_list:"capacity"],
station_capacity=env.snapshot_list["stations"][env.frame_index : env.agent_idx_list : "capacity"],
station_neighbor_list=station_neighbor_list,
decision_interval=decision_interval,
config=config.ilp
config=config.ilp,
)
agent = MaIlpAgent(
ilp=ilp,
num_station=num_station,
num_time_interval=math.ceil(config.ilp.plan_window_size / decision_interval),
ticks_per_interval=decision_interval,
ma_window_size=config.forecasting.ma_window_size
ma_window_size=config.forecasting.ma_window_size,
)
pre_decision_tick: int = -1
@ -252,15 +270,15 @@ if __name__ == "__main__":
else:
action = agent.get_action_list(
env_tick=env.tick,
init_inventory=env.snapshot_list["stations"][
env.frame_index:env.agent_idx_list:"bikes"
].astype(np.int16),
finished_events=env.get_finished_events()
init_inventory=env.snapshot_list["stations"][env.frame_index : env.agent_idx_list : "bikes"].astype(
np.int16,
),
finished_events=env.get_finished_events(),
)
pre_decision_tick = decision_event.tick
_, decision_event, is_done = env.step(action=action)
print(
f"[{'De' if PEEP_AND_USE_REAL_DATA else 'MA'}] "
f"Topology {config.env.topology} with seed {config.env.seed}: {env.metrics}"
f"Topology {config.env.topology} with seed {config.env.seed}: {env.metrics}",
)

Просмотреть файл

@ -1,12 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from random import seed, randint
from random import randint
from maro.simulator import Env
from maro.simulator.scenarios.cim.common import Action, ActionType
if __name__ == "__main__":
# Initialize an environment with a specific scenario, related topology.
env = Env(scenario="cim", topology="global_trade.22p_l0.1", start_tick=0, durations=100)
@ -29,7 +28,7 @@ if __name__ == "__main__":
decision_event.vessel_idx,
decision_event.port_idx,
randint(0, action_scope.discharge if to_discharge else action_scope.load),
ActionType.DISCHARGE if to_discharge else ActionType.LOAD
ActionType.DISCHARGE if to_discharge else ActionType.LOAD,
)
# Drive environment with dummy action (no repositioning)

Просмотреть файл

@ -9,13 +9,12 @@ os.environ["MARO_STREAMIT_ENABLED"] = "true"
os.environ["MARO_STREAMIT_EXPERIMENT_NAME"] = "experiment_example"
from random import seed, randint
from random import randint, seed
from maro.simulator import Env
from maro.simulator.scenarios.cim.common import Action, ActionScope, ActionType
from maro.simulator.scenarios.cim.common import Action, ActionType
from maro.streamit import streamit
if __name__ == "__main__":
seed(0)
NUM_EPISODE = 2
@ -42,7 +41,7 @@ if __name__ == "__main__":
decision_event.vessel_idx,
decision_event.port_idx,
randint(0, action_scope.discharge if to_discharge else action_scope.load),
ActionType.DISCHARGE if to_discharge else ActionType.LOAD
ActionType.DISCHARGE if to_discharge else ActionType.LOAD,
)
# Drive environment with dummy action (no repositioning)

Просмотреть файл

@ -1,12 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from random import seed, randint
from random import randint, seed
from maro.simulator import Env
from maro.simulator.scenarios.cim.common import Action, ActionType
if __name__ == "__main__":
seed(0)
NUM_EPISODE = 2
@ -17,12 +16,15 @@ if __name__ == "__main__":
If you leave value to empty string, it will dump to current folder.
For getting dump data, please uncomment below line and specify dump destination folder.
"""
opts['enable-dump-snapshot'] = 'YOUR_FOLDER_NAME'
opts["enable-dump-snapshot"] = "YOUR_FOLDER_NAME"
# Initialize an environment with a specific scenario, related topology.
env = Env(
scenario="cim", topology="global_trade.22p_l0.1",
start_tick=0, durations=100, options=opts
scenario="cim",
topology="global_trade.22p_l0.1",
start_tick=0,
durations=100,
options=opts,
)
# To reset environmental data before starting a new experiment.
@ -40,7 +42,7 @@ if __name__ == "__main__":
decision_event.vessel_idx,
decision_event.port_idx,
randint(0, action_scope.discharge if to_discharge else action_scope.load),
ActionType.DISCHARGE if to_discharge else ActionType.LOAD
ActionType.DISCHARGE if to_discharge else ActionType.LOAD,
)
# Drive environment with dummy action (no repositioning)

Просмотреть файл

@ -17,8 +17,14 @@ For getting dump data, please uncomment below line and specify dump destination
"""
# opts['enable-dump-snapshot'] = ''
env = Env(scenario="citi_bike", topology="toy.4s_4t", start_tick=start_tick,
durations=durations, snapshot_resolution=60, options=opts)
env = Env(
scenario="citi_bike",
topology="toy.4s_4t",
start_tick=start_tick,
durations=durations,
snapshot_resolution=60,
options=opts,
)
print(env.summary)
@ -35,11 +41,11 @@ for ep in range(max_ep):
if decision_evt is not None:
action = Action(decision_evt.station_idx, 0, 10)
station_ss = env.snapshot_list['stations']
shortage_states = station_ss[::'shortage']
station_ss = env.snapshot_list["stations"]
shortage_states = station_ss[::"shortage"]
print("total shortage", shortage_states.sum())
trips_states = station_ss[::'trip_requirement']
trips_states = station_ss[::"trip_requirement"]
print("total trip", trips_states.sum())
cost_states = station_ss[::["extra_cost", "transfer_cost"]]
@ -54,7 +60,7 @@ for ep in range(max_ep):
# NOTE: We have not clear the trip adj at each tick so it is an accumulative value,
# then we can just query last snapshot to calc total trips.
trips_adj = matrix_ss[last_snapshot_index::'trips_adj']
trips_adj = matrix_ss[last_snapshot_index::"trips_adj"]
# Reshape it we need an easy way to access.
# trips_adj = trips_adj.reshape((-1, len(station_ss)))

Просмотреть файл

@ -14,9 +14,11 @@ def worker(group_name):
Args:
group_name (str): Identifier for the group of all communication components.
"""
proxy = Proxy(group_name=group_name,
component_type="worker",
expected_peers={"master": 1})
proxy = Proxy(
group_name=group_name,
component_type="worker",
expected_peers={"master": 1},
)
counter = 0
print(f"{proxy.name}'s counter is {counter}.")
@ -45,14 +47,14 @@ def master(group_name: str, worker_num: int, is_immediate: bool = False):
proxy = Proxy(
group_name=group_name,
component_type="master",
expected_peers={"worker": worker_num}
expected_peers={"worker": worker_num},
)
if is_immediate:
session_ids = proxy.ibroadcast(
component_type="worker",
tag="INC",
session_type=SessionType.NOTIFICATION
session_type=SessionType.NOTIFICATION,
)
# Do some tasks with higher priority here.
replied_msgs = proxy.receive_by_id(session_ids, timeout=-1)
@ -61,13 +63,13 @@ def master(group_name: str, worker_num: int, is_immediate: bool = False):
component_type="worker",
tag="INC",
session_type=SessionType.NOTIFICATION,
timeout=-1
timeout=-1,
)
for msg in replied_msgs:
print(
f"{proxy.name} get receive notification from {msg.source} with "
f"message session stage {msg.session_stage}."
f"message session stage {msg.session_stage}.",
)
@ -84,7 +86,7 @@ if __name__ == "__main__":
workers = mp.Pool(worker_number)
master_process = mp.Process(target=master, args=(group_name, worker_number, is_immediate,))
master_process = mp.Process(target=master, args=(group_name, worker_number, is_immediate))
master_process.start()
workers.map(worker, [group_name] * worker_number)

Просмотреть файл

@ -16,9 +16,11 @@ def summation_worker(group_name):
Args:
group_name (str): Identifier for the group of all communication components.
"""
proxy = Proxy(group_name=group_name,
component_type="sum_worker",
expected_peers={"master": 1})
proxy = Proxy(
group_name=group_name,
component_type="sum_worker",
expected_peers={"master": 1},
)
# Nonrecurring receive the message from the proxy.
msg = proxy.receive_once()
@ -36,9 +38,11 @@ def multiplication_worker(group_name):
Args:
group_name (str): Identifier for the group of all communication components.
"""
proxy = Proxy(group_name=group_name,
component_type="multiply_worker",
expected_peers={"master": 1})
proxy = Proxy(
group_name=group_name,
component_type="multiply_worker",
expected_peers={"master": 1},
)
# Nonrecurring receive the message from the proxy.
msg = proxy.receive_once()
@ -62,10 +66,14 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
you can do something with high priority before receiving replied messages from peers.
Sync Mode: It will block until the proxy returns all the replied messages.
"""
proxy = Proxy(group_name=group_name,
component_type="master",
expected_peers={"sum_worker": sum_worker_number,
"multiply_worker": multiply_worker_number})
proxy = Proxy(
group_name=group_name,
component_type="master",
expected_peers={
"sum_worker": sum_worker_number,
"multiply_worker": multiply_worker_number,
},
)
sum_list = np.random.randint(0, 10, 100)
multiple_list = np.random.randint(1, 10, 20)
@ -75,25 +83,30 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
destination_payload_list = []
for idx, peer in enumerate(proxy.peers["sum_worker"]):
data_length_per_peer = int(len(sum_list) / len(proxy.peers["sum_worker"]))
destination_payload_list.append((peer, sum_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer]))
destination_payload_list.append((peer, sum_list[idx * data_length_per_peer : (idx + 1) * data_length_per_peer]))
# Assign multiply tasks for multiplication workers.
for idx, peer in enumerate(proxy.peers["multiply_worker"]):
data_length_per_peer = int(len(multiple_list) / len(proxy.peers["multiply_worker"]))
destination_payload_list.append(
(peer, multiple_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer]))
(peer, multiple_list[idx * data_length_per_peer : (idx + 1) * data_length_per_peer]),
)
if is_immediate:
session_ids = proxy.iscatter(tag="job",
session_type=SessionType.TASK,
destination_payload_list=destination_payload_list)
session_ids = proxy.iscatter(
tag="job",
session_type=SessionType.TASK,
destination_payload_list=destination_payload_list,
)
# Do some tasks with higher priority here.
replied_msgs = proxy.receive_by_id(session_ids, timeout=-1)
else:
replied_msgs = proxy.scatter(tag="job",
session_type=SessionType.TASK,
destination_payload_list=destination_payload_list,
timeout=-1)
replied_msgs = proxy.scatter(
tag="job",
session_type=SessionType.TASK,
destination_payload_list=destination_payload_list,
timeout=-1,
)
sum_result, multiply_result = 0, 1
for msg in replied_msgs:
@ -105,8 +118,8 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int,
multiply_result *= msg.body
# Check task result correction.
assert(sum(sum_list) == sum_result)
assert(np.prod(multiple_list) == multiply_result)
assert sum(sum_list) == sum_result
assert np.prod(multiple_list) == multiply_result
if __name__ == "__main__":
@ -124,8 +137,10 @@ if __name__ == "__main__":
# Worker's pool for sum_worker and prod_worker.
workers = mp.Pool(sum_worker_number + multiply_worker_number)
master_process = mp.Process(target=master,
args=(group_name, sum_worker_number, multiply_worker_number, is_immediate,))
master_process = mp.Process(
target=master,
args=(group_name, sum_worker_number, multiply_worker_number, is_immediate),
)
master_process.start()
for s in range(sum_worker_number):

Просмотреть файл

@ -16,9 +16,11 @@ def worker(group_name):
Args:
group_name (str): Identifier for the group of all communication components
"""
proxy = Proxy(group_name=group_name,
component_type="worker",
expected_peers={"master": 1})
proxy = Proxy(
group_name=group_name,
component_type="worker",
expected_peers={"master": 1},
)
# Nonrecurring receive the message from the proxy.
msg = proxy.receive_once()
@ -40,19 +42,23 @@ def master(group_name: str, is_immediate: bool = False):
you can do something with high priority before receiving replied messages from peers.
Sync Mode: It will block until the proxy returns all the replied messages.
"""
proxy = Proxy(group_name=group_name,
component_type="master",
expected_peers={"worker": 1})
proxy = Proxy(
group_name=group_name,
component_type="master",
expected_peers={"worker": 1},
)
random_integer_list = np.random.randint(0, 100, 5)
print(f"generate random integer list: {random_integer_list}.")
for peer in proxy.peers["worker"]:
message = SessionMessage(tag="sum",
source=proxy.name,
destination=peer,
body=random_integer_list,
session_type=SessionType.TASK)
message = SessionMessage(
tag="sum",
source=proxy.name,
destination=peer,
body=random_integer_list,
session_type=SessionType.TASK,
)
if is_immediate:
session_id = proxy.isend(message)
# Do some tasks with higher priority here.
@ -74,7 +80,7 @@ if __name__ == "__main__":
group_name = "proxy_send_simple_example"
is_immediate = False
master_process = mp.Process(target=master, args=(group_name, is_immediate,))
master_process = mp.Process(target=master, args=(group_name, is_immediate))
worker_process = mp.Process(target=worker, args=(group_name,))
master_process.start()
worker_process.start()

Просмотреть файл

@ -5,7 +5,7 @@ from maro.cli.local.commands import run
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("conf_path", help='Path of the job deployment')
parser.add_argument("conf_path", help="Path of the job deployment")
parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow")
return parser.parse_args()

Просмотреть файл

@ -80,5 +80,5 @@ retail_frame.stores[0].shortages[:] = [i + 1 for i in range(TOTAL_PRODUCT_CATEGO
retail_frame.take_snapshot(1)
# Query shortage information of first and second store at first and second tick.
store_shortage_history = snapshot_list["store"][[0, 1]: [0, 1]: "shortages"].reshape(2, -1)
store_shortage_history = snapshot_list["store"][[0, 1]:[0, 1]:"shortages"].reshape(2, -1)
print(f"First and second store shortage history at the first and second tick (numpy array): {store_shortage_history}")

Просмотреть файл

@ -2,13 +2,16 @@
# Licensed under the MIT license.
from enum import Enum
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent
from maro.vector_env import VectorEnv
class VectorEnvUsage(Enum):
PUSH_ONE_FORWARD = "push the first environment forward and left others behind"
PUSH_ALL_FORWARD = "push all environments forward together"
USAGE = VectorEnvUsage.PUSH_ONE_FORWARD
if __name__ == "__main__":
@ -28,7 +31,7 @@ if __name__ == "__main__":
# Showcase: how to access information from snapshot list in vector env.
if env0_dec:
ss0 = env.snapshot_list["vessels"][env0_dec.tick:env0_dec.vessel_idx:"remaining_space"]
ss0 = env.snapshot_list["vessels"][env0_dec.tick : env0_dec.vessel_idx : "remaining_space"]
# 1. Only push specified (1st for this example) environment, leave others behind.
if USAGE == VectorEnvUsage.PUSH_ONE_FORWARD and env0_dec:
@ -38,8 +41,8 @@ if __name__ == "__main__":
vessel_idx=env0_dec.vessel_idx,
port_idx=env0_dec.port_idx,
quantity=env0_dec.action_scope.load,
action_type=ActionType.LOAD
)
action_type=ActionType.LOAD,
),
}
# 2. Only pass action to 1st environment (give None to other environments),
@ -51,7 +54,7 @@ if __name__ == "__main__":
vessel_idx=env0_dec.vessel_idx,
port_idx=env0_dec.port_idx,
quantity=env0_dec.action_scope.load,
action_type=ActionType.LOAD
action_type=ActionType.LOAD,
)
metrics, decision_event, is_done = env.step(action)

Просмотреть файл

@ -3,19 +3,21 @@
from dataclasses import dataclass
@dataclass
class IlpPmCapacity:
core: int
mem: int
@dataclass
class IlpVmInfo:
id: int=-1
pm_idx: int=-2
core: int=-1
mem: int=-1
lifetime: int=-1
arrival_env_tick: int=-1
id: int = -1
pm_idx: int = -2
core: int = -1
mem: int = -1
lifetime: int = -1
arrival_env_tick: int = -1
def remaining_lifetime(self, env_tick: int):
return self.lifetime - (env_tick - self.arrival_env_tick)

Просмотреть файл

@ -1,20 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import timeit
from collections import defaultdict, Counter
from collections import Counter, defaultdict
from typing import List, Set
from maro.data_lib import BinaryReader
from maro.simulator.scenarios.vm_scheduling import PostponeAction, AllocateAction
from maro.simulator.scenarios.vm_scheduling.common import Action
from maro.utils import DottableDict, Logger
import numpy as np
from common import IlpPmCapacity, IlpVmInfo
from vm_scheduling_ilp import NOT_ALLOCATE_NOW, VmSchedulingILP
class IlpAgent():
from maro.data_lib import BinaryReader
from maro.simulator.scenarios.vm_scheduling import AllocateAction, PostponeAction
from maro.simulator.scenarios.vm_scheduling.common import Action
from maro.utils import DottableDict, Logger
class IlpAgent:
def __init__(
self,
ilp_config: DottableDict,
@ -24,11 +24,11 @@ class IlpAgent():
env_duration: int,
simulation_logger: Logger,
ilp_logger: Logger,
log_path: str
log_path: str,
):
self._simulation_logger = simulation_logger
self._ilp_logger = ilp_logger
self._allocation_counter = Counter()
pm_capacity: List[IlpPmCapacity] = [IlpPmCapacity(core=pm[0], mem=pm[1]) for pm in pm_capacity]
@ -41,7 +41,7 @@ class IlpAgent():
self.vm_item_picker = self.vm_reader.items_tick_picker(
env_start_tick,
env_start_tick + env_duration,
time_unit="s"
time_unit="s",
)
# Used to keep the info already read from the vm_item_picker.
@ -87,7 +87,7 @@ class IlpAgent():
core=vm.vm_cpu_cores,
mem=vm.vm_memory,
lifetime=vm.vm_lifetime,
arrival_env_tick=tick
arrival_env_tick=tick,
)
if tick < env_tick + self.ilp_apply_buffer_size:
self.refreshed_allocated_vm_dict[vm.vm_id] = vmInfo
@ -109,7 +109,13 @@ class IlpAgent():
# Choose action by ILP, may trigger a new formulation and solution,
# may directly return the decision if the cur_vm_id is still in the apply buffer size of last solution.
chosen_pm_idx = self.ilp.choose_pm(env_tick, cur_vm_id, self.allocated_vm, self.future_vm_req, self._vm_id_to_idx)
chosen_pm_idx = self.ilp.choose_pm(
env_tick,
cur_vm_id,
self.allocated_vm,
self.future_vm_req,
self._vm_id_to_idx,
)
self._simulation_logger.info(f"tick: {env_tick}, vm: {cur_vm_id} -> pm: {chosen_pm_idx}")
if chosen_pm_idx == NOT_ALLOCATE_NOW:

Просмотреть файл

@ -1,20 +1,19 @@
import io
import os
import pprint
import shutil
import timeit
from typing import List, Set
import pprint
import yaml
from ilp_agent import IlpAgent
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import DecisionPayload
from maro.simulator.scenarios.vm_scheduling.common import Action
from maro.utils import convert_dottable, Logger, LogFormat
from maro.utils import LogFormat, Logger, convert_dottable
from ilp_agent import IlpAgent
os.environ['LOG_LEVEL'] = 'CRITICAL'
os.environ["LOG_LEVEL"] = "CRITICAL"
FILE_PATH = os.path.split(os.path.realpath(__file__))[0]
CONFIG_PATH = os.path.join(FILE_PATH, "config.yml")
with io.open(CONFIG_PATH, "r") as in_file:
@ -32,11 +31,11 @@ if __name__ == "__main__":
topology=config.env.topology,
start_tick=config.env.start_tick,
durations=config.env.durations,
snapshot_resolution=config.env.resolution
snapshot_resolution=config.env.resolution,
)
shutil.copy(
os.path.join(env._business_engine._config_path, "config.yml"),
os.path.join(LOG_PATH, "BEconfig.yml")
os.path.join(LOG_PATH, "BEconfig.yml"),
)
shutil.copy(CONFIG_PATH, os.path.join(LOG_PATH, "config.yml"))
@ -51,9 +50,7 @@ if __name__ == "__main__":
metrics, decision_event, is_done = env.step(None)
# Get the core & memory capacity of all PMs in this environment.
pm_capacity = env.snapshot_list["pms"][
env.frame_index::["cpu_cores_capacity", "memory_capacity"]
].reshape(-1, 2)
pm_capacity = env.snapshot_list["pms"][env.frame_index :: ["cpu_cores_capacity", "memory_capacity"]].reshape(-1, 2)
pm_num = pm_capacity.shape[0]
# ILP agent.
@ -65,7 +62,7 @@ if __name__ == "__main__":
env_duration=config.env.durations,
simulation_logger=simulation_logger,
ilp_logger=ilp_logger,
log_path=LOG_PATH
log_path=LOG_PATH,
)
while not is_done:
@ -83,7 +80,7 @@ if __name__ == "__main__":
end_time = timeit.default_timer()
simulation_logger.info(
f"[Offline ILP] Topology: {config.env.topology}. Total ticks: {config.env.durations}."
f" Start tick: {config.env.start_tick}."
f" Start tick: {config.env.start_tick}.",
)
simulation_logger.info(f"[Timer] {end_time - start_time:.2f} seconds to finish the simulation.")
ilp_agent.report_allocation_summary()

Просмотреть файл

@ -1,23 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import os
import timeit
from typing import List
import math
import numpy as np
from pulp import PULP_CBC_CMD, GLPK, GUROBI_CMD, LpInteger, LpMaximize, LpProblem, LpStatus, LpVariable, lpSum
from common import IlpPmCapacity, IlpVmInfo
from pulp import GLPK, PULP_CBC_CMD, LpInteger, LpMaximize, LpProblem, LpStatus, LpVariable, lpSum
from maro.utils import DottableDict, Logger
from common import IlpPmCapacity, IlpVmInfo
# To indicate the decision of not allocate or cannot allocate any PM for current VM request.
NOT_ALLOCATE_NOW = -1
class VmSchedulingILP():
class VmSchedulingILP:
def __init__(self, config: DottableDict, pm_capacity: List[IlpPmCapacity], logger: Logger, log_path: str):
self._logger = logger
self._log_path = log_path
@ -76,7 +75,7 @@ class VmSchedulingILP():
name=f"Place_VM{vm_idx}_{self._future_vm_req[vm_idx].id}_on_PM{pm_idx}",
lowBound=0,
upBound=1,
cat=LpInteger
cat=LpInteger,
)
def _add_constraints(self, problem: LpProblem):
@ -84,7 +83,7 @@ class VmSchedulingILP():
for vm_idx in range(self._vm_num):
problem += (
lpSum(self._mapping[pm_idx][vm_idx] for pm_idx in range(self._pm_num)) <= 1,
f"Mapping_VM{vm_idx}_to_max_1_PM"
f"Mapping_VM{vm_idx}_to_max_1_PM",
)
# PM capacity limitation: core + mem.
@ -95,16 +94,20 @@ class VmSchedulingILP():
vm.core * self._mapping[pm_idx][vm_idx]
for vm_idx, vm in enumerate(self._future_vm_req)
if (vm.arrival_env_tick - self._env_tick <= t and vm.remaining_lifetime(self._env_tick) >= t)
) + self._pm_allocated_core[t][pm_idx] <= self._pm_capacity[pm_idx].core * self.core_upper_ratio,
f"T{t}_PM{pm_idx}_core_capacity_limit"
)
+ self._pm_allocated_core[t][pm_idx]
<= self._pm_capacity[pm_idx].core * self.core_upper_ratio,
f"T{t}_PM{pm_idx}_core_capacity_limit",
)
problem += (
lpSum(
vm.mem * self._mapping[pm_idx][vm_idx]
for vm_idx, vm in enumerate(self._future_vm_req)
if (vm.arrival_env_tick - self._env_tick <= t and vm.remaining_lifetime(self._env_tick) >= t)
) + self._pm_allocated_mem[t][pm_idx] <= self._pm_capacity[pm_idx].mem * self.mem_upper_ratio,
f"T{t}_PM{pm_idx}_mem_capacity_limit"
)
+ self._pm_allocated_mem[t][pm_idx]
<= self._pm_capacity[pm_idx].mem * self.mem_upper_ratio,
f"T{t}_PM{pm_idx}_mem_capacity_limit",
)
def _set_objective(self, problem: LpProblem):
@ -129,7 +132,7 @@ class VmSchedulingILP():
problem = LpProblem(
name=f"VM_Scheduling_from_tick_{self._env_tick}",
sense=LpMaximize
sense=LpMaximize,
)
self._init_variables()
self._add_constraints(problem=problem)
@ -148,8 +151,9 @@ class VmSchedulingILP():
if self._mapping[pm_idx][vm_idx].varValue:
chosen_pm_idx = pm_idx
break
self._logger.info(f"Solution tick: {self._env_tick}, vm: {self._future_vm_req[vm_idx].id} -> pm: {chosen_pm_idx}")
self._logger.info(
f"Solution tick: {self._env_tick}, vm: {self._future_vm_req[vm_idx].id} -> pm: {chosen_pm_idx}",
)
def choose_pm(
self,

Просмотреть файл

@ -1,15 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict
import torch
from torch.optim import Adam, SGD
from torch.optim import SGD, Adam
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
actor_net_conf = {
"hidden_dims": [64, 32, 32],
@ -39,7 +36,7 @@ class MyActorNet(DiscreteACBasedNet):
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)
def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
features, masks = states[:, :self._num_features], states[:, self._num_features:]
features, masks = states[:, : self._num_features], states[:, self._num_features :]
masks += 1e-8 # this is to prevent zero probability and infinite logP.
return self._actor(features) * masks
@ -52,7 +49,7 @@ class MyCriticNet(VNet):
self._optim = SGD(self._critic.parameters(), lr=critic_learning_rate)
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
features, masks = states[:, :self._num_features], states[:, self._num_features:]
features, masks = states[:, : self._num_features], states[:, self._num_features :]
masks += 1e-8 # this is to prevent zero probability and infinite logP.
return self._critic(features).squeeze(-1)
@ -70,6 +67,6 @@ def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer:
grad_iters=100,
critic_loss_cls=torch.nn.MSELoss,
min_logp=-20,
lam=.0,
lam=0.0,
),
)

Просмотреть файл

@ -33,8 +33,8 @@ class MyQNet(DiscreteQNet):
self._lr_scheduler = CosineAnnealingWarmRestarts(self._optim, **q_net_lr_scheduler_params)
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
masks = states[:, self._num_features:]
q_for_all_actions = self._fc(states[:, :self._num_features])
masks = states[:, self._num_features :]
q_for_all_actions = self._fc(states[:, : self._num_features])
return q_for_all_actions + (masks - 1) * 1e8
@ -44,11 +44,13 @@ class MaskedEpsGreedy:
self._num_features = num_features
def __call__(self, states, actions, num_actions, *, epsilon):
masks = states[:, self._num_features:]
return np.array([
action if np.random.random() > epsilon else np.random.choice(np.where(mask == 1)[0])
for action, mask in zip(actions, masks)
])
masks = states[:, self._num_features :]
return np.array(
[
action if np.random.random() > epsilon else np.random.choice(np.where(mask == 1)[0])
for action, mask in zip(actions, masks)
],
)
def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str) -> ValueBasedPolicy:
@ -56,14 +58,18 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
name=name,
q_net=MyQNet(state_dim, action_num, num_features),
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
exploration_scheduling_options=[(
"epsilon", MultiLinearExplorationScheduler, {
"splits": [(100, 0.32)],
"initial_value": 0.4,
"last_ep": 400,
"final_value": 0.0,
}
)],
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(100, 0.32)],
"initial_value": 0.4,
"last_ep": 400,
"final_value": 0.0,
},
),
],
warmup=100,
)

Просмотреть файл

@ -3,7 +3,6 @@
from maro.simulator import Env
env_conf = {
"scenario": "vm_scheduling",
"topology": "azure.2019.10k",

Просмотреть файл

@ -15,7 +15,13 @@ from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
from .config import (
num_features, pm_attributes, pm_window_size, reward_shaping_conf, seed, test_reward_shaping_conf, test_seed,
num_features,
pm_attributes,
pm_window_size,
reward_shaping_conf,
seed,
test_reward_shaping_conf,
test_seed,
)
timestamp = str(time.time())
@ -31,13 +37,15 @@ class VMEnvSampler(AbsEnvSampler):
self._test_env.set_seed(test_seed)
# adjust the ratio of the success allocation and the total income when computing the reward
self.num_pms = self._learn_env.business_engine._pm_amount # the number of pms
self.num_pms = self._learn_env.business_engine._pm_amount # the number of pms
self._durations = self._learn_env.business_engine._max_tick
self._pm_state_history = np.zeros((pm_window_size - 1, self.num_pms, 2))
self._legal_pm_mask = None
def _get_global_and_agent_state_impl(
self, event: DecisionPayload, tick: int = None,
self,
event: DecisionPayload,
tick: int = None,
) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
pm_state, vm_state = self._get_pm_state(), self._get_vm_state(event)
# get the legal number of PM.
@ -61,7 +69,9 @@ class VMEnvSampler(AbsEnvSampler):
return None, {"AGENT": state}
def _translate_to_env_action(
self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionPayload,
self,
action_dict: Dict[Any, Union[np.ndarray, List[object]]],
event: DecisionPayload,
) -> Dict[Any, object]:
if action_dict["AGENT"] == self.num_pms:
return {"AGENT": PostponeAction(vm_id=event.vm_id, postpone_step=1)}
@ -71,17 +81,17 @@ class VMEnvSampler(AbsEnvSampler):
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionPayload, tick: int) -> Dict[Any, float]:
action = env_action_dict["AGENT"]
conf = reward_shaping_conf if self._env == self._learn_env else test_reward_shaping_conf
if isinstance(action, PostponeAction): # postponement
if isinstance(action, PostponeAction): # postponement
if np.sum(self._legal_pm_mask) != 1:
reward = -0.1 * conf["alpha"] + 0.0 * conf["beta"]
else:
reward = 0.0 * conf["alpha"] + 0.0 * conf["beta"]
else:
reward = self._get_allocation_reward(event, conf["alpha"], conf["beta"]) if event else .0
reward = self._get_allocation_reward(event, conf["alpha"], conf["beta"]) if event else 0.0
return {"AGENT": np.float32(reward)}
def _get_pm_state(self):
total_pm_info = self._env.snapshot_list["pms"][self._env.frame_index::pm_attributes]
total_pm_info = self._env.snapshot_list["pms"][self._env.frame_index :: pm_attributes]
total_pm_info = total_pm_info.reshape(self.num_pms, len(pm_attributes))
# normalize the attributes of pms' cpu and memory
@ -99,21 +109,24 @@ class VMEnvSampler(AbsEnvSampler):
# get the sequence pms' information
self._pm_state_history = np.concatenate((self._pm_state_history, total_pm_info), axis=0)
return self._pm_state_history[-pm_window_size:, :, :] # (win_size, num_pms, 2)
return self._pm_state_history[-pm_window_size:, :, :] # (win_size, num_pms, 2)
def _get_vm_state(self, event):
return np.array([
event.vm_cpu_cores_requirement / self._max_cpu_capacity,
event.vm_memory_requirement / self._max_memory_capacity,
(self._durations - self._env.tick) * 1.0 / 200, # TODO: CHANGE 200 TO SOMETHING CONFIGURABLE
self._env.business_engine._get_unit_price(event.vm_cpu_cores_requirement, event.vm_memory_requirement)
])
return np.array(
[
event.vm_cpu_cores_requirement / self._max_cpu_capacity,
event.vm_memory_requirement / self._max_memory_capacity,
(self._durations - self._env.tick) * 1.0 / 200, # TODO: CHANGE 200 TO SOMETHING CONFIGURABLE
self._env.business_engine._get_unit_price(event.vm_cpu_cores_requirement, event.vm_memory_requirement),
],
)
def _get_allocation_reward(self, event: DecisionPayload, alpha: float, beta: float):
vm_unit_price = self._env.business_engine._get_unit_price(
event.vm_cpu_cores_requirement, event.vm_memory_requirement
event.vm_cpu_cores_requirement,
event.vm_memory_requirement,
)
return (alpha + beta * vm_unit_price * min(self._durations - event.frame_index, event.remaining_buffer_time))
return alpha + beta * vm_unit_price * min(self._durations - event.frame_index, event.remaining_buffer_time)
def _post_step(self, cache_element: CacheElement) -> None:
self._info["env_metric"] = {k: v for k, v in self._env.metrics.items() if k != "total_latency"}
@ -127,7 +140,9 @@ class VMEnvSampler(AbsEnvSampler):
action = cache_element.action_dict["AGENT"]
if cache_element.state:
mask = cache_element.state[num_features:]
self._info["actions_by_core_requirement"][cache_element.event.vm_cpu_cores_requirement].append([action, mask])
self._info["actions_by_core_requirement"][cache_element.event.vm_cpu_cores_requirement].append(
[action, mask],
)
self._info["action_sequence"].append(action)
def _post_eval_step(self, cache_element: CacheElement) -> None:

Просмотреть файл

@ -1,10 +1,11 @@
from functools import partial
from typing import Any, Callable, Dict, Optional
from examples.vm_scheduling.rl.algorithms.ac import get_ac_policy
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn_policy
from examples.vm_scheduling.rl.algorithms.ac import get_ac, get_ac_policy
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn, get_dqn_policy
from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf
from examples.vm_scheduling.rl.env_sampler import VMEnvSampler
from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
@ -30,14 +31,22 @@ class VMBundle(RLComponentBundle):
if algorithm == "ac":
policy_creator = {
f"{algorithm}.policy": partial(
get_ac_policy, state_dim, action_num, num_features, f"{algorithm}.policy",
)
get_ac_policy,
state_dim,
action_num,
num_features,
f"{algorithm}.policy",
),
}
elif algorithm == "dqn":
policy_creator = {
f"{algorithm}.policy": partial(
get_dqn_policy, state_dim, action_num, num_features, f"{algorithm}.policy",
)
get_dqn_policy,
state_dim,
action_num,
num_features,
f"{algorithm}.policy",
),
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
@ -46,10 +55,8 @@ class VMBundle(RLComponentBundle):
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
if algorithm == "ac":
from .algorithms.ac import get_ac, get_ac_policy
trainer_creator = {algorithm: partial(get_ac, state_dim, num_features, algorithm)}
elif algorithm == "dqn":
from .algorithms.dqn import get_dqn, get_dqn_policy
trainer_creator = {algorithm: partial(get_dqn, algorithm)}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")

Просмотреть файл

@ -1,6 +1,6 @@
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling.common import Action
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
from maro.simulator.scenarios.vm_scheduling.common import Action
class VMSchedulingAgent(object):
@ -8,15 +8,14 @@ class VMSchedulingAgent(object):
self._algorithm = algorithm
def choose_action(self, decision_event: DecisionPayload, env: Env) -> Action:
"""This method will determine whether to postpone the current VM or allocate a PM to the current VM.
"""
"""This method will determine whether to postpone the current VM or allocate a PM to the current VM."""
valid_pm_num: int = len(decision_event.valid_pms)
if valid_pm_num <= 0:
# No valid PM now, postpone.
action: PostponeAction = PostponeAction(
vm_id=decision_event.vm_id,
postpone_step=1
postpone_step=1,
)
else:
action: AllocateAction = self._algorithm.allocate_vm(decision_event, env)

Просмотреть файл

@ -1,10 +1,9 @@
import numpy as np
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from rule_based_algorithm import RuleBasedAlgorithm
class BestFit(RuleBasedAlgorithm):
def __init__(self, **kwargs):
@ -17,7 +16,7 @@ class BestFit(RuleBasedAlgorithm):
# Take action to allocate on the chose PM.
action: AllocateAction = AllocateAction(
vm_id=decision_event.vm_id,
pm_id=decision_event.valid_pms[chosen_idx]
pm_id=decision_event.valid_pms[chosen_idx],
)
return action
@ -25,8 +24,12 @@ class BestFit(RuleBasedAlgorithm):
def _pick_pm_func(self, decision_event, env) -> int:
# Get the capacity and allocated cores from snapshot.
valid_pm_info = env.snapshot_list["pms"][
env.frame_index:decision_event.valid_pms:[
"cpu_cores_capacity", "cpu_cores_allocated", "memory_capacity", "memory_allocated", "energy_consumption"
env.frame_index : decision_event.valid_pms : [
"cpu_cores_capacity",
"cpu_cores_allocated",
"memory_capacity",
"memory_allocated",
"energy_consumption",
]
].reshape(-1, 5)
# Calculate to get the remaining cpu cores.
@ -37,23 +40,19 @@ class BestFit(RuleBasedAlgorithm):
energy_consumption = valid_pm_info[:, 4]
# Choose the PM with the preference rule.
chosen_idx: int = 0
if self._metric_type == 'remaining_cpu_cores':
if self._metric_type == "remaining_cpu_cores":
chosen_idx = np.argmin(cpu_cores_remaining)
elif self._metric_type == 'remaining_memory':
elif self._metric_type == "remaining_memory":
chosen_idx = np.argmin(memory_remaining)
elif self._metric_type == 'energy_consumption':
elif self._metric_type == "energy_consumption":
chosen_idx = np.argmax(energy_consumption)
elif self._metric_type == 'remaining_cpu_cores_and_energy_consumption':
elif self._metric_type == "remaining_cpu_cores_and_energy_consumption":
maximum_energy_consumption = energy_consumption[0]
minimum_remaining_cpu_cores = cpu_cores_remaining[0]
for i, remaining in enumerate(cpu_cores_remaining):
energy = energy_consumption[i]
if (
remaining < minimum_remaining_cpu_cores
or (
remaining == minimum_remaining_cpu_cores
and energy > maximum_energy_consumption
)
if remaining < minimum_remaining_cpu_cores or (
remaining == minimum_remaining_cpu_cores and energy > maximum_energy_consumption
):
chosen_idx = i
minimum_remaining_cpu_cores = remaining

Просмотреть файл

@ -1,21 +1,19 @@
import random
import numpy as np
from yaml import safe_load
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from maro.utils.utils import convert_dottable
from rule_based_algorithm import RuleBasedAlgorithm
class BinPacking(RuleBasedAlgorithm):
def __init__(self, **kwargs):
super().__init__()
self._max_cpu_oversubscription_rate: float = kwargs["env"].configs.MAX_CPU_OVERSUBSCRIPTION_RATE
total_pm_cpu_info = kwargs["env"].snapshot_list["pms"][
kwargs["env"].frame_index::["cpu_cores_capacity"]
].reshape(-1)
total_pm_cpu_info = (
kwargs["env"].snapshot_list["pms"][kwargs["env"].frame_index :: ["cpu_cores_capacity"]].reshape(-1)
)
self._pm_num: int = total_pm_cpu_info.shape[0]
self._pm_cpu_core_num: int = int(np.max(total_pm_cpu_info) * self._max_cpu_oversubscription_rate)
@ -29,7 +27,7 @@ class BinPacking(RuleBasedAlgorithm):
# Get the number of PM, maximum CPU core and max cpu oversubscription rate.
total_pm_info = env.snapshot_list["pms"][
env.frame_index::["cpu_cores_capacity", "cpu_cores_allocated"]
env.frame_index :: ["cpu_cores_capacity", "cpu_cores_allocated"]
].reshape(-1, 2)
cpu_cores_remaining = total_pm_info[:, 0] * self._max_cpu_oversubscription_rate - total_pm_info[:, 1]
@ -57,7 +55,7 @@ class BinPacking(RuleBasedAlgorithm):
# Take action to allocate on the chosen pm.
action: AllocateAction = AllocateAction(
vm_id=decision_event.vm_id,
pm_id=chosen_idx
pm_id=chosen_idx,
)
return action

Просмотреть файл

@ -1,8 +1,8 @@
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from rule_based_algorithm import RuleBasedAlgorithm
class FirstFit(RuleBasedAlgorithm):
def __init__(self, **kwargs):
@ -14,7 +14,7 @@ class FirstFit(RuleBasedAlgorithm):
# Take action to allocate on the chose PM.
action: AllocateAction = AllocateAction(
vm_id=decision_event.vm_id,
pm_id=chosen_idx
pm_id=chosen_idx,
)
return action

Просмотреть файл

@ -1,19 +1,18 @@
import importlib
import io
import os
import random
import timeit
import yaml
import importlib
from agent import VMSchedulingAgent
from maro.simulator import Env
from maro.utils import convert_dottable
from agent import VMSchedulingAgent
def import_class(name):
components = name.rsplit('.', 1)
components = name.rsplit(".", 1)
mod = importlib.import_module(components[0])
mod = getattr(mod, components[1])
return mod
@ -33,7 +32,7 @@ if __name__ == "__main__":
topology=config.env.topology,
start_tick=config.env.start_tick,
durations=config.env.durations,
snapshot_resolution=config.env.resolution
snapshot_resolution=config.env.resolution,
)
if config.env.seed is not None:
@ -57,7 +56,7 @@ if __name__ == "__main__":
end_time = timeit.default_timer()
print(
f"[{config.algorithm.type.split('.')[1]}] Topology: {config.env.topology}. Total ticks: {config.env.durations}."
f" Start tick: {config.env.start_tick}."
f" Start tick: {config.env.start_tick}.",
)
print(f"[Timer] {end_time - start_time:.2f} seconds to finish the simulation.")
print(metrics)

Просмотреть файл

@ -1,15 +1,15 @@
import random
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from rule_based_algorithm import RuleBasedAlgorithm
class RandomPick(RuleBasedAlgorithm):
def __init__(self, **kwargs):
super().__init__()
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
valid_pm_num: int = len(decision_event.valid_pms)
# Random choose a valid PM.
@ -17,7 +17,7 @@ class RandomPick(RuleBasedAlgorithm):
# Take action to allocate on the chosen PM.
action: AllocateAction = AllocateAction(
vm_id=decision_event.vm_id,
pm_id=decision_event.valid_pms[chosen_idx]
pm_id=decision_event.valid_pms[chosen_idx],
)
return action

Просмотреть файл

@ -1,14 +1,16 @@
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload
from rule_based_algorithm import RuleBasedAlgorithm
class RoundRobin(RuleBasedAlgorithm):
def __init__(self, **kwargs):
super().__init__()
self._prev_idx: int = 0
self._pm_num: int = kwargs["env"].snapshot_list["pms"][kwargs["env"].frame_index::["cpu_cores_capacity"]].shape[0]
self._pm_num: int = (
kwargs["env"].snapshot_list["pms"][kwargs["env"].frame_index :: ["cpu_cores_capacity"]].shape[0]
)
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
# Choose the valid PM which index is next to the previous chose PM's index
@ -21,7 +23,7 @@ class RoundRobin(RuleBasedAlgorithm):
# Take action to allocate on the chosen PM.
action: AllocateAction = AllocateAction(
vm_id=decision_event.vm_id,
pm_id=chosen_idx
pm_id=chosen_idx,
)
return action

Просмотреть файл

@ -9,6 +9,5 @@ class RuleBasedAlgorithm(object):
@abc.abstractmethod
def allocate_vm(self, decision_event: DecisionPayload, env: Env) -> AllocateAction:
"""This method will determine allocate which PM to the current VM.
"""
pass
"""This method will determine allocate which PM to the current VM."""
raise NotImplementedError

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# isort: skip_file
from .__misc__ import __data_version__, __version__
from maro.utils.utils import check_deployment_status, deploy

Просмотреть файл

@ -6,21 +6,37 @@
#distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
import numpy as np
cimport numpy as np
cimport cython
cimport cython
cimport numpy as np
from cpython cimport bool
from cython cimport view
from cython.operator cimport dereference as deref
from cpython cimport bool
from libcpp cimport bool as cppbool
from libcpp.map cimport map
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
INT, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX,
ATTR_CHAR, ATTR_UCHAR, ATTR_SHORT, ATTR_USHORT, ATTR_INT, ATTR_UINT,
ATTR_LONG, ATTR_ULONG, ATTR_FLOAT, ATTR_DOUBLE)
from maro.backends.backend cimport (
ATTR_CHAR,
ATTR_DOUBLE,
ATTR_FLOAT,
ATTR_INT,
ATTR_LONG,
ATTR_SHORT,
ATTR_TYPE,
ATTR_UCHAR,
ATTR_UINT,
ATTR_ULONG,
ATTR_USHORT,
INT,
NODE_INDEX,
NODE_TYPE,
SLOT_INDEX,
UINT,
ULONG,
AttributeType,
BackendAbc,
SnapshotListAbc,
)
# Ensure numpy will not crash, as we use numpy as query result
np.import_array()

Просмотреть файл

@ -5,8 +5,7 @@
#distutils: language = c++
from cpython cimport bool
from libc.stdint cimport int32_t, int64_t, int16_t, int8_t, uint32_t, uint64_t
from libc.stdint cimport int8_t, int16_t, int32_t, int64_t, uint32_t, uint64_t
# common types

Просмотреть файл

@ -5,8 +5,10 @@
#distutils: language = c++
from enum import Enum
from cpython cimport bool
cdef class AttributeType:
Byte = b"byte"
UByte = b"ubyte"

Просмотреть файл

@ -6,8 +6,19 @@
from cpython cimport bool
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
INT, USHORT, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
from maro.backends.backend cimport (
ATTR_TYPE,
INT,
NODE_INDEX,
NODE_TYPE,
SLOT_INDEX,
UINT,
ULONG,
USHORT,
AttributeType,
BackendAbc,
SnapshotListAbc,
)
cdef class SnapshotList:

Просмотреть файл

@ -8,31 +8,42 @@
import os
cimport cython
cimport numpy as np
import numpy as np
from cpython cimport bool
from typing import Union
from maro.backends.backend cimport (
ATTR_TYPE,
INT,
NODE_INDEX,
NODE_TYPE,
SLOT_INDEX,
UINT,
ULONG,
USHORT,
AttributeType,
BackendAbc,
SnapshotListAbc,
)
from maro.backends.np_backend cimport NumpyBackend
from maro.backends.raw_backend cimport RawBackend
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
INT, UINT, ULONG, USHORT, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
from maro.utils.exception.backends_exception import (
BackendsGetItemInvalidException,
BackendsSetItemInvalidException,
BackendsArrayAttributeAccessException,
BackendsAppendToNonListAttributeException,
BackendsResizeNonListAttributeException,
BackendsClearNonListAttributeException,
BackendsInsertNonListAttributeException,
BackendsRemoveFromNonListAttributeException,
BackendsAccessDeletedNodeException,
BackendsAppendToNonListAttributeException,
BackendsArrayAttributeAccessException,
BackendsClearNonListAttributeException,
BackendsGetItemInvalidException,
BackendsInsertNonListAttributeException,
BackendsInvalidAttributeException,
BackendsInvalidNodeException,
BackendsInvalidAttributeException
BackendsRemoveFromNonListAttributeException,
BackendsResizeNonListAttributeException,
BackendsSetItemInvalidException,
)
# Old type definition mapping.

Просмотреть файл

@ -5,11 +5,21 @@
#distutils: language = c++
import numpy as np
cimport numpy as np
cimport cython
cimport cython
cimport numpy as np
from cpython cimport bool
from maro.backends.backend cimport BackendAbc, SnapshotListAbc, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX
from maro.backends.backend cimport (
ATTR_TYPE,
NODE_INDEX,
NODE_TYPE,
SLOT_INDEX,
UINT,
ULONG,
BackendAbc,
SnapshotListAbc,
)
cdef class NumpyBackend(BackendAbc):

Просмотреть файл

@ -8,13 +8,24 @@
import os
import numpy as np
cimport numpy as np
cimport cython
cimport numpy as np
from cpython cimport bool
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
INT, UINT, ULONG, USHORT, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX)
from maro.backends.backend cimport (
ATTR_TYPE,
INT,
NODE_INDEX,
NODE_TYPE,
SLOT_INDEX,
UINT,
ULONG,
USHORT,
AttributeType,
BackendAbc,
SnapshotListAbc,
)
# Attribute data type mapping.
attribute_type_mapping = {
@ -44,11 +55,10 @@ attribute_type_range = {
IF NODES_MEMORY_LAYOUT == "ONE_BLOCK":
# with this flag, we will allocate a big enough memory for all node types, then use this block construct numpy array
from cpython cimport Py_INCREF, PyObject, PyTypeObject
from cpython.mem cimport PyMem_Free, PyMem_Malloc
from libc.string cimport memset
from cpython cimport PyObject, Py_INCREF, PyTypeObject
from cpython.mem cimport PyMem_Malloc, PyMem_Free
# we need this to avoid seg fault
np.import_array()

Просмотреть файл

@ -5,14 +5,29 @@
#distutils: language = c++
cimport cython
from cpython cimport bool
from libcpp cimport bool as cppbool
from libcpp.string cimport string
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, INT, UINT, ULONG, USHORT, NODE_INDEX, SLOT_INDEX,
ATTR_CHAR, ATTR_SHORT, ATTR_INT, ATTR_LONG, ATTR_FLOAT, ATTR_DOUBLE, QUERY_FLOAT, ATTR_TYPE, NODE_TYPE)
from maro.backends.backend cimport (
ATTR_CHAR,
ATTR_DOUBLE,
ATTR_FLOAT,
ATTR_INT,
ATTR_LONG,
ATTR_SHORT,
ATTR_TYPE,
INT,
NODE_INDEX,
NODE_TYPE,
QUERY_FLOAT,
SLOT_INDEX,
UINT,
ULONG,
USHORT,
BackendAbc,
SnapshotListAbc,
)
cdef extern from "raw/common.h" namespace "maro::backends::raw":

Просмотреть файл

@ -8,21 +8,37 @@
import warnings
import numpy as np
cimport numpy as np
cimport cython
cimport cython
cimport numpy as np
from cpython cimport bool
from cython cimport view
from cython.operator cimport dereference as deref
from cpython cimport bool
from libcpp cimport bool as cppbool
from libcpp.map cimport map
from maro.backends.backend cimport (BackendAbc, SnapshotListAbc, AttributeType,
INT, UINT, ULONG, NODE_TYPE, ATTR_TYPE, NODE_INDEX, SLOT_INDEX,
ATTR_CHAR, ATTR_UCHAR, ATTR_SHORT, ATTR_USHORT, ATTR_INT, ATTR_UINT,
ATTR_LONG, ATTR_ULONG, ATTR_FLOAT, ATTR_DOUBLE)
from maro.backends.backend cimport (
ATTR_CHAR,
ATTR_DOUBLE,
ATTR_FLOAT,
ATTR_INT,
ATTR_LONG,
ATTR_SHORT,
ATTR_TYPE,
ATTR_UCHAR,
ATTR_UINT,
ATTR_ULONG,
ATTR_USHORT,
INT,
NODE_INDEX,
NODE_TYPE,
SLOT_INDEX,
UINT,
ULONG,
AttributeType,
BackendAbc,
SnapshotListAbc,
)
# Ensure numpy will not crash, as we use numpy as query result
np.import_array()

Просмотреть файл

@ -83,9 +83,8 @@ class CitiBikePipeline(DataPipeline):
with zipfile.ZipFile(self._download_file, "r") as zip_ref:
for filename in zip_ref.namelist():
# Only one csv file is expected.
if (
filename.endswith(".csv") and
(not (filename.startswith("__MACOSX") or filename.startswith(".")))
if filename.endswith(".csv") and (
not (filename.startswith("__MACOSX") or filename.startswith("."))
):
logger.info_green(f"Unzip {filename} from {self._download_file}.")
@ -106,10 +105,13 @@ class CitiBikePipeline(DataPipeline):
with open(self._station_info_file, mode="r", encoding="utf-8") as station_file:
# read station to station file
raw_station_data = pd.DataFrame.from_dict(pd.read_json(station_file)["data"]["stations"])
station_data = raw_station_data.rename(columns={
"lon": "station_longitude",
"lat": "station_latitude",
"region_id": "region"})
station_data = raw_station_data.rename(
columns={
"lon": "station_longitude",
"lat": "station_latitude",
"region_id": "region",
},
)
# group by station to generate station init info
full_stations = station_data[
@ -124,7 +126,8 @@ class CitiBikePipeline(DataPipeline):
full_stations["station_latitude"] = pd.to_numeric(full_stations["station_latitude"], downcast="float")
full_stations.drop(full_stations[full_stations["capacity"] == 0].index, axis=0, inplace=True)
full_stations.dropna(
subset=["station_id", "capacity", "station_longitude", "station_latitude"], inplace=True
subset=["station_id", "capacity", "station_longitude", "station_latitude"],
inplace=True,
)
self._common_data["full_stations"] = full_stations
@ -141,13 +144,24 @@ class CitiBikePipeline(DataPipeline):
with open(file, "r", encoding="utf-8", errors="ignore") as fp:
ret = pd.read_csv(fp)
ret = ret[[
"tripduration", "starttime", "start station id", "end station id", "start station latitude",
"start station longitude", "end station latitude", "end station longitude", "gender", "usertype",
"bikeid"
]]
ret = ret[
[
"tripduration",
"starttime",
"start station id",
"end station id",
"start station latitude",
"start station longitude",
"end station latitude",
"end station longitude",
"gender",
"usertype",
"bikeid",
]
]
ret["tripduration"] = pd.to_numeric(
pd.to_numeric(ret["tripduration"], downcast="integer") / 60, downcast="integer"
pd.to_numeric(ret["tripduration"], downcast="integer") / 60,
downcast="integer",
)
ret["starttime"] = pd.to_datetime(ret["starttime"])
ret["start station id"] = pd.to_numeric(ret["start station id"], errors="coerce", downcast="integer")
@ -158,23 +172,34 @@ class CitiBikePipeline(DataPipeline):
ret["end station longitude"] = pd.to_numeric(ret["end station longitude"], downcast="float")
ret["bikeid"] = pd.to_numeric(ret["bikeid"], errors="coerce", downcast="integer")
ret["gender"] = pd.to_numeric(ret["gender"], errors="coerce", downcast="integer")
ret["usertype"] = ret["usertype"].apply(str).apply(
lambda x: 0 if x in ["Subscriber", "subscriber"] else 1 if x in ["Customer", "customer"] else 2
ret["usertype"] = (
ret["usertype"]
.apply(str)
.apply(
lambda x: 0 if x in ["Subscriber", "subscriber"] else 1 if x in ["Customer", "customer"] else 2,
)
)
ret.dropna(
subset=[
"start station id",
"end station id",
"start station latitude",
"end station latitude",
"start station longitude",
"end station longitude",
],
inplace=True,
)
ret.dropna(subset=[
"start station id", "end station id", "start station latitude", "end station latitude",
"start station longitude", "end station longitude"
], inplace=True)
ret.drop(
ret[
(ret["tripduration"] <= 1) |
(ret["start station latitude"] == 0) |
(ret["start station longitude"] == 0) |
(ret["end station latitude"] == 0) |
(ret["end station longitude"] == 0)
(ret["tripduration"] <= 1)
| (ret["start station latitude"] == 0)
| (ret["start station longitude"] == 0)
| (ret["end station latitude"] == 0)
| (ret["end station longitude"] == 0)
].index,
axis=0,
inplace=True
inplace=True,
)
ret = ret.sort_values(by="starttime", ascending=True)
@ -184,16 +209,16 @@ class CitiBikePipeline(DataPipeline):
used_bikes = len(src_data[["bikeid"]].drop_duplicates(subset=["bikeid"]))
trip_data = src_data[
(src_data["start station latitude"] > 40.689960) &
(src_data["start station latitude"] < 40.768334) &
(src_data["start station longitude"] > -74.019623) &
(src_data["start station longitude"] < -73.909760)
(src_data["start station latitude"] > 40.689960)
& (src_data["start station latitude"] < 40.768334)
& (src_data["start station longitude"] > -74.019623)
& (src_data["start station longitude"] < -73.909760)
]
trip_data = trip_data[
(trip_data["end station latitude"] > 40.689960) &
(trip_data["end station latitude"] < 40.768334) &
(trip_data["end station longitude"] > -74.019623) &
(trip_data["end station longitude"] < -73.909760)
(trip_data["end station latitude"] > 40.689960)
& (trip_data["end station latitude"] < 40.768334)
& (trip_data["end station longitude"] > -74.019623)
& (trip_data["end station longitude"] < -73.909760)
]
trip_data["start_station_id"] = trip_data["start station id"]
@ -202,25 +227,40 @@ class CitiBikePipeline(DataPipeline):
# get new stations
used_stations = []
used_stations.append(
trip_data[["start_station_id", "start station latitude", "start station longitude", ]].drop_duplicates(
subset=["start_station_id"]).rename(
columns={
"start_station_id": "station_id",
"start station latitude": "latitude",
"start station longitude": "longitude"
}))
trip_data[["start_station_id", "start station latitude", "start station longitude"]]
.drop_duplicates(
subset=["start_station_id"],
)
.rename(
columns={
"start_station_id": "station_id",
"start station latitude": "latitude",
"start station longitude": "longitude",
},
),
)
used_stations.append(
trip_data[["end_station_id", "end station latitude", "end station longitude", ]].drop_duplicates(
subset=["end_station_id"]).rename(
columns={
"end_station_id": "station_id",
"end station latitude": "latitude",
"end station longitude": "longitude"
}))
trip_data[["end_station_id", "end station latitude", "end station longitude"]]
.drop_duplicates(
subset=["end_station_id"],
)
.rename(
columns={
"end_station_id": "station_id",
"end station latitude": "latitude",
"end station longitude": "longitude",
},
),
)
in_data_station = pd.concat(used_stations, ignore_index=True).drop_duplicates(
subset=["station_id"]
).sort_values(by=["station_id"]).reset_index(drop=True)
in_data_station = (
pd.concat(used_stations, ignore_index=True)
.drop_duplicates(
subset=["station_id"],
)
.sort_values(by=["station_id"])
.reset_index(drop=True)
)
stations_existed = pd.DataFrame(in_data_station[["station_id"]])
@ -229,11 +269,11 @@ class CitiBikePipeline(DataPipeline):
# get start station id and end station id
trip_data = trip_data.join(
stations_existed.set_index("station_id"),
on="start_station_id"
on="start_station_id",
).rename(columns={"station_index": "start_station_index"})
trip_data = trip_data.join(
stations_existed.set_index("station_id"),
on="end_station_id"
on="end_station_id",
).rename(columns={"station_index": "end_station_index"})
trip_data = trip_data.rename(columns={"starttime": "start_time", "tripduration": "duration"})
@ -244,13 +284,17 @@ class CitiBikePipeline(DataPipeline):
return trip_data, used_bikes, in_data_station, stations_existed
def _process_current_topo_station_info(
self, stations_existed: pd.DataFrame, used_bikes: int, loc_ref: pd.DataFrame):
self,
stations_existed: pd.DataFrame,
used_bikes: int,
loc_ref: pd.DataFrame,
):
data_station_init = stations_existed.join(
self._common_data["full_stations"][["station_id", "capacity"]].set_index("station_id"),
on="station_id"
on="station_id",
).join(
loc_ref[["station_id", "latitude", "longitude"]].set_index("station_id"),
on="station_id"
on="station_id",
)
# data_station_init.rename(columns={"station_id": "station_index"}, inplace=True)
avg_capacity = int(self._common_data["full_dock_num"] / self._common_data["full_station_num"])
@ -259,22 +303,36 @@ class CitiBikePipeline(DataPipeline):
data_station_init.fillna(value=values, inplace=True)
data_station_init["init"] = (data_station_init["capacity"] * avalible_bike_rate).round().apply(int)
data_station_init["capacity"] = pd.to_numeric(
data_station_init["capacity"], errors="coerce", downcast="integer"
data_station_init["capacity"],
errors="coerce",
downcast="integer",
)
data_station_init["station_id"] = pd.to_numeric(
data_station_init["station_id"], errors="coerce", downcast="integer"
data_station_init["station_id"],
errors="coerce",
downcast="integer",
)
return data_station_init
def _process_distance(self, station_info: pd.DataFrame):
distance_adj = pd.DataFrame(0, index=station_info["station_index"],
columns=station_info["station_index"], dtype=np.float)
distance_adj = pd.DataFrame(
0,
index=station_info["station_index"],
columns=station_info["station_index"],
dtype=np.float,
)
look_up_df = station_info[["latitude", "longitude"]]
return distance_adj.apply(lambda x: pd.DataFrame(x).apply(lambda y: geopy.distance.distance(
(look_up_df.at[x.name, "latitude"], look_up_df.at[x.name, "longitude"]),
(look_up_df.at[y.name, "latitude"], look_up_df.at[y.name, "longitude"])
).km, axis=1), axis=1)
return distance_adj.apply(
lambda x: pd.DataFrame(x).apply(
lambda y: geopy.distance.distance(
(look_up_df.at[x.name, "latitude"], look_up_df.at[x.name, "longitude"]),
(look_up_df.at[y.name, "latitude"], look_up_df.at[y.name, "longitude"]),
).km,
axis=1,
),
axis=1,
)
def _preprocess(self, unzipped_file: str):
self._read_common_data()
@ -292,7 +350,9 @@ class CitiBikePipeline(DataPipeline):
logger.info_green("Processing station info data.")
station_info = self._process_current_topo_station_info(
stations_existed=stations_existed, used_bikes=used_bikes, loc_ref=in_data_station
stations_existed=stations_existed,
used_bikes=used_bikes,
loc_ref=in_data_station,
)
with open(self._station_meta_file, mode="w", encoding="utf-8", newline="") as f:
station_info.to_csv(f, index=False, header=True)
@ -415,7 +475,13 @@ class CitiBikeTopology(DataTopology):
"""
def __init__(
self, topology: str, trip_source: str, station_info: str, weather_source: str, is_temp: bool = False):
self,
topology: str,
trip_source: str,
station_info: str,
weather_source: str,
is_temp: bool = False,
):
super().__init__()
self._data_pipeline["trip"] = CitiBikePipeline(topology, trip_source, station_info, is_temp)
self._data_pipeline["weather"] = NOAAWeatherPipeline(topology, weather_source, is_temp)
@ -453,7 +519,14 @@ class CitiBikeToyPipeline(DataPipeline):
_meta_file_name = "trips.yml"
def __init__(
self, start_time: str, end_time: str, stations: list, trips: list, topology: str, is_temp: bool = False):
self,
start_time: str,
end_time: str,
stations: list,
trips: list,
topology: str,
is_temp: bool = False,
):
super().__init__("citi_bike", topology, "", is_temp)
self._start_time = start_time
self._end_time = end_time
@ -465,7 +538,6 @@ class CitiBikeToyPipeline(DataPipeline):
def download(self, is_force: bool):
"""Toy datapipeline not need download process."""
pass
def _station_dict_to_pd(self, station_dict):
"""Convert dictionary of station information to pd series."""
@ -477,7 +549,8 @@ class CitiBikeToyPipeline(DataPipeline):
station_dict["lat"],
station_dict["lon"],
],
index=["station_index", "capacity", "init", "latitude", "longitude"])
index=["station_index", "capacity", "init", "latitude", "longitude"],
)
def _gen_stations(self):
"""Generate station meta csv."""
@ -523,10 +596,14 @@ class CitiBikeToyPipeline(DataPipeline):
trips_df = pd.DataFrame.from_dict(trips)
trips_df["start_station_index"] = pd.to_numeric(
trips_df["start_station_index"], errors="coerce", downcast="integer"
trips_df["start_station_index"],
errors="coerce",
downcast="integer",
)
trips_df["end_station_index"] = pd.to_numeric(
trips_df["end_station_index"], errors="coerce", downcast="integer"
trips_df["end_station_index"],
errors="coerce",
downcast="integer",
)
self._new_file_list.append(self._clean_file)
with open(self._clean_file, "w", encoding="utf-8", newline="") as f:
@ -540,13 +617,19 @@ class CitiBikeToyPipeline(DataPipeline):
0,
index=station_init["station_index"],
columns=station_init["station_index"],
dtype=np.float
dtype=np.float,
)
look_up_df = station_init[["latitude", "longitude"]]
distance_df = distance_adj.apply(lambda x: pd.DataFrame(x).apply(lambda y: geopy.distance.distance(
(look_up_df.at[x.name, "latitude"], look_up_df.at[x.name, "longitude"]),
(look_up_df.at[y.name, "latitude"], look_up_df.at[y.name, "longitude"])
).km, axis=1), axis=1)
distance_df = distance_adj.apply(
lambda x: pd.DataFrame(x).apply(
lambda y: geopy.distance.distance(
(look_up_df.at[x.name, "latitude"], look_up_df.at[x.name, "longitude"]),
(look_up_df.at[y.name, "latitude"], look_up_df.at[y.name, "longitude"]),
).km,
axis=1,
),
axis=1,
)
self._new_file_list.append(self._distance_file)
with open(self._distance_file, "w", encoding="utf-8", newline="") as f:
distance_df.to_csv(f, index=False, header=True)
@ -589,7 +672,6 @@ class WeatherToyPipeline(WeatherPipeline):
def download(self, is_force: bool):
"""Toy datapipeline not need download process."""
pass
def clean(self):
"""Clean the original data file."""
@ -660,13 +742,13 @@ class CitiBikeToyTopology(DataTopology):
stations=cfg["stations"],
trips=cfg["trips"],
topology=topology,
is_temp=is_temp
is_temp=is_temp,
)
self._data_pipeline["weather"] = WeatherToyPipeline(
topology=topology,
start_time=cfg["start_time"],
end_time=cfg["end_time"],
is_temp=is_temp
is_temp=is_temp,
)
else:
logger.warning(f"Config file {config_path} for toy topology {topology} not found.")
@ -701,7 +783,7 @@ class CitiBikeProcess:
self.topologies[topology] = CitiBikeToyTopology(
topology=topology,
config_path=self._conf["trips"][topology]["toy_meta_path"],
is_temp=is_temp
is_temp=is_temp,
)
else:
self.topologies[topology] = CitiBikeTopology(
@ -709,7 +791,7 @@ class CitiBikeProcess:
trip_source=self._conf["trips"][topology]["trip_remote_url"],
station_info=self._conf["station_info"]["ny_station_info_url"],
weather_source=self._conf["weather"][topology]["noaa_weather_url"],
is_temp=is_temp
is_temp=is_temp,
)
@ -782,8 +864,8 @@ class NOAAWeatherPipeline(WeatherPipeline):
def _gen_fall_back_file(self):
fall_back_content = [
"\"STATION\",\"DATE\",\"AWND\",\"PRCP\",\"SNOW\",\"TMAX\",\"TMIN\"\n",
",,,,,,\n"
'"STATION","DATE","AWND","PRCP","SNOW","TMAX","TMIN"\n',
",,,,,,\n",
]
with open(self._download_file, mode="w", encoding="utf-8", newline="") as f:
f.writelines(fall_back_content)

Просмотреть файл

@ -16,7 +16,7 @@ scenario_map["vm_scheduling"] = VmSchedulingProcess
def generate(scenario: str, topology: str = "", forced: bool = False, **kwargs):
logger.info_green(
f"Generating data files for scenario {scenario} topology {topology}"
f" {'forced redownload.' if forced else ', not forced redownload.'}"
f" {'forced redownload.' if forced else ', not forced redownload.'}",
)
if scenario in scenario_map:
process = scenario_map[scenario]()

Просмотреть файл

@ -33,7 +33,7 @@ class VmSchedulingPipeline(DataPipeline):
_meta_file_name = "vmtable.yml"
# VM category includes three types, converting to 0, 1, 2.
_category_map = {'Delay-insensitive': 0, 'Interactive': 1, 'Unknown': 2}
_category_map = {"Delay-insensitive": 0, "Interactive": 1, "Unknown": 2}
def __init__(self, topology: str, source: str, sample: int, seed: int, is_temp: bool = False):
super().__init__(scenario="vm_scheduling", topology=topology, source=source, is_temp=is_temp)
@ -56,8 +56,8 @@ class VmSchedulingPipeline(DataPipeline):
aria2p.Client(
host="http://localhost",
port=6800,
secret=""
)
secret="",
),
)
self._download_file_list = []
@ -97,14 +97,14 @@ class VmSchedulingPipeline(DataPipeline):
with open(self._download_file, mode="r", encoding="utf-8") as urls:
for remote_url in urls.read().splitlines():
# Get the file name.
file_name = remote_url.split('/')[-1]
file_name = remote_url.split("/")[-1]
# Two kinds of required files "vmtable" and "vm_cpu_readings-" start with vm.
if file_name.startswith("vmtable"):
if (not is_force) and os.path.exists(self._vm_table_file):
logger.info_green(f"{self._vm_table_file} already exists, skipping download.")
else:
logger.info_green(f"Downloading vmtable from {remote_url} to {self._vm_table_file}.")
self.aria2.add_uris(uris=[remote_url], options={'dir': self._download_folder})
self.aria2.add_uris(uris=[remote_url], options={"dir": self._download_folder})
elif file_name.startswith("vm_cpu_readings") and num_files > 0:
num_files -= 1
@ -115,7 +115,7 @@ class VmSchedulingPipeline(DataPipeline):
logger.info_green(f"{cpu_readings_file} already exists, skipping download.")
else:
logger.info_green(f"Downloading cpu_readings from {remote_url} to {cpu_readings_file}.")
self.aria2.add_uris(uris=[remote_url], options={'dir': f"{self._download_folder}"})
self.aria2.add_uris(uris=[remote_url], options={"dir": f"{self._download_folder}"})
self._check_all_download_completed()
@ -149,7 +149,7 @@ class VmSchedulingPipeline(DataPipeline):
# Unzip gz file.
with gzip.open(original_file, mode="rb") as f_in:
logger.info_green(
f"Unzip {original_file} to {raw_file}."
f"Unzip {original_file} to {raw_file}.",
)
with open(raw_file, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
@ -184,63 +184,80 @@ class VmSchedulingPipeline(DataPipeline):
"""Process vmtable file."""
headers = [
'vmid', 'subscriptionid', 'deploymentid', 'vmcreated', 'vmdeleted', 'maxcpu', 'avgcpu', 'p95maxcpu',
'vmcategory', 'vmcorecountbucket', 'vmmemorybucket'
"vmid",
"subscriptionid",
"deploymentid",
"vmcreated",
"vmdeleted",
"maxcpu",
"avgcpu",
"p95maxcpu",
"vmcategory",
"vmcorecountbucket",
"vmmemorybucket",
]
required_headers = [
'vmid', 'subscriptionid', 'deploymentid', 'vmcreated', 'vmdeleted', 'vmcategory',
'vmcorecountbucket', 'vmmemorybucket'
"vmid",
"subscriptionid",
"deploymentid",
"vmcreated",
"vmdeleted",
"vmcategory",
"vmcorecountbucket",
"vmmemorybucket",
]
vm_table = pd.read_csv(raw_vm_table_file, header=None, index_col=False, names=headers)
vm_table = vm_table.loc[:, required_headers]
# Convert to tick by dividing by 300 (5 minutes).
vm_table['vmcreated'] = pd.to_numeric(vm_table['vmcreated'], errors="coerce", downcast="integer") // 300
vm_table['vmdeleted'] = pd.to_numeric(vm_table['vmdeleted'], errors="coerce", downcast="integer") // 300
vm_table["vmcreated"] = pd.to_numeric(vm_table["vmcreated"], errors="coerce", downcast="integer") // 300
vm_table["vmdeleted"] = pd.to_numeric(vm_table["vmdeleted"], errors="coerce", downcast="integer") // 300
# The lifetime of the VM is deleted time - created time + 1 (tick).
vm_table['lifetime'] = vm_table['vmdeleted'] - vm_table['vmcreated'] + 1
vm_table["lifetime"] = vm_table["vmdeleted"] - vm_table["vmcreated"] + 1
vm_table['vmcategory'] = vm_table['vmcategory'].map(self._category_map)
vm_table["vmcategory"] = vm_table["vmcategory"].map(self._category_map)
# Transform vmcorecount '>24' bucket to 32 and vmmemory '>64' to 128.
vm_table = vm_table.replace({'vmcorecountbucket': '>24'}, 32)
vm_table = vm_table.replace({'vmmemorybucket': '>64'}, 128)
vm_table['vmcorecountbucket'] = pd.to_numeric(
vm_table['vmcorecountbucket'], errors="coerce", downcast="integer"
vm_table = vm_table.replace({"vmcorecountbucket": ">24"}, 32)
vm_table = vm_table.replace({"vmmemorybucket": ">64"}, 128)
vm_table["vmcorecountbucket"] = pd.to_numeric(
vm_table["vmcorecountbucket"],
errors="coerce",
downcast="integer",
)
vm_table['vmmemorybucket'] = pd.to_numeric(vm_table['vmmemorybucket'], errors="coerce", downcast="integer")
vm_table["vmmemorybucket"] = pd.to_numeric(vm_table["vmmemorybucket"], errors="coerce", downcast="integer")
vm_table.dropna(inplace=True)
vm_table = vm_table.sort_values(by='vmcreated', ascending=True)
vm_table = vm_table.sort_values(by="vmcreated", ascending=True)
# Generate ID map.
vm_id_map = self._generate_id_map(vm_table['vmid'].unique())
sub_id_map = self._generate_id_map(vm_table['subscriptionid'].unique())
deployment_id_map = self._generate_id_map(vm_table['deploymentid'].unique())
vm_id_map = self._generate_id_map(vm_table["vmid"].unique())
sub_id_map = self._generate_id_map(vm_table["subscriptionid"].unique())
deployment_id_map = self._generate_id_map(vm_table["deploymentid"].unique())
id_maps = (vm_id_map, sub_id_map, deployment_id_map)
# Mapping IDs.
vm_table['vmid'] = vm_table['vmid'].map(vm_id_map)
vm_table['subscriptionid'] = vm_table['subscriptionid'].map(sub_id_map)
vm_table['deploymentid'] = vm_table['deploymentid'].map(deployment_id_map)
vm_table["vmid"] = vm_table["vmid"].map(vm_id_map)
vm_table["subscriptionid"] = vm_table["subscriptionid"].map(sub_id_map)
vm_table["deploymentid"] = vm_table["deploymentid"].map(deployment_id_map)
# Sampling the VM table.
# 2695548 is the total number of vms in the original Azure public dataset.
if self._sample < 2695548:
vm_table = vm_table.sample(n=self._sample, random_state=self._seed)
vm_table = vm_table.sort_values(by='vmcreated', ascending=True)
vm_table = vm_table.sort_values(by="vmcreated", ascending=True)
return id_maps, vm_table
def _convert_cpu_readings_id(self, old_data_path: str, new_data_path: str, vm_id_map: dict):
"""Convert vmid in each cpu readings file."""
with open(old_data_path, 'r') as f_in:
with open(old_data_path, "r") as f_in:
csv_reader = reader(f_in)
with open(new_data_path, 'w') as f_out:
with open(new_data_path, "w") as f_out:
csv_writer = writer(f_out)
csv_writer.writerow(['timestamp', 'vmid', 'maxcpu'])
csv_writer.writerow(["timestamp", "vmid", "maxcpu"])
for row in csv_reader:
# [timestamp, vmid, mincpu, maxcpu, avgcpu]
if row[1] in vm_id_map:
@ -248,12 +265,12 @@ class VmSchedulingPipeline(DataPipeline):
csv_writer.writerow(new_row)
def _write_id_map_to_csv(self, id_maps):
file_name = ['vm_id_map', 'sub_id_map', 'deployment_id_map']
file_name = ["vm_id_map", "sub_id_map", "deployment_id_map"]
for index in range(len(id_maps)):
id_map = id_maps[index]
with open(os.path.join(self._raw_folder, file_name[index]) + '.csv', 'w') as f:
with open(os.path.join(self._raw_folder, file_name[index]) + ".csv", "w") as f:
csv_writer = writer(f)
csv_writer.writerow(['original_id', 'new_id'])
csv_writer.writerow(["original_id", "new_id"])
for key, value in id_map.items():
csv_writer.writerow([key, value])
@ -288,7 +305,7 @@ class VmSchedulingPipeline(DataPipeline):
self._convert_cpu_readings_id(
old_data_path=raw_cpu_readings_file,
new_data_path=clean_cpu_readings_file,
vm_id_map=filtered_vm_id_map
vm_id_map=filtered_vm_id_map,
)
def build(self):
@ -313,7 +330,7 @@ class VmSchedulingTopology(DataTopology):
source=source,
sample=sample,
seed=seed,
is_temp=is_temp
is_temp=is_temp,
)
@ -336,5 +353,5 @@ class VmSchedulingProcess:
source=self._conf["vm_data"][topology]["remote_url"],
sample=self._conf["vm_data"][topology]["sample"],
seed=self._conf["vm_data"][topology]["seed"],
is_temp=is_temp
is_temp=is_temp,
)

Просмотреть файл

@ -77,12 +77,12 @@ class GrassAzureExecutor(GrassExecutor):
# Simultaneously capture image and init master
build_node_image_thread = threading.Thread(
target=GrassAzureExecutor._build_node_image,
args=(cluster_details,)
args=(cluster_details,),
)
build_node_image_thread.start()
create_and_init_master_thread = threading.Thread(
target=GrassAzureExecutor._create_and_init_master,
args=(cluster_details,)
args=(cluster_details,),
)
create_and_init_master_thread.start()
build_node_image_thread.join()
@ -125,14 +125,14 @@ class GrassAzureExecutor(GrassExecutor):
"root['connection']['ssh']": {"port": GlobalParams.DEFAULT_SSH_PORT},
"root['connection']['ssh']['port']": GlobalParams.DEFAULT_SSH_PORT,
"root['connection']['api_server']": {"port": GrassParams.DEFAULT_API_SERVER_PORT},
"root['connection']['api_server']['port']": GrassParams.DEFAULT_API_SERVER_PORT
"root['connection']['api_server']['port']": GrassParams.DEFAULT_API_SERVER_PORT,
}
with open(f"{GrassPaths.ABS_MARO_GRASS_LIB}/deployments/internal/grass_azure_create.yml") as fr:
create_deployment_template = yaml.safe_load(fr)
DeploymentValidator.validate_and_fill_dict(
template_dict=create_deployment_template,
actual_dict=create_deployment,
optional_key_to_value=optional_key_to_value
optional_key_to_value=optional_key_to_value,
)
# Init runtime fields.
@ -169,7 +169,7 @@ class GrassAzureExecutor(GrassExecutor):
else:
AzureController.create_resource_group(
resource_group=resource_group,
location=cluster_details["cloud"]["location"]
location=cluster_details["cloud"]["location"],
)
logger.info_green(f"Resource group '{resource_group}' is created")
@ -192,13 +192,13 @@ class GrassAzureExecutor(GrassExecutor):
)
ArmTemplateParameterBuilder.create_vnet(
cluster_details=cluster_details,
export_path=parameters_file_path
export_path=parameters_file_path,
)
AzureController.start_deployment(
resource_group=cluster_details["cloud"]["resource_group"],
deployment_name="vnet",
template_file_path=template_file_path,
parameters_file_path=parameters_file_path
parameters_file_path=parameters_file_path,
)
logger.info_green("Vnet is created")
@ -233,13 +233,13 @@ class GrassAzureExecutor(GrassExecutor):
ArmTemplateParameterBuilder.create_build_node_image_vm(
cluster_details=cluster_details,
node_size=cluster_details["master"]["node_size"],
export_path=parameters_file_path
export_path=parameters_file_path,
)
AzureController.start_deployment(
resource_group=cluster_details["cloud"]["resource_group"],
deployment_name=resource_name,
template_file_path=template_file_path,
parameters_file_path=parameters_file_path
parameters_file_path=parameters_file_path,
)
# Gracefully wait
time.sleep(10)
@ -247,7 +247,7 @@ class GrassAzureExecutor(GrassExecutor):
# Get public ip address
ip_addresses = AzureController.list_ip_addresses(
resource_group=cluster_details["cloud"]["resource_group"],
vm_name=vm_name
vm_name=vm_name,
)
public_ip_address = ip_addresses[0]["virtualMachine"]["network"]["publicIpAddresses"][0]["ipAddress"]
@ -255,7 +255,7 @@ class GrassAzureExecutor(GrassExecutor):
GrassAzureExecutor.retry_connection(
node_username=cluster_details["cloud"]["default_username"],
node_hostname=public_ip_address,
node_ssh_port=cluster_details["connection"]["ssh"]["port"]
node_ssh_port=cluster_details["connection"]["ssh"]["port"],
)
# Run init image script
@ -264,12 +264,12 @@ class GrassAzureExecutor(GrassExecutor):
remote_dir="~/",
node_username=cluster_details["cloud"]["default_username"],
node_hostname=public_ip_address,
node_ssh_port=cluster_details["connection"]["ssh"]["port"]
node_ssh_port=cluster_details["connection"]["ssh"]["port"],
)
GrassAzureExecutor.remote_init_build_node_image_vm(
node_username=cluster_details["cloud"]["default_username"],
node_hostname=public_ip_address,
node_ssh_port=cluster_details["connection"]["ssh"]["port"]
node_ssh_port=cluster_details["connection"]["ssh"]["port"],
)
# Extract image
@ -278,14 +278,14 @@ class GrassAzureExecutor(GrassExecutor):
AzureController.create_image_from_vm(
resource_group=cluster_details["cloud"]["resource_group"],
image_name=image_name,
vm_name=vm_name
vm_name=vm_name,
)
# Delete resources
GrassAzureExecutor._delete_resources(
resource_group=cluster_details["cloud"]["resource_group"],
resource_name=resource_name,
cluster_id=cluster_details["id"]
cluster_id=cluster_details["id"],
)
logger.info_green("MARO Node Image is built")
@ -313,7 +313,7 @@ class GrassAzureExecutor(GrassExecutor):
user_id=cluster_details["user"]["id"],
master_to_dev_encryption_private_key=cluster_details["user"]["master_to_dev_encryption_private_key"],
dev_to_master_encryption_public_key=cluster_details["user"]["dev_to_master_encryption_public_key"],
dev_to_master_signing_private_key=cluster_details["user"]["dev_to_master_signing_private_key"]
dev_to_master_signing_private_key=cluster_details["user"]["dev_to_master_signing_private_key"],
)
master_api_client.create_master(master_details=cluster_details["master"])
master_api_client.create_cluster(cluster_details=cluster_details)
@ -338,25 +338,24 @@ class GrassAzureExecutor(GrassExecutor):
# Create ARM parameters and start deployment
template_file_path = f"{GrassPaths.ABS_MARO_GRASS_LIB}/modes/azure/create_master/template.json"
parameters_file_path = (
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}"
f"/master/arm_create_master_parameters.json"
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}" f"/master/arm_create_master_parameters.json"
)
ArmTemplateParameterBuilder.create_master(
cluster_details=cluster_details,
node_size=cluster_details["master"]["node_size"],
export_path=parameters_file_path
export_path=parameters_file_path,
)
AzureController.start_deployment(
resource_group=cluster_details["cloud"]["resource_group"],
deployment_name="master",
template_file_path=template_file_path,
parameters_file_path=parameters_file_path
parameters_file_path=parameters_file_path,
)
# Get master IP addresses
ip_addresses = AzureController.list_ip_addresses(
resource_group=cluster_details["cloud"]["resource_group"],
vm_name=vm_name
vm_name=vm_name,
)
public_ip_address = ip_addresses[0]["virtualMachine"]["network"]["publicIpAddresses"][0]["ipAddress"]
private_ip_address = ip_addresses[0]["virtualMachine"]["network"]["privateIpAddresses"][0]
@ -433,12 +432,12 @@ class GrassAzureExecutor(GrassExecutor):
if node_size_to_count[node_size] > replicas:
self._delete_nodes(
num=node_size_to_count[node_size] - replicas,
node_size=node_size
node_size=node_size,
)
elif node_size_to_count[node_size] < replicas:
self._create_nodes(
num=replicas - node_size_to_count[node_size],
node_size=node_size
node_size=node_size,
)
else:
logger.warning_yellow("Replica is match, no create or delete")
@ -459,7 +458,7 @@ class GrassAzureExecutor(GrassExecutor):
with ThreadPool(GlobalParams.PARALLELS) as pool:
pool.starmap(
self._create_node,
[[node_size]] * num
[[node_size]] * num,
)
def _create_node(self, node_size: str) -> None:
@ -478,7 +477,7 @@ class GrassAzureExecutor(GrassExecutor):
# Create node
join_cluster_deployment = self._create_vm(
node_name=node_name,
node_size=node_size
node_size=node_size,
)
# Start joining cluster
@ -512,12 +511,12 @@ class GrassAzureExecutor(GrassExecutor):
with ThreadPool(GlobalParams.PARALLELS) as pool:
pool.starmap(
self._delete_node,
params
params,
)
else:
logger.warning_yellow(
"Unable to scale down.\n"
f"Only {len(deletable_nodes)} nodes are deletable, but need to delete {num} to meet the replica"
f"Only {len(deletable_nodes)} nodes are deletable, but need to delete {num} to meet the replica",
)
def _create_vm(self, node_name: str, node_size: str) -> dict:
@ -543,19 +542,19 @@ class GrassAzureExecutor(GrassExecutor):
node_name=node_name,
cluster_details=self.cluster_details,
node_size=node_size,
export_path=parameters_file_path
export_path=parameters_file_path,
)
AzureController.start_deployment(
resource_group=self.resource_group,
deployment_name=node_name,
template_file_path=template_file_path,
parameters_file_path=parameters_file_path
parameters_file_path=parameters_file_path,
)
# Get node IP addresses
ip_addresses = AzureController.list_ip_addresses(
resource_group=self.resource_group,
vm_name=f"{self.cluster_id}-{node_name}-vm"
vm_name=f"{self.cluster_id}-{node_name}-vm",
)
logger.info_green(f"VM '{node_name}' is created")
@ -566,11 +565,11 @@ class GrassAzureExecutor(GrassExecutor):
"master": {
"private_ip_address": self.master_private_ip_address,
"api_server": {
"port": self.master_api_server_port
"port": self.master_api_server_port,
},
"redis": {
"port": self.master_redis_port
}
"port": self.master_redis_port,
},
},
"node": {
"name": node_name,
@ -584,23 +583,23 @@ class GrassAzureExecutor(GrassExecutor):
"resources": {
"cpu": "all",
"memory": "all",
"gpu": "all"
"gpu": "all",
},
"api_server": {
"port": self.api_server_port
"port": self.api_server_port,
},
"ssh": {
"port": self.ssh_port
}
"port": self.ssh_port,
},
},
"configs": {
"install_node_runtime": False,
"install_node_gpu_support": False
}
"install_node_gpu_support": False,
},
}
with open(
file=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/nodes/{node_name}/join_cluster_deployment.yml",
mode="w"
mode="w",
) as fw:
yaml.safe_dump(data=join_cluster_deployment, stream=fw)
@ -624,13 +623,13 @@ class GrassAzureExecutor(GrassExecutor):
self._delete_resources(
resource_group=self.resource_group,
cluster_id=self.cluster_id,
resource_name=node_name
resource_name=node_name,
)
# Delete azure deployment
AzureController.delete_deployment(
resource_group=self.resource_group,
deployment_name=node_name
deployment_name=node_name,
)
# Delete node related files
@ -655,21 +654,22 @@ class GrassAzureExecutor(GrassExecutor):
self.retry_connection(
node_username=node_details["username"],
node_hostname=node_details["public_ip_address"],
node_ssh_port=node_details["ssh"]["port"]
node_ssh_port=node_details["ssh"]["port"],
)
# Copy required files
local_path_to_remote_dir = {
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/nodes/{node_name}/join_cluster_deployment.yml":
f"{GlobalPaths.MARO_LOCAL}/clusters/{self.cluster_name}/nodes/{node_name}"
}
local_path = (
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/nodes/{node_name}/" + "join_cluster_deployment.yml"
)
remote_dir = f"{GlobalPaths.MARO_LOCAL}/clusters/{self.cluster_name}/nodes/{node_name}"
local_path_to_remote_dir = {local_path: remote_dir}
for local_path, remote_dir in local_path_to_remote_dir.items():
FileSynchronizer.copy_files_to_node(
local_path=local_path,
remote_dir=remote_dir,
node_username=node_details["username"],
node_hostname=node_details["public_ip_address"],
node_ssh_port=node_details["ssh"]["port"]
node_ssh_port=node_details["ssh"]["port"],
)
# Remote join cluster
@ -682,7 +682,7 @@ class GrassAzureExecutor(GrassExecutor):
deployment_path=(
f"{GlobalPaths.MARO_LOCAL}/clusters/{self.cluster_name}/nodes/{node_name}"
f"/join_cluster_deployment.yml"
)
),
)
logger.info_green(f"Node '{node_name}' is joined")
@ -710,7 +710,7 @@ class GrassAzureExecutor(GrassExecutor):
# Check replicas
if len(startable_nodes) < replicas:
raise BadRequestError(
f"No enough '{node_size}' nodes can be started, only {len(startable_nodes)} is able to start"
f"No enough '{node_size}' nodes can be started, only {len(startable_nodes)} is able to start",
)
# Parallel start
@ -718,7 +718,7 @@ class GrassAzureExecutor(GrassExecutor):
with ThreadPool(GlobalParams.PARALLELS) as pool:
pool.starmap(
self._start_node,
params
params,
)
def _start_node(self, node_name: str):
@ -735,7 +735,7 @@ class GrassAzureExecutor(GrassExecutor):
# Start node vm
AzureController.start_vm(
resource_group=self.resource_group,
vm_name=f"{self.cluster_id}-{node_name}-vm"
vm_name=f"{self.cluster_id}-{node_name}-vm",
)
# Start node
@ -761,16 +761,16 @@ class GrassAzureExecutor(GrassExecutor):
stoppable_nodes_details = []
for node_details in nodes_details:
if (
node_details["node_size"] == node_size and
node_details["state"]["status"] == NodeStatus.RUNNING and
self._count_running_containers(node_details) == 0
node_details["node_size"] == node_size
and node_details["state"]["status"] == NodeStatus.RUNNING
and self._count_running_containers(node_details) == 0
):
stoppable_nodes_details.append(node_details)
# Check replicas
if len(stoppable_nodes_details) < replicas:
raise BadRequestError(
f"No more '{node_size}' nodes can be stopped, only {len(stoppable_nodes_details)} are stoppable"
f"No more '{node_size}' nodes can be stopped, only {len(stoppable_nodes_details)} are stoppable",
)
# Parallel stop
@ -778,7 +778,7 @@ class GrassAzureExecutor(GrassExecutor):
with ThreadPool(GlobalParams.PARALLELS) as pool:
pool.starmap(
self._stop_node,
params
params,
)
def _stop_node(self, node_details: dict):
@ -800,7 +800,7 @@ class GrassAzureExecutor(GrassExecutor):
# Stop node vm
AzureController.stop_vm(
resource_group=self.resource_group,
vm_name=f"{self.cluster_id}-{node_name}-vm"
vm_name=f"{self.cluster_id}-{node_name}-vm",
)
logger.info_green(f"Node '{node_name}' is stopped")
@ -964,7 +964,7 @@ class ArmTemplateParameterBuilder:
# Load and update parameters
with open(
file=f"{GrassPaths.ABS_MARO_GRASS_LIB}/modes/azure/create_build_node_image_vm/parameters.json",
mode="r"
mode="r",
) as fr:
base_parameters = json.load(fr)
parameters = base_parameters["parameters"]
@ -1008,7 +1008,7 @@ class ArmTemplateParameterBuilder:
parameters["adminUsername"]["value"] = cluster_details["cloud"]["default_username"]
parameters["imageResourceId"]["value"] = AzureController.get_image_resource_id(
resource_group=cluster_details["cloud"]["resource_group"],
image_name=f"{cluster_details['id']}-node-image"
image_name=f"{cluster_details['id']}-node-image",
)
parameters["location"]["value"] = cluster_details["cloud"]["location"]
parameters["networkInterfaceName"]["value"] = f"{cluster_details['id']}-{node_name}-nic"

Просмотреть файл

@ -21,7 +21,11 @@ from maro.cli.utils.name_creator import NameCreator
from maro.cli.utils.params import GlobalPaths
from maro.cli.utils.subprocess import Subprocess
from maro.utils.exception.cli_exception import (
BadRequestError, CliError, ClusterInternalError, CommandExecutionError, FileOperationError
BadRequestError,
CliError,
ClusterInternalError,
CommandExecutionError,
FileOperationError,
)
from maro.utils.logger import CliLogger
@ -56,7 +60,7 @@ class GrassExecutor:
user_id=self.user_details["id"],
master_to_dev_encryption_private_key=self.user_details["master_to_dev_encryption_private_key"],
dev_to_master_encryption_public_key=self.user_details["dev_to_master_encryption_public_key"],
dev_to_master_signing_private_key=self.user_details["dev_to_master_signing_private_key"]
dev_to_master_signing_private_key=self.user_details["dev_to_master_signing_private_key"],
)
self.master_ssh_port = self.cluster_details["master"]["ssh"]["port"]
self.master_api_server_port = self.cluster_details["master"]["api_server"]["port"]
@ -79,18 +83,18 @@ class GrassExecutor:
GrassExecutor.retry_connection(
node_username=cluster_details["master"]["username"],
node_hostname=cluster_details["master"]["public_ip_address"],
node_ssh_port=cluster_details["master"]["ssh"]["port"]
node_ssh_port=cluster_details["master"]["ssh"]["port"],
)
DetailsWriter.save_cluster_details(
cluster_name=cluster_details["name"],
cluster_details=cluster_details
cluster_details=cluster_details,
)
# Copy required files
local_path_to_remote_dir = {
GrassPaths.ABS_MARO_GRASS_LIB: f"{GlobalPaths.MARO_SHARED}/lib",
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}": f"{GlobalPaths.MARO_SHARED}/clusters"
f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}": f"{GlobalPaths.MARO_SHARED}/clusters",
}
for local_path, remote_dir in local_path_to_remote_dir.items():
FileSynchronizer.copy_files_to_node(
@ -98,7 +102,7 @@ class GrassExecutor:
remote_dir=remote_dir,
node_username=cluster_details["master"]["username"],
node_hostname=cluster_details["master"]["public_ip_address"],
node_ssh_port=cluster_details["master"]["ssh"]["port"]
node_ssh_port=cluster_details["master"]["ssh"]["port"],
)
# Remote init master
@ -106,7 +110,7 @@ class GrassExecutor:
master_username=cluster_details["master"]["username"],
master_hostname=cluster_details["master"]["public_ip_address"],
master_ssh_port=cluster_details["master"]["ssh"]["port"],
cluster_name=cluster_details["name"]
cluster_name=cluster_details["name"],
)
# Gracefully wait
time.sleep(10)
@ -129,7 +133,7 @@ class GrassExecutor:
master_hostname=cluster_details["master"]["public_ip_address"],
master_ssh_port=cluster_details["master"]["ssh"]["port"],
user_id=cluster_details["user"]["admin_id"],
user_role=UserRole.ADMIN
user_role=UserRole.ADMIN,
)
# Update user_details, "admin_id" change to "id"
@ -138,15 +142,15 @@ class GrassExecutor:
# Save dev_to_master private key
os.makedirs(
name=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}/users/{user_details['id']}",
exist_ok=True
exist_ok=True,
)
with open(
file=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{cluster_details['name']}/users/{user_details['id']}/user_details",
mode="w"
mode="w",
) as fw:
yaml.safe_dump(
data=user_details,
stream=fw
stream=fw,
)
# Save default user
@ -168,8 +172,9 @@ class GrassExecutor:
logger.info(
json.dumps(
nodes_details,
indent=4, sort_keys=True
)
indent=4,
sort_keys=True,
),
)
# maro grass image
@ -194,7 +199,7 @@ class GrassExecutor:
abs_image_path = f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/image_files/{new_file_name}"
DockerController.save_image(
image_name=image_name,
abs_export_path=abs_image_path
abs_export_path=abs_image_path,
)
else:
# Push image from local image file.
@ -204,7 +209,7 @@ class GrassExecutor:
FileSynchronizer.copy_and_rename(
source_path=image_path,
target_dir=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}/image_files",
new_name=new_file_name
new_name=new_file_name,
)
# Use md5_checksum to skip existed image file.
remote_image_file_details = self.master_api_client.get_image_file(image_file_name=new_file_name)
@ -220,13 +225,13 @@ class GrassExecutor:
remote_dir=f"{GlobalPaths.MARO_SHARED}/clusters/{self.cluster_name}/image_files",
node_username=self.master_username,
node_hostname=self.master_public_ip_address,
node_ssh_port=self.master_ssh_port
node_ssh_port=self.master_ssh_port,
)
self.master_api_client.create_image_file(
image_file_details={
"name": new_file_name,
"md5_checksum": local_md5_checksum
}
"md5_checksum": local_md5_checksum,
},
)
logger.info_green(f"Image {image_name} is loaded")
else:
@ -251,7 +256,7 @@ class GrassExecutor:
remote_dir=f"{GlobalPaths.MARO_SHARED}/clusters/{self.cluster_name}/data{remote_path}",
node_username=self.master_username,
node_hostname=self.master_public_ip_address,
node_ssh_port=self.master_ssh_port
node_ssh_port=self.master_ssh_port,
)
def pull_data(self, local_path: str, remote_path: str) -> None:
@ -271,7 +276,7 @@ class GrassExecutor:
remote_path=f"{GlobalPaths.MARO_SHARED}/clusters/{self.cluster_name}/data{remote_path}",
node_username=self.master_username,
node_hostname=self.master_public_ip_address,
node_ssh_port=self.master_ssh_port
node_ssh_port=self.master_ssh_port,
)
# maro grass job
@ -333,8 +338,9 @@ class GrassExecutor:
logger.info(
json.dumps(
jobs_details,
indent=4, sort_keys=True
)
indent=4,
sort_keys=True,
),
)
def get_job_logs(self, job_name: str, export_dir: str = "./") -> None:
@ -357,7 +363,7 @@ class GrassExecutor:
remote_path=f"{GlobalPaths.MARO_SHARED}/clusters/{self.cluster_name}/logs/{job_details['id']}",
node_username=self.master_username,
node_hostname=self.master_public_ip_address,
node_ssh_port=self.master_ssh_port
node_ssh_port=self.master_ssh_port,
)
except CommandExecutionError:
logger.error_red("No logs have been created at this time")
@ -375,14 +381,14 @@ class GrassExecutor:
"""
# Validate grass_azure_start_job
optional_key_to_value = {
"root['tags']": {}
"root['tags']": {},
}
with open(f"{GrassPaths.ABS_MARO_GRASS_LIB}/deployments/internal/grass_azure_start_job.yml") as fr:
start_job_template = yaml.safe_load(fr)
DeploymentValidator.validate_and_fill_dict(
template_dict=start_job_template,
actual_dict=start_job_deployment,
optional_key_to_value=optional_key_to_value
optional_key_to_value=optional_key_to_value,
)
# Validate component
@ -393,7 +399,7 @@ class GrassExecutor:
DeploymentValidator.validate_and_fill_dict(
template_dict=start_job_component_template,
actual_dict=component_details,
optional_key_to_value={}
optional_key_to_value={},
)
# Init runtime fields
@ -457,7 +463,7 @@ class GrassExecutor:
DeploymentValidator.validate_and_fill_dict(
template_dict=start_job_template,
actual_dict=start_schedule_deployment,
optional_key_to_value={}
optional_key_to_value={},
)
# Validate component
@ -468,7 +474,7 @@ class GrassExecutor:
DeploymentValidator.validate_and_fill_dict(
template_dict=start_job_component_template,
actual_dict=component_details,
optional_key_to_value={}
optional_key_to_value={},
)
return start_schedule_deployment
@ -497,8 +503,9 @@ class GrassExecutor:
logger.info(
json.dumps(
return_status,
indent=4, sort_keys=True
)
indent=4,
sort_keys=True,
),
)
# maro grass template
@ -582,8 +589,11 @@ class GrassExecutor:
@staticmethod
def remote_create_user(
master_username: str, master_hostname: str, master_ssh_port: int,
user_id: str, user_role: str
master_username: str,
master_hostname: str,
master_ssh_port: int,
user_id: str,
user_role: str,
) -> dict:
"""Remote create MARO User.
@ -609,8 +619,12 @@ class GrassExecutor:
@staticmethod
def remote_join_cluster(
node_username: str, node_hostname: str, node_ssh_port: int,
master_private_ip_address: str, master_api_server_port: int, deployment_path: str
node_username: str,
node_hostname: str,
node_ssh_port: int,
master_private_ip_address: str,
master_api_server_port: int,
deployment_path: str,
) -> None:
"""Remote join cluster.
@ -735,14 +749,14 @@ class GrassExecutor:
GrassExecutor.test_ssh_default_port_connection(
node_ssh_port=node_ssh_port,
node_username=node_username,
node_hostname=node_hostname
node_hostname=node_hostname,
)
return
except (CliError, TimeoutExpired):
remain_retries -= 1
logger.debug(
f"Unable to connect to {node_hostname} with port {node_ssh_port}, "
f"remains {remain_retries} retries"
f"remains {remain_retries} retries",
)
time.sleep(5)
raise ClusterInternalError(f"Unable to connect to {node_hostname} with port {node_ssh_port}")
@ -751,7 +765,7 @@ class GrassExecutor:
@staticmethod
def _get_md5_checksum(path: str, block_size=128) -> str:
""" Get md5 checksum of a local file.
"""Get md5 checksum of a local file.
Args:
path (str): path of the local file.

Просмотреть файл

@ -27,8 +27,9 @@ logger = CliLogger(name=__name__)
class GrassLocalExecutor(AbsVisibleExecutor):
def __init__(self, cluster_name: str, cluster_details: dict = None):
self.cluster_name = cluster_name
self.cluster_details = DetailsReader.load_cluster_details(cluster_name=cluster_name) \
if not cluster_details else cluster_details
self.cluster_details = (
DetailsReader.load_cluster_details(cluster_name=cluster_name) if not cluster_details else cluster_details
)
# Connection with Redis
redis_port = self.cluster_details["master"]["redis"]["port"]
@ -37,7 +38,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
self._redis_connection.ping()
except Exception:
redis_process = subprocess.Popen(
["redis-server", "--port", str(redis_port), "--daemonize yes"]
["redis-server", "--port", str(redis_port), "--daemonize yes"],
)
redis_process.wait(timeout=2)
@ -53,7 +54,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
deployment["total_request_resource"] = {
"cpu": total_cpu,
"memory": total_memory,
"gpu": total_gpu
"gpu": total_gpu,
}
deployment["status"] = JobStatus.PENDING
@ -76,7 +77,9 @@ class GrassLocalExecutor(AbsVisibleExecutor):
# Update resource
is_satisfied, updated_resource = resource_op(
available_resource, cluster_resource, ResourceOperation.ALLOCATION
available_resource,
cluster_resource,
ResourceOperation.ALLOCATION,
)
if not is_satisfied:
self._resource_redis.sub_cluster()
@ -91,13 +94,13 @@ class GrassLocalExecutor(AbsVisibleExecutor):
self._redis_connection.hset(
f"{self.cluster_name}:runtime_detail",
"available_resource",
json.dumps(cluster_resource)
json.dumps(cluster_resource),
)
# Save cluster config locally.
DetailsWriter.save_cluster_details(
cluster_name=self.cluster_name,
cluster_details=self.cluster_details
cluster_details=self.cluster_details,
)
logger.info(f"{self.cluster_name} is created.")
@ -117,7 +120,9 @@ class GrassLocalExecutor(AbsVisibleExecutor):
# Update resource
cluster_resource = self.cluster_details["master"]["resource"]
_, updated_resource = resource_op(
available_resource, cluster_resource, ResourceOperation.RELEASE
available_resource,
cluster_resource,
ResourceOperation.RELEASE,
)
self._resource_redis.set_available_resource(updated_resource)
@ -156,7 +161,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
is_satisfied, _ = resource_op(
self.cluster_details["master"]["resource"],
start_job_deployment["total_request_resource"],
ResourceOperation.ALLOCATION
ResourceOperation.ALLOCATION,
)
if not is_satisfied:
raise BadRequestError(f"No enough resource to start job {start_job_deployment['name']}.")
@ -171,13 +176,13 @@ class GrassLocalExecutor(AbsVisibleExecutor):
self._redis_connection.hset(
f"{self.cluster_name}:job_details",
job_name,
json.dumps(job_details)
json.dumps(job_details),
)
# Push job name to pending_job_tickets
self._redis_connection.lpush(
f"{self.cluster_name}:pending_job_tickets",
job_name
job_name,
)
logger.info(f"Sending {job_name} into pending job tickets.")
@ -189,7 +194,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
# push job_name into killed_job_tickets
self._redis_connection.lpush(
f"{self.cluster_name}:killed_job_tickets",
job_name
job_name,
)
logger.info(f"Sending {job_name} into killed job tickets.")
@ -231,7 +236,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
is_satisfied, _ = resource_op(
self.cluster_details["master"]["resource"],
start_schedule_deployment["total_request_resource"],
ResourceOperation.ALLOCATION
ResourceOperation.ALLOCATION,
)
if not is_satisfied:
raise BadRequestError(f"No enough resource to start schedule {schedule_name} in {self.cluster_name}.")
@ -240,7 +245,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
self._redis_connection.hset(
f"{self.cluster_name}:job_details",
schedule_name,
json.dumps(start_schedule_deployment)
json.dumps(start_schedule_deployment),
)
job_list = start_schedule_deployment["job_names"]
@ -256,7 +261,7 @@ class GrassLocalExecutor(AbsVisibleExecutor):
def stop_schedule(self, schedule_name: str):
try:
schedule_details = json.loads(
self._redis_connection.hget(f"{self.cluster_name}:job_details", schedule_name)
self._redis_connection.hget(f"{self.cluster_name}:job_details", schedule_name),
)
except Exception:
logger.error(f"No such schedule '{schedule_name}' in Redis.")
@ -281,15 +286,17 @@ class GrassLocalExecutor(AbsVisibleExecutor):
def get_job_queue(self):
pending_job_queue = self._redis_connection.lrange(
f"{self.cluster_name}:pending_job_tickets",
0, -1
0,
-1,
)
killed_job_queue = self._redis_connection.lrange(
f"{self.cluster_name}:killed_job_tickets",
0, -1
0,
-1,
)
return {
"pending_jobs": pending_job_queue,
"killed_jobs": killed_job_queue
"killed_jobs": killed_job_queue,
}
def get_resource(self):
@ -298,6 +305,6 @@ class GrassLocalExecutor(AbsVisibleExecutor):
def get_resource_usage(self, previous_length: int = 0):
available_resource = self._redis_connection.hget(
f"{self.cluster_name}:runtime_detail",
"available_resource"
"available_resource",
)
return json.loads(available_resource)

Просмотреть файл

@ -63,7 +63,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
user_id=cluster_details["user"]["id"],
master_to_dev_encryption_private_key=cluster_details["user"]["master_to_dev_encryption_private_key"],
dev_to_master_encryption_public_key=cluster_details["user"]["dev_to_master_encryption_public_key"],
dev_to_master_signing_private_key=cluster_details["user"]["dev_to_master_signing_private_key"]
dev_to_master_signing_private_key=cluster_details["user"]["dev_to_master_signing_private_key"],
)
master_api_client.create_master(master_details=cluster_details["master"])
master_api_client.create_cluster(cluster_details=cluster_details)
@ -95,20 +95,20 @@ class GrassOnPremisesExecutor(GrassExecutor):
"root['master']['fluentd']": {"port": GlobalParams.DEFAULT_FLUENTD_PORT},
"root['master']['fluentd']['port']": GlobalParams.DEFAULT_FLUENTD_PORT,
"root['master']['samba']": {
"password": samba_password
"password": samba_password,
},
"root['master']['samba']['password']": samba_password,
"root['master']['ssh']": {"port": GlobalParams.DEFAULT_SSH_PORT},
"root['master']['ssh']['port']": GlobalParams.DEFAULT_SSH_PORT,
"root['master']['api_server']": {"port": GrassParams.DEFAULT_API_SERVER_PORT},
"root['master']['api_server']['port']": GrassParams.DEFAULT_API_SERVER_PORT
"root['master']['api_server']['port']": GrassParams.DEFAULT_API_SERVER_PORT,
}
with open(f"{GrassPaths.ABS_MARO_GRASS_LIB}/deployments/internal/grass_on_premises_create.yml") as fr:
create_deployment_template = yaml.safe_load(fr)
DeploymentValidator.validate_and_fill_dict(
template_dict=create_deployment_template,
actual_dict=create_deployment,
optional_key_to_value=optional_key_to_value
optional_key_to_value=optional_key_to_value,
)
# Init runtime fields.
@ -134,13 +134,13 @@ class GrassOnPremisesExecutor(GrassExecutor):
self.remote_leave_cluster(
node_username=node_details["username"],
node_hostname=node_details["public_ip_address"],
node_ssh_port=node_details["ssh"]["port"]
node_ssh_port=node_details["ssh"]["port"],
)
self.remote_delete_master(
master_username=self.master_username,
master_hostname=self.master_public_ip_address,
master_ssh_port=self.master_ssh_port
master_ssh_port=self.master_ssh_port,
)
shutil.rmtree(path=f"{GlobalPaths.ABS_MARO_CLUSTERS}/{self.cluster_name}")
@ -177,7 +177,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
# Get standardized join_cluster_deployment
join_cluster_deployment = GrassOnPremisesExecutor._standardize_join_cluster_deployment(
join_cluster_deployment=join_cluster_deployment
join_cluster_deployment=join_cluster_deployment,
)
# Save join_cluster_deployment TODO: do checking, already join another node
@ -186,7 +186,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
# Copy required files
local_path_to_remote_dir = {
f"{GlobalPaths.ABS_MARO_LOCAL_TMP}/join_cluster_deployment.yml": GlobalPaths.MARO_LOCAL_TMP
f"{GlobalPaths.ABS_MARO_LOCAL_TMP}/join_cluster_deployment.yml": GlobalPaths.MARO_LOCAL_TMP,
}
for local_path, remote_dir in local_path_to_remote_dir.items():
FileSynchronizer.copy_files_to_node(
@ -194,7 +194,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
remote_dir=remote_dir,
node_username=join_cluster_deployment["node"]["username"],
node_hostname=join_cluster_deployment["node"]["public_ip_address"],
node_ssh_port=join_cluster_deployment["node"]["ssh"]["port"]
node_ssh_port=join_cluster_deployment["node"]["ssh"]["port"],
)
# Remote join node
@ -204,7 +204,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
node_ssh_port=join_cluster_deployment["node"]["ssh"]["port"],
master_private_ip_address=join_cluster_deployment["master"]["private_ip_address"],
master_api_server_port=join_cluster_deployment["master"]["api_server"]["port"],
deployment_path=f"{GlobalPaths.MARO_LOCAL_TMP}/join_cluster_deployment.yml"
deployment_path=f"{GlobalPaths.MARO_LOCAL_TMP}/join_cluster_deployment.yml",
)
os.remove(f"{GlobalPaths.ABS_MARO_LOCAL_TMP}/join_cluster_deployment.yml")
@ -230,7 +230,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
"root['node']['resources']": {
"cpu": "all",
"memory": "all",
"gpu": "all"
"gpu": "all",
},
"root['node']['resources']['cpu']": "all",
"root['node']['resources']['memory']": "all",
@ -241,21 +241,21 @@ class GrassOnPremisesExecutor(GrassExecutor):
"root['node']['ssh']['port']": GlobalParams.DEFAULT_SSH_PORT,
"root['configs']": {
"install_node_runtime": False,
"install_node_gpu_support": False
"install_node_gpu_support": False,
},
"root['configs']['install_node_runtime']": False,
"root['configs']['install_node_gpu_support']": False
"root['configs']['install_node_gpu_support']": False,
}
with open(
file=f"{GrassPaths.ABS_MARO_GRASS_LIB}/deployments/internal/grass_on_premises_join_cluster.yml",
mode="r"
mode="r",
) as fr:
create_deployment_template = yaml.safe_load(stream=fr)
DeploymentValidator.validate_and_fill_dict(
template_dict=create_deployment_template,
actual_dict=join_cluster_deployment,
optional_key_to_value=optional_key_to_value
optional_key_to_value=optional_key_to_value,
)
return join_cluster_deployment
@ -283,7 +283,7 @@ class GrassOnPremisesExecutor(GrassExecutor):
GrassOnPremisesExecutor.remote_leave_cluster(
node_username=leave_cluster_deployment["node"]["username"],
node_hostname=leave_cluster_deployment["node"]["hostname"],
node_ssh_port=leave_cluster_deployment["node"]["ssh"]["port"]
node_ssh_port=leave_cluster_deployment["node"]["ssh"]["port"],
)
logger.info_green("Node is left")

Просмотреть файл

@ -9,8 +9,7 @@ from maro.cli.utils.operation_lock_wrapper import operation_lock
@check_details_validity
@operation_lock
def push_image(
cluster_name: str, image_name: str, image_path: str, remote_context_path: str, remote_image_name: str,
**kwargs
cluster_name: str, image_name: str, image_path: str, remote_context_path: str, remote_image_name: str, **kwargs
):
# Late imports.
from maro.cli.grass.executors.grass_azure_executor import GrassAzureExecutor
@ -26,7 +25,7 @@ def push_image(
image_name=image_name,
image_path=image_path,
remote_context_path=remote_context_path,
remote_image_name=remote_image_name
remote_image_name=remote_image_name,
)
elif cluster_details["mode"] == "grass/on-premises":
executor = GrassOnPremisesExecutor(cluster_name=cluster_name)
@ -34,7 +33,7 @@ def push_image(
image_name=image_name,
image_path=image_path,
remote_context_path=remote_context_path,
remote_image_name=remote_image_name
remote_image_name=remote_image_name,
)
else:
raise BadRequestError(f"Unsupported operation in mode '{cluster_details['mode']}'.")

Просмотреть файл

@ -72,7 +72,7 @@ if __name__ == "__main__":
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True
universal_newlines=True,
)
while True:
next_line = process.stdout.readline()

Просмотреть файл

@ -31,7 +31,7 @@ class UserCreator:
self._local_cluster_details = local_cluster_details
self._redis_controller = RedisController(
host="localhost",
port=self._local_cluster_details["master"]["redis"]["port"]
port=self._local_cluster_details["master"]["redis"]["port"],
)
@staticmethod
@ -39,22 +39,22 @@ class UserCreator:
rsa_key = rsa.generate_private_key(
backend=default_backend(),
public_exponent=65537,
key_size=2048
key_size=2048,
)
# Format and encoding are diff from OpenSSH
private_key = rsa_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
encryption_algorithm=serialization.NoEncryption(),
)
public_key = rsa_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return {
"public_key": public_key.decode("utf-8"),
"private_key": private_key.decode("utf-8")
"private_key": private_key.decode("utf-8"),
}
def create_user(self, user_id: str, user_role: str) -> None:
@ -72,8 +72,8 @@ class UserCreator:
"master_to_dev_encryption_public_key": master_to_dev_encryption_key_pair["public_key"],
"dev_to_master_encryption_private_key": dev_to_master_encryption_key_pair["private_key"],
"dev_to_master_signing_public_key": dev_to_master_signing_key_pair["public_key"],
"master_to_dev_signing_private_key": master_to_dev_signing_key_pair["private_key"]
}
"master_to_dev_signing_private_key": master_to_dev_signing_key_pair["private_key"],
},
)
# Write private key to console.
@ -85,9 +85,9 @@ class UserCreator:
"master_to_dev_encryption_private_key": master_to_dev_encryption_key_pair["private_key"],
"dev_to_master_encryption_public_key": dev_to_master_encryption_key_pair["public_key"],
"dev_to_master_signing_private_key": dev_to_master_signing_key_pair["private_key"],
"master_to_dev_signing_public_key": master_to_dev_signing_key_pair["public_key"]
}
)
"master_to_dev_signing_public_key": master_to_dev_signing_key_pair["public_key"],
},
),
)

Просмотреть файл

@ -28,6 +28,7 @@ REMOVE_CONTAINERS = """sudo docker rm -f maro-fluentd-{cluster_id} maro-redis-{c
# Master Deleter.
class MasterDeleter:
def __init__(self, local_cluster_details: dict):
self._local_cluster_details = local_cluster_details
@ -83,7 +84,7 @@ class Subprocess:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
timeout=timeout
timeout=timeout,
)
if completed_process.returncode != 0:
raise Exception(completed_process.stderr)

Просмотреть файл

@ -106,7 +106,7 @@ class MasterInitializer:
master_redis_port=self.master_details["redis"]["port"],
cluster_name=cluster_details["name"],
master_fluentd_port=self.master_details["fluentd"]["port"],
steps=5
steps=5,
)
Subprocess.interactive_run(command=command)
@ -115,7 +115,7 @@ class MasterInitializer:
# Rewrite data in .service and write it to systemd folder
with open(
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/master_agent/maro-master-agent.service",
mode="r"
mode="r",
) as fr:
service_file = fr.read()
service_file = service_file.format(maro_shared_path=Paths.ABS_MARO_SHARED)
@ -131,13 +131,13 @@ class MasterInitializer:
# Rewrite data in .service and write it to systemd folder
with open(
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/master_api_server/maro-master-api-server.service",
mode="r"
mode="r",
) as fr:
service_file = fr.read()
service_file = service_file.format(
home_path=str(pathlib.Path.home()),
maro_shared_path=Paths.ABS_MARO_SHARED,
master_api_server_port=self.master_details["api_server"]["port"]
master_api_server_port=self.master_details["api_server"]["port"],
)
os.makedirs(name=os.path.expanduser("~/.config/systemd/user/"), exist_ok=True)
with open(file=os.path.expanduser("~/.config/systemd/user/maro-master-api-server.service"), mode="w") as fw:
@ -152,7 +152,7 @@ class MasterInitializer:
os.makedirs(name=f"{Paths.ABS_MARO_LOCAL}/scripts", exist_ok=True)
shutil.copy2(
src=f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/master/delete_master.py",
dst=f"{Paths.ABS_MARO_LOCAL}/scripts"
dst=f"{Paths.ABS_MARO_LOCAL}/scripts",
)

Просмотреть файл

@ -47,7 +47,7 @@ class RedisController:
def delete_node_details(self, node_name: str) -> None:
self._redis.hdel(
"name_to_node_details",
node_name
node_name,
)
@ -71,14 +71,14 @@ class Subprocess:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
timeout=timeout
timeout=timeout,
)
if completed_process.returncode != 0:
raise Exception(completed_process.stderr)
sys.stderr.write(completed_process.stderr)
if __name__ == '__main__':
if __name__ == "__main__":
local_cluster_details = DetailsReader.load_local_cluster_details()
local_node_details = DetailsReader.load_local_node_details()
@ -87,6 +87,6 @@ if __name__ == '__main__':
redis_controller = RedisController(
host=local_cluster_details["master"]["private_ip_address"],
port=local_cluster_details["master"]["redis"]["port"]
port=local_cluster_details["master"]["redis"]["port"],
)
redis_controller.delete_node_details(node_name=local_node_details["name"])

Просмотреть файл

@ -111,6 +111,7 @@ echo "{public_key}" >> ~/.ssh/authorized_keys
# Node Joiner.
class NodeJoiner:
def __init__(self, join_cluster_deployment: dict):
self.join_cluster_deployment = join_cluster_deployment
@ -118,7 +119,7 @@ class NodeJoiner:
redis_controller = RedisController(
host=join_cluster_deployment["master"]["private_ip_address"],
port=join_cluster_deployment["master"]["redis"]["port"]
port=join_cluster_deployment["master"]["redis"]["port"],
)
self.cluster_details = redis_controller.get_cluster_details()
@ -138,7 +139,7 @@ class NodeJoiner:
node_details["image_files"] = {}
node_details["containers"] = {}
node_details["state"] = {
"status": NodeStatus.PENDING
"status": NodeStatus.PENDING,
}
return node_details
@ -163,7 +164,7 @@ class NodeJoiner:
master_username=self.master_details["username"],
master_hostname=self.master_details["private_ip_address"],
master_samba_password=self.master_details["samba"]["password"],
maro_shared_path=Paths.ABS_MARO_SHARED
maro_shared_path=Paths.ABS_MARO_SHARED,
)
Subprocess.run(command=command)
@ -171,7 +172,7 @@ class NodeJoiner:
# Rewrite data in .service and write it to systemd folder.
with open(
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/node_agent/maro-node-agent.service",
mode="r"
mode="r",
) as fr:
service_file = fr.read()
service_file = service_file.format(maro_shared_path=Paths.ABS_MARO_SHARED)
@ -186,13 +187,13 @@ class NodeJoiner:
# Rewrite data in .service and write it to systemd folder.
with open(
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/node_api_server/maro-node-api-server.service",
mode="r"
mode="r",
) as fr:
service_file = fr.read()
service_file = service_file.format(
home_path=str(pathlib.Path.home()),
maro_shared_path=Paths.ABS_MARO_SHARED,
node_api_server_port=self.node_details["api_server"]["port"]
node_api_server_port=self.node_details["api_server"]["port"],
)
os.makedirs(os.path.expanduser("~/.config/systemd/user/"), exist_ok=True)
with open(file=os.path.expanduser("~/.config/systemd/user/maro-node-api-server.service"), mode="w") as fw:
@ -205,13 +206,13 @@ class NodeJoiner:
def copy_leave_script():
src_files = [
f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/node/leave_cluster.py",
f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/node/activate_leave_cluster.py"
f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/node/activate_leave_cluster.py",
]
os.makedirs(name=f"{Paths.ABS_MARO_LOCAL}/scripts", exist_ok=True)
for src_file in src_files:
shutil.copy2(
src=src_file,
dst=f"{Paths.ABS_MARO_LOCAL}/scripts"
dst=f"{Paths.ABS_MARO_LOCAL}/scripts",
)
def load_master_public_key(self):
@ -227,11 +228,11 @@ class NodeJoiner:
"master": {
"private_ip_address": "",
"api_server": {
"port": ""
"port": "",
},
"redis": {
"port": ""
}
"port": "",
},
},
"node": {
"hostname": "",
@ -241,19 +242,19 @@ class NodeJoiner:
"resources": {
"cpu": "",
"memory": "",
"gpu": ""
"gpu": "",
},
"ssh": {
"port": ""
"port": "",
},
"api_server": {
"port": ""
}
"port": "",
},
},
"configs": {
"install_node_runtime": "",
"install_node_gpu_support": ""
}
"install_node_gpu_support": "",
},
}
DeploymentValidator.validate_and_fill_dict(
template_dict=join_cluster_deployment_template,
@ -266,7 +267,7 @@ class NodeJoiner:
"root['node']['resources']": {
"cpu": "all",
"memory": "all",
"gpu": "all"
"gpu": "all",
},
"root['node']['resources']['cpu']": "all",
"root['node']['resources']['memory']": "all",
@ -277,11 +278,11 @@ class NodeJoiner:
"root['node']['ssh']['port']": Params.DEFAULT_SSH_PORT,
"root['configs']": {
"install_node_runtime": False,
"install_node_gpu_support": False
"install_node_gpu_support": False,
},
"root['configs']['install_node_runtime']": False,
"root['configs']['install_node_gpu_support']": False
}
"root['configs']['install_node_gpu_support']": False,
},
)
return join_cluster_deployment
@ -289,6 +290,7 @@ class NodeJoiner:
# Utils Classes.
class Params:
DEFAULT_SSH_PORT = 22
DEFAULT_REDIS_PORT = 6379
@ -340,13 +342,13 @@ class RedisController:
self._redis.hset(
"name_to_node_details",
node_details["name"],
json.dumps(node_details)
json.dumps(node_details),
)
"""Utils."""
def lock(self, name: str) -> Lock:
""" Get a new lock with redis.
"""Get a new lock with redis.
Use 'with lock(name):' paradigm to do the locking.
@ -386,7 +388,7 @@ class DeploymentValidator:
DeploymentValidator._set_value(
original_dict=actual_dict,
key_list=DeploymentValidator._get_parent_to_child_key_list(deep_diff_str=missing_key_str),
value=optional_key_to_value[missing_key_str]
value=optional_key_to_value[missing_key_str],
)
@staticmethod
@ -454,7 +456,7 @@ class Subprocess:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
timeout=timeout
timeout=timeout,
)
if completed_process.returncode != 0:
raise Exception(completed_process.stderr)
@ -477,7 +479,7 @@ class Subprocess:
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True
universal_newlines=True,
)
while True:
next_line = process.stdout.readline()
@ -523,7 +525,7 @@ if __name__ == "__main__":
join_cluster_deployment = yaml.safe_load(stream=fr)
join_cluster_deployment = NodeJoiner.standardize_join_cluster_deployment(
join_cluster_deployment=join_cluster_deployment
join_cluster_deployment=join_cluster_deployment,
)
node_joiner = NodeJoiner(join_cluster_deployment=join_cluster_deployment)
node_joiner.init_node_runtime_env()

Просмотреть файл

@ -92,7 +92,7 @@ class Subprocess:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
timeout=timeout
timeout=timeout,
)
if completed_process.returncode != 0:
raise Exception(completed_process.stderr)

Просмотреть файл

@ -23,7 +23,7 @@ if __name__ == "__main__":
# Rewrite data in .service and write it to systemd folder
with open(
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/node_agent/maro-node-agent.service",
mode="r"
mode="r",
) as fr:
service_file = fr.read()
service_file = service_file.format(maro_shared_path=Paths.ABS_MARO_SHARED)

Просмотреть файл

@ -28,13 +28,13 @@ if __name__ == "__main__":
# Rewrite data in .service and write it to systemd folder
with open(
file=f"{Paths.ABS_MARO_SHARED}/lib/grass/services/node_api_server/maro-node-api-server.service",
mode="r"
mode="r",
) as fr:
service_file = fr.read()
service_file = service_file.format(
home_path=str(pathlib.Path.home()),
maro_shared_path=Paths.ABS_MARO_SHARED,
node_api_server_port=local_node_details["api_server"]["port"]
node_api_server_port=local_node_details["api_server"]["port"],
)
os.makedirs(os.path.expanduser("~/.config/systemd/user/"), exist_ok=True)
with open(file=os.path.expanduser("~/.config/systemd/user/maro-node-api-server.service"), mode="w") as fw:

Просмотреть файл

@ -8,8 +8,7 @@ from .params import Paths
class DetailsReader:
"""Reader class for details.
"""
"""Reader class for details."""
@staticmethod
def load_cluster_details(cluster_name: str) -> dict:

Просмотреть файл

@ -10,8 +10,7 @@ from .params import Paths
class DetailsWriter:
"""Writer class for details.
"""
"""Writer class for details."""
@staticmethod
def save_local_cluster_details(cluster_details: dict) -> None:

Просмотреть файл

@ -7,8 +7,7 @@ import redis
class RedisController:
"""Controller class for Redis.
"""
"""Controller class for Redis."""
def __init__(self, host: str, port: int):
self._redis = redis.Redis(host=host, port=port, encoding="utf-8", decode_responses=True)
@ -19,5 +18,5 @@ class RedisController:
return self._redis.hset(
"id_to_user_details",
user_id,
json.dumps(obj=user_details)
json.dumps(obj=user_details),
)

Просмотреть файл

@ -7,8 +7,7 @@ import sys
class Subprocess:
"""Wrapper class of subprocess, with CliException integrated.
"""
"""Wrapper class of subprocess, with CliException integrated."""
@staticmethod
def run(command: str, timeout: int = None) -> str:
@ -29,7 +28,7 @@ class Subprocess:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
timeout=timeout
timeout=timeout,
)
if completed_process.returncode != 0:
raise Exception(completed_process.stderr)
@ -52,7 +51,7 @@ class Subprocess:
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True
universal_newlines=True,
)
while True:
next_line = process.stdout.readline()

Просмотреть файл

@ -21,13 +21,15 @@ logger = logging.getLogger(__name__)
AVAILABLE_METRICS = {
"cpu",
"memory",
"gpu"
"gpu",
}
ERROR_CODE_FOR_NOT_RESTART = 64
ERROR_CODE_FOR_STOP_JOB = 65
ERROR_CODES_FOR_NOT_RESTART_CONTAINER = {
0, ERROR_CODE_FOR_NOT_RESTART, ERROR_CODE_FOR_STOP_JOB
0,
ERROR_CODE_FOR_NOT_RESTART,
ERROR_CODE_FOR_STOP_JOB,
}
@ -44,27 +46,27 @@ class MasterAgent:
"""
job_tracking_agent = JobTrackingAgent(
local_cluster_details=self._local_cluster_details,
local_master_details=self._local_master_details
local_master_details=self._local_master_details,
)
job_tracking_agent.start()
container_tracking_agent = ContainerTrackingAgent(
local_cluster_details=self._local_cluster_details,
local_master_details=self._local_master_details
local_master_details=self._local_master_details,
)
container_tracking_agent.start()
pending_job_agent = PendingJobAgent(
local_cluster_details=self._local_cluster_details,
local_master_details=self._local_master_details
local_master_details=self._local_master_details,
)
pending_job_agent.start()
container_runtime_agent = ContainerRuntimeAgent(
local_cluster_details=self._local_cluster_details,
local_master_details=self._local_master_details
local_master_details=self._local_master_details,
)
container_runtime_agent.start()
killed_job_agent = KilledJobAgent(
local_cluster_details=self._local_cluster_details,
local_master_details=self._local_master_details
local_master_details=self._local_master_details,
)
killed_job_agent.start()
@ -81,7 +83,7 @@ class JobTrackingAgent(multiprocessing.Process):
self._redis_controller = RedisController(
host="localhost",
port=local_master_details["redis"]["port"]
port=local_master_details["redis"]["port"],
)
self._check_interval = check_interval
@ -138,7 +140,7 @@ class JobTrackingAgent(multiprocessing.Process):
job_details["status"] = job_state
self._redis_controller.set_job_details(
job_name=job_name,
job_details=job_details
job_details=job_details,
)
# Utils.
@ -170,7 +172,7 @@ class ContainerTrackingAgent(multiprocessing.Process):
self._redis_controller = RedisController(
host="localhost",
port=local_master_details["redis"]["port"]
port=local_master_details["redis"]["port"],
)
self._check_interval = check_interval
@ -221,7 +223,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
self._redis_controller = RedisController(
host="localhost",
port=local_master_details["redis"]["port"]
port=local_master_details["redis"]["port"],
)
self._check_interval = check_interval
@ -258,7 +260,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
# Remove container.
is_remove_container = self._is_remove_container(
container_details=container_details,
job_runtime_details=job_runtime_details
job_runtime_details=job_runtime_details,
)
if is_remove_container:
node_name = container_details["node_name"]
@ -272,7 +274,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
# Restart container.
if self._is_restart_container(
container_details=container_details,
job_runtime_details=job_runtime_details
job_runtime_details=job_runtime_details,
):
self._restart_container(container_name=container_name, container_details=container_details)
@ -309,7 +311,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
"""
exceed_maximum_restart_times = self._redis_controller.get_rejoin_component_restart_times(
job_id=container_details["job_id"],
component_id=container_details["component_id"]
component_id=container_details["component_id"],
) >= int(job_runtime_details.get("rejoin:max_restart_times", sys.maxsize))
return (
container_details["state"]["Status"] == "exited"
@ -346,13 +348,13 @@ class ContainerRuntimeAgent(multiprocessing.Process):
"""
# Get component_name_to_container_name.
rejoin_container_name_to_component_name = self._redis_controller.get_rejoin_container_name_to_component_name(
job_id=container_details["job_id"]
job_id=container_details["job_id"],
)
# If the mapping not exists, or the container is not in the mapping, skip the restart operation.
if (
rejoin_container_name_to_component_name is None or
container_name not in rejoin_container_name_to_component_name
rejoin_container_name_to_component_name is None
or container_name not in rejoin_container_name_to_component_name
):
logger.warning(f"Container {container_name} is not found in container_name_to_component_name mapping")
return
@ -364,24 +366,24 @@ class ContainerRuntimeAgent(multiprocessing.Process):
# Get resources and allocation plan.
free_resources = ResourceController.get_free_resources(
redis_controller=self._redis_controller,
cluster_name=self._cluster_name
cluster_name=self._cluster_name,
)
required_resources = [
ContainerResource(
container_name=ContainerController.build_container_name(
job_id=container_details["job_id"],
component_id=container_details["component_id"],
component_index=container_details["component_index"]
component_index=container_details["component_index"],
),
cpu=float(container_details["cpu"]),
memory=float(container_details["memory"].replace("m", "")),
gpu=float(container_details["gpu"])
)
gpu=float(container_details["gpu"]),
),
]
allocation_plan = ResourceController._get_single_metric_balanced_allocation_plan(
allocation_details={"metric": "cpu"},
required_resources=required_resources,
free_resources=free_resources
free_resources=free_resources,
)
# Start a new container.
@ -392,11 +394,11 @@ class ContainerRuntimeAgent(multiprocessing.Process):
container_name=container_name,
node_details=node_details,
job_details=job_details,
component_name=component_name
component_name=component_name,
)
self._redis_controller.incr_rejoin_component_restart_times(
job_id=container_details["job_id"],
component_id=container_details["component_id"]
component_id=container_details["component_id"],
)
except ResourceAllocationFailed as e:
logger.warning(f"{e}")
@ -436,13 +438,13 @@ class ContainerRuntimeAgent(multiprocessing.Process):
NodeApiClientV1.remove_container(
node_hostname=node_details["hostname"],
node_api_server_port=node_details["api_server"]["port"],
container_name=container_name
container_name=container_name,
)
else:
NodeApiClientV1.stop_container(
node_hostname=node_details["hostname"],
node_api_server_port=node_details["api_server"]["port"],
container_name=container_name
container_name=container_name,
)
def _start_container(self, container_name: str, node_details: dict, job_details: dict, component_name: str) -> None:
@ -486,7 +488,6 @@ class ContainerRuntimeAgent(multiprocessing.Process):
"command": job_details["components"][component_type]["command"],
"image_name": job_details["components"][component_type]["image"],
"volumes": [f"{maro_mount_source}:{mount_target}"],
# System related.
"container_name": container_name,
"fluentd_address": f"{self._master_hostname}:{self._master_fluentd_port}",
@ -503,7 +504,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
"COMPONENT_INDEX": component_index,
"CONTAINER_NAME": container_name,
"COMPONENT_NAME": component_name,
"PYTHONUNBUFFERED": 0
"PYTHONUNBUFFERED": 0,
},
"labels": {
"cluster_id": cluster_id,
@ -518,8 +519,8 @@ class ContainerRuntimeAgent(multiprocessing.Process):
"container_name": container_name,
"component_name": component_name,
"cpu": cpu,
"memory": memory
}
"memory": memory,
},
}
if gpu != 0:
@ -529,7 +530,7 @@ class ContainerRuntimeAgent(multiprocessing.Process):
NodeApiClientV1.create_container(
node_hostname=node_details["hostname"],
node_api_server_port=node_details["api_server"]["port"],
create_config=create_config
create_config=create_config,
)
@ -548,7 +549,7 @@ class PendingJobAgent(multiprocessing.Process):
self._redis_controller = RedisController(
host="localhost",
port=local_master_details["redis"]["port"]
port=local_master_details["redis"]["port"],
)
self._check_interval = check_interval
@ -580,7 +581,7 @@ class PendingJobAgent(multiprocessing.Process):
# Get free resources at the very beginning.
free_resources = ResourceController.get_free_resources(
redis_controller=self._redis_controller,
cluster_name=self._cluster_name
cluster_name=self._cluster_name,
)
# Iterate tickets.
@ -596,14 +597,14 @@ class PendingJobAgent(multiprocessing.Process):
allocation_plan = ResourceController.get_allocation_plan(
allocation_details=job_details["allocation"],
required_resources=required_resources,
free_resources=free_resources
free_resources=free_resources,
)
for container_name, node_name in allocation_plan.items():
node_details = self._redis_controller.get_node_details(node_name=node_name)
self._start_container(
container_name=container_name,
node_details=node_details,
job_details=job_details
job_details=job_details,
)
self._redis_controller.remove_pending_job_ticket(job_name=pending_job_name)
job_details["status"] = JobStatus.RUNNING
@ -654,7 +655,6 @@ class PendingJobAgent(multiprocessing.Process):
"command": job_details["components"][component_type]["command"],
"image_name": job_details["components"][component_type]["image"],
"volumes": [f"{maro_mount_source}:{mount_target}"],
# System related.
"container_name": container_name,
"fluentd_address": f"{self._master_hostname}:{self._master_fluentd_port}",
@ -670,7 +670,7 @@ class PendingJobAgent(multiprocessing.Process):
"COMPONENT_TYPE": component_type,
"COMPONENT_INDEX": component_index,
"CONTAINER_NAME": container_name,
"PYTHONUNBUFFERED": 0
"PYTHONUNBUFFERED": 0,
},
"labels": {
"cluster_id": cluster_id,
@ -684,8 +684,8 @@ class PendingJobAgent(multiprocessing.Process):
"component_index": component_index,
"container_name": container_name,
"cpu": cpu,
"memory": memory
}
"memory": memory,
},
}
if gpu != 0:
@ -695,7 +695,7 @@ class PendingJobAgent(multiprocessing.Process):
NodeApiClientV1.create_container(
node_hostname=node_details["hostname"],
node_api_server_port=node_details["api_server"]["port"],
create_config=create_config
create_config=create_config,
)
@ -712,7 +712,7 @@ class KilledJobAgent(multiprocessing.Process):
self._redis_controller = RedisController(
host="localhost",
port=local_master_details["redis"]["port"]
port=local_master_details["redis"]["port"],
)
self._check_interval = check_interval
@ -791,13 +791,12 @@ class KilledJobAgent(multiprocessing.Process):
NodeApiClientV1.remove_container(
node_hostname=node_details["hostname"],
node_api_server_port=node_details["api_server"]["port"],
container_name=container_name
container_name=container_name,
)
class ResourceController:
"""Controller class for computing resources in MARO Nodes.
"""
"""Controller class for computing resources in MARO Nodes."""
@staticmethod
def get_allocation_plan(allocation_details: dict, required_resources: list, free_resources: list) -> dict:
@ -815,13 +814,13 @@ class ResourceController:
return ResourceController._get_single_metric_balanced_allocation_plan(
allocation_details=allocation_details,
required_resources=required_resources,
free_resources=free_resources
free_resources=free_resources,
)
elif allocation_details["mode"] == "single-metric-compacted":
return ResourceController._get_single_metric_compacted_allocation_plan(
allocation_details=allocation_details,
required_resources=required_resources,
free_resources=free_resources
free_resources=free_resources,
)
else:
raise ResourceAllocationFailed("Invalid allocation mode.")
@ -829,7 +828,8 @@ class ResourceController:
@staticmethod
def _get_single_metric_compacted_allocation_plan(
allocation_details: dict,
required_resources: list, free_resources: list
required_resources: list,
free_resources: list,
) -> dict:
"""Get single_metric_compacted allocation plan.
@ -856,13 +856,13 @@ class ResourceController:
for required_resource in required_resources:
heapq.heappush(
required_resources_pq,
(-getattr(required_resource, metric), required_resource)
(-getattr(required_resource, metric), required_resource),
)
free_resources_pq = []
for free_resource in free_resources:
heapq.heappush(
free_resources_pq,
(getattr(free_resource, metric), free_resource)
(getattr(free_resource, metric), free_resource),
)
# Get allocation.
@ -890,23 +890,23 @@ class ResourceController:
free_resource.gpu -= required_resource.gpu
heapq.heappush(
free_resources_pq,
(getattr(free_resource, metric), free_resource)
(getattr(free_resource, metric), free_resource),
)
for not_usable_free_resource in not_usable_free_resources:
heapq.heappush(
free_resources_pq,
(getattr(not_usable_free_resource, metric), not_usable_free_resource)
(getattr(not_usable_free_resource, metric), not_usable_free_resource),
)
else:
# add previous resources back, to do printing.
for not_usable_free_resource in not_usable_free_resources:
heapq.heappush(
free_resources_pq,
(getattr(not_usable_free_resource, metric), not_usable_free_resource)
(getattr(not_usable_free_resource, metric), not_usable_free_resource),
)
heapq.heappush(
required_resources_pq,
(-getattr(required_resource, metric), required_resource)
(-getattr(required_resource, metric), required_resource),
)
logger.warning(allocation_plan)
@ -921,7 +921,8 @@ class ResourceController:
@staticmethod
def _get_single_metric_balanced_allocation_plan(
allocation_details: dict,
required_resources: list, free_resources: list
required_resources: list,
free_resources: list,
) -> dict:
"""Get single_metric_balanced allocation plan.
@ -948,13 +949,13 @@ class ResourceController:
for required_resource in required_resources:
heapq.heappush(
required_resources_pq,
(-getattr(required_resource, metric), required_resource)
(-getattr(required_resource, metric), required_resource),
)
free_resources_pq = []
for free_resource in free_resources:
heapq.heappush(
free_resources_pq,
(-getattr(free_resource, metric), free_resource)
(-getattr(free_resource, metric), free_resource),
)
# Get allocation.
@ -982,23 +983,23 @@ class ResourceController:
free_resource.gpu -= required_resource.gpu
heapq.heappush(
free_resources_pq,
(-getattr(free_resource, metric), free_resource)
(-getattr(free_resource, metric), free_resource),
)
for not_usable_free_resource in not_usable_free_resources:
heapq.heappush(
free_resources_pq,
(-getattr(not_usable_free_resource, metric), not_usable_free_resource)
(-getattr(not_usable_free_resource, metric), not_usable_free_resource),
)
else:
# add previous resources back, to do printing.
for not_usable_free_resource in not_usable_free_resources:
heapq.heappush(
free_resources_pq,
(-getattr(not_usable_free_resource, metric), not_usable_free_resource)
(-getattr(not_usable_free_resource, metric), not_usable_free_resource),
)
heapq.heappush(
required_resources_pq,
(-getattr(required_resource, metric), required_resource)
(-getattr(required_resource, metric), required_resource),
)
logger.warning(allocation_plan)
@ -1038,8 +1039,8 @@ class ResourceController:
node_name=node_name,
cpu=target_free_cpu,
memory=target_free_memory,
gpu=target_free_gpu
)
gpu=target_free_gpu,
),
)
except KeyError:
# node_details is not in stable state.
@ -1076,14 +1077,13 @@ class ResourceController:
cpu=required_cpu,
memory=required_memory,
gpu=required_gpu,
)
),
)
return resources_list
class ContainerController:
"""Controller class for container.
"""
"""Controller class for container."""
@staticmethod
def build_container_name(job_id: str, component_id: str, component_index: int) -> str:
@ -1103,8 +1103,7 @@ class ContainerController:
class JobController:
"""Controller class for MARO Job.
"""
"""Controller class for MARO Job."""
@staticmethod
def get_component_id_to_component_type(job_details: dict) -> dict:
@ -1130,11 +1129,11 @@ if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-7s | %(threadName)-10s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
datefmt="%Y-%m-%d %H:%M:%S",
)
master_agent = MasterAgent(
local_cluster_details=DetailsReader.load_local_cluster_details(),
local_master_details=DetailsReader.load_local_master_details()
local_master_details=DetailsReader.load_local_master_details(),
)
master_agent.start()

Просмотреть файл

@ -67,12 +67,12 @@ class PendingJobAgent(mp.Process):
# Allocation
cluster_resource = json.loads(
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource")
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource"),
)
is_satisfied, updated_resource = resource_op(
cluster_resource,
job_detail["total_request_resource"],
ResourceOperation.ALLOCATION
ResourceOperation.ALLOCATION,
)
if not is_satisfied:
continue
@ -82,17 +82,17 @@ class PendingJobAgent(mp.Process):
self.redis_connection.lrem(f"{self.cluster_name}:pending_job_tickets", 0, job_name)
# Update resource
cluster_resource = json.loads(
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource")
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource"),
)
is_satisfied, updated_resource = resource_op(
cluster_resource,
job_detail["total_request_resource"],
ResourceOperation.ALLOCATION
ResourceOperation.ALLOCATION,
)
self.redis_connection.hset(
f"{self.cluster_name}:runtime_detail",
"available_resource",
json.dumps(updated_resource)
json.dumps(updated_resource),
)
def _start_job(self, job_detail: dict):
@ -100,14 +100,8 @@ class PendingJobAgent(mp.Process):
for component_type, command_info in job_detail["components"].items():
for number in range(command_info["num"]):
container_name = NameCreator.create_name_with_uuid(prefix=component_type)
environment_parameters = (
f"-e CONTAINER_NAME={container_name} "
f"-e JOB_NAME={job_detail['name']} "
)
labels = (
f"-l CONTAINER_NAME={container_name} "
f"-l JOB_NAME={job_detail['name']} "
)
environment_parameters = f"-e CONTAINER_NAME={container_name} " f"-e JOB_NAME={job_detail['name']} "
labels = f"-l CONTAINER_NAME={container_name} " f"-l JOB_NAME={job_detail['name']} "
if int(command_info["resources"]["gpu"]) == 0:
component_command = START_CONTAINER_COMMAND.format(
cpu=command_info["resources"]["cpu"],
@ -117,7 +111,7 @@ class PendingJobAgent(mp.Process):
environment_parameters=environment_parameters,
labels=labels,
image_name=command_info["image"],
command=command_info["command"]
command=command_info["command"],
)
else:
component_command = START_CONTAINER_WITH_GPU_COMMAND.format(
@ -129,12 +123,15 @@ class PendingJobAgent(mp.Process):
environment_parameters=environment_parameters,
labels=labels,
image_name=command_info["image"],
command=command_info["command"]
command=command_info["command"],
)
completed_process = subprocess.run(
component_command,
shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8"
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf8",
)
if completed_process.returncode != 0:
raise ResourceAllocationFailed(completed_process.stderr)
@ -145,7 +142,7 @@ class PendingJobAgent(mp.Process):
self.redis_connection.hset(
f"{self.cluster_name}:job_details",
job_detail["name"],
json.dumps(job_detail)
json.dumps(job_detail),
)
@ -163,7 +160,7 @@ class ContainerTrackingAgent(mp.Process):
def _check_container_status(self):
running_jobs = ContainerTrackingAgent.get_running_jobs(
self.redis_connection.hgetall(f"{self.cluster_name}:job_details")
self.redis_connection.hgetall(f"{self.cluster_name}:job_details"),
)
for job_name, job_detail in running_jobs.items():
@ -174,7 +171,10 @@ class ContainerTrackingAgent(mp.Process):
command = f"docker inspect {container_name}"
completed_process = subprocess.run(
command,
shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8"
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf8",
)
return_str = completed_process.stdout
inspect_details_list = json.loads(return_str)
@ -229,7 +229,7 @@ class JobTrackingAgent(mp.Process):
def _check_job_state(self):
finished_jobs = self._get_finished_jobs(
self.redis_connection.hgetall(f"{self.cluster_name}:job_details")
self.redis_connection.hgetall(f"{self.cluster_name}:job_details"),
)
for job_name, job_detail in finished_jobs.items():
@ -254,25 +254,31 @@ class JobTrackingAgent(mp.Process):
for container_name in container_list:
command = f"docker stop {container_name}"
completed_process = subprocess.run(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8"
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf8",
)
if completed_process.returncode != 0:
raise ResourceAllocationFailed(completed_process.stderr)
def _job_clear(self, job_name: str, release_resource: dict):
cluster_resource = json.loads(
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource")
self.redis_connection.hget(f"{self.cluster_name}:runtime_detail", "available_resource"),
)
# resource release
_, updated_resource = resource_op(
cluster_resource, release_resource, ResourceOperation.RELEASE
cluster_resource,
release_resource,
ResourceOperation.RELEASE,
)
self.redis_connection.hset(
f"{self.cluster_name}:runtime_detail",
"available_resource",
json.dumps(updated_resource)
json.dumps(updated_resource),
)
@ -312,7 +318,7 @@ class MasterAgent:
self.check_interval = self.cluster_detail["master"]["agents"]["check_interval"]
self.redis_connection = redis.Redis(
host="localhost",
port=self.cluster_detail["master"]["redis"]["port"]
port=self.cluster_detail["master"]["redis"]["port"],
)
self.redis_connection.hset(f"{self.cluster_name}:runtime_detail", "agent_id", os.getpid())
@ -321,28 +327,28 @@ class MasterAgent:
pending_job_agent = PendingJobAgent(
cluster_name=self.cluster_name,
redis_connection=self.redis_connection,
check_interval=self.check_interval
check_interval=self.check_interval,
)
pending_job_agent.start()
killed_job_agent = KilledJobAgent(
cluster_name=self.cluster_name,
redis_connection=self.redis_connection,
check_interval=self.check_interval
check_interval=self.check_interval,
)
killed_job_agent.start()
job_tracking_agent = JobTrackingAgent(
cluster_name=self.cluster_name,
redis_connection=self.redis_connection,
check_interval=self.check_interval
check_interval=self.check_interval,
)
job_tracking_agent.start()
container_tracking_agent = ContainerTrackingAgent(
cluster_name=self.cluster_name,
redis_connection=self.redis_connection,
check_interval=self.check_interval
check_interval=self.check_interval,
)
container_tracking_agent.start()

Просмотреть файл

@ -6,27 +6,26 @@ import requests
class NodeApiClientV1:
"""Client class for Node API Server.
"""
"""Client class for Node API Server."""
@staticmethod
def create_container(node_hostname: str, node_api_server_port: int, create_config: dict) -> dict:
response = requests.post(
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers",
json=create_config
json=create_config,
)
return response.json()
@staticmethod
def stop_container(node_hostname: str, node_api_server_port: int, container_name: str) -> dict:
response = requests.post(
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers/{container_name}:stop"
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers/{container_name}:stop",
)
return response.json()
@staticmethod
def remove_container(node_hostname: str, node_api_server_port: int, container_name: str) -> dict:
response = requests.delete(
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers/{container_name}"
url=f"http://{node_hostname}:{node_api_server_port}/v1/containers/{container_name}",
)
return response.json()

Просмотреть файл

@ -15,6 +15,7 @@ URL_PREFIX = "/v1/cluster"
# Api functions.
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
@check_jwt_validity
def get_cluster():

Просмотреть файл

@ -15,6 +15,7 @@ URL_PREFIX = "/v1/containers"
# Api functions.
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
@check_jwt_validity
def list_containers():

Просмотреть файл

@ -24,7 +24,7 @@ def get_job_queue():
killed_job_queue = redis_controller.get_killed_job_ticket()
return {
"pending_jobs": pending_job_queue,
"killed_jobs": killed_job_queue
"killed_jobs": killed_job_queue,
}
@ -66,10 +66,10 @@ def create_job(**kwargs):
job_details = kwargs["json_dict"]
redis_controller.set_job_details(
job_name=job_details["name"],
job_details=job_details
job_details=job_details,
)
redis_controller.push_pending_job_ticket(
job_name=job_details["name"]
job_name=job_details["name"],
)
return {}
@ -118,6 +118,6 @@ def clean_jobs():
node_hostname = node_details["hostname"]
for container_name in node_details["containers"]:
requests.delete(
url=f"http://{node_hostname}:{node_details['api_server']['port']}/containers/{container_name}"
url=f"http://{node_hostname}:{node_details['api_server']['port']}/containers/{container_name}",
)
return {}

Просмотреть файл

@ -18,7 +18,7 @@ def get_init_node_script():
return send_from_directory(
directory=f"{Paths.ABS_MARO_SHARED}/lib/grass/scripts/node",
filename="join_cluster.py",
as_attachment=True
as_attachment=True,
)
except FileNotFoundError:
abort(404)

Просмотреть файл

@ -22,6 +22,7 @@ URL_PREFIX = "/v1/master"
# Api functions.
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
@check_jwt_validity
def get_master():
@ -79,7 +80,7 @@ def save_master_key(private_key: str) -> None:
fw.write(private_key)
os.chmod(
path=f"{Paths.ABS_MARO_LOCAL}/cluster/{cluster_name}/master_to_node_openssh_private_key",
mode=stat.S_IRWXU
mode=stat.S_IRWXU,
)
@ -87,20 +88,20 @@ def generate_rsa_openssh_key_pair() -> dict:
rsa_key = rsa.generate_private_key(
backend=default_backend(),
public_exponent=65537,
key_size=2048
key_size=2048,
)
# Format and encoding are diff from OpenSSH
private_key = rsa_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
encryption_algorithm=serialization.NoEncryption(),
)
public_key = rsa_key.public_key().public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH
format=serialization.PublicFormat.OpenSSH,
)
return {
"public_key": public_key.decode("utf-8"),
"private_key": private_key.decode("utf-8")
"private_key": private_key.decode("utf-8"),
}

Просмотреть файл

@ -20,6 +20,7 @@ URL_PREFIX = "/v1/nodes"
# Api functions.
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
@check_jwt_validity
def list_nodes():
@ -65,14 +66,14 @@ def create_node(**kwargs):
node_details["image_files"] = {}
node_details["containers"] = {}
node_details["state"] = {
"status": NodeStatus.PENDING
"status": NodeStatus.PENDING,
}
node_name = node_details["name"]
with redis_controller.lock(f"lock:name_to_node_details:{node_name}"):
redis_controller.set_node_details(
node_name=node_name,
node_details=node_details
node_details=node_details,
)
return node_details
@ -121,7 +122,7 @@ def start_node(node_name: str):
node_username=node_details["username"],
node_hostname=node_details["hostname"],
node_ssh_port=node_details["ssh"]["port"],
cluster_name=local_cluster_details["name"]
cluster_name=local_cluster_details["name"],
)
except ConnectionFailed:
abort(400)
@ -162,7 +163,7 @@ def stop_node(node_name: str):
node_username=node_details["username"],
node_hostname=node_details["hostname"],
node_ssh_port=node_details["ssh"]["port"],
cluster_name=local_cluster_details["name"]
cluster_name=local_cluster_details["name"],
)
except ConnectionFailed:
abort(400)

Просмотреть файл

@ -18,6 +18,7 @@ URL_PREFIX = "/v1/schedules"
# Api functions.
@blueprint.route(f"{URL_PREFIX}", methods=["GET"])
@check_jwt_validity
def list_schedules():
@ -57,14 +58,14 @@ def create_schedule(**kwargs):
redis_controller.set_schedule_details(
schedule_name=schedule_details["name"],
schedule_details=schedule_details
schedule_details=schedule_details,
)
# Build individual jobs
for job_name in schedule_details["job_names"]:
redis_controller.set_job_details(
job_name=job_name,
job_details=_build_job_details(schedule_details=schedule_details, job_name=job_name)
job_details=_build_job_details(schedule_details=schedule_details, job_name=job_name),
)
redis_controller.push_pending_job_ticket(job_name=job_name)
return {}
@ -111,7 +112,7 @@ def _build_job_details(schedule_details: dict, job_name: str) -> dict:
job_details["name"] = job_name
job_details["tags"] = {
"schedule_name": schedule_details["name"],
"schedule_id": schedule_details["id"]
"schedule_id": schedule_details["id"],
}
job_details.pop("job_names")

Просмотреть файл

@ -19,5 +19,5 @@ URL_PREFIX = "/v1/status"
def status():
return {
"status": "OK",
"time": time.time()
"time": time.time(),
}

Просмотреть файл

@ -11,6 +11,7 @@ URL_PREFIX = "/v1/visible"
# Api functions.
@blueprint.route(f"{URL_PREFIX}/static", methods=["GET"])
@check_jwt_validity
def get_static_resource():
@ -34,6 +35,6 @@ def get_dynamic_resource(previous_length: str):
"""
name_to_node_usage = redis_controller.get_resource_usage(
previous_length=int(previous_length)
previous_length=int(previous_length),
)
return name_to_node_usage

Просмотреть файл

@ -42,11 +42,11 @@ def check_jwt_validity(func):
user_details = redis_controller.get_user_details(user_id=payload["user_id"])
# Get decrypted_bytes
if request.data != b'':
if request.data != b"":
decrypted_bytes = _get_decrypted_bytes(
payload=payload,
encrypted_bytes=request.data,
user_details=user_details
user_details=user_details,
)
kwargs["json_dict"] = json.loads(decrypted_bytes.decode("utf-8"))
@ -69,7 +69,7 @@ def check_jwt_validity(func):
def _get_encrypted_bytes(json_dict: dict, aes_key: bytes, aes_ctr_nonce: bytes) -> bytes:
cipher = Cipher(
algorithm=algorithms.AES(key=aes_key),
mode=modes.CTR(nonce=aes_ctr_nonce)
mode=modes.CTR(nonce=aes_ctr_nonce),
)
encryptor = cipher.encryptor()
return_bytes = encryptor.update(json.dumps(json_dict).encode("utf-8")) + encryptor.finalize()
@ -80,21 +80,21 @@ def _get_decrypted_bytes(payload: dict, encrypted_bytes: bytes, user_details: di
# Decrypted aes_key and aes_ctr_nonce
dev_to_master_encryption_private_key_obj = serialization.load_pem_private_key(
data=user_details["dev_to_master_encryption_private_key"].encode("utf-8"),
password=None
password=None,
)
aes_key = dev_to_master_encryption_private_key_obj.decrypt(
ciphertext=base64.b64decode(payload["aes_key"].encode("ascii")),
padding=_get_asymmetric_padding()
padding=_get_asymmetric_padding(),
)
aes_ctr_nonce = dev_to_master_encryption_private_key_obj.decrypt(
ciphertext=base64.b64decode(payload["aes_ctr_nonce"].encode("ascii")),
padding=_get_asymmetric_padding()
padding=_get_asymmetric_padding(),
)
# Return decrypted_bytes
cipher = Cipher(
algorithm=algorithms.AES(key=aes_key),
mode=modes.CTR(nonce=aes_ctr_nonce)
mode=modes.CTR(nonce=aes_ctr_nonce),
)
decryptor = cipher.decryptor()
return decryptor.update(encrypted_bytes) + decryptor.finalize()
@ -109,12 +109,12 @@ def build_response(return_json: dict, user_details: dict) -> Response:
encrypted_bytes = _get_encrypted_bytes(
json_dict=return_json,
aes_key=aes_key,
aes_ctr_nonce=aes_ctr_nonce
aes_ctr_nonce=aes_ctr_nonce,
)
# Encrypt aes_key and aes_ctr_nonce with rsa_key_pair
master_to_dev_encryption_public_key_obj = serialization.load_pem_public_key(
data=user_details["master_to_dev_encryption_public_key"].encode("utf-8")
data=user_details["master_to_dev_encryption_public_key"].encode("utf-8"),
)
# Build jwt_token
@ -123,18 +123,18 @@ def build_response(return_json: dict, user_details: dict) -> Response:
"aes_key": base64.b64encode(
master_to_dev_encryption_public_key_obj.encrypt(
plaintext=aes_key,
padding=_get_asymmetric_padding()
)
padding=_get_asymmetric_padding(),
),
).decode("ascii"),
"aes_ctr_nonce": base64.b64encode(
master_to_dev_encryption_public_key_obj.encrypt(
plaintext=aes_ctr_nonce,
padding=_get_asymmetric_padding()
)
).decode("ascii")
padding=_get_asymmetric_padding(),
),
).decode("ascii"),
},
key=user_details["master_to_dev_signing_private_key"],
algorithm="RS512"
algorithm="RS512",
)
# Build response
@ -147,5 +147,5 @@ def _get_asymmetric_padding():
return padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
label=None,
)

Просмотреть файл

@ -33,24 +33,24 @@ class NodeAgent:
self._redis_controller = RedisController(
host=self._local_master_details["hostname"],
port=self._local_master_details["redis"]["port"]
port=self._local_master_details["redis"]["port"],
)
# Init agents.
self.load_image_agent = LoadImageAgent(
local_cluster_details=local_cluster_details,
local_master_details=local_master_details,
local_node_details=local_node_details
local_node_details=local_node_details,
)
self.node_tracking_agent = NodeTrackingAgent(
local_cluster_details=local_cluster_details,
local_master_details=local_master_details,
local_node_details=local_node_details
local_node_details=local_node_details,
)
self.resource_tracking_agent = ResourceTrackingAgent(
local_cluster_details=local_cluster_details,
local_master_details=local_master_details,
local_node_details=local_node_details
local_node_details=local_node_details,
)
# When SIGTERM, gracefully exit.
@ -70,7 +70,7 @@ class NodeAgent:
self.resource_tracking_agent.join()
def gracefully_exit(self, signum, frame) -> None:
""" Gracefully exit when SIGTERM.
"""Gracefully exit when SIGTERM.
If we get SIGKILL here, it means that the node is not stopped properly,
the status of the node remains 'RUNNING'.
@ -87,7 +87,7 @@ class NodeAgent:
# Set STOPPED state
state_details = {
"status": NodeStatus.STOPPED,
"check_time": self._redis_controller.get_time()
"check_time": self._redis_controller.get_time(),
}
with self._redis_controller.lock(f"lock:name_to_node_details:{self._local_node_details['name']}"):
node_details = self._redis_controller.get_node_details(node_name=self._local_node_details["name"])
@ -96,7 +96,7 @@ class NodeAgent:
node_details["state"] = state_details
self._redis_controller.set_node_details(
node_name=self._local_node_details["name"],
node_details=node_details
node_details=node_details,
)
sys.exit(0)
@ -116,7 +116,7 @@ class NodeAgent:
if not isinstance(resources_details["memory"], (float, int)):
if resources_details["memory"] != "all":
logger.warning("Invalid memory assignment, will use all memories in this node")
resources_details["memory"] = psutil.virtual_memory().total / (1024 ** 2) # (float) in MByte
resources_details["memory"] = psutil.virtual_memory().total / (1024**2) # (float) in MByte
if not isinstance(resources_details["gpu"], (float, int)):
if resources_details["gpu"] != "all":
@ -140,7 +140,7 @@ class NodeAgent:
node_details["resources"] = resources_details
self._redis_controller.set_node_details(
node_name=self._local_node_details["name"],
node_details=node_details
node_details=node_details,
)
@ -152,8 +152,10 @@ class NodeTrackingAgent(multiprocessing.Process):
def __init__(
self,
local_cluster_details: dict, local_master_details: dict, local_node_details: dict,
check_interval: int = 5
local_cluster_details: dict,
local_master_details: dict,
local_node_details: dict,
check_interval: int = 5,
):
super().__init__()
self._local_cluster_details = local_cluster_details
@ -162,7 +164,7 @@ class NodeTrackingAgent(multiprocessing.Process):
self._redis_controller = RedisController(
host=self._local_master_details["hostname"],
port=self._local_master_details["redis"]["port"]
port=self._local_master_details["redis"]["port"],
)
self._check_interval = check_interval
@ -201,20 +203,20 @@ class NodeTrackingAgent(multiprocessing.Process):
container_name_to_inspect_details = self._get_container_name_to_inspect_details()
self._update_name_to_container_details(
container_name_to_inspect_details=container_name_to_inspect_details,
name_to_container_details=name_to_container_details
name_to_container_details=name_to_container_details,
)
# Resources related
resources_details = node_details["resources"]
self._update_occupied_resources(
container_name_to_inspect_details=container_name_to_inspect_details,
resources_details=resources_details
resources_details=resources_details,
)
# State related.
state_details = {
"status": NodeStatus.RUNNING,
"check_time": self._redis_controller.get_time()
"check_time": self._redis_controller.get_time(),
}
# Save details.
@ -225,13 +227,13 @@ class NodeTrackingAgent(multiprocessing.Process):
node_details["state"] = state_details
self._redis_controller.set_node_details(
node_name=self._local_node_details["name"],
node_details=node_details
node_details=node_details,
)
@staticmethod
def _update_name_to_container_details(
container_name_to_inspect_details: dict,
name_to_container_details: dict
name_to_container_details: dict,
) -> None:
"""Update name_to_container_details from container_name_to_inspect_details.
@ -246,10 +248,10 @@ class NodeTrackingAgent(multiprocessing.Process):
for container_name, inspect_details in container_name_to_inspect_details.items():
# Extract container state and labels.
name_to_container_details[container_name] = NodeTrackingAgent._extract_labels(
inspect_details=inspect_details
inspect_details=inspect_details,
)
name_to_container_details[container_name]["state"] = NodeTrackingAgent._extract_state(
inspect_details=inspect_details
inspect_details=inspect_details,
)
@staticmethod
@ -333,8 +335,10 @@ class LoadImageAgent(multiprocessing.Process):
def __init__(
self,
local_cluster_details: dict, local_master_details: dict, local_node_details: dict,
check_interval: int = 10
local_cluster_details: dict,
local_master_details: dict,
local_node_details: dict,
check_interval: int = 10,
):
super().__init__()
self._local_cluster_details = local_cluster_details
@ -343,7 +347,7 @@ class LoadImageAgent(multiprocessing.Process):
self._redis_controller = RedisController(
host=self._local_master_details["hostname"],
port=self._local_master_details["redis"]["port"]
port=self._local_master_details["redis"]["port"],
)
self._check_interval = check_interval
@ -382,21 +386,25 @@ class LoadImageAgent(multiprocessing.Process):
for image_file_name, image_file_details in name_to_image_file_details_in_master.items():
if (
image_file_name not in name_to_image_file_details_in_node
or name_to_image_file_details_in_node[image_file_name]["md5_checksum"] !=
image_file_details["md5_checksum"]
or name_to_image_file_details_in_node[image_file_name]["md5_checksum"]
!= image_file_details["md5_checksum"]
):
unloaded_image_names.append(image_file_name)
# Parallel load
with ThreadPool(5) as pool:
params = [
[os.path.expanduser(
f"~/.maro-shared/clusters/{self._local_cluster_details['name']}/image_files/{unloaded_image_name}")]
[
os.path.expanduser(
f"~/.maro-shared/clusters/{self._local_cluster_details['name']}/"
+ "image_files/{unloaded_image_name}",
),
]
for unloaded_image_name in unloaded_image_names
]
pool.starmap(
self._load_image,
params
params,
)
with self._redis_controller.lock(f"lock:name_to_node_details:{self._local_node_details['name']}"):
@ -405,7 +413,7 @@ class LoadImageAgent(multiprocessing.Process):
node_details["image_files"] = name_to_image_file_details_in_master
self._redis_controller.set_node_details(
node_name=self._local_node_details["name"],
node_details=node_details
node_details=node_details,
)
@staticmethod
@ -421,7 +429,7 @@ class ResourceTrackingAgent(multiprocessing.Process):
local_cluster_details: dict,
local_master_details: dict,
local_node_details: dict,
check_interval: int = 30
check_interval: int = 30,
):
super().__init__()
self._local_cluster_details = local_cluster_details
@ -430,7 +438,7 @@ class ResourceTrackingAgent(multiprocessing.Process):
self._redis_controller = RedisController(
host=self._local_master_details["hostname"],
port=self._local_master_details["redis"]["port"]
port=self._local_master_details["redis"]["port"],
)
self._check_interval = check_interval
@ -468,7 +476,7 @@ class ResourceTrackingAgent(multiprocessing.Process):
node_name=self._local_node_details["name"],
cpu_usage=cpu_usage_per_core,
memory_usage=memory_usage,
gpu_memory_usage=gpu_memory_usage
gpu_memory_usage=gpu_memory_usage,
)
@ -476,12 +484,12 @@ if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-7s | %(threadName)-10s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
datefmt="%Y-%m-%d %H:%M:%S",
)
node_agent = NodeAgent(
local_cluster_details=DetailsReader.load_local_cluster_details(),
local_master_details=DetailsReader.load_local_master_details(),
local_node_details=DetailsReader.load_local_node_details()
local_node_details=DetailsReader.load_local_node_details(),
)
node_agent.start()

Просмотреть файл

@ -15,6 +15,7 @@ URL_PREFIX = "/v1/containers"
# Api functions.
@blueprint.route(f"{URL_PREFIX}", methods=["POST"])
def create_container():
"""Create a container, aka 'docker run'.

Просмотреть файл

@ -16,5 +16,5 @@ URL_PREFIX = "/v1/status"
def status():
return {
"status": "OK",
"time": time.time()
"time": time.time(),
}

Просмотреть файл

@ -14,8 +14,7 @@ logger = logging.getLogger(__name__)
class ConnectionTester:
"""Tester class for connection.
"""
"""Tester class for connection."""
@staticmethod
def test_ssh_default_port_connection(node_username: str, node_hostname: str, node_ssh_port: int, cluster_name: str):
@ -36,14 +35,14 @@ class ConnectionTester:
node_ssh_port=node_ssh_port,
node_username=node_username,
node_hostname=node_hostname,
cluster_name=cluster_name
cluster_name=cluster_name,
)
return True
except (CommandExecutionError, TimeoutExpired):
remain_retries -= 1
logger.debug(
f"Unable to connect to {node_hostname} with port {node_ssh_port}, "
f"remains {remain_retries} retries"
f"remains {remain_retries} retries",
)
time.sleep(5)
raise ConnectionFailed(f"Unable to connect to {node_hostname} with port {node_ssh_port}")

Просмотреть файл

@ -8,8 +8,7 @@ from .params import Paths
class DetailsReader:
"""Reader class for details.
"""
"""Reader class for details."""
@staticmethod
def load_cluster_details(cluster_name: str) -> dict:

Просмотреть файл

@ -8,8 +8,7 @@ from .subprocess import Subprocess
class DockerController:
"""Controller class for docker.
"""
"""Controller class for docker."""
@staticmethod
def remove_container(container_name: str) -> None:
@ -69,13 +68,12 @@ class DockerController:
command=create_config["command"],
image_name=create_config["image_name"],
volumes=DockerController._build_list_params_str(params=create_config["volumes"], option="-v"),
# System related.
container_name=create_config["container_name"],
fluentd_address=create_config["fluentd_address"],
fluentd_tag=create_config["fluentd_tag"],
environments=DockerController._build_dict_params_str(params=create_config["environments"], option="-e"),
labels=DockerController._build_dict_params_str(params=create_config["labels"], option="-l")
labels=DockerController._build_dict_params_str(params=create_config["labels"], option="-l"),
)
# Start creating
@ -86,7 +84,7 @@ class DockerController:
@staticmethod
def list_container_names() -> list:
command = "sudo docker ps -a --format \"{{.Names}}\""
command = 'sudo docker ps -a --format "{{.Names}}"'
return_str = Subprocess.run(command=command)
if return_str == "":
return []

Просмотреть файл

@ -4,41 +4,36 @@
# First Layer.
class AgentError(Exception):
""" Base error class for all MARO Grass Agents."""
pass
"""Base error class for all MARO Grass Agents."""
# Second Layer.
class UserFault(AgentError):
""" Users should be responsible for the errors."""
pass
"""Users should be responsible for the errors."""
class ServiceError(AgentError):
""" MARO Services should be responsible for the errors."""
pass
"""MARO Services should be responsible for the errors."""
# Third Layer.
class ResourceAllocationFailed(UserFault):
""" Resources are insufficient, unable to allocate."""
pass
"""Resources are insufficient, unable to allocate."""
class StartContainerError(ServiceError):
""" Error when starting containers."""
pass
"""Error when starting containers."""
class CommandExecutionError(ServiceError):
""" Failed to execute shell commands."""
pass
"""Failed to execute shell commands."""
class ConnectionFailed(ServiceError):
""" Failed to connect to other nodes."""
pass
"""Failed to connect to other nodes."""

Просмотреть файл

@ -6,8 +6,7 @@ import uuid
class NameCreator:
"""Creator class for MARO Resource namings.
"""
"""Creator class for MARO Resource namings."""
@staticmethod
def create_name_with_uuid(prefix: str, uuid_len: int = 16) -> str:

Просмотреть файл

@ -9,8 +9,7 @@ from redis.lock import Lock
class RedisController:
"""Controller class for Redis.
"""
"""Controller class for Redis."""
def __init__(self, host: str, port: int):
self._redis = redis.Redis(host=host, port=port, encoding="utf-8", decode_responses=True)
@ -23,20 +22,20 @@ class RedisController:
def set_cluster_details(self, cluster_details: dict):
self._redis.set(
"cluster_details",
json.dumps(cluster_details)
json.dumps(cluster_details),
)
"""Master Details Related."""
def get_master_details(self) -> dict:
return json.loads(
self._redis.get("master_details")
self._redis.get("master_details"),
)
def set_master_details(self, master_details: dict) -> None:
self._redis.set(
"master_details",
json.dumps(master_details)
json.dumps(master_details),
)
def delete_master_details(self) -> None:
@ -46,7 +45,7 @@ class RedisController:
def get_name_to_node_details(self) -> dict:
name_to_node_details = self._redis.hgetall(
"name_to_node_details"
"name_to_node_details",
)
for node_name, node_details_str in name_to_node_details.items():
name_to_node_details[node_name] = json.loads(node_details_str)
@ -54,7 +53,7 @@ class RedisController:
def get_name_to_node_resources(self) -> dict:
name_to_node_details = self._redis.hgetall(
"name_to_node_details"
"name_to_node_details",
)
for node_name, node_details_str in name_to_node_details.items():
node_details = json.loads(node_details_str)
@ -64,7 +63,7 @@ class RedisController:
def get_node_details(self, node_name: str) -> dict:
node_details = self._redis.hget(
"name_to_node_details",
node_name
node_name,
)
if node_details is None:
return {}
@ -75,13 +74,13 @@ class RedisController:
self._redis.hset(
"name_to_node_details",
node_name,
json.dumps(node_details)
json.dumps(node_details),
)
def delete_node_details(self, node_name: str) -> None:
self._redis.hdel(
"name_to_node_details",
node_name
node_name,
)
def push_resource_usage(
@ -89,29 +88,29 @@ class RedisController:
node_name: str,
cpu_usage: list,
memory_usage: float,
gpu_memory_usage: list
gpu_memory_usage: list,
):
# Push cpu usage to redis
self._redis.rpush(
f"{node_name}:cpu_usage_per_core",
json.dumps(cpu_usage)
json.dumps(cpu_usage),
)
# Push memory usage to redis
self._redis.rpush(
f"{node_name}:memory_usage",
json.dumps(memory_usage)
json.dumps(memory_usage),
)
# Push gpu memory usage to redis
self._redis.rpush(
f"{node_name}:gpu_memory_usage",
json.dumps(gpu_memory_usage)
json.dumps(gpu_memory_usage),
)
def get_resource_usage(self, previous_length: int):
name_to_node_details = self._redis.hgetall(
"name_to_node_details"
"name_to_node_details",
)
node_name_list = name_to_node_details.keys()
name_to_node_usage = {}
@ -120,15 +119,18 @@ class RedisController:
usage_dict = {}
usage_dict["cpu"] = self._redis.lrange(
f"{node_name}:cpu_usage_per_core",
previous_length, -1
previous_length,
-1,
)
usage_dict["memory"] = self._redis.lrange(
f"{node_name}:memory_usage",
previous_length, -1
previous_length,
-1,
)
usage_dict["gpu"] = self._redis.lrange(
f"{node_name}:gpu_memory_usage",
previous_length, -1
previous_length,
-1,
)
name_to_node_usage[node_name] = usage_dict
@ -147,7 +149,7 @@ class RedisController:
def get_job_details(self, job_name: str) -> dict:
return_str = self._redis.hget(
"name_to_job_details",
job_name
job_name,
)
return json.loads(return_str) if return_str is not None else {}
@ -155,13 +157,13 @@ class RedisController:
self._redis.hset(
"name_to_job_details",
job_name,
json.dumps(job_details)
json.dumps(job_details),
)
def delete_job_details(self, job_name: str) -> None:
self._redis.hdel(
"name_to_job_details",
job_name
job_name,
)
"""Schedule Details Related."""
@ -177,7 +179,7 @@ class RedisController:
def get_schedule_details(self, schedule_name: str) -> dict:
return_str = self._redis.hget(
"name_to_schedule_details",
schedule_name
schedule_name,
)
return json.loads(return_str) if return_str is not None else {}
@ -185,13 +187,13 @@ class RedisController:
self._redis.hset(
"name_to_schedule_details",
schedule_name,
json.dumps(schedule_details)
json.dumps(schedule_details),
)
def delete_schedule_details(self, schedule_name: str) -> None:
self._redis.hdel(
"name_to_schedule_details",
schedule_name
schedule_name,
)
"""Container Details Related."""
@ -213,14 +215,14 @@ class RedisController:
name_to_container_details[container_name] = json.dumps(container_details)
self._redis.hmset(
"name_to_container_details",
name_to_container_details
name_to_container_details,
)
def set_container_details(self, container_name: str, container_details: dict) -> None:
self._redis.hset(
"name_to_container_details",
container_name,
container_details
container_details,
)
"""Pending Job Tickets Related."""
@ -229,20 +231,20 @@ class RedisController:
return self._redis.lrange(
"pending_job_tickets",
0,
-1
-1,
)
def push_pending_job_ticket(self, job_name: str):
self._redis.rpush(
"pending_job_tickets",
job_name
job_name,
)
def remove_pending_job_ticket(self, job_name: str):
self._redis.lrem(
"pending_job_tickets",
0,
job_name
job_name,
)
def delete_pending_jobs_queue(self):
@ -254,20 +256,20 @@ class RedisController:
return self._redis.lrange(
"killed_job_tickets",
0,
-1
-1,
)
def push_killed_job_ticket(self, job_name: str):
self._redis.rpush(
"killed_job_tickets",
job_name
job_name,
)
def remove_killed_job_ticket(self, job_name: str):
self._redis.lrem(
"killed_job_tickets",
0,
job_name
job_name,
)
def delete_killed_jobs_queue(self):
@ -277,7 +279,7 @@ class RedisController:
def get_rejoin_component_name_to_container_name(self, job_id: str) -> dict:
return self._redis.hgetall(
f"job:{job_id}:rejoin_component_name_to_container_name"
f"job:{job_id}:rejoin_component_name_to_container_name",
)
def get_rejoin_container_name_to_component_name(self, job_id: str) -> dict:
@ -286,18 +288,18 @@ class RedisController:
def delete_rejoin_container_name_to_component_name(self, job_id: str) -> None:
self._redis.delete(
f"job:{job_id}:rejoin_component_name_to_container_name"
f"job:{job_id}:rejoin_component_name_to_container_name",
)
def get_job_runtime_details(self, job_id: str) -> dict:
return self._redis.hgetall(
f"job:{job_id}:runtime_details"
f"job:{job_id}:runtime_details",
)
def get_rejoin_component_restart_times(self, job_id: str, component_id: str) -> int:
restart_times = self._redis.hget(
f"job:{job_id}:component_id_to_restart_times",
component_id
component_id,
)
return 0 if restart_times is None else int(restart_times)
@ -305,7 +307,7 @@ class RedisController:
self._redis.hincrby(
f"job:{job_id}:component_id_to_restart_times",
component_id,
1
1,
)
"""User Related."""
@ -313,14 +315,14 @@ class RedisController:
def get_user_details(self, user_id: str) -> dict:
return_str = self._redis.hget(
"id_to_user_details",
user_id
user_id,
)
return json.loads(return_str) if return_str is not None else {}
"""Utils."""
def get_time(self) -> int:
""" Get current unix timestamp (seconds) from Redis server.
"""Get current unix timestamp (seconds) from Redis server.
Returns:
int: current timestamp.
@ -328,7 +330,7 @@ class RedisController:
return self._redis.time()[0]
def lock(self, name: str) -> Lock:
""" Get a new lock with redis.
"""Get a new lock with redis.
Use 'with lock(name):' paradigm to do the locking.

Просмотреть файл

@ -3,8 +3,7 @@
class BasicResource:
"""An abstraction class for computing resources.
"""
"""An abstraction class for computing resources."""
def __init__(self, cpu: float, memory: float, gpu: float):
self.cpu = cpu

Просмотреть файл

@ -8,8 +8,7 @@ from .exception import CommandExecutionError
class Subprocess:
"""Wrapper class of subprocess
"""
"""Wrapper class of subprocess"""
@staticmethod
def run(command: str, timeout: int = None) -> str:
@ -30,7 +29,7 @@ class Subprocess:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
timeout=timeout
timeout=timeout,
)
if completed_process.returncode != 0:
raise CommandExecutionError(completed_process.stderr)

Просмотреть файл

@ -7,5 +7,5 @@ def template(export_path: str, **kwargs):
from maro.cli.grass.executors.grass_executor import GrassExecutor
GrassExecutor.template(
export_path=export_path
export_path=export_path,
)

Просмотреть файл

@ -8,8 +8,7 @@ from maro.cli.utils.subprocess import Subprocess
class DockerController:
"""Controller class for docker.
"""
"""Controller class for docker."""
@staticmethod
def save_image(image_name: str, abs_export_path: str):

Просмотреть файл

@ -15,24 +15,23 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
class EncryptedRequests:
"""Wrapper class for requests with encryption/decryption integrated.
"""
"""Wrapper class for requests with encryption/decryption integrated."""
def __init__(
self,
user_id: str,
master_to_dev_encryption_private_key: str,
dev_to_master_encryption_public_key: str,
dev_to_master_signing_private_key: str
dev_to_master_signing_private_key: str,
):
self._user_id = user_id
self._dev_to_master_signing_private_key = dev_to_master_signing_private_key
self._master_to_dev_encryption_private_key_obj = serialization.load_pem_private_key(
master_to_dev_encryption_private_key.encode("utf-8"),
password=None
password=None,
)
self._dev_to_master_encryption_public_key_obj = serialization.load_pem_public_key(
dev_to_master_encryption_public_key.encode("utf-8")
dev_to_master_encryption_public_key.encode("utf-8"),
)
def get(self, url: str):
@ -45,8 +44,8 @@ class EncryptedRequests:
url=url,
headers=self._get_new_headers(
aes_key=aes_key,
aes_ctr_nonce=aes_ctr_nonce
)
aes_ctr_nonce=aes_ctr_nonce,
),
)
# Parse response
@ -63,13 +62,15 @@ class EncryptedRequests:
url=url,
headers=self._get_new_headers(
aes_key=aes_key,
aes_ctr_nonce=aes_ctr_nonce
aes_ctr_nonce=aes_ctr_nonce,
),
data=None if json_dict is None else self._get_encrypted_bytes(
data=None
if json_dict is None
else self._get_encrypted_bytes(
json_dict=json_dict,
aes_key=aes_key,
aes_ctr_nonce=aes_ctr_nonce
)
aes_ctr_nonce=aes_ctr_nonce,
),
)
# Parse response
@ -86,13 +87,15 @@ class EncryptedRequests:
url=url,
headers=self._get_new_headers(
aes_key=aes_key,
aes_ctr_nonce=aes_ctr_nonce
aes_ctr_nonce=aes_ctr_nonce,
),
data=None if json_dict is None else self._get_encrypted_bytes(
data=None
if json_dict is None
else self._get_encrypted_bytes(
json_dict=json_dict,
aes_key=aes_key,
aes_ctr_nonce=aes_ctr_nonce
)
aes_ctr_nonce=aes_ctr_nonce,
),
)
# Parse response
@ -105,7 +108,7 @@ class EncryptedRequests:
def _get_encrypted_bytes(json_dict: dict, aes_key: bytes, aes_ctr_nonce: bytes) -> bytes:
cipher = Cipher(
algorithm=algorithms.AES(key=aes_key),
mode=modes.CTR(nonce=aes_ctr_nonce)
mode=modes.CTR(nonce=aes_ctr_nonce),
)
encryptor = cipher.encryptor()
return_bytes = encryptor.update(json.dumps(json_dict).encode("utf-8")) + encryptor.finalize()
@ -121,17 +124,17 @@ class EncryptedRequests:
payload = jwt.decode(jwt=jwt_token, options={"verify_signature": False})
aes_key = self._master_to_dev_encryption_private_key_obj.decrypt(
ciphertext=base64.b64decode(payload["aes_key"].encode("ascii")),
padding=self._get_asymmetric_padding()
padding=self._get_asymmetric_padding(),
)
aes_ctr_nonce = self._master_to_dev_encryption_private_key_obj.decrypt(
ciphertext=base64.b64decode(payload["aes_ctr_nonce"].encode("ascii")),
padding=self._get_asymmetric_padding()
padding=self._get_asymmetric_padding(),
)
# Return decrypted_bytes
cipher = Cipher(
algorithm=algorithms.AES(key=aes_key),
mode=modes.CTR(nonce=aes_ctr_nonce)
mode=modes.CTR(nonce=aes_ctr_nonce),
)
decryptor = cipher.decryptor()
return decryptor.update(encrypted_bytes) + decryptor.finalize()
@ -145,22 +148,22 @@ class EncryptedRequests:
"aes_key": base64.b64encode(
self._dev_to_master_encryption_public_key_obj.encrypt(
plaintext=aes_key,
padding=self._get_asymmetric_padding()
)
padding=self._get_asymmetric_padding(),
),
).decode("ascii"),
"aes_ctr_nonce": base64.b64encode(
self._dev_to_master_encryption_public_key_obj.encrypt(
plaintext=aes_ctr_nonce,
padding=self._get_asymmetric_padding()
)
).decode("ascii")
padding=self._get_asymmetric_padding(),
),
).decode("ascii"),
},
key=self._dev_to_master_signing_private_key,
algorithm="RS512"
algorithm="RS512",
)
return {
"Authorization": f"Bearer {jwt_token}"
"Authorization": f"Bearer {jwt_token}",
}
@staticmethod
@ -168,5 +171,5 @@ class EncryptedRequests:
return padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
label=None,
)

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше