Add local file config in isort & Add missing copyright info (#547)

* Add local file config; add missing copyright info.

* Add empty line at the end of pyproject.toml

* Reward calculating bug fix
This commit is contained in:
Huoran Li 2022-06-14 14:47:47 +08:00 коммит произвёл GitHub
Родитель 7e3c1d5893
Коммит 8c0ad5a13d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
71 изменённых файлов: 169 добавлений и 37 удалений

3
.github/linters/pyproject.toml поставляемый
Просмотреть файл

@ -4,4 +4,5 @@ line-length = 120
[tool.isort]
profile = "black"
line_length = 120
known_first_party = "maro"
known_first_party = ["maro"]
known_local_folder = ["examples", "scripts", "tests"]

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from maro.rl.policy import DiscretePolicyGradient

Просмотреть файл

@ -1,9 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Any, Callable, Dict, Optional
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
from examples.cim.rl.env_sampler import CIMEnvSampler
from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
@ -13,6 +13,8 @@ from .algorithms.ac import get_ac, get_ac_policy
from .algorithms.dqn import get_dqn, get_dqn_policy
from .algorithms.maddpg import get_maddpg, get_maddpg_policy
from .algorithms.ppo import get_ppo, get_ppo_policy
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
from examples.cim.rl.env_sampler import CIMEnvSampler
class CIMBundle(RLComponentBundle):

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import heapq
import io
import os

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import List, Tuple

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import io
import math

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
from maro.cli.local.commands import run

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from maro.backends.frame import FrameBase, FrameNode, NodeAttribute, NodeBase, node
TOTAL_PRODUCT_CATEGORIES = 10

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import io
import os
import pprint

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,16 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Any, Callable, Dict, Optional
from examples.vm_scheduling.rl.algorithms.ac import get_ac, get_ac_policy
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn, get_dqn_policy
from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf
from examples.vm_scheduling.rl.env_sampler import VMEnvSampler
from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
from examples.vm_scheduling.rl.algorithms.ac import get_ac, get_ac_policy
from examples.vm_scheduling.rl.algorithms.dqn import get_dqn, get_dqn_policy
from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf
from examples.vm_scheduling.rl.env_sampler import VMEnvSampler
class VMBundle(RLComponentBundle):
def get_env_config(self) -> dict:

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction
from maro.simulator.scenarios.vm_scheduling.common import Action

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
from rule_based_algorithm import RuleBasedAlgorithm

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
import numpy as np

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib
import io
import os

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
from rule_based_algorithm import RuleBasedAlgorithm

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from rule_based_algorithm import RuleBasedAlgorithm
from maro.simulator import Env

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
from maro.simulator import Env

Просмотреть файл

@ -1,2 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import csv
import os

Просмотреть файл

@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from maro.cli.data_pipeline.citi_bike import CitiBikeProcess
from maro.cli.data_pipeline.vm_scheduling import VmSchedulingProcess

Просмотреть файл

@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import gzip
import os
import shutil

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from flask import Blueprint
from ..jwt_wrapper import check_jwt_validity

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import fcntl
import json

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from math import ceil
from typing import Callable

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
from typing import Optional

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Optional, Union
from .event import AbsEvent, ActualEvent, CascadeEvent, DummyEvent

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from itertools import count
from typing import Iterable, Iterator, List, Union

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod
from collections import deque
from typing import Iterable

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABCMeta
from typing import Tuple

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABCMeta
from typing import Tuple

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABCMeta
from typing import Tuple

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -376,6 +376,7 @@ class AbsEnvSampler(object, metaclass=ABCMeta):
cache_element.event,
cache_element.tick,
)
cache_element.reward_dict = {agent: cache_element.reward_dict[agent] for agent in cache_element.agent_names}
def _append_cache_element(self, cache_element: Optional[CacheElement]) -> None:
"""`cache_element` == None means we are processing the last element in trans_cache"""

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Callable, Dict

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import os
from abc import ABCMeta, abstractmethod

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from dataclasses import dataclass

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import importlib
import os

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib
import os
import sys

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib
import os
import sys

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import os
import random

Просмотреть файл

@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.s
# Licensed under the MIT license.
from datetime import datetime

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import io
import os
from abc import ABC, abstractmethod

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base_exception import MAROException
from .error_code import ERROR_CODE

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
This script is used to launch data and vis services, and the start the experiment script.

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import csv
import os
import pickle
@ -13,8 +16,6 @@ from maro.simulator.utils import random
os.environ["MARO_STREAMIT_ENABLED"] = "true"
os.environ["MARO_STREAMIT_EXPERIMENT_NAME"] = "cim_testing"
from tests.utils import backends_to_test, compare_dictionary
from maro.data_lib.cim import dump_from_config
from maro.data_lib.cim.entities import PortSetting, Stop, SyntheticPortSetting, VesselSetting
from maro.data_lib.cim.vessel_stop_wrapper import VesselStopsWrapper
@ -23,6 +24,8 @@ from maro.simulator.scenarios.cim.business_engine import CimBusinessEngine
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent
from maro.simulator.scenarios.cim.ports_order_export import PortOrderExporter
from tests.utils import backends_to_test, compare_dictionary
TOPOLOGY_PATH_CONFIG = "tests/data/cim/case_data/config_folder"
TOPOLOGY_PATH_DUMP = "tests/data/cim/case_data/dump_folder"
TOPOLOGY_PATH_REAL_BIN = "tests/data/cim/case_data/real_folder_bin"

Просмотреть файл

@ -4,13 +4,13 @@
import os
import unittest
from tests.utils import backends_to_test, be_run_to_end, next_step
from maro.data_lib import BinaryConverter
from maro.event_buffer import EventBuffer
from maro.simulator.scenarios.citi_bike.business_engine import CitibikeBusinessEngine
from maro.simulator.scenarios.citi_bike.events import CitiBikeEvents
from tests.utils import backends_to_test, be_run_to_end, next_step
def setup_case(case_name: str, max_tick: int):
config_path = os.path.join("tests/data/citi_bike", case_name)

Просмотреть файл

@ -12,7 +12,6 @@ import unittest
import uuid
import yaml
from tests.cli.utils import record_running_time
from maro.cli.grass.utils.params import NodeStatus
from maro.cli.utils.azure_controller import AzureController
@ -20,6 +19,8 @@ from maro.cli.utils.params import GlobalParams, GlobalPaths
from maro.cli.utils.subprocess import Subprocess
from maro.utils.exception.cli_exception import CommandExecutionError
from tests.cli.utils import record_running_time
@unittest.skipUnless(os.environ.get("test_with_cli", False), "Require CLI prerequisites.")
class TestGrassAzure(unittest.TestCase):

Просмотреть файл

@ -12,7 +12,6 @@ import unittest
import uuid
import yaml
from tests.cli.utils import record_running_time
from maro.cli.grass.utils.params import GrassParams, NodeStatus
from maro.cli.utils.azure_controller import AzureController
@ -20,6 +19,8 @@ from maro.cli.utils.params import GlobalParams, GlobalPaths
from maro.cli.utils.subprocess import Subprocess
from maro.utils.exception.cli_exception import CommandExecutionError
from tests.cli.utils import record_running_time
@unittest.skipUnless(os.environ.get("test_with_cli", False), "Require CLI prerequisites.")
class TestGrassOnPremises(unittest.TestCase):

Просмотреть файл

@ -13,13 +13,14 @@ import unittest
import uuid
import yaml
from tests.cli.utils import record_running_time
from maro.cli.utils.azure_controller import AzureController
from maro.cli.utils.params import GlobalParams, GlobalPaths
from maro.cli.utils.subprocess import Subprocess
from maro.utils.exception.cli_exception import CommandExecutionError
from tests.cli.utils import record_running_time
@unittest.skipUnless(os.environ.get("test_with_cli", False), "Require cli prerequisites.")
class TestK8s(unittest.TestCase):

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -8,10 +8,10 @@ import threading
import unittest
from concurrent.futures import ThreadPoolExecutor
from tests.communication.utils import get_random_port, proxy_generator
from maro.communication import SessionMessage, dist
from tests.communication.utils import get_random_port, proxy_generator
def handler_function(that, proxy, message):
replied_payload = {"counter": message.body["counter"] + 1}

Просмотреть файл

@ -6,10 +6,10 @@ import subprocess
import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed
from tests.communication.utils import get_random_port, proxy_generator
from maro.communication import SessionMessage, SessionType
from tests.communication.utils import get_random_port, proxy_generator
def message_receive(proxy):
return proxy.receive_once().body

Просмотреть файл

@ -8,10 +8,10 @@ import sys
import time
import unittest
from tests.communication.utils import get_random_port
from maro.communication import Proxy, SessionMessage, SessionType
from tests.communication.utils import get_random_port
PROXY_PARAMETER = {
"group_name": "communication_unit_test",
"enable_rejoin": True,

Просмотреть файл

@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT licence
# Licensed under the MIT license.
import os
import tempfile

Просмотреть файл

@ -4,13 +4,12 @@
import os
import unittest
from tests.utils import backends_to_test
from maro.simulator.core import Env
from maro.simulator.utils import get_available_envs, get_scenarios, get_topologies
from maro.simulator.utils.common import frame_index_to_ticks, tick_to_frame_index
from .dummy.dummy_business_engine import DummyEngine
from tests.utils import backends_to_test
def run_to_end(env: Env):

Просмотреть файл

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tempfile
import time
import unittest

Просмотреть файл

@ -6,7 +6,6 @@ import unittest
import numpy as np
import pandas as pd
from tests.utils import backends_to_test
from maro.backends.backend import AttributeType
from maro.backends.frame import FrameBase, FrameNode, NodeAttribute, NodeBase, node
@ -16,6 +15,8 @@ from maro.utils.exception.backends_exception import (
BackendsSetItemInvalidException,
)
from tests.utils import backends_to_test
STATIC_NODE_NUM = 5
DYNAMIC_NODE_NUM = 10

Просмотреть файл

@ -3,9 +3,8 @@
import unittest
from tests.utils import backends_to_test
from .test_frame import DYNAMIC_NODE_NUM, STATIC_NODE_NUM, build_frame
from tests.utils import backends_to_test
class TestFrame(unittest.TestCase):