зеркало из https://github.com/microsoft/MLOS.git
mypy static type checking for mlos_bench (#306)
Adds static type checking via mypy for mlos_bench as well. Builds on #305, #307 - [x] Add `Protocol`s for different `Service` types so `Environment`s can check that the appropriate `Service` mix-ins have been setup. - [x] Config Loader - [x] Local Exec - [x] Remote Exec - [x] Remote Fileshare - [x] VM Ops - [ ] Future PR: Split VM Ops to VM/Host and OS Ops (for local SSH support)
This commit is contained in:
Родитель
10e4fa5b91
Коммит
9fe180a9a5
|
@ -1,7 +1,10 @@
|
|||
// vim: set ft=jsonc:
|
||||
{
|
||||
"makefile.extensionOutputFolder": "./.vscode",
|
||||
"python.defaultInterpreterPath": "${env:HOME}${env:USERPROFILE}/.conda/envs/mlos_core/bin/python",
|
||||
// Note: this only works in WSL/Linux currently.
|
||||
"python.defaultInterpreterPath": "${env:HOME}/.conda/envs/mlos_core/bin/python",
|
||||
// For Windows it should be this instead:
|
||||
//"python.defaultInterpreterPath": "${env:USERPROFILE}/.conda/envs/mlos_core/python.exe",
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.pylintEnabled": true,
|
||||
|
|
5
Makefile
5
Makefile
|
@ -12,7 +12,7 @@ MLOS_BENCH_PYTHON_FILES := $(shell find ./mlos_bench/ -type f -name '*.py' 2>/de
|
|||
|
||||
DOCKER := $(shell which docker)
|
||||
# Make sure the build directory exists.
|
||||
MKDIR_BUILD := $(shell mkdir -p build)
|
||||
MKDIR_BUILD := $(shell test -d build || mkdir build)
|
||||
|
||||
# Allow overriding the default verbosity of conda for CI jobs.
|
||||
#CONDA_INFO_LEVEL ?= -q
|
||||
|
@ -113,7 +113,7 @@ build/pylint.%.${CONDA_ENV_NAME}.build-stamp: build/conda-env.${CONDA_ENV_NAME}.
|
|||
touch $@
|
||||
|
||||
.PHONY: mypy
|
||||
mypy: conda-env build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp # TODO: build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp
|
||||
mypy: conda-env build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp
|
||||
|
||||
build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES)
|
||||
build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES)
|
||||
|
@ -339,6 +339,7 @@ build/check-doc.build-stamp: doc/build/html/index.html doc/build/html/htmlcov/in
|
|||
-e 'Problems with "include" directive path:' \
|
||||
-e 'duplicate object description' \
|
||||
-e "document isn't included in any toctree" \
|
||||
-e "more than one target found for cross-reference" \
|
||||
-e "toctree contains reference to nonexisting document 'auto_examples/index'" \
|
||||
-e "failed to import function 'create' from module '(SpaceAdapter|Optimizer)Factory'" \
|
||||
-e "No module named '(SpaceAdapter|Optimizer)Factory'" \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
FROM nginx:latest
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends linklint curl
|
||||
# Our ngxinx config overrides the default listening port.
|
||||
# Our nginx config overrides the default listening port.
|
||||
ARG NGINX_PORT=81
|
||||
ENV NGINX_PORT=${NGINX_PORT}
|
||||
EXPOSE ${NGINX_PORT}
|
||||
|
|
|
@ -156,6 +156,12 @@ Service Mix-ins
|
|||
|
||||
Service
|
||||
FileShareService
|
||||
|
||||
.. currentmodule:: mlos_bench.service.config_persistence
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
:template: class.rst
|
||||
|
||||
ConfigPersistenceService
|
||||
|
||||
Local Services
|
||||
|
|
|
@ -13,7 +13,7 @@ import json
|
|||
import argparse
|
||||
|
||||
|
||||
def _main(fname_input: str, fname_output: str):
|
||||
def _main(fname_input: str, fname_output: str) -> None:
|
||||
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
|
||||
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
|
||||
for (key, val) in json.load(fh_tunables).items():
|
||||
|
|
|
@ -12,7 +12,7 @@ import argparse
|
|||
import pandas as pd
|
||||
|
||||
|
||||
def _main(input_file: str, output_file: str):
|
||||
def _main(input_file: str, output_file: str) -> None:
|
||||
"""
|
||||
Re-shape Redis benchmark CSV results from wide to long.
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,7 @@ import json
|
|||
import argparse
|
||||
|
||||
|
||||
def _main(fname_input: str, fname_output: str):
|
||||
def _main(fname_input: str, fname_output: str) -> None:
|
||||
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
|
||||
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
|
||||
for (key, val) in json.load(fh_tunables).items():
|
||||
|
|
|
@ -13,7 +13,7 @@ import json
|
|||
import argparse
|
||||
|
||||
|
||||
def _main(fname_input: str, fname_output: str):
|
||||
def _main(fname_input: str, fname_output: str) -> None:
|
||||
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
|
||||
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
|
||||
for (key, val) in json.load(fh_tunables).items():
|
||||
|
|
|
@ -9,9 +9,12 @@ A hierarchy of benchmark environments.
|
|||
import abc
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.config_loader_type import SupportsConfigLoading
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.util import instantiate_from_config
|
||||
|
||||
|
@ -19,13 +22,21 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Environment(metaclass=abc.ABCMeta):
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
An abstract base of all benchmark environments.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def new(cls, env_name, class_name, config, global_config=None, tunables=None, service=None):
|
||||
# pylint: disable=too-many-arguments
|
||||
def new(cls,
|
||||
*,
|
||||
env_name: str,
|
||||
class_name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
) -> "Environment":
|
||||
"""
|
||||
Factory method for a new environment with a given config.
|
||||
|
||||
|
@ -58,8 +69,13 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
return instantiate_from_config(cls, class_name, env_name, config,
|
||||
global_config, tunables, service)
|
||||
|
||||
def __init__(self, name, config, global_config=None, tunables=None, service=None):
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
"""
|
||||
Create a new environment with a given config.
|
||||
|
||||
|
@ -84,10 +100,17 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
self.config = config
|
||||
self._service = service
|
||||
self._is_ready = False
|
||||
self._params = {}
|
||||
self._params: Dict[str, TunableValue] = {}
|
||||
|
||||
self._config_loader_service: SupportsConfigLoading
|
||||
if self._service is not None and isinstance(self._service, SupportsConfigLoading):
|
||||
self._config_loader_service = self._service
|
||||
|
||||
if global_config is None:
|
||||
global_config = {}
|
||||
|
||||
self._const_args = config.get("const_args", {})
|
||||
for key in set(self._const_args).intersection(global_config or {}):
|
||||
for key in set(self._const_args).intersection(global_config):
|
||||
self._const_args[key] = global_config[key]
|
||||
|
||||
for key in config.get("required_args", []):
|
||||
|
@ -110,13 +133,13 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
_LOG.debug("Config for: %s\n%s",
|
||||
name, json.dumps(self.config, indent=2))
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Env: {self.__class__} :: '{self.name}'"
|
||||
|
||||
def _combine_tunables(self, tunables):
|
||||
def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]:
|
||||
"""
|
||||
Plug tunable values into the base config. If the tunable group is unknown,
|
||||
ignore it (it might belong to another environment). This method should
|
||||
|
@ -130,14 +153,14 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
|
||||
Returns
|
||||
-------
|
||||
params : dict
|
||||
params : Dict[str, Union[int, float, str]]
|
||||
Free-format dictionary that contains the new environment configuration.
|
||||
"""
|
||||
return tunables.get_param_values(
|
||||
group_names=self._tunable_params.get_names(),
|
||||
group_names=list(self._tunable_params.get_names()),
|
||||
into_params=self._const_args.copy())
|
||||
|
||||
def tunable_params(self):
|
||||
def tunable_params(self) -> TunableGroups:
|
||||
"""
|
||||
Get the configuration space of the given environment.
|
||||
|
||||
|
@ -148,7 +171,7 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
"""
|
||||
return self._tunable_params
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Set up a new benchmark environment, if necessary. This method must be
|
||||
idempotent, i.e., calling it several times in a row should be
|
||||
|
@ -170,15 +193,18 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
_LOG.info("Setup %s :: %s", self, tunables)
|
||||
assert isinstance(tunables, TunableGroups)
|
||||
|
||||
if global_config is None:
|
||||
global_config = {}
|
||||
|
||||
self._params = self._combine_tunables(tunables)
|
||||
for key in set(self._params).intersection(global_config or {}):
|
||||
for key in set(self._params).intersection(global_config):
|
||||
self._params[key] = global_config[key]
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2))
|
||||
|
||||
return True
|
||||
|
||||
def teardown(self):
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Tear down the benchmark environment. This method must be idempotent,
|
||||
i.e., calling it several times in a row should be equivalent to a
|
||||
|
|
|
@ -7,7 +7,7 @@ Composite benchmark environment.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.environment.status import Status
|
||||
|
@ -22,9 +22,13 @@ class CompositeEnv(Environment):
|
|||
Composite benchmark environment.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, config: dict, global_config: dict = None,
|
||||
tunables: TunableGroups = None, service: Service = None):
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
"""
|
||||
Create a new environment with a given config.
|
||||
|
||||
|
@ -44,23 +48,23 @@ class CompositeEnv(Environment):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name, config, global_config, tunables, service)
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
|
||||
self._children = []
|
||||
self._children: List[Environment] = []
|
||||
|
||||
for child_config_file in config.get("include_children", []):
|
||||
for env in self._service.load_environment_list(
|
||||
for env in self._config_loader_service.load_environment_list(
|
||||
child_config_file, global_config, tunables, self._service):
|
||||
self._add_child(env)
|
||||
|
||||
for child_config in config.get("children", []):
|
||||
self._add_child(self._service.build_environment(
|
||||
self._add_child(self._config_loader_service.build_environment(
|
||||
child_config, global_config, tunables, self._service))
|
||||
|
||||
if not self._children:
|
||||
raise ValueError("At least one child environment must be present")
|
||||
|
||||
def _add_child(self, env: Environment):
|
||||
def _add_child(self, env: Environment) -> None:
|
||||
"""
|
||||
Add a new child environment to the composite environment.
|
||||
This method is called from the constructor only.
|
||||
|
@ -68,7 +72,7 @@ class CompositeEnv(Environment):
|
|||
self._children.append(env)
|
||||
self._tunable_params.update(env.tunable_params())
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Set up the children environments.
|
||||
|
||||
|
@ -92,7 +96,7 @@ class CompositeEnv(Environment):
|
|||
)
|
||||
return self._is_ready
|
||||
|
||||
def teardown(self):
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Tear down the children environments. This method is idempotent,
|
||||
i.e., calling it several times is equivalent to a single call.
|
||||
|
|
|
@ -17,23 +17,25 @@ import pandas
|
|||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.local_exec_type import SupportsLocalExec
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalEnv(Environment):
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
Scheduler-side Environment that runs scripts locally.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
# pylint: disable=too-many-arguments
|
||||
"""
|
||||
Create a new environment for local execution.
|
||||
|
||||
|
@ -56,15 +58,19 @@ class LocalEnv(Environment):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name, config, global_config, tunables, service)
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
|
||||
"LocalEnv requires a service that supports local execution"
|
||||
self._local_exec_service: SupportsLocalExec = self._service
|
||||
|
||||
self._temp_dir = self.config.get("temp_dir")
|
||||
self._script_setup = self.config.get("setup")
|
||||
self._script_run = self.config.get("run")
|
||||
self._script_teardown = self.config.get("teardown")
|
||||
|
||||
self._dump_params_file = self.config.get("dump_params_file")
|
||||
self._read_results_file = self.config.get("read_results_file")
|
||||
self._dump_params_file: Optional[str] = self.config.get("dump_params_file")
|
||||
self._read_results_file: Optional[str] = self.config.get("read_results_file")
|
||||
|
||||
if self._script_setup is None and \
|
||||
self._script_run is None and \
|
||||
|
@ -77,7 +83,7 @@ class LocalEnv(Environment):
|
|||
if self._script_run is None and self._read_results_file is not None:
|
||||
raise ValueError("'run' must be present if 'read_results_file' is specified")
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Check if the environment is ready and set up the application
|
||||
and benchmarks, if necessary.
|
||||
|
@ -105,7 +111,7 @@ class LocalEnv(Environment):
|
|||
self._is_ready = True
|
||||
return True
|
||||
|
||||
with self._service.temp_dir_context(self._temp_dir) as temp_dir:
|
||||
with self._local_exec_service.temp_dir_context(self._temp_dir) as temp_dir:
|
||||
|
||||
_LOG.info("Set up the environment locally: %s at %s", self, temp_dir)
|
||||
|
||||
|
@ -116,7 +122,7 @@ class LocalEnv(Environment):
|
|||
json.dump(tunables.get_param_values(self._tunable_params.get_names()),
|
||||
fh_tunables)
|
||||
|
||||
(return_code, _stdout, stderr) = self._service.local_exec(
|
||||
(return_code, _stdout, stderr) = self._local_exec_service.local_exec(
|
||||
self._script_setup, env=self._params, cwd=temp_dir)
|
||||
|
||||
if return_code == 0:
|
||||
|
@ -125,7 +131,7 @@ class LocalEnv(Environment):
|
|||
_LOG.warning("ERROR: Local setup returns with code %d stderr:\n%s",
|
||||
return_code, stderr)
|
||||
|
||||
self._is_ready = return_code == 0
|
||||
self._is_ready = bool(return_code == 0)
|
||||
|
||||
return self._is_ready
|
||||
|
||||
|
@ -145,10 +151,10 @@ class LocalEnv(Environment):
|
|||
if not (status.is_ready and self._script_run):
|
||||
return result
|
||||
|
||||
with self._service.temp_dir_context(self._temp_dir) as temp_dir:
|
||||
with self._local_exec_service.temp_dir_context(self._temp_dir) as temp_dir:
|
||||
|
||||
_LOG.info("Run script locally on: %s at %s", self, temp_dir)
|
||||
(return_code, _stdout, stderr) = self._service.local_exec(
|
||||
(return_code, _stdout, stderr) = self._local_exec_service.local_exec(
|
||||
self._script_run, env=self._params, cwd=temp_dir)
|
||||
|
||||
if return_code != 0:
|
||||
|
@ -156,23 +162,24 @@ class LocalEnv(Environment):
|
|||
return_code, stderr)
|
||||
return (Status.FAILED, None)
|
||||
|
||||
data = pandas.read_csv(self._service.resolve_path(
|
||||
assert self._read_results_file is not None
|
||||
data = pandas.read_csv(self._config_loader_service.resolve_path(
|
||||
self._read_results_file, extra_paths=[temp_dir]))
|
||||
|
||||
_LOG.debug("Read data:\n%s", data)
|
||||
if len(data) != 1:
|
||||
_LOG.warning("Local run has %d results - returning the last one", len(data))
|
||||
|
||||
data = data.iloc[-1].to_dict()
|
||||
_LOG.info("Local run complete: %s ::\n%s", self, data)
|
||||
return (Status.SUCCEEDED, data)
|
||||
data_dict = data.iloc[-1].to_dict()
|
||||
_LOG.info("Local run complete: %s ::\n%s", self, data_dict)
|
||||
return (Status.SUCCEEDED, data_dict)
|
||||
|
||||
def teardown(self):
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Clean up the local environment.
|
||||
"""
|
||||
if self._script_teardown:
|
||||
_LOG.info("Local teardown: %s", self)
|
||||
(status, _) = self._service.local_exec(self._script_teardown, env=self._params)
|
||||
(status, _, _) = self._local_exec_service.local_exec(self._script_teardown, env=self._params)
|
||||
_LOG.info("Local teardown complete: %s :: %s", self, status)
|
||||
super().teardown()
|
||||
|
|
|
@ -10,11 +10,14 @@ and upload/download data to the shared storage.
|
|||
import logging
|
||||
|
||||
from string import Template
|
||||
from typing import Optional, Dict, List, Tuple, Any
|
||||
from typing import List, Generator, Iterable, Mapping, Optional, Tuple
|
||||
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.local_exec_type import SupportsLocalExec
|
||||
from mlos_bench.service.types.fileshare_type import SupportsFileShareOps
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.environment.local.local_env import LocalEnv
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -27,12 +30,12 @@ class LocalFileShareEnv(LocalEnv):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
# pylint: disable=too-many-arguments
|
||||
"""
|
||||
Create a new application environment with a given config.
|
||||
|
||||
|
@ -56,7 +59,16 @@ class LocalFileShareEnv(LocalEnv):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name, config, global_config, tunables, service)
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
|
||||
"LocalEnv requires a service that supports local execution"
|
||||
self._local_exec_service: SupportsLocalExec = self._service
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \
|
||||
"LocalEnv requires a service that supports file upload/download operations"
|
||||
self._file_share_service: SupportsFileShareOps = self._service
|
||||
|
||||
self._upload = self._template_from_to("upload")
|
||||
self._download = self._template_from_to("download")
|
||||
|
||||
|
@ -71,7 +83,8 @@ class LocalFileShareEnv(LocalEnv):
|
|||
]
|
||||
|
||||
@staticmethod
|
||||
def _expand(from_to: List[Tuple[Template, Template]], params: Dict[str, Any]):
|
||||
def _expand(from_to: Iterable[Tuple[Template, Template]],
|
||||
params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]:
|
||||
"""
|
||||
Substitute $var parameters in from/to path templates.
|
||||
Return a generator of (str, str) pairs of paths.
|
||||
|
@ -81,7 +94,7 @@ class LocalFileShareEnv(LocalEnv):
|
|||
for (path_from, path_to) in from_to
|
||||
)
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Run setup scripts locally and upload the scripts and data to the shared storage.
|
||||
|
||||
|
@ -102,14 +115,14 @@ class LocalFileShareEnv(LocalEnv):
|
|||
True if operation is successful, false otherwise.
|
||||
"""
|
||||
prev_temp_dir = self._temp_dir
|
||||
with self._service.temp_dir_context(self._temp_dir) as self._temp_dir:
|
||||
with self._local_exec_service.temp_dir_context(self._temp_dir) as self._temp_dir:
|
||||
# Override _temp_dir so that setup and upload both use the same path.
|
||||
self._is_ready = super().setup(tunables, global_config)
|
||||
if self._is_ready:
|
||||
params = self._params.copy()
|
||||
params["PWD"] = self._temp_dir
|
||||
for (path_from, path_to) in self._expand(self._upload, params):
|
||||
self._service.upload(self._service.resolve_path(
|
||||
self._file_share_service.upload(self._config_loader_service.resolve_path(
|
||||
path_from, extra_paths=[self._temp_dir]), path_to)
|
||||
self._temp_dir = prev_temp_dir
|
||||
return self._is_ready
|
||||
|
@ -128,13 +141,13 @@ class LocalFileShareEnv(LocalEnv):
|
|||
be in the `score` field.
|
||||
"""
|
||||
prev_temp_dir = self._temp_dir
|
||||
with self._service.temp_dir_context(self._temp_dir) as self._temp_dir:
|
||||
with self._local_exec_service.temp_dir_context(self._temp_dir) as self._temp_dir:
|
||||
# Override _temp_dir so that download and run both use the same path.
|
||||
params = self._params.copy()
|
||||
params["PWD"] = self._temp_dir
|
||||
for (path_from, path_to) in self._expand(self._download, params):
|
||||
self._service.download(
|
||||
path_from, self._service.resolve_path(
|
||||
self._file_share_service.download(
|
||||
path_from, self._config_loader_service.resolve_path(
|
||||
path_to, extra_paths=[self._temp_dir]))
|
||||
result = super().run()
|
||||
self._temp_dir = prev_temp_dir
|
||||
|
|
|
@ -12,9 +12,9 @@ from typing import Optional, Tuple
|
|||
|
||||
import numpy
|
||||
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.tunables import Tunable, TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -28,12 +28,12 @@ class MockEnv(Environment):
|
|||
_NOISE_VAR = 0.2 # Variance of the Gaussian noise added to the benchmark value.
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
# pylint: disable=too-many-arguments
|
||||
"""
|
||||
Create a new environment that produces mock benchmark data.
|
||||
|
||||
|
@ -52,7 +52,7 @@ class MockEnv(Environment):
|
|||
service: Service
|
||||
An optional service object. Not used by this class.
|
||||
"""
|
||||
super().__init__(name, config, global_config, tunables, service)
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
seed = self.config.get("seed")
|
||||
self._random = random.Random(seed) if seed is not None else None
|
||||
self._range = self.config.get("range")
|
||||
|
@ -95,15 +95,14 @@ class MockEnv(Environment):
|
|||
That is, map current value to the [0, 1] range.
|
||||
"""
|
||||
val = None
|
||||
if tunable.type == "categorical":
|
||||
val = (tunable.categorical_values.index(tunable.value) /
|
||||
if tunable.is_categorical:
|
||||
val = (tunable.categorical_values.index(tunable.categorical_value) /
|
||||
float(len(tunable.categorical_values) - 1))
|
||||
elif tunable.type in {"int", "float"}:
|
||||
if not tunable.range:
|
||||
raise ValueError("Tunable must have a range: " + tunable.name)
|
||||
val = ((tunable.value - tunable.range[0]) /
|
||||
elif tunable.is_numerical:
|
||||
val = ((tunable.numerical_value - tunable.range[0]) /
|
||||
float(tunable.range[1] - tunable.range[0]))
|
||||
else:
|
||||
raise ValueError("Invalid parameter type: " + tunable.type)
|
||||
# Explicitly clip the value in case of numerical errors.
|
||||
return numpy.clip(val, 0, 1)
|
||||
ret: float = numpy.clip(val, 0, 1)
|
||||
return ret
|
||||
|
|
|
@ -6,9 +6,14 @@
|
|||
OS-level remote Environment on Azure.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
from mlos_bench.environment import Environment, Status
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -19,7 +24,43 @@ class OSEnv(Environment):
|
|||
OS Level Environment for a host.
|
||||
"""
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
"""
|
||||
Create a new environment for remote execution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
Human-readable name of the environment.
|
||||
config : dict
|
||||
Free-format dictionary that contains the benchmark environment
|
||||
configuration. Each config must have at least the "tunable_params"
|
||||
and the "const_args" sections.
|
||||
`RemoteEnv` must also have at least some of the following parameters:
|
||||
{setup, run, teardown, wait_boot}
|
||||
global_config : dict
|
||||
Free-format dictionary of global parameters (e.g., security credentials)
|
||||
to be mixed in into the "const_args" section of the local config.
|
||||
tunables : TunableGroups
|
||||
A collection of tunable parameters for *all* environments.
|
||||
service: Service
|
||||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
|
||||
# TODO: Refactor this as "host" and "os" operations to accommodate SSH service.
|
||||
assert self._service is not None and isinstance(self._service, SupportsVMOps), \
|
||||
"RemoteEnv requires a service that supports host operations"
|
||||
self._host_service: SupportsVMOps = self._service
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Check if the host is up and running; boot it, if necessary.
|
||||
|
||||
|
@ -43,21 +84,21 @@ class OSEnv(Environment):
|
|||
if not super().setup(tunables, global_config):
|
||||
return False
|
||||
|
||||
(status, params) = self._service.vm_start(self._params)
|
||||
(status, params) = self._host_service.vm_start(self._params)
|
||||
if status.is_pending:
|
||||
(status, _) = self._service.wait_vm_operation(params)
|
||||
(status, _) = self._host_service.wait_vm_operation(params)
|
||||
|
||||
self._is_ready = status in {Status.SUCCEEDED, Status.READY}
|
||||
return self._is_ready
|
||||
|
||||
def teardown(self):
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Clean up and shut down the host without deprovisioning it.
|
||||
"""
|
||||
_LOG.info("OS tear down: %s", self)
|
||||
(status, params) = self._service.vm_stop()
|
||||
(status, params) = self._host_service.vm_stop()
|
||||
if status.is_pending:
|
||||
(status, _) = self._service.wait_vm_operation(params)
|
||||
(status, _) = self._host_service.wait_vm_operation(params)
|
||||
|
||||
super().teardown()
|
||||
_LOG.debug("Final status of OS stopping: %s :: %s", self, status)
|
||||
|
|
|
@ -7,11 +7,13 @@ Remotely executed benchmark/script environment.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec
|
||||
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -23,12 +25,12 @@ class RemoteEnv(Environment):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
# pylint: disable=too-many-arguments
|
||||
"""
|
||||
Create a new environment for remote execution.
|
||||
|
||||
|
@ -51,7 +53,16 @@ class RemoteEnv(Environment):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name, config, global_config, tunables, service)
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \
|
||||
"RemoteEnv requires a service that supports remote execution operations"
|
||||
self._remote_exec_service: SupportsRemoteExec = self._service
|
||||
|
||||
# TODO: Refactor this as "host" and "os" operations to accommodate SSH service.
|
||||
assert self._service is not None and isinstance(self._service, SupportsVMOps), \
|
||||
"RemoteEnv requires a service that supports host operations"
|
||||
self._host_service: SupportsVMOps = self._service
|
||||
|
||||
self._wait_boot = self.config.get("wait_boot", False)
|
||||
self._script_setup = self.config.get("setup")
|
||||
|
@ -65,7 +76,7 @@ class RemoteEnv(Environment):
|
|||
raise ValueError("At least one of {setup, run, teardown}" +
|
||||
" must be present or wait_boot set to True.")
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Check if the environment is ready and set up the application
|
||||
and benchmarks on a remote host.
|
||||
|
@ -89,9 +100,9 @@ class RemoteEnv(Environment):
|
|||
|
||||
if self._wait_boot:
|
||||
_LOG.info("Wait for the remote environment to start: %s", self)
|
||||
(status, params) = self._service.vm_start(self._params)
|
||||
(status, params) = self._host_service.vm_start(self._params)
|
||||
if status.is_pending:
|
||||
(status, _) = self._service.wait_vm_operation(params)
|
||||
(status, _) = self._host_service.wait_vm_operation(params)
|
||||
if not status.is_succeeded:
|
||||
return False
|
||||
|
||||
|
@ -130,7 +141,7 @@ class RemoteEnv(Environment):
|
|||
_LOG.info("Remote run complete: %s :: %s", self, result)
|
||||
return result
|
||||
|
||||
def teardown(self):
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Clean up and shut down the remote environment.
|
||||
"""
|
||||
|
@ -140,7 +151,7 @@ class RemoteEnv(Environment):
|
|||
_LOG.info("Remote teardown complete: %s :: %s", self, status)
|
||||
super().teardown()
|
||||
|
||||
def _remote_exec(self, script: List[str]) -> Tuple[Status, Optional[dict]]:
|
||||
def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, Optional[dict]]:
|
||||
"""
|
||||
Run a script on the remote host.
|
||||
|
||||
|
@ -156,10 +167,10 @@ class RemoteEnv(Environment):
|
|||
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
|
||||
"""
|
||||
_LOG.debug("Submit script: %s", self)
|
||||
(status, output) = self._service.remote_exec(script, self._params)
|
||||
(status, output) = self._remote_exec_service.remote_exec(script, self._params)
|
||||
_LOG.debug("Script submitted: %s %s :: %s", self, status, output)
|
||||
if status in {Status.PENDING, Status.SUCCEEDED}:
|
||||
(status, output) = self._service.get_remote_exec_results(output)
|
||||
(status, output) = self._remote_exec_service.get_remote_exec_results(output)
|
||||
# TODO: extract the results from `output`.
|
||||
_LOG.debug("Status: %s :: %s", status, output)
|
||||
return (status, output)
|
||||
|
|
|
@ -6,9 +6,13 @@
|
|||
"Remote" VM Environment.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
from mlos_bench.environment import Environment
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -19,7 +23,40 @@ class VMEnv(Environment):
|
|||
"Remote" VM environment.
|
||||
"""
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
"""
|
||||
Create a new environment for VM operations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
Human-readable name of the environment.
|
||||
config : dict
|
||||
Free-format dictionary that contains the benchmark environment
|
||||
configuration. Each config must have at least the "tunable_params"
|
||||
and the "const_args" sections.
|
||||
global_config : dict
|
||||
Free-format dictionary of global parameters (e.g., security credentials)
|
||||
to be mixed in into the "const_args" section of the local config.
|
||||
tunables : TunableGroups
|
||||
A collection of tunable parameters for *all* environments.
|
||||
service: Service
|
||||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsVMOps), \
|
||||
"VMEnv requires a service that supports VM operations"
|
||||
self._vm_service: SupportsVMOps = self._service
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Check if VM is ready. (Re)provision and start it, if necessary.
|
||||
|
||||
|
@ -43,21 +80,21 @@ class VMEnv(Environment):
|
|||
if not super().setup(tunables, global_config):
|
||||
return False
|
||||
|
||||
(status, params) = self._service.vm_provision(self._params)
|
||||
(status, params) = self._vm_service.vm_provision(self._params)
|
||||
if status.is_pending:
|
||||
(status, _) = self._service.wait_vm_deployment(True, params)
|
||||
(status, _) = self._vm_service.wait_vm_deployment(True, params)
|
||||
|
||||
self._is_ready = status.is_succeeded
|
||||
return self._is_ready
|
||||
|
||||
def teardown(self):
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Shut down the VM and release it.
|
||||
"""
|
||||
_LOG.info("VM tear down: %s", self)
|
||||
(status, params) = self._service.vm_deprovision()
|
||||
(status, params) = self._vm_service.vm_deprovision()
|
||||
if status.is_pending:
|
||||
(status, _) = self._service.wait_vm_deployment(False, params)
|
||||
(status, _) = self._vm_service.wait_vm_deployment(False, params)
|
||||
|
||||
super().teardown()
|
||||
_LOG.debug("Final status of VM deprovisioning: %s :: %s", self, status)
|
||||
|
|
|
@ -24,7 +24,7 @@ class Status(enum.Enum):
|
|||
TIMED_OUT = 7
|
||||
|
||||
@property
|
||||
def is_good(self):
|
||||
def is_good(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is good.
|
||||
"""
|
||||
|
@ -36,42 +36,42 @@ class Status(enum.Enum):
|
|||
}
|
||||
|
||||
@property
|
||||
def is_pending(self):
|
||||
def is_pending(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is PENDING.
|
||||
"""
|
||||
return self == Status.PENDING
|
||||
|
||||
@property
|
||||
def is_ready(self):
|
||||
def is_ready(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is READY.
|
||||
"""
|
||||
return self == Status.READY
|
||||
|
||||
@property
|
||||
def is_succeeded(self):
|
||||
def is_succeeded(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is SUCCEEDED.
|
||||
"""
|
||||
return self == Status.SUCCEEDED
|
||||
|
||||
@property
|
||||
def is_failed(self):
|
||||
def is_failed(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is FAILED.
|
||||
"""
|
||||
return self == Status.FAILED
|
||||
|
||||
@property
|
||||
def is_canceled(self):
|
||||
def is_canceled(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is CANCELED.
|
||||
"""
|
||||
return self == Status.CANCELED
|
||||
|
||||
@property
|
||||
def is_timed_out(self):
|
||||
def is_timed_out(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is TIMEDOUT.
|
||||
"""
|
||||
|
|
|
@ -9,10 +9,11 @@ Helper functions to launch the benchmark and the optimizer from the command line
|
|||
import logging
|
||||
import argparse
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from mlos_bench.environment import Environment
|
||||
from mlos_bench.service import LocalExecService, ConfigPersistenceService
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.service.local.local_exec import LocalExecService
|
||||
from mlos_bench.service.config_persistence import ConfigPersistenceService
|
||||
|
||||
_LOG_LEVEL = logging.INFO
|
||||
_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s'
|
||||
|
@ -30,9 +31,9 @@ class Launcher:
|
|||
|
||||
_LOG.info("Launch: %s", description)
|
||||
|
||||
self._config_loader = None
|
||||
self._env_config_file = None
|
||||
self._global_config = {}
|
||||
self._config_loader: ConfigPersistenceService
|
||||
self._env_config_file: str
|
||||
self._global_config: Dict[str, Any] = {}
|
||||
self._parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
self._parser.add_argument(
|
||||
|
@ -89,7 +90,9 @@ class Launcher:
|
|||
|
||||
if args.globals is not None:
|
||||
for config_file in args.globals:
|
||||
self._global_config.update(self._config_loader.load_config(config_file))
|
||||
conf = self._config_loader.load_config(config_file)
|
||||
assert isinstance(conf, dict)
|
||||
self._global_config.update(conf)
|
||||
|
||||
self._global_config.update(Launcher._try_parse_extra_args(args_rest))
|
||||
if args.config_path:
|
||||
|
@ -98,7 +101,7 @@ class Launcher:
|
|||
return args
|
||||
|
||||
@staticmethod
|
||||
def _try_parse_extra_args(cmdline: List[str]) -> Dict[str, str]:
|
||||
def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, str]:
|
||||
"""
|
||||
Helper function to parse global key/value pairs from the command line.
|
||||
"""
|
||||
|
@ -132,7 +135,9 @@ class Launcher:
|
|||
Load JSON config file. Use path relative to `config_path` if required.
|
||||
"""
|
||||
assert self._config_loader is not None, "Call after invoking .parse_args()"
|
||||
return self._config_loader.load_config(json_file_name)
|
||||
conf = self._config_loader.load_config(json_file_name)
|
||||
assert isinstance(conf, dict)
|
||||
return conf
|
||||
|
||||
def load_env(self) -> Environment:
|
||||
"""
|
||||
|
|
|
@ -8,7 +8,7 @@ and mlos_core optimizers.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from mlos_bench.environment.status import Status
|
||||
|
@ -24,7 +24,7 @@ class Optimizer(metaclass=ABCMeta):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(tunables: TunableGroups, config: dict, global_config: dict = None):
|
||||
def load(tunables: TunableGroups, config: dict, global_config: Optional[dict] = None) -> "Optimizer":
|
||||
"""
|
||||
Instantiate the Optimizer shim from the configuration.
|
||||
|
||||
|
@ -48,7 +48,7 @@ class Optimizer(metaclass=ABCMeta):
|
|||
return opt
|
||||
|
||||
@classmethod
|
||||
def new(cls, class_name: str, tunables: TunableGroups, config: dict):
|
||||
def new(cls, class_name: str, tunables: TunableGroups, config: dict) -> "Optimizer":
|
||||
"""
|
||||
Factory method for a new optimizer with a given config.
|
||||
|
||||
|
@ -89,11 +89,13 @@ class Optimizer(metaclass=ABCMeta):
|
|||
self._tunables = tunables
|
||||
self._iter = 1
|
||||
self._max_iter = int(self._config.pop('max_iterations', 10))
|
||||
self._opt_target = self._config.pop('maximize', None)
|
||||
if self._opt_target is None:
|
||||
self._opt_target: str
|
||||
_opt_target = self._config.pop('maximize', None)
|
||||
if _opt_target is None:
|
||||
self._opt_target = self._config.pop('minimize', 'score')
|
||||
self._opt_sign = 1
|
||||
else:
|
||||
self._opt_target = _opt_target
|
||||
if 'minimize' in self._config:
|
||||
raise ValueError("Cannot specify both 'maximize' and 'minimize'.")
|
||||
self._opt_sign = -1
|
||||
|
@ -110,18 +112,18 @@ class Optimizer(metaclass=ABCMeta):
|
|||
return self._opt_target
|
||||
|
||||
@abstractmethod
|
||||
def bulk_register(self, configs: List[dict], scores: List[float],
|
||||
status: Optional[List[Status]] = None) -> bool:
|
||||
def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]],
|
||||
status: Optional[Sequence[Status]] = None) -> bool:
|
||||
"""
|
||||
Pre-load the optimizer with the bulk data from previous experiments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
configs : List[dict]
|
||||
configs : Sequence[dict]
|
||||
Records of tunable values from other experiments.
|
||||
scores : List[float]
|
||||
scores : Sequence[float]
|
||||
Benchmark results from experiments that correspond to `configs`.
|
||||
status : Optional[List[float]]
|
||||
status : Optional[Sequence[float]]
|
||||
Status of the experiments that correspond to `configs`.
|
||||
|
||||
Returns
|
||||
|
@ -152,7 +154,7 @@ class Optimizer(metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def register(self, tunables: TunableGroups, status: Status,
|
||||
score: Union[float, dict] = None) -> float:
|
||||
score: Optional[Union[float, Dict[str, float]]] = None) -> Optional[float]:
|
||||
"""
|
||||
Register the observation for the given configuration.
|
||||
|
||||
|
@ -163,7 +165,7 @@ class Optimizer(metaclass=ABCMeta):
|
|||
Usually it's the same config that the `.suggest()` method returned.
|
||||
status : Status
|
||||
Final status of the experiment (e.g., SUCCEEDED or FAILED).
|
||||
score : Union[float, dict]
|
||||
score : Union[float, Dict[str, float]]
|
||||
A scalar or a dict with the final benchmark results.
|
||||
None if the experiment was not successful.
|
||||
|
||||
|
@ -178,7 +180,7 @@ class Optimizer(metaclass=ABCMeta):
|
|||
raise ValueError("Status and score must be consistent.")
|
||||
return self._get_score(status, score)
|
||||
|
||||
def _get_score(self, status: Status, score: Union[float, dict]) -> float:
|
||||
def _get_score(self, status: Status, score: Optional[Union[float, Dict[str, float]]]) -> Optional[float]:
|
||||
"""
|
||||
Extract a scalar benchmark score from the dataframe.
|
||||
Change the sign if we are maximizing.
|
||||
|
@ -187,7 +189,7 @@ class Optimizer(metaclass=ABCMeta):
|
|||
----------
|
||||
status : Status
|
||||
Final status of the experiment (e.g., SUCCEEDED or FAILED).
|
||||
score : Union[float, dict]
|
||||
score : Union[float, Dict[str, float]]
|
||||
A scalar or a dict with the final benchmark results.
|
||||
None if the experiment was not successful.
|
||||
|
||||
|
@ -198,6 +200,7 @@ class Optimizer(metaclass=ABCMeta):
|
|||
"""
|
||||
if not status.is_succeeded:
|
||||
return None
|
||||
assert score is not None
|
||||
if isinstance(score, dict):
|
||||
score = score[self._opt_target]
|
||||
return float(score) * self._opt_sign
|
||||
|
@ -210,7 +213,7 @@ class Optimizer(metaclass=ABCMeta):
|
|||
return self._iter <= self._max_iter
|
||||
|
||||
@abstractmethod
|
||||
def get_best_observation(self) -> Tuple[float, TunableGroups]:
|
||||
def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
|
||||
"""
|
||||
Get the best observation so far.
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@ Functions to convert TunableGroups to ConfigSpace for use with the mlos_core opt
|
|||
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ConfigSpace.hyperparameters import Hyperparameter
|
||||
from ConfigSpace import UniformIntegerHyperparameter
|
||||
from ConfigSpace import UniformFloatHyperparameter
|
||||
|
@ -21,7 +23,7 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _tunable_to_hyperparameter(
|
||||
tunable: Tunable, group_name: str = None, cost: int = 0) -> Hyperparameter:
|
||||
tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> Hyperparameter:
|
||||
"""
|
||||
Convert a single Tunable to an equivalent ConfigSpace Hyperparameter object.
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ A wrapper for mlos_core optimizers for mlos_bench.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from typing import Optional, Sequence, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
@ -47,17 +47,17 @@ class MlosCoreOptimizer(Optimizer):
|
|||
space_adapter_type=space_adapter_type,
|
||||
space_adapter_kwargs=space_adapter_config)
|
||||
|
||||
def bulk_register(self, configs: List[dict], scores: List[float],
|
||||
status: Optional[List[Status]] = None) -> bool:
|
||||
def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]],
|
||||
status: Optional[Sequence[Status]] = None) -> bool:
|
||||
if not super().bulk_register(configs, scores, status):
|
||||
return False
|
||||
tunables_names = list(self._tunables.get_param_values().keys())
|
||||
df_configs = pd.DataFrame(configs)[tunables_names]
|
||||
df_scores = pd.Series(scores, dtype=float) * self._opt_sign
|
||||
if status is not None:
|
||||
df_status_succ = pd.Series(status) == Status.SUCCEEDED
|
||||
df_configs = df_configs[df_status_succ]
|
||||
df_scores = df_scores[df_status_succ]
|
||||
df_status_ok = pd.Series(status) == Status.SUCCEEDED
|
||||
df_configs = df_configs[df_status_ok]
|
||||
df_scores = df_scores[df_status_ok]
|
||||
self._opt.register(df_configs, df_scores)
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
(score, _) = self.get_best_observation()
|
||||
|
@ -70,7 +70,7 @@ class MlosCoreOptimizer(Optimizer):
|
|||
return self._tunables.copy().assign(df_config.loc[0].to_dict())
|
||||
|
||||
def register(self, tunables: TunableGroups, status: Status,
|
||||
score: Union[float, dict] = None) -> float:
|
||||
score: Optional[Union[float, dict]] = None) -> Optional[float]:
|
||||
score = super().register(tunables, status, score)
|
||||
# TODO: mlos_core currently does not support registration of failed trials:
|
||||
if status.is_succeeded:
|
||||
|
@ -81,7 +81,7 @@ class MlosCoreOptimizer(Optimizer):
|
|||
self._iter += 1
|
||||
return score
|
||||
|
||||
def get_best_observation(self) -> Tuple[float, TunableGroups]:
|
||||
def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
|
||||
df_config = self._opt.get_best_observation()
|
||||
if len(df_config) == 0:
|
||||
return (None, None)
|
||||
|
|
|
@ -8,9 +8,11 @@ Mock optimizer for mlos_bench.
|
|||
|
||||
import random
|
||||
import logging
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.tunables.tunable import Tunable, TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
from mlos_bench.optimizer.base_optimizer import Optimizer
|
||||
|
@ -26,16 +28,16 @@ class MockOptimizer(Optimizer):
|
|||
def __init__(self, tunables: TunableGroups, config: dict):
|
||||
super().__init__(tunables, config)
|
||||
rnd = random.Random(config.get("seed", 42))
|
||||
self._random = {
|
||||
self._random: Dict[str, Callable[[Tunable], TunableValue]] = {
|
||||
"categorical": lambda tunable: rnd.choice(tunable.categorical_values),
|
||||
"float": lambda tunable: rnd.uniform(*tunable.range),
|
||||
"int": lambda tunable: rnd.randint(*tunable.range),
|
||||
}
|
||||
self._best_config = None
|
||||
self._best_score = None
|
||||
self._best_config: Optional[TunableGroups] = None
|
||||
self._best_score: Optional[float] = None
|
||||
|
||||
def bulk_register(self, configs: List[dict], scores: List[float],
|
||||
status: Optional[List[Status]] = None) -> bool:
|
||||
def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]],
|
||||
status: Optional[Sequence[Status]] = None) -> bool:
|
||||
if not super().bulk_register(configs, scores, status):
|
||||
return False
|
||||
if status is None:
|
||||
|
@ -60,15 +62,18 @@ class MockOptimizer(Optimizer):
|
|||
return tunables
|
||||
|
||||
def register(self, tunables: TunableGroups, status: Status,
|
||||
score: Union[float, dict] = None) -> float:
|
||||
score = super().register(tunables, status, score)
|
||||
if status.is_succeeded and (self._best_score is None or score < self._best_score):
|
||||
self._best_score = score
|
||||
score: Optional[Union[float, dict]] = None) -> Optional[float]:
|
||||
registered_score = super().register(tunables, status, score)
|
||||
if status.is_succeeded and (
|
||||
self._best_score is None or (registered_score is not None and registered_score < self._best_score)
|
||||
):
|
||||
self._best_score = registered_score
|
||||
self._best_config = tunables.copy()
|
||||
self._iter += 1
|
||||
return score
|
||||
return registered_score
|
||||
|
||||
def get_best_observation(self) -> Tuple[float, TunableGroups]:
|
||||
def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
|
||||
if self._best_score is None:
|
||||
return (None, None)
|
||||
assert self._best_config is not None
|
||||
return (self._best_score * self._opt_sign, self._best_config)
|
||||
|
|
|
@ -17,7 +17,7 @@ from mlos_bench.launcher import Launcher
|
|||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _main():
|
||||
def _main() -> None:
|
||||
|
||||
launcher = Launcher("mlos_bench run_bench")
|
||||
|
||||
|
|
|
@ -9,16 +9,20 @@ OS Autotune main optimization loop.
|
|||
See `--help` output for details.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import logging
|
||||
|
||||
from mlos_bench.launcher import Launcher
|
||||
from mlos_bench.optimizer import Optimizer
|
||||
from mlos_bench.environment import Status, Environment
|
||||
from mlos_bench.optimizer.base_optimizer import Optimizer
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _main():
|
||||
def _main() -> None:
|
||||
|
||||
launcher = Launcher("mlos_bench run_opt")
|
||||
|
||||
|
@ -40,7 +44,7 @@ def _main():
|
|||
_LOG.info("Final result: %s", result)
|
||||
|
||||
|
||||
def _optimize(env: Environment, opt: Optimizer, no_teardown: bool):
|
||||
def _optimize(env: Environment, opt: Optimizer, no_teardown: bool) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
|
||||
"""
|
||||
Main optimization loop.
|
||||
"""
|
||||
|
@ -57,7 +61,11 @@ def _optimize(env: Environment, opt: Optimizer, no_teardown: bool):
|
|||
continue
|
||||
|
||||
(status, output) = env.run() # Block and wait for the final result
|
||||
value = output['score'] if status.is_succeeded else None
|
||||
if status.is_succeeded:
|
||||
assert output is not None
|
||||
value = output['score']
|
||||
else:
|
||||
value = None
|
||||
|
||||
_LOG.info("Result: %s = %s :: %s", tunables, status, value)
|
||||
opt.register(tunables, status, value)
|
||||
|
|
|
@ -8,14 +8,11 @@ Services for implementing Environments for mlos_bench.
|
|||
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.base_fileshare import FileShareService
|
||||
|
||||
from mlos_bench.service.local.local_exec import LocalExecService
|
||||
from mlos_bench.service.config_persistence import ConfigPersistenceService
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Service',
|
||||
'LocalExecService',
|
||||
'FileShareService',
|
||||
'ConfigPersistenceService',
|
||||
'LocalExecService',
|
||||
]
|
||||
|
|
|
@ -11,11 +11,12 @@ import logging
|
|||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.fileshare_type import SupportsFileShareOps
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileShareService(Service, metaclass=ABCMeta):
|
||||
class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
|
||||
"""
|
||||
An abstract base of all file shares.
|
||||
"""
|
||||
|
@ -41,7 +42,7 @@ class FileShareService(Service, metaclass=ABCMeta):
|
|||
])
|
||||
|
||||
@abstractmethod
|
||||
def download(self, remote_path: str, local_path: str, recursive: bool = True):
|
||||
def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None:
|
||||
"""
|
||||
Downloads contents from a remote share path to a local path.
|
||||
|
||||
|
@ -57,10 +58,10 @@ class FileShareService(Service, metaclass=ABCMeta):
|
|||
if True (the default), download the entire directory tree.
|
||||
"""
|
||||
_LOG.info("Download from File Share %s recursively: %s -> %s",
|
||||
"" if recursive else "non-", remote_path, local_path)
|
||||
"" if recursive else "non", remote_path, local_path)
|
||||
|
||||
@abstractmethod
|
||||
def upload(self, local_path: str, remote_path: str, recursive: bool = True):
|
||||
def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None:
|
||||
"""
|
||||
Uploads contents from a local path to remote share path.
|
||||
|
||||
|
@ -75,4 +76,4 @@ class FileShareService(Service, metaclass=ABCMeta):
|
|||
if True (the default), upload the entire directory tree.
|
||||
"""
|
||||
_LOG.info("Upload to File Share %s recursively: %s -> %s",
|
||||
"" if recursive else "non-", local_path, remote_path)
|
||||
"" if recursive else "non", local_path, remote_path)
|
||||
|
|
|
@ -9,8 +9,9 @@ Base class for the service mix-ins.
|
|||
import json
|
||||
import logging
|
||||
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
from mlos_bench.service.types.config_loader_type import SupportsConfigLoading
|
||||
from mlos_bench.util import instantiate_from_config
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -22,7 +23,7 @@ class Service:
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def new(cls, class_name: str, config: dict, parent):
|
||||
def new(cls, class_name: str, config: dict, parent: Optional["Service"]) -> "Service":
|
||||
"""
|
||||
Factory method for a new service with a given config.
|
||||
|
||||
|
@ -46,7 +47,7 @@ class Service:
|
|||
"""
|
||||
return instantiate_from_config(cls, class_name, config, parent)
|
||||
|
||||
def __init__(self, config: dict = None, parent=None):
|
||||
def __init__(self, config: Optional[dict] = None, parent: Optional["Service"] = None):
|
||||
"""
|
||||
Create a new service with a given config.
|
||||
|
||||
|
@ -61,11 +62,15 @@ class Service:
|
|||
"""
|
||||
self.config = config or {}
|
||||
self._parent = parent
|
||||
self._services = {}
|
||||
self._services: Dict[str, Callable] = {}
|
||||
|
||||
if parent:
|
||||
self.register(parent.export())
|
||||
|
||||
self._config_loader_service: SupportsConfigLoading
|
||||
if parent and isinstance(parent, SupportsConfigLoading):
|
||||
self._config_loader_service = parent
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Service: %s Config:\n%s",
|
||||
self.__class__.__name__, json.dumps(self.config, indent=2))
|
||||
|
@ -73,7 +78,7 @@ class Service:
|
|||
self.__class__.__name__,
|
||||
[] if parent is None else list(parent._services.keys()))
|
||||
|
||||
def register(self, services):
|
||||
def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None:
|
||||
"""
|
||||
Register new mix-in services.
|
||||
|
||||
|
|
|
@ -12,24 +12,25 @@ import os
|
|||
import json # For logging only
|
||||
import logging
|
||||
|
||||
from typing import Optional, List
|
||||
from typing import List, Iterable, Optional, Union
|
||||
|
||||
import json5 # To read configs with comments and other JSON5 syntax features
|
||||
|
||||
from mlos_bench.util import prepare_class_load
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.config_loader_type import SupportsConfigLoading
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigPersistenceService(Service):
|
||||
class ConfigPersistenceService(Service, SupportsConfigLoading):
|
||||
"""
|
||||
Collection of methods to deserialize the Environment, Service, and TunableGroups objects.
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict = None, parent: Service = None):
|
||||
def __init__(self, config: Optional[dict] = None, parent: Optional[Service] = None):
|
||||
"""
|
||||
Create a new instance of config persistence service.
|
||||
|
||||
|
@ -58,7 +59,7 @@ class ConfigPersistenceService(Service):
|
|||
])
|
||||
|
||||
def resolve_path(self, file_path: str,
|
||||
extra_paths: Optional[List[str]] = None) -> str:
|
||||
extra_paths: Optional[Iterable[str]] = None) -> str:
|
||||
"""
|
||||
Prepend the suitable `_config_path` to `path` if the latter is not absolute.
|
||||
If `_config_path` is `None` or `path` is absolute, return `path` as is.
|
||||
|
@ -67,7 +68,7 @@ class ConfigPersistenceService(Service):
|
|||
----------
|
||||
file_path : str
|
||||
Path to the input config file.
|
||||
extra_paths : List[str]
|
||||
extra_paths : Iterable[str]
|
||||
Additional directories to prepend to the list of search paths.
|
||||
|
||||
Returns
|
||||
|
@ -86,7 +87,7 @@ class ConfigPersistenceService(Service):
|
|||
_LOG.debug("Path not resolved: %s", file_path)
|
||||
return file_path
|
||||
|
||||
def load_config(self, json_file_name: str) -> dict:
|
||||
def load_config(self, json_file_name: str) -> Union[dict, List[dict]]:
|
||||
"""
|
||||
Load JSON config file. Search for a file relative to `_config_path`
|
||||
if the input path is not absolute.
|
||||
|
@ -99,18 +100,18 @@ class ConfigPersistenceService(Service):
|
|||
|
||||
Returns
|
||||
-------
|
||||
config : dict
|
||||
config : Union[dict, List[dict]]
|
||||
Free-format dictionary that contains the configuration.
|
||||
"""
|
||||
json_file_name = self.resolve_path(json_file_name)
|
||||
_LOG.info("Load config: %s", json_file_name)
|
||||
with open(json_file_name, mode='r', encoding='utf-8') as fh_json:
|
||||
return json5.load(fh_json)
|
||||
return json5.load(fh_json) # type: ignore[no-any-return]
|
||||
|
||||
def build_environment(self, config: dict,
|
||||
global_config: dict = None,
|
||||
tunables: TunableGroups = None,
|
||||
service: Service = None) -> Environment:
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None) -> Environment:
|
||||
"""
|
||||
Factory method for a new environment with a given config.
|
||||
|
||||
|
@ -146,15 +147,17 @@ class ConfigPersistenceService(Service):
|
|||
tunables = self.load_tunables(env_tunables_path, tunables)
|
||||
|
||||
_LOG.debug("Creating env: %s :: %s", env_name, env_class)
|
||||
env = Environment.new(env_name, env_class, env_config, global_config, tunables, service)
|
||||
env = Environment.new(env_name=env_name, class_name=env_class,
|
||||
config=env_config, global_config=global_config,
|
||||
tunables=tunables, service=service)
|
||||
|
||||
_LOG.info("Created env: %s :: %s", env_name, env)
|
||||
return env
|
||||
|
||||
@classmethod
|
||||
def _build_standalone_service(cls, config: dict,
|
||||
global_config: dict = None,
|
||||
parent: Service = None) -> Service:
|
||||
global_config: Optional[dict] = None,
|
||||
parent: Optional[Service] = None) -> Service:
|
||||
"""
|
||||
Factory method for a new service with a given config.
|
||||
|
||||
|
@ -180,9 +183,9 @@ class ConfigPersistenceService(Service):
|
|||
return service
|
||||
|
||||
@classmethod
|
||||
def _build_composite_service(cls, config_list: List[dict],
|
||||
global_config: dict = None,
|
||||
parent: Service = None) -> Service:
|
||||
def _build_composite_service(cls, config_list: Iterable[dict],
|
||||
global_config: Optional[dict] = None,
|
||||
parent: Optional[Service] = None) -> Service:
|
||||
"""
|
||||
Factory method for a new service with a given config.
|
||||
|
||||
|
@ -218,8 +221,8 @@ class ConfigPersistenceService(Service):
|
|||
return service
|
||||
|
||||
@classmethod
|
||||
def build_service(cls, config: List[dict], global_config: dict = None,
|
||||
parent: Service = None) -> Service:
|
||||
def build_service(cls, config: Union[dict, List[dict]], global_config: Optional[dict] = None,
|
||||
parent: Optional[Service] = None) -> Service:
|
||||
"""
|
||||
Factory method for a new service with a given config.
|
||||
|
||||
|
@ -252,7 +255,7 @@ class ConfigPersistenceService(Service):
|
|||
return cls._build_composite_service(config, global_config, parent)
|
||||
|
||||
@staticmethod
|
||||
def build_tunables(config: dict, parent: TunableGroups = None) -> TunableGroups:
|
||||
def build_tunables(config: dict, parent: Optional[TunableGroups] = None) -> TunableGroups:
|
||||
"""
|
||||
Create a new collection of tunable parameters.
|
||||
|
||||
|
@ -280,8 +283,8 @@ class ConfigPersistenceService(Service):
|
|||
groups.update(TunableGroups(config))
|
||||
return groups
|
||||
|
||||
def load_environment(self, json_file_name: str, global_config: dict = None,
|
||||
tunables: TunableGroups = None, service: Service = None) -> Environment:
|
||||
def load_environment(self, json_file_name: str, global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None, service: Optional[Service] = None) -> Environment:
|
||||
"""
|
||||
Load and build new environment from the config file.
|
||||
|
||||
|
@ -302,11 +305,12 @@ class ConfigPersistenceService(Service):
|
|||
A new benchmarking environment.
|
||||
"""
|
||||
config = self.load_config(json_file_name)
|
||||
assert isinstance(config, dict)
|
||||
return self.build_environment(config, global_config, tunables, service)
|
||||
|
||||
def load_environment_list(
|
||||
self, json_file_name: str, global_config: dict = None,
|
||||
tunables: TunableGroups = None, service: Service = None) -> List[Environment]:
|
||||
self, json_file_name: str, global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None, service: Optional[Service] = None) -> List[Environment]:
|
||||
"""
|
||||
Load and build a list of environments from the config file.
|
||||
|
||||
|
@ -335,8 +339,8 @@ class ConfigPersistenceService(Service):
|
|||
for config in config_list
|
||||
]
|
||||
|
||||
def load_services(self, json_file_names: List[str],
|
||||
global_config: dict = None, parent: Service = None) -> Service:
|
||||
def load_services(self, json_file_names: Iterable[str],
|
||||
global_config: Optional[dict] = None, parent: Optional[Service] = None) -> Service:
|
||||
"""
|
||||
Read the configuration files and bundle all service methods
|
||||
from those configs into a single Service object.
|
||||
|
@ -363,8 +367,8 @@ class ConfigPersistenceService(Service):
|
|||
ConfigPersistenceService.build_service(config, global_config, service).export())
|
||||
return service
|
||||
|
||||
def load_tunables(self, json_file_names: List[str],
|
||||
parent: TunableGroups = None) -> TunableGroups:
|
||||
def load_tunables(self, json_file_names: Iterable[str],
|
||||
parent: Optional[TunableGroups] = None) -> TunableGroups:
|
||||
"""
|
||||
Load a collection of tunable parameters from JSON files.
|
||||
|
||||
|
@ -386,5 +390,6 @@ class ConfigPersistenceService(Service):
|
|||
groups.update(parent)
|
||||
for fname in json_file_names:
|
||||
config = self.load_config(fname)
|
||||
assert isinstance(config, dict)
|
||||
groups.update(TunableGroups(config))
|
||||
return groups
|
||||
|
|
|
@ -15,21 +15,25 @@ import shlex
|
|||
import subprocess
|
||||
import logging
|
||||
|
||||
from typing import Optional, Tuple, List, Dict
|
||||
from typing import Dict, Iterable, Mapping, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.local_exec_type import SupportsLocalExec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalExecService(Service):
|
||||
class LocalExecService(Service, SupportsLocalExec):
|
||||
"""
|
||||
Collection of methods to run scripts and commands in an external process
|
||||
on the node acting as the scheduler. Can be useful for data processing
|
||||
due to reduced dependency management complications vs the target environment.
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict = None, parent: Service = None):
|
||||
def __init__(self, config: Optional[dict] = None, parent: Optional[Service] = None):
|
||||
"""
|
||||
Create a new instance of a service to run scripts locally.
|
||||
|
||||
|
@ -45,7 +49,7 @@ class LocalExecService(Service):
|
|||
self._temp_dir = self.config.get("temp_dir")
|
||||
self.register([self.temp_dir_context, self.local_exec])
|
||||
|
||||
def temp_dir_context(self, path: Optional[str] = None):
|
||||
def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
|
||||
"""
|
||||
Create a temp directory or use the provided path.
|
||||
|
||||
|
@ -63,8 +67,8 @@ class LocalExecService(Service):
|
|||
return tempfile.TemporaryDirectory()
|
||||
return contextlib.nullcontext(path or self._temp_dir)
|
||||
|
||||
def local_exec(self, script_lines: List[str],
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
def local_exec(self, script_lines: Iterable[str],
|
||||
env: Optional[Mapping[str, "TunableValue"]] = None,
|
||||
cwd: Optional[str] = None,
|
||||
return_on_error: bool = False) -> Tuple[int, str, str]:
|
||||
"""
|
||||
|
@ -72,10 +76,10 @@ class LocalExecService(Service):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
script_lines : List[str]
|
||||
script_lines : Iterable[str]
|
||||
Lines of the script to run locally.
|
||||
Treat every line as a separate command to run.
|
||||
env : Dict[str, str]
|
||||
env : Mapping[str, Union[int, float, str]]
|
||||
Environment variables (optional).
|
||||
cwd : str
|
||||
Work directory to run the script at.
|
||||
|
@ -110,7 +114,8 @@ class LocalExecService(Service):
|
|||
return (return_code, stdout, stderr)
|
||||
|
||||
def _local_exec_script(self, script_line: str,
|
||||
env: Dict[str, str], cwd: str) -> Tuple[int, str, str]:
|
||||
env_params: Optional[Mapping[str, "TunableValue"]],
|
||||
cwd: str) -> Tuple[int, str, str]:
|
||||
"""
|
||||
Execute the script from `script_path` in a local process.
|
||||
|
||||
|
@ -118,9 +123,9 @@ class LocalExecService(Service):
|
|||
----------
|
||||
script_line : str
|
||||
Line of the script to tun in the local process.
|
||||
args : List[str]
|
||||
args : Iterable[str]
|
||||
Command line arguments for the script.
|
||||
env : Dict[str, str]
|
||||
env_params : Mapping[str, Union[int, float, str]]
|
||||
Environment variables.
|
||||
cwd : str
|
||||
Work directory to run the script at.
|
||||
|
@ -131,7 +136,7 @@ class LocalExecService(Service):
|
|||
A 3-tuple of return code, stdout, and stderr of the script process.
|
||||
"""
|
||||
cmd = shlex.split(script_line)
|
||||
script_path = self._parent.resolve_path(cmd[0])
|
||||
script_path = self._config_loader_service.resolve_path(cmd[0])
|
||||
if os.path.exists(script_path):
|
||||
script_path = os.path.abspath(script_path)
|
||||
|
||||
|
@ -139,8 +144,10 @@ class LocalExecService(Service):
|
|||
if script_path.strip().lower().endswith(".py"):
|
||||
cmd = [sys.executable] + cmd
|
||||
|
||||
if env:
|
||||
env = {key: str(val) for (key, val) in env.items()}
|
||||
env: Dict[str, str] = {}
|
||||
if env_params:
|
||||
env = {key: str(val) for (key, val) in env_params.items()}
|
||||
|
||||
if sys.platform == 'win32':
|
||||
# A hack to run Python on Windows with env variables set:
|
||||
env_copy = os.environ.copy()
|
||||
|
|
|
@ -13,7 +13,8 @@ from typing import Set
|
|||
|
||||
from azure.storage.fileshare import ShareClient
|
||||
|
||||
from mlos_bench.service import Service, FileShareService
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.base_fileshare import FileShareService
|
||||
from mlos_bench.util import check_required_params
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -57,7 +58,7 @@ class AzureFileShareService(FileShareService):
|
|||
credential=config["storageAccountKey"],
|
||||
)
|
||||
|
||||
def download(self, remote_path: str, local_path: str, recursive: bool = True):
|
||||
def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None:
|
||||
super().download(remote_path, local_path, recursive)
|
||||
dir_client = self._share_client.get_directory_client(remote_path)
|
||||
if dir_client.exists():
|
||||
|
@ -75,13 +76,13 @@ class AzureFileShareService(FileShareService):
|
|||
file_client = self._share_client.get_file_client(remote_path)
|
||||
data = file_client.download_file()
|
||||
with open(local_path, "wb") as output_file:
|
||||
data.readinto(output_file)
|
||||
data.readinto(output_file) # type: ignore[no-untyped-call]
|
||||
|
||||
def upload(self, local_path: str, remote_path: str, recursive: bool = True):
|
||||
def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None:
|
||||
super().upload(local_path, remote_path, recursive)
|
||||
self._upload(local_path, remote_path, recursive, set())
|
||||
|
||||
def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]):
|
||||
def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None:
|
||||
"""
|
||||
Upload contents from a local path to an Azure file share.
|
||||
This method is called from `.upload()` above. We need it to avoid exposing
|
||||
|
@ -124,7 +125,7 @@ class AzureFileShareService(FileShareService):
|
|||
with open(local_path, "rb") as file_data:
|
||||
file_client.upload_file(file_data)
|
||||
|
||||
def _remote_makedirs(self, remote_path: str):
|
||||
def _remote_makedirs(self, remote_path: str) -> None:
|
||||
"""
|
||||
Create remote directories for the entire path.
|
||||
Succeeds even some or all directories along the path already exist.
|
||||
|
|
|
@ -10,18 +10,20 @@ import json
|
|||
import time
|
||||
import logging
|
||||
|
||||
from typing import Any, Tuple, List, Dict, Callable
|
||||
from typing import Callable, Iterable, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from mlos_bench.environment import Status
|
||||
from mlos_bench.service import Service
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec
|
||||
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
|
||||
from mlos_bench.util import check_required_params
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
||||
class AzureVMService(Service, SupportsVMOps, SupportsRemoteExec): # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
Helper methods to manage VMs on Azure.
|
||||
"""
|
||||
|
@ -134,7 +136,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
self.vm_start,
|
||||
self.vm_stop,
|
||||
self.vm_deprovision,
|
||||
self.vm_reboot,
|
||||
self.vm_restart,
|
||||
self.remote_exec,
|
||||
self.get_remote_exec_results
|
||||
])
|
||||
|
@ -144,7 +146,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
self._poll_timeout = float(config.get("pollTimeout", AzureVMService._POLL_TIMEOUT))
|
||||
self._request_timeout = float(config.get("requestTimeout", AzureVMService._REQUEST_TIMEOUT))
|
||||
|
||||
self._deploy_template = self._parent.load_config(config['deployTemplatePath'])
|
||||
self._deploy_template = self._config_loader_service.load_config(config['deployTemplatePath'])
|
||||
|
||||
self._url_deploy = AzureVMService._URL_DEPLOY.format(
|
||||
subscription=config["subscription"],
|
||||
|
@ -188,7 +190,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_arm_parameters(json_data):
|
||||
def _extract_arm_parameters(json_data: dict) -> dict:
|
||||
"""
|
||||
Extract parameters from the ARM Template REST response JSON.
|
||||
|
||||
|
@ -203,7 +205,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
if val.get("value") is not None
|
||||
}
|
||||
|
||||
def _azure_vm_post_helper(self, url: str) -> Tuple[Status, Dict]:
|
||||
def _azure_vm_post_helper(self, url: str) -> Tuple[Status, dict]:
|
||||
"""
|
||||
General pattern for performing an action on an Azure VM via its REST API.
|
||||
|
||||
|
@ -237,7 +239,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
elif "Location" in response.headers:
|
||||
result["asyncResultsUrl"] = response.headers.get("Location")
|
||||
if "Retry-After" in response.headers:
|
||||
result["pollInterval"] = float(response.headers["Retry-After"])
|
||||
result["pollInterval"] = str(float(response.headers["Retry-After"]))
|
||||
|
||||
return (Status.PENDING, result)
|
||||
else:
|
||||
|
@ -245,7 +247,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
# _LOG.error("Bad Request:\n%s", response.request.body)
|
||||
return (Status.FAILED, {})
|
||||
|
||||
def check_vm_operation_status(self, params: Dict) -> Tuple[Status, Dict]:
|
||||
def check_vm_operation_status(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Checks the status of a pending operation on an Azure VM.
|
||||
|
||||
|
@ -285,7 +287,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
_LOG.error("Response: %s :: %s", response, response.text)
|
||||
return Status.FAILED, {}
|
||||
|
||||
def wait_vm_deployment(self, is_setup: bool, params: Dict[str, Any]) -> Tuple[Status, Dict]:
|
||||
def wait_vm_deployment(self, is_setup: bool, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED.
|
||||
Return TIMED_OUT when timing out.
|
||||
|
@ -308,7 +310,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
"provision" if is_setup else "deprovision")
|
||||
return self._wait_while(self._check_deployment, Status.PENDING, params)
|
||||
|
||||
def wait_vm_operation(self, params: Dict[str, Any]) -> Tuple[Status, Dict]:
|
||||
def wait_vm_operation(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED.
|
||||
Return TIMED_OUT when timing out.
|
||||
|
@ -330,8 +332,8 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
_LOG.info("Wait for operation on VM %s", self.config["vmName"])
|
||||
return self._wait_while(self.check_vm_operation_status, Status.RUNNING, params)
|
||||
|
||||
def _wait_while(self, func: Callable[[Dict[str, Any]], Tuple[Status, Dict]],
|
||||
loop_status: Status, params: Dict[str, Any]) -> Tuple[Status, Dict]:
|
||||
def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]],
|
||||
loop_status: Status, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Invoke `func` periodically while the status is equal to `loop_status`.
|
||||
Return TIMED_OUT when timing out.
|
||||
|
@ -377,7 +379,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
_LOG.warning("Request timed out: %s", params)
|
||||
return (Status.TIMED_OUT, {})
|
||||
|
||||
def _check_deployment(self, _params: Dict) -> Tuple[Status, Dict]:
|
||||
def _check_deployment(self, _params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Check if Azure deployment exists.
|
||||
Return SUCCEEDED if true, PENDING otherwise.
|
||||
|
@ -409,7 +411,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
_LOG.error("Response: %s :: %s", response, response.text)
|
||||
return (Status.FAILED, {})
|
||||
|
||||
def vm_provision(self, params: Dict) -> Tuple[Status, Dict]:
|
||||
def vm_provision(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Check if Azure VM is ready. Deploy a new VM, if necessary.
|
||||
|
||||
|
@ -464,7 +466,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
# _LOG.error("Bad Request:\n%s", response.request.body)
|
||||
return (Status.FAILED, {})
|
||||
|
||||
def vm_start(self, params: Dict) -> Tuple[Status, Dict]:
|
||||
def vm_start(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Start the VM on Azure.
|
||||
|
||||
|
@ -482,7 +484,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
_LOG.info("Start VM: %s :: %s", self.config["vmName"], params)
|
||||
return self._azure_vm_post_helper(self._url_start)
|
||||
|
||||
def vm_stop(self) -> Tuple[Status, Dict]:
|
||||
def vm_stop(self) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Stops the VM on Azure by initiating a graceful shutdown.
|
||||
|
||||
|
@ -495,7 +497,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
_LOG.info("Stop VM: %s", self.config["vmName"])
|
||||
return self._azure_vm_post_helper(self._url_stop)
|
||||
|
||||
def vm_deprovision(self) -> Tuple[Status, Dict]:
|
||||
def vm_deprovision(self) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Deallocates the VM on Azure by shutting it down then releasing the compute resources.
|
||||
|
||||
|
@ -508,7 +510,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
_LOG.info("Deprovision VM: %s", self.config["vmName"])
|
||||
return self._azure_vm_post_helper(self._url_stop)
|
||||
|
||||
def vm_reboot(self) -> Tuple[Status, Dict]:
|
||||
def vm_restart(self) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Reboot the VM on Azure by initiating a graceful shutdown.
|
||||
|
||||
|
@ -521,13 +523,13 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
_LOG.info("Reboot VM: %s", self.config["vmName"])
|
||||
return self._azure_vm_post_helper(self._url_reboot)
|
||||
|
||||
def remote_exec(self, script: List[str], params: Dict[str, Any]) -> Tuple[Status, Dict]:
|
||||
def remote_exec(self, script: Iterable[str], params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Run a command on Azure VM.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
script : List[str]
|
||||
script : Iterable[str]
|
||||
A list of lines to execute as a script on a remote VM.
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of parameters.
|
||||
|
@ -546,7 +548,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
|
||||
json_req = {
|
||||
"commandId": "RunShellScript",
|
||||
"script": script,
|
||||
"script": list(script),
|
||||
"parameters": [{"name": key, "value": val} for (key, val) in params.items()]
|
||||
}
|
||||
|
||||
|
@ -576,7 +578,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
|
|||
# _LOG.error("Bad Request:\n%s", response.request.body)
|
||||
return (Status.FAILED, {})
|
||||
|
||||
def get_remote_exec_results(self, params: Dict) -> Tuple[Status, Dict]:
|
||||
def get_remote_exec_results(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Get the results of the asynchronously running command.
|
||||
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
# Service Type Interfaces
|
||||
|
||||
Service loading in `mlos_bench` uses a [mix-in](https://en.wikipedia.org/wiki/Mixin#In_Python) approach to combine the functionality of multiple classes specified at runtime through config files into a single class.
|
||||
|
||||
This can make type checking in the `Environments` that use those `Services` a little tricky, both for developers and checking tools.
|
||||
|
||||
To address this we define `@runtime_checkable` decorated [`Protocols`](https://peps.python.org/pep-0544/) ("interfaces" in other languages) to declare the expected behavior of the `Services` that are loaded at runtime in the `Environments` that use them.
|
|
@ -0,0 +1,22 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Service types for implementing declaring Service behavior for Environments to use in mlos_bench.
|
||||
"""
|
||||
|
||||
from mlos_bench.service.types.config_loader_type import SupportsConfigLoading
|
||||
from mlos_bench.service.types.fileshare_type import SupportsFileShareOps
|
||||
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
|
||||
from mlos_bench.service.types.local_exec_type import SupportsLocalExec
|
||||
from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec
|
||||
|
||||
|
||||
__all__ = [
|
||||
'SupportsConfigLoading',
|
||||
'SupportsFileShareOps',
|
||||
'SupportsVMOps',
|
||||
'SupportsLocalExec',
|
||||
'SupportsRemoteExec',
|
||||
]
|
|
@ -0,0 +1,111 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for helper functions to lookup and load configs.
|
||||
"""
|
||||
|
||||
from typing import List, Iterable, Optional, Union, Protocol, runtime_checkable, TYPE_CHECKING
|
||||
|
||||
|
||||
# Avoid's circular import issues.
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsConfigLoading(Protocol):
|
||||
"""
|
||||
Protocol interface for helper functions to lookup and load configs.
|
||||
"""
|
||||
|
||||
def resolve_path(self, file_path: str,
|
||||
extra_paths: Optional[Iterable[str]] = None) -> str:
|
||||
"""
|
||||
Prepend the suitable `_config_path` to `path` if the latter is not absolute.
|
||||
If `_config_path` is `None` or `path` is absolute, return `path` as is.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : str
|
||||
Path to the input config file.
|
||||
extra_paths : Iterable[str]
|
||||
Additional directories to prepend to the list of search paths.
|
||||
|
||||
Returns
|
||||
-------
|
||||
path : str
|
||||
An actual path to the config or script.
|
||||
"""
|
||||
|
||||
def load_config(self, json_file_name: str) -> Union[dict, List[dict]]:
|
||||
"""
|
||||
Load JSON config file. Search for a file relative to `_config_path`
|
||||
if the input path is not absolute.
|
||||
This method is exported to be used as a service.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
json_file_name : str
|
||||
Path to the input config file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
config : Union[dict, List[dict]]
|
||||
Free-format dictionary that contains the configuration.
|
||||
"""
|
||||
|
||||
def build_environment(self, config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional["TunableGroups"] = None,
|
||||
service: Optional["Service"] = None) -> "Environment":
|
||||
"""
|
||||
Factory method for a new environment with a given config.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : dict
|
||||
A dictionary with three mandatory fields:
|
||||
"name": Human-readable string describing the environment;
|
||||
"class": FQN of a Python class to instantiate;
|
||||
"config": Free-format dictionary to pass to the constructor.
|
||||
global_config : dict
|
||||
Global parameters to add to the environment config.
|
||||
tunables : TunableGroups
|
||||
A collection of groups of tunable parameters for all environments.
|
||||
service: Service
|
||||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
|
||||
Returns
|
||||
-------
|
||||
env : Environment
|
||||
An instance of the `Environment` class initialized with `config`.
|
||||
"""
|
||||
|
||||
def load_environment_list(
|
||||
self, json_file_name: str, global_config: Optional[dict] = None,
|
||||
tunables: Optional["TunableGroups"] = None, service: Optional["Service"] = None) -> List["Environment"]:
|
||||
"""
|
||||
Load and build a list of environments from the config file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
json_file_name : str
|
||||
The environment JSON configuration file.
|
||||
Can contain either one environment or a list of environments.
|
||||
global_config : dict
|
||||
Global parameters to add to the environment config.
|
||||
tunables : TunableGroups
|
||||
An optional collection of tunables to add to the environment.
|
||||
service : Service
|
||||
An optional reference of the parent service to mix in.
|
||||
|
||||
Returns
|
||||
-------
|
||||
env : List[Environment]
|
||||
A list of new benchmarking environments.
|
||||
"""
|
|
@ -0,0 +1,47 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for file share operations.
|
||||
"""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsFileShareOps(Protocol):
|
||||
"""
|
||||
Protocol interface for file share operations.
|
||||
"""
|
||||
|
||||
def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None:
|
||||
"""
|
||||
Downloads contents from a remote share path to a local path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
remote_path : str
|
||||
Path to download from the remote file share, a file if recursive=False
|
||||
or a directory if recursive=True.
|
||||
local_path : str
|
||||
Path to store the downloaded content to.
|
||||
recursive : bool
|
||||
If False, ignore the subdirectories;
|
||||
if True (the default), download the entire directory tree.
|
||||
"""
|
||||
|
||||
def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None:
|
||||
"""
|
||||
Uploads contents from a local path to remote share path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
local_path : str
|
||||
Path to the local directory to upload contents from.
|
||||
remote_path : str
|
||||
Path in the remote file share to store the uploaded content to.
|
||||
recursive : bool
|
||||
If False, ignore the subdirectories;
|
||||
if True (the default), upload the entire directory tree.
|
||||
"""
|
|
@ -0,0 +1,68 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for Service types that provide helper functions to run
|
||||
scripts and commands locally on the scheduler side.
|
||||
"""
|
||||
|
||||
from typing import Iterable, Mapping, Optional, Tuple, Union, Protocol, runtime_checkable
|
||||
|
||||
import tempfile
|
||||
import contextlib
|
||||
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsLocalExec(Protocol):
|
||||
"""
|
||||
Protocol interface for a collection of methods to run scripts and commands
|
||||
in an external process on the node acting as the scheduler. Can be useful
|
||||
for data processing due to reduced dependency management complications vs
|
||||
the target environment.
|
||||
Used in LocalEnv and provided by LocalExecService.
|
||||
"""
|
||||
|
||||
def local_exec(self, script_lines: Iterable[str],
|
||||
env: Optional[Mapping[str, TunableValue]] = None,
|
||||
cwd: Optional[str] = None,
|
||||
return_on_error: bool = False) -> Tuple[int, str, str]:
|
||||
"""
|
||||
Execute the script lines from `script_lines` in a local process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
script_lines : Iterable[str]
|
||||
Lines of the script to run locally.
|
||||
Treat every line as a separate command to run.
|
||||
env : Mapping[str, Union[int, float, str]]
|
||||
Environment variables (optional).
|
||||
cwd : str
|
||||
Work directory to run the script at.
|
||||
If omitted, use `temp_dir` or create a temporary dir.
|
||||
return_on_error : bool
|
||||
If True, stop running script lines on first non-zero return code.
|
||||
The default is False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(return_code, stdout, stderr) : (int, str, str)
|
||||
A 3-tuple of return code, stdout, and stderr of the script process.
|
||||
"""
|
||||
|
||||
def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
|
||||
"""
|
||||
Create a temp directory or use the provided path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
A path to the temporary directory. Create a new one if None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
temp_dir_context : TemporaryDirectory
|
||||
Temporary directory context to use in the `with` clause.
|
||||
"""
|
|
@ -0,0 +1,59 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for Service types that provide helper functions to run
|
||||
scripts on a remote host OS.
|
||||
"""
|
||||
|
||||
from typing import Iterable, Tuple, Protocol, runtime_checkable
|
||||
|
||||
|
||||
from mlos_bench.environment.status import Status
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsRemoteExec(Protocol):
|
||||
"""
|
||||
Protocol interface for Service types that provide helper functions to run
|
||||
scripts on a remote host OS.
|
||||
"""
|
||||
|
||||
def remote_exec(self, script: Iterable[str], params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Run a command on remote host OS.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
script : Iterable[str]
|
||||
A list of lines to execute as a script on a remote VM.
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of parameters.
|
||||
They usually come from `const_args` and `tunable_params`
|
||||
properties of the Environment.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict)
|
||||
A pair of Status and result.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
|
||||
def get_remote_exec_results(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Get the results of the asynchronously running command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of tunable parameters.
|
||||
Must have the "asyncResultsUrl" key to get the results.
|
||||
If the key is not present, return Status.PENDING.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict)
|
||||
A pair of Status and result.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
|
||||
"""
|
|
@ -0,0 +1,125 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for VM provisioning operations.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Protocol, runtime_checkable
|
||||
|
||||
from mlos_bench.environment.status import Status
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsVMOps(Protocol):
|
||||
"""
|
||||
Protocol interface for VM provisioning operations.
|
||||
"""
|
||||
|
||||
def vm_provision(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Check if VM is ready. Deploy a new VM, if necessary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of tunable parameters.
|
||||
VMEnv tunables are variable parameters that, together with the
|
||||
VMEnv configuration, are sufficient to provision a VM.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict={})
|
||||
A pair of Status and result. The result is always {}.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
|
||||
def wait_vm_deployment(self, is_setup: bool, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED.
|
||||
Return TIMED_OUT when timing out.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
is_setup : bool
|
||||
If True, wait for VM being deployed; otherwise, wait for successful deprovisioning.
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of tunable parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict)
|
||||
A pair of Status and result.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
|
||||
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
|
||||
"""
|
||||
|
||||
def vm_start(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Start a VM.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of tunable parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict={})
|
||||
A pair of Status and result. The result is always {}.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
|
||||
def vm_stop(self) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Stops the VM by initiating a graceful shutdown.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict={})
|
||||
A pair of Status and result. The result is always {}.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
|
||||
def vm_restart(self) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Restarts the VM by initiating a graceful shutdown.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict={})
|
||||
A pair of Status and result. The result is always {}.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
|
||||
def vm_deprovision(self) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Deallocates the VM by shutting it down then releasing the compute resources.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict={})
|
||||
A pair of Status and result. The result is always {}.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
|
||||
def wait_vm_operation(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED.
|
||||
Return TIMED_OUT when timing out.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params: dict
|
||||
Flat dictionary of (key, value) pairs of tunable parameters.
|
||||
Must have the "asyncResultsUrl" key to get the results.
|
||||
If the key is not present, return Status.PENDING.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : (Status, dict)
|
||||
A pair of Status and result.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
|
||||
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
|
||||
"""
|
|
@ -0,0 +1,8 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tests for mlos_bench.
|
||||
Used to make mypy happy about multiple conftest.py modules.
|
||||
"""
|
|
@ -72,7 +72,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv:
|
|||
Test fixture for MockEnv.
|
||||
"""
|
||||
return MockEnv(
|
||||
"Test Env",
|
||||
name="Test Env",
|
||||
config={
|
||||
"tunable_groups": ["provision", "boot", "kernel"],
|
||||
"seed": 13,
|
||||
|
@ -88,7 +88,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv:
|
|||
Test fixture for MockEnv.
|
||||
"""
|
||||
return MockEnv(
|
||||
"Test Env No Noise",
|
||||
name="Test Env No Noise",
|
||||
config={
|
||||
"tunable_groups": ["provision", "boot", "kernel"],
|
||||
"range": [60, 120]
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tests for mlos_bench.environment.
|
||||
Used to make mypy happy about multiple conftest.py modules.
|
||||
"""
|
|
@ -8,25 +8,27 @@ Unit tests for mock benchmark environment.
|
|||
|
||||
import pytest
|
||||
|
||||
from mlos_bench.environment import MockEnv
|
||||
from mlos_bench.environment.mock_env import MockEnv
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
|
||||
def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups):
|
||||
def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Check the default values of the mock environment.
|
||||
"""
|
||||
assert mock_env.setup(tunable_groups)
|
||||
(status, data) = mock_env.run()
|
||||
assert status.is_succeeded
|
||||
assert data is not None
|
||||
assert data["score"] == pytest.approx(78.45, 0.01)
|
||||
# Second time, results should differ because of the noise.
|
||||
(status, data) = mock_env.run()
|
||||
assert status.is_succeeded
|
||||
assert data is not None
|
||||
assert data["score"] == pytest.approx(98.21, 0.01)
|
||||
|
||||
|
||||
def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups):
|
||||
def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Check the default values of the mock environment.
|
||||
"""
|
||||
|
@ -35,6 +37,7 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr
|
|||
# Noise-free results should be the same every time.
|
||||
(status, data) = mock_env_no_noise.run()
|
||||
assert status.is_succeeded
|
||||
assert data is not None
|
||||
assert data["score"] == pytest.approx(80.11, 0.01)
|
||||
|
||||
|
||||
|
@ -51,7 +54,7 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr
|
|||
}, 79.1),
|
||||
])
|
||||
def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups,
|
||||
tunable_values: dict, expected_score: float):
|
||||
tunable_values: dict, expected_score: float) -> None:
|
||||
"""
|
||||
Check the benchmark values of the mock environment after the assignment.
|
||||
"""
|
||||
|
@ -59,6 +62,7 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups,
|
|||
assert mock_env.setup(tunable_groups)
|
||||
(status, data) = mock_env.run()
|
||||
assert status.is_succeeded
|
||||
assert data is not None
|
||||
assert data["score"] == pytest.approx(expected_score, 0.01)
|
||||
|
||||
|
||||
|
@ -76,7 +80,7 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups,
|
|||
])
|
||||
def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv,
|
||||
tunable_groups: TunableGroups,
|
||||
tunable_values: dict, expected_score: float):
|
||||
tunable_values: dict, expected_score: float) -> None:
|
||||
"""
|
||||
Check the benchmark values of the noiseless mock environment after the assignment.
|
||||
"""
|
||||
|
@ -86,4 +90,5 @@ def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv,
|
|||
# Noise-free environment should produce the same results every time.
|
||||
(status, data) = mock_env_no_noise.run()
|
||||
assert status.is_succeeded
|
||||
assert data is not None
|
||||
assert data["score"] == pytest.approx(expected_score, 0.01)
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tests for mlos_bench.optimizer.
|
||||
Used to make mypy happy about multiple conftest.py modules.
|
||||
"""
|
|
@ -8,9 +8,9 @@ Unit tests for mock mlos_bench optimizer.
|
|||
|
||||
import pytest
|
||||
|
||||
from mlos_bench.environment import Status
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.optimizer import MlosCoreOptimizer
|
||||
from mlos_bench.optimizer.mlos_core_optimizer import MlosCoreOptimizer
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
@ -41,7 +41,7 @@ def mock_scores() -> list:
|
|||
return [88.88, 66.66, 99.99]
|
||||
|
||||
|
||||
def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list):
|
||||
def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None:
|
||||
"""
|
||||
Make sure that llamatune+emukit optimizer initializes and works correctly.
|
||||
"""
|
||||
|
|
|
@ -8,8 +8,8 @@ Unit tests for mock mlos_bench optimizer.
|
|||
|
||||
import pytest
|
||||
|
||||
from mlos_bench.environment import Status
|
||||
from mlos_bench.optimizer import MockOptimizer
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.optimizer.mock_optimizer import MockOptimizer
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
@ -40,7 +40,7 @@ def mock_configurations() -> list:
|
|||
|
||||
def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float:
|
||||
"""
|
||||
Run several iterations of the oiptimizer and return the best score.
|
||||
Run several iterations of the optimizer and return the best score.
|
||||
"""
|
||||
for (tunable_values, score) in mock_configurations:
|
||||
assert mock_opt.not_converged()
|
||||
|
@ -49,10 +49,12 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float:
|
|||
mock_opt.register(tunables, Status.SUCCEEDED, score)
|
||||
|
||||
(score, _tunables) = mock_opt.get_best_observation()
|
||||
assert score is not None
|
||||
assert isinstance(score, float)
|
||||
return score
|
||||
|
||||
|
||||
def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list):
|
||||
def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> None:
|
||||
"""
|
||||
Make sure that mock optimizer produces consistent suggestions.
|
||||
"""
|
||||
|
@ -60,7 +62,7 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list):
|
|||
assert score == pytest.approx(66.66, 0.01)
|
||||
|
||||
|
||||
def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list):
|
||||
def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list) -> None:
|
||||
"""
|
||||
Check the maximization mode of the mock optimizer.
|
||||
"""
|
||||
|
@ -68,7 +70,7 @@ def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: li
|
|||
assert score == pytest.approx(99.99, 0.01)
|
||||
|
||||
|
||||
def test_mock_optimizer_register_fail(mock_opt: MockOptimizer):
|
||||
def test_mock_optimizer_register_fail(mock_opt: MockOptimizer) -> None:
|
||||
"""
|
||||
Check the input acceptance conditions for Optimizer.register().
|
||||
"""
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Optional, List
|
|||
|
||||
import pytest
|
||||
|
||||
from mlos_bench.environment import Status
|
||||
from mlos_bench.environment.status import Status
|
||||
from mlos_bench.optimizer import Optimizer, MockOptimizer, MlosCoreOptimizer
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
@ -46,7 +46,7 @@ def mock_configs() -> List[dict]:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scores() -> List[float]:
|
||||
def mock_scores() -> List[Optional[float]]:
|
||||
"""
|
||||
Mock benchmark results from earlier experiments.
|
||||
"""
|
||||
|
@ -62,13 +62,14 @@ def mock_status() -> List[Status]:
|
|||
|
||||
|
||||
def _test_opt_update_min(opt: Optimizer, configs: List[dict],
|
||||
scores: List[float], status: Optional[List[Status]] = None):
|
||||
scores: List[float], status: Optional[List[Status]] = None) -> None:
|
||||
"""
|
||||
Test the bulk update of the optimizer on the minimization problem.
|
||||
"""
|
||||
opt.bulk_register(configs, scores, status)
|
||||
(score, tunables) = opt.get_best_observation()
|
||||
assert score == pytest.approx(66.66, 0.01)
|
||||
assert tunables is not None
|
||||
assert tunables.get_param_values() == {
|
||||
"vmSize": "Standard_B4ms",
|
||||
"rootfs": "ext4",
|
||||
|
@ -77,13 +78,14 @@ def _test_opt_update_min(opt: Optimizer, configs: List[dict],
|
|||
|
||||
|
||||
def _test_opt_update_max(opt: Optimizer, configs: List[dict],
|
||||
scores: List[float], status: Optional[List[Status]] = None):
|
||||
scores: List[float], status: Optional[List[Status]] = None) -> None:
|
||||
"""
|
||||
Test the bulk update of the optimizer on the maximiation prtoblem.
|
||||
Test the bulk update of the optimizer on the maximization problem.
|
||||
"""
|
||||
opt.bulk_register(configs, scores, status)
|
||||
(score, tunables) = opt.get_best_observation()
|
||||
assert score == pytest.approx(99.99, 0.01)
|
||||
assert tunables is not None
|
||||
assert tunables.get_param_values() == {
|
||||
"vmSize": "Standard_B2s",
|
||||
"rootfs": "xfs",
|
||||
|
@ -92,7 +94,7 @@ def _test_opt_update_max(opt: Optimizer, configs: List[dict],
|
|||
|
||||
|
||||
def test_update_mock_min(mock_opt: MockOptimizer, mock_configs: List[dict],
|
||||
mock_scores: List[float], mock_status: List[Status]):
|
||||
mock_scores: List[float], mock_status: List[Status]) -> None:
|
||||
"""
|
||||
Test the bulk update of the mock optimizer on the minimization problem.
|
||||
"""
|
||||
|
@ -100,7 +102,7 @@ def test_update_mock_min(mock_opt: MockOptimizer, mock_configs: List[dict],
|
|||
|
||||
|
||||
def test_update_mock_max(mock_opt_max: MockOptimizer, mock_configs: List[dict],
|
||||
mock_scores: List[float], mock_status: List[Status]):
|
||||
mock_scores: List[float], mock_status: List[Status]) -> None:
|
||||
"""
|
||||
Test the bulk update of the mock optimizer on the maximization problem.
|
||||
"""
|
||||
|
@ -108,7 +110,7 @@ def test_update_mock_max(mock_opt_max: MockOptimizer, mock_configs: List[dict],
|
|||
|
||||
|
||||
def test_update_emukit(emukit_opt: MlosCoreOptimizer, mock_configs: List[dict],
|
||||
mock_scores: List[float], mock_status: List[Status]):
|
||||
mock_scores: List[float], mock_status: List[Status]) -> None:
|
||||
"""
|
||||
Test the bulk update of the EmuKit optimizer.
|
||||
"""
|
||||
|
@ -116,7 +118,7 @@ def test_update_emukit(emukit_opt: MlosCoreOptimizer, mock_configs: List[dict],
|
|||
|
||||
|
||||
def test_update_emukit_max(emukit_opt_max: MlosCoreOptimizer, mock_configs: List[dict],
|
||||
mock_scores: List[float], mock_status: List[Status]):
|
||||
mock_scores: List[float], mock_status: List[Status]) -> None:
|
||||
"""
|
||||
Test the bulk update of the EmuKit optimizer on the maximization problem.
|
||||
"""
|
||||
|
@ -124,7 +126,7 @@ def test_update_emukit_max(emukit_opt_max: MlosCoreOptimizer, mock_configs: List
|
|||
|
||||
|
||||
def test_update_scikit_gp(scikit_gp_opt: MlosCoreOptimizer, mock_configs: List[dict],
|
||||
mock_scores: List[float], mock_status: List[Status]):
|
||||
mock_scores: List[float], mock_status: List[Status]) -> None:
|
||||
"""
|
||||
Test the bulk update of the scikit-optimize GP optimizer.
|
||||
"""
|
||||
|
@ -132,7 +134,7 @@ def test_update_scikit_gp(scikit_gp_opt: MlosCoreOptimizer, mock_configs: List[d
|
|||
|
||||
|
||||
def test_update_scikit_et(scikit_et_opt: MlosCoreOptimizer, mock_configs: List[dict],
|
||||
mock_scores: List[float], mock_status: List[Status]):
|
||||
mock_scores: List[float], mock_status: List[Status]) -> None:
|
||||
"""
|
||||
Test the bulk update of the scikit-optimize ET optimizer.
|
||||
"""
|
||||
|
|
|
@ -10,7 +10,8 @@ from typing import Tuple
|
|||
|
||||
import pytest
|
||||
|
||||
from mlos_bench.environment import Environment, MockEnv
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.environment.mock_env import MockEnv
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.optimizer import Optimizer, MockOptimizer, MlosCoreOptimizer
|
||||
|
||||
|
@ -27,17 +28,20 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]:
|
|||
assert env.setup(tunables)
|
||||
|
||||
(status, output) = env.run()
|
||||
score = output['score']
|
||||
assert status.is_succeeded
|
||||
assert output is not None
|
||||
score = output['score']
|
||||
assert 60 <= score <= 120
|
||||
|
||||
opt.register(tunables, status, score)
|
||||
|
||||
return opt.get_best_observation()
|
||||
(best_score, best_tunables) = opt.get_best_observation()
|
||||
assert isinstance(best_score, float) and isinstance(best_tunables, TunableGroups)
|
||||
return (best_score, best_tunables)
|
||||
|
||||
|
||||
def test_mock_optimization_loop(mock_env_no_noise: MockEnv,
|
||||
mock_opt: MockOptimizer):
|
||||
mock_opt: MockOptimizer) -> None:
|
||||
"""
|
||||
Toy optimization loop with mock environment and optimizer.
|
||||
"""
|
||||
|
@ -51,7 +55,7 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv,
|
|||
|
||||
|
||||
def test_scikit_gp_optimization_loop(mock_env_no_noise: MockEnv,
|
||||
scikit_gp_opt: MlosCoreOptimizer):
|
||||
scikit_gp_opt: MlosCoreOptimizer) -> None:
|
||||
"""
|
||||
Toy optimization loop with mock environment and Scikit GP optimizer.
|
||||
"""
|
||||
|
@ -65,7 +69,7 @@ def test_scikit_gp_optimization_loop(mock_env_no_noise: MockEnv,
|
|||
|
||||
|
||||
def test_scikit_et_optimization_loop(mock_env_no_noise: MockEnv,
|
||||
scikit_et_opt: MlosCoreOptimizer):
|
||||
scikit_et_opt: MlosCoreOptimizer) -> None:
|
||||
"""
|
||||
Toy optimization loop with mock environment and Scikit ET optimizer.
|
||||
"""
|
||||
|
@ -79,7 +83,7 @@ def test_scikit_et_optimization_loop(mock_env_no_noise: MockEnv,
|
|||
|
||||
|
||||
def test_emukit_optimization_loop(mock_env_no_noise: MockEnv,
|
||||
emukit_opt: MlosCoreOptimizer):
|
||||
emukit_opt: MlosCoreOptimizer) -> None:
|
||||
"""
|
||||
Toy optimization loop with mock environment and EmuKit optimizer.
|
||||
"""
|
||||
|
@ -89,7 +93,7 @@ def test_emukit_optimization_loop(mock_env_no_noise: MockEnv,
|
|||
|
||||
|
||||
def test_emukit_optimization_loop_max(mock_env_no_noise: MockEnv,
|
||||
emukit_opt_max: MlosCoreOptimizer):
|
||||
emukit_opt_max: MlosCoreOptimizer) -> None:
|
||||
"""
|
||||
Toy optimization loop with mock environment and EmuKit optimizer
|
||||
in maximization mode.
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tests for mlos_bench.service.
|
||||
Used to make mypy happy about multiple conftest.py modules.
|
||||
"""
|
|
@ -9,7 +9,7 @@ Unit tests for configuration persistence service.
|
|||
import os
|
||||
import pytest
|
||||
|
||||
from mlos_bench.service import ConfigPersistenceService
|
||||
from mlos_bench.service.config_persistence import ConfigPersistenceService
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
@ -27,7 +27,7 @@ def config_persistence_service() -> ConfigPersistenceService:
|
|||
})
|
||||
|
||||
|
||||
def test_resolve_path(config_persistence_service: ConfigPersistenceService):
|
||||
def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> None:
|
||||
"""
|
||||
Check if we can actually find a file somewhere in `config_path`.
|
||||
"""
|
||||
|
@ -37,7 +37,7 @@ def test_resolve_path(config_persistence_service: ConfigPersistenceService):
|
|||
assert os.path.exists(path)
|
||||
|
||||
|
||||
def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService):
|
||||
def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None:
|
||||
"""
|
||||
Check if non-existent file resolves without using `config_path`.
|
||||
"""
|
||||
|
@ -47,7 +47,7 @@ def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService)
|
|||
assert path == file_path
|
||||
|
||||
|
||||
def test_load_config(config_persistence_service: ConfigPersistenceService):
|
||||
def test_load_config(config_persistence_service: ConfigPersistenceService) -> None:
|
||||
"""
|
||||
Check if we can successfully load a config file located relative to `config_path`.
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tests for mlos_bench.service.local.
|
||||
Used to make mypy happy about multiple conftest.py modules.
|
||||
"""
|
|
@ -5,12 +5,17 @@
|
|||
"""
|
||||
Unit tests for LocalExecService to run Python scripts locally.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from mlos_bench.service import LocalExecService, ConfigPersistenceService
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.service.local.local_exec import LocalExecService
|
||||
from mlos_bench.service.config_persistence import ConfigPersistenceService
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
@ -25,7 +30,7 @@ def local_exec_service() -> LocalExecService:
|
|||
}))
|
||||
|
||||
|
||||
def test_run_python_script(local_exec_service: LocalExecService):
|
||||
def test_run_python_script(local_exec_service: LocalExecService) -> None:
|
||||
"""
|
||||
Run a Python script using a local_exec service.
|
||||
"""
|
||||
|
@ -33,7 +38,7 @@ def test_run_python_script(local_exec_service: LocalExecService):
|
|||
output_file = "./config-kernel.sh"
|
||||
|
||||
# Tunable parameters to save in JSON
|
||||
params = {
|
||||
params: Dict[str, TunableValue] = {
|
||||
"sched_migration_cost_ns": 40000,
|
||||
"sched_granularity_ns": 800000
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@ import sys
|
|||
import pytest
|
||||
import pandas
|
||||
|
||||
from mlos_bench.service import LocalExecService, ConfigPersistenceService
|
||||
from mlos_bench.service.local.local_exec import LocalExecService
|
||||
from mlos_bench.service.config_persistence import ConfigPersistenceService
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
# -- Ignore pylint complaints about pytest references to
|
||||
|
@ -26,7 +27,7 @@ def local_exec_service() -> LocalExecService:
|
|||
return LocalExecService(parent=ConfigPersistenceService())
|
||||
|
||||
|
||||
def test_run_script(local_exec_service: LocalExecService):
|
||||
def test_run_script(local_exec_service: LocalExecService) -> None:
|
||||
"""
|
||||
Run a script locally and check the results.
|
||||
"""
|
||||
|
@ -37,7 +38,7 @@ def test_run_script(local_exec_service: LocalExecService):
|
|||
assert stderr.strip() == ""
|
||||
|
||||
|
||||
def test_run_script_multiline(local_exec_service: LocalExecService):
|
||||
def test_run_script_multiline(local_exec_service: LocalExecService) -> None:
|
||||
"""
|
||||
Run a multiline script locally and check the results.
|
||||
"""
|
||||
|
@ -51,7 +52,7 @@ def test_run_script_multiline(local_exec_service: LocalExecService):
|
|||
assert stderr.strip() == ""
|
||||
|
||||
|
||||
def test_run_script_multiline_env(local_exec_service: LocalExecService):
|
||||
def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None:
|
||||
"""
|
||||
Run a multiline script locally and pass the environment variables to it.
|
||||
"""
|
||||
|
@ -68,7 +69,7 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService):
|
|||
assert stderr.strip() == ""
|
||||
|
||||
|
||||
def test_run_script_read_csv(local_exec_service: LocalExecService):
|
||||
def test_run_script_read_csv(local_exec_service: LocalExecService) -> None:
|
||||
"""
|
||||
Run a script locally and read the resulting CSV file.
|
||||
"""
|
||||
|
@ -89,7 +90,7 @@ def test_run_script_read_csv(local_exec_service: LocalExecService):
|
|||
assert all(data.col2 == [222, 444])
|
||||
|
||||
|
||||
def test_run_script_write_read_txt(local_exec_service: LocalExecService):
|
||||
def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None:
|
||||
"""
|
||||
Write data a temp location and run a script that updates it there.
|
||||
"""
|
||||
|
@ -112,7 +113,7 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService):
|
|||
assert fh_input.read().split() == ["hello", "world", "test"]
|
||||
|
||||
|
||||
def test_run_script_fail(local_exec_service: LocalExecService):
|
||||
def test_run_script_fail(local_exec_service: LocalExecService) -> None:
|
||||
"""
|
||||
Try to run a non-existent command.
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tests for mlos_bench.service.remote.
|
||||
Used to make mypy happy about multiple conftest.py modules.
|
||||
"""
|
|
@ -0,0 +1,8 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tests for mlos_bench.service.remote.azure.
|
||||
Used to make mypy happy about multiple conftest.py modules.
|
||||
"""
|
|
@ -7,7 +7,9 @@ Tests for mlos_bench.service.remote.azure.azure_fileshare
|
|||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock, patch, call
|
||||
from unittest.mock import MagicMock, Mock, patch, call
|
||||
|
||||
from mlos_bench.service.remote.azure.azure_fileshare import AzureFileShareService
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
# pylint: disable=too-many-arguments
|
||||
|
@ -16,21 +18,21 @@ from unittest.mock import Mock, patch, call
|
|||
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs")
|
||||
def test_download_file(mock_makedirs, mock_open, azure_fileshare):
|
||||
def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
|
||||
filename = "test.csv"
|
||||
remote_folder = "a/remote/folder"
|
||||
local_folder = "some/local/folder"
|
||||
remote_path = f"{remote_folder}/{filename}"
|
||||
local_path = f"{local_folder}/{filename}"
|
||||
# pylint: disable=protected-access
|
||||
mock_share_client = azure_fileshare._share_client
|
||||
mock_share_client.get_directory_client.return_value = Mock(
|
||||
exists=Mock(return_value=False)
|
||||
)
|
||||
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
|
||||
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \
|
||||
patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client:
|
||||
mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False))
|
||||
|
||||
azure_fileshare.download(remote_path, local_path)
|
||||
|
||||
mock_share_client.get_file_client.assert_called_with(remote_path)
|
||||
mock_get_file_client.assert_called_with(remote_path)
|
||||
|
||||
mock_makedirs.assert_called_with(
|
||||
local_folder,
|
||||
exist_ok=True,
|
||||
|
@ -40,7 +42,7 @@ def test_download_file(mock_makedirs, mock_open, azure_fileshare):
|
|||
assert open_mode == "wb"
|
||||
|
||||
|
||||
def make_dir_client_returns(remote_folder: str):
|
||||
def make_dir_client_returns(remote_folder: str) -> dict:
|
||||
return {
|
||||
remote_folder: Mock(
|
||||
exists=Mock(return_value=True),
|
||||
|
@ -66,20 +68,24 @@ def make_dir_client_returns(remote_folder: str):
|
|||
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs")
|
||||
def test_download_folder_non_recursive(mock_makedirs, mock_open, azure_fileshare):
|
||||
def test_download_folder_non_recursive(mock_makedirs: MagicMock,
|
||||
mock_open: MagicMock,
|
||||
azure_fileshare: AzureFileShareService) -> None:
|
||||
remote_folder = "a/remote/folder"
|
||||
local_folder = "some/local/folder"
|
||||
dir_client_returns = make_dir_client_returns(remote_folder)
|
||||
# pylint: disable=protected-access
|
||||
mock_share_client = azure_fileshare._share_client
|
||||
mock_share_client.get_directory_client.side_effect = lambda x: dir_client_returns[x]
|
||||
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
|
||||
with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \
|
||||
patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
|
||||
|
||||
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]
|
||||
|
||||
azure_fileshare.download(remote_folder, local_folder, recursive=False)
|
||||
|
||||
mock_share_client.get_file_client.assert_called_with(
|
||||
mock_get_file_client.assert_called_with(
|
||||
f"{remote_folder}/a_file_1.csv",
|
||||
)
|
||||
mock_share_client.get_directory_client.assert_has_calls([
|
||||
mock_get_directory_client.assert_has_calls([
|
||||
call(remote_folder),
|
||||
call(f"{remote_folder}/a_file_1.csv"),
|
||||
], any_order=True)
|
||||
|
@ -87,21 +93,22 @@ def test_download_folder_non_recursive(mock_makedirs, mock_open, azure_fileshare
|
|||
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs")
|
||||
def test_download_folder_recursive(mock_makedirs, mock_open, azure_fileshare):
|
||||
def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
|
||||
remote_folder = "a/remote/folder"
|
||||
local_folder = "some/local/folder"
|
||||
dir_client_returns = make_dir_client_returns(remote_folder)
|
||||
# pylint: disable=protected-access
|
||||
mock_share_client = azure_fileshare._share_client
|
||||
mock_share_client.get_directory_client.side_effect = lambda x: dir_client_returns[x]
|
||||
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
|
||||
with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \
|
||||
patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
|
||||
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]
|
||||
|
||||
azure_fileshare.download(remote_folder, local_folder, recursive=True)
|
||||
|
||||
mock_share_client.get_file_client.assert_has_calls([
|
||||
mock_get_file_client.assert_has_calls([
|
||||
call(f"{remote_folder}/a_file_1.csv"),
|
||||
call(f"{remote_folder}/a_folder/a_file_2.csv"),
|
||||
], any_order=True)
|
||||
mock_share_client.get_directory_client.assert_has_calls([
|
||||
mock_get_directory_client.assert_has_calls([
|
||||
call(remote_folder),
|
||||
call(f"{remote_folder}/a_file_1.csv"),
|
||||
call(f"{remote_folder}/a_folder"),
|
||||
|
@ -111,19 +118,19 @@ def test_download_folder_recursive(mock_makedirs, mock_open, azure_fileshare):
|
|||
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir")
|
||||
def test_upload_file(mock_isdir, mock_open, azure_fileshare):
|
||||
def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
|
||||
filename = "test.csv"
|
||||
remote_folder = "a/remote/folder"
|
||||
local_folder = "some/local/folder"
|
||||
remote_path = f"{remote_folder}/{filename}"
|
||||
local_path = f"{local_folder}/{filename}"
|
||||
# pylint: disable=protected-access
|
||||
mock_share_client = azure_fileshare._share_client
|
||||
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
|
||||
mock_isdir.return_value = False
|
||||
|
||||
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
|
||||
azure_fileshare.upload(local_path, remote_path)
|
||||
|
||||
mock_share_client.get_file_client.assert_called_with(remote_path)
|
||||
mock_get_file_client.assert_called_with(remote_path)
|
||||
open_path, open_mode = mock_open.call_args.args
|
||||
assert os.path.abspath(local_path) == os.path.abspath(open_path)
|
||||
assert open_mode == "rb"
|
||||
|
@ -136,11 +143,11 @@ class MyDirEntry:
|
|||
self.name = name
|
||||
self.is_a_dir = is_a_dir
|
||||
|
||||
def is_dir(self):
|
||||
def is_dir(self) -> bool:
|
||||
return self.is_a_dir
|
||||
|
||||
|
||||
def make_scandir_returns(local_folder: str):
|
||||
def make_scandir_returns(local_folder: str) -> dict:
|
||||
return {
|
||||
local_folder: [
|
||||
MyDirEntry("a_folder", True),
|
||||
|
@ -152,7 +159,7 @@ def make_scandir_returns(local_folder: str):
|
|||
}
|
||||
|
||||
|
||||
def make_isdir_returns(local_folder: str):
|
||||
def make_isdir_returns(local_folder: str) -> dict:
|
||||
return {
|
||||
local_folder: True,
|
||||
f"{local_folder}/a_file_1.csv": False,
|
||||
|
@ -161,7 +168,7 @@ def make_isdir_returns(local_folder: str):
|
|||
}
|
||||
|
||||
|
||||
def process_paths(input_path):
|
||||
def process_paths(input_path: str) -> str:
|
||||
skip_prefix = os.getcwd()
|
||||
# Remove prefix from os.path.abspath if there
|
||||
if input_path == os.path.abspath(input_path):
|
||||
|
@ -175,37 +182,43 @@ def process_paths(input_path):
|
|||
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir")
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.scandir")
|
||||
def test_upload_directory_non_recursive(mock_scandir, mock_isdir, mock_open, azure_fileshare):
|
||||
def test_upload_directory_non_recursive(mock_scandir: MagicMock,
|
||||
mock_isdir: MagicMock,
|
||||
mock_open: MagicMock,
|
||||
azure_fileshare: AzureFileShareService) -> None:
|
||||
remote_folder = "a/remote/folder"
|
||||
local_folder = "some/local/folder"
|
||||
scandir_returns = make_scandir_returns(local_folder)
|
||||
isdir_returns = make_isdir_returns(local_folder)
|
||||
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
|
||||
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
|
||||
# pylint: disable=protected-access
|
||||
mock_share_client = azure_fileshare._share_client
|
||||
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
|
||||
|
||||
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
|
||||
azure_fileshare.upload(local_folder, remote_folder, recursive=False)
|
||||
|
||||
mock_share_client.get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv")
|
||||
mock_get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv")
|
||||
|
||||
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir")
|
||||
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.scandir")
|
||||
def test_upload_directory_recursive(mock_scandir, mock_isdir, mock_open, azure_fileshare):
|
||||
def test_upload_directory_recursive(mock_scandir: MagicMock,
|
||||
mock_isdir: MagicMock,
|
||||
mock_open: MagicMock,
|
||||
azure_fileshare: AzureFileShareService) -> None:
|
||||
remote_folder = "a/remote/folder"
|
||||
local_folder = "some/local/folder"
|
||||
scandir_returns = make_scandir_returns(local_folder)
|
||||
isdir_returns = make_isdir_returns(local_folder)
|
||||
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
|
||||
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
|
||||
# pylint: disable=protected-access
|
||||
mock_share_client = azure_fileshare._share_client
|
||||
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
|
||||
|
||||
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
|
||||
azure_fileshare.upload(local_folder, remote_folder, recursive=True)
|
||||
|
||||
mock_share_client.get_file_client.assert_has_calls([
|
||||
mock_get_file_client.assert_has_calls([
|
||||
call(f"{remote_folder}/a_file_1.csv"),
|
||||
call(f"{remote_folder}/a_folder/a_file_2.csv"),
|
||||
], any_order=True)
|
||||
|
|
|
@ -10,7 +10,9 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from mlos_bench.environment import Status
|
||||
from mlos_bench.environment.status import Status
|
||||
|
||||
from mlos_bench.service.remote.azure.azure_services import AzureVMService
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
# pylint: disable=too-many-arguments
|
||||
|
@ -21,7 +23,7 @@ from mlos_bench.environment import Status
|
|||
("vm_start", True),
|
||||
("vm_stop", False),
|
||||
("vm_deprovision", False),
|
||||
("vm_reboot", False),
|
||||
("vm_restart", False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
("http_status_code", "operation_status"), [
|
||||
|
@ -31,8 +33,8 @@ from mlos_bench.environment import Status
|
|||
(404, Status.FAILED),
|
||||
])
|
||||
@patch("mlos_bench.service.remote.azure.azure_services.requests")
|
||||
def test_vm_operation_status(mock_requests, azure_vm_service, operation_name,
|
||||
accepts_params, http_status_code, operation_status):
|
||||
def test_vm_operation_status(mock_requests: MagicMock, azure_vm_service: AzureVMService, operation_name: str,
|
||||
accepts_params: bool, http_status_code: int, operation_status: Status) -> None:
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = http_status_code
|
||||
|
@ -49,7 +51,7 @@ def test_vm_operation_status(mock_requests, azure_vm_service, operation_name,
|
|||
|
||||
@patch("mlos_bench.service.remote.azure.azure_services.time.sleep")
|
||||
@patch("mlos_bench.service.remote.azure.azure_services.requests")
|
||||
def test_wait_vm_operation_ready(mock_requests, mock_sleep, azure_vm_service):
|
||||
def test_wait_vm_operation_ready(mock_requests: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService) -> None:
|
||||
|
||||
# Mock response header
|
||||
async_url = "DUMMY_ASYNC_URL"
|
||||
|
@ -73,7 +75,7 @@ def test_wait_vm_operation_ready(mock_requests, mock_sleep, azure_vm_service):
|
|||
|
||||
|
||||
@patch("mlos_bench.service.remote.azure.azure_services.requests")
|
||||
def test_wait_vm_operation_timeout(mock_requests, azure_vm_service):
|
||||
def test_wait_vm_operation_timeout(mock_requests: MagicMock, azure_vm_service: AzureVMService) -> None:
|
||||
|
||||
# Mock response header
|
||||
params = {
|
||||
|
@ -99,8 +101,8 @@ def test_wait_vm_operation_timeout(mock_requests, azure_vm_service):
|
|||
(404, Status.FAILED),
|
||||
])
|
||||
@patch("mlos_bench.service.remote.azure.azure_services.requests")
|
||||
def test_remote_exec_status(mock_requests, azure_vm_service, http_status_code, operation_status):
|
||||
|
||||
def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service: AzureVMService,
|
||||
http_status_code: int, operation_status: Status) -> None:
|
||||
script = ["command_1", "command_2"]
|
||||
|
||||
mock_response = MagicMock()
|
||||
|
@ -113,7 +115,7 @@ def test_remote_exec_status(mock_requests, azure_vm_service, http_status_code, o
|
|||
|
||||
|
||||
@patch("mlos_bench.service.remote.azure.azure_services.requests")
|
||||
def test_remote_exec_headers_output(mock_requests, azure_vm_service):
|
||||
def test_remote_exec_headers_output(mock_requests: MagicMock, azure_vm_service: AzureVMService) -> None:
|
||||
|
||||
async_url_key = "asyncResultsUrl"
|
||||
async_url_value = "DUMMY_ASYNC_URL"
|
||||
|
@ -158,14 +160,15 @@ def test_remote_exec_headers_output(mock_requests, azure_vm_service):
|
|||
(Status.PENDING, {}, {}),
|
||||
(Status.FAILED, {}, {}),
|
||||
])
|
||||
def test_get_remote_exec_results(azure_vm_service, operation_status: Status,
|
||||
wait_output: dict, results_output: dict):
|
||||
def test_get_remote_exec_results(azure_vm_service: AzureVMService, operation_status: Status,
|
||||
wait_output: dict, results_output: dict) -> None:
|
||||
|
||||
params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"}
|
||||
|
||||
mock_wait_vm_operation = MagicMock()
|
||||
mock_wait_vm_operation.return_value = (operation_status, wait_output)
|
||||
azure_vm_service.wait_vm_operation = mock_wait_vm_operation
|
||||
# azure_vm_service.wait_vm_operation = mock_wait_vm_operation
|
||||
setattr(azure_vm_service, "wait_vm_operation", mock_wait_vm_operation)
|
||||
|
||||
status, cmd_output = azure_vm_service.get_remote_exec_results(params)
|
||||
|
||||
|
|
|
@ -9,14 +9,14 @@ Configuration test fixtures for azure_services in mlos_bench.
|
|||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from mlos_bench.service import ConfigPersistenceService
|
||||
from mlos_bench.service.config_persistence import ConfigPersistenceService
|
||||
from mlos_bench.service.remote.azure import AzureVMService, AzureFileShareService
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_persistence_service():
|
||||
def config_persistence_service() -> ConfigPersistenceService:
|
||||
"""
|
||||
Test fixture for ConfigPersistenceService.
|
||||
"""
|
||||
|
@ -28,7 +28,7 @@ def config_persistence_service():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_vm_service(config_persistence_service: ConfigPersistenceService):
|
||||
def azure_vm_service(config_persistence_service: ConfigPersistenceService) -> AzureVMService:
|
||||
"""
|
||||
Creates a dummy Azure VM service for tests that require it.
|
||||
"""
|
||||
|
@ -45,7 +45,7 @@ def azure_vm_service(config_persistence_service: ConfigPersistenceService):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_fileshare(config_persistence_service: ConfigPersistenceService):
|
||||
def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService:
|
||||
"""
|
||||
Creates a dummy AzureFileShareService for tests that require it.
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tests for mlos_bench.tunables.
|
||||
Used to make mypy happy about multiple conftest.py modules.
|
||||
"""
|
|
@ -46,7 +46,7 @@ def configuration_space() -> ConfigurationSpace:
|
|||
|
||||
|
||||
def _cmp_tunable_hyperparameter_categorical(
|
||||
tunable: Tunable, cs_param: CategoricalHyperparameter):
|
||||
tunable: Tunable, cs_param: CategoricalHyperparameter) -> None:
|
||||
"""
|
||||
Check if categorical Tunable and ConfigSpace Hyperparameter actually match.
|
||||
"""
|
||||
|
@ -56,7 +56,7 @@ def _cmp_tunable_hyperparameter_categorical(
|
|||
|
||||
|
||||
def _cmp_tunable_hyperparameter_int(
|
||||
tunable: Tunable, cs_param: UniformIntegerHyperparameter):
|
||||
tunable: Tunable, cs_param: UniformIntegerHyperparameter) -> None:
|
||||
"""
|
||||
Check if integer Tunable and ConfigSpace Hyperparameter actually match.
|
||||
"""
|
||||
|
@ -66,7 +66,7 @@ def _cmp_tunable_hyperparameter_int(
|
|||
|
||||
|
||||
def _cmp_tunable_hyperparameter_float(
|
||||
tunable: Tunable, cs_param: UniformFloatHyperparameter):
|
||||
tunable: Tunable, cs_param: UniformFloatHyperparameter) -> None:
|
||||
"""
|
||||
Check if float Tunable and ConfigSpace Hyperparameter actually match.
|
||||
"""
|
||||
|
@ -75,7 +75,7 @@ def _cmp_tunable_hyperparameter_float(
|
|||
assert cs_param.default_value == tunable.value
|
||||
|
||||
|
||||
def test_tunable_to_hyperparameter_categorical(tunable_categorical):
|
||||
def test_tunable_to_hyperparameter_categorical(tunable_categorical: Tunable) -> None:
|
||||
"""
|
||||
Check the conversion of Tunable to CategoricalHyperparameter.
|
||||
"""
|
||||
|
@ -83,7 +83,7 @@ def test_tunable_to_hyperparameter_categorical(tunable_categorical):
|
|||
_cmp_tunable_hyperparameter_categorical(tunable_categorical, cs_param)
|
||||
|
||||
|
||||
def test_tunable_to_hyperparameter_int(tunable_int):
|
||||
def test_tunable_to_hyperparameter_int(tunable_int: Tunable) -> None:
|
||||
"""
|
||||
Check the conversion of Tunable to UniformIntegerHyperparameter.
|
||||
"""
|
||||
|
@ -91,7 +91,7 @@ def test_tunable_to_hyperparameter_int(tunable_int):
|
|||
_cmp_tunable_hyperparameter_int(tunable_int, cs_param)
|
||||
|
||||
|
||||
def test_tunable_to_hyperparameter_float(tunable_float):
|
||||
def test_tunable_to_hyperparameter_float(tunable_float: Tunable) -> None:
|
||||
"""
|
||||
Check the conversion of Tunable to UniformFloatHyperparameter.
|
||||
"""
|
||||
|
@ -106,7 +106,7 @@ _CMP_FUNC = {
|
|||
}
|
||||
|
||||
|
||||
def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups):
|
||||
def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Check the conversion of TunableGroups to ConfigurationSpace.
|
||||
Make sure that the corresponding Tunable and Hyperparameter objects match.
|
||||
|
@ -119,7 +119,7 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups):
|
|||
|
||||
|
||||
def test_tunable_groups_to_configspace(
|
||||
tunable_groups: TunableGroups, configuration_space: ConfigurationSpace):
|
||||
tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None:
|
||||
"""
|
||||
Check the conversion of the entire TunableGroups collection
|
||||
to a single ConfigurationSpace object.
|
||||
|
|
|
@ -8,8 +8,11 @@ Unit tests for assigning values to the individual parameters within tunable grou
|
|||
|
||||
import pytest
|
||||
|
||||
from mlos_bench.tunables.tunable import Tunable
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
def test_tunables_assign_unknown_param(tunable_groups):
|
||||
|
||||
def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Make sure that bulk assignment fails for parameters
|
||||
that don't exist in the TunableGroups object.
|
||||
|
@ -23,7 +26,7 @@ def test_tunables_assign_unknown_param(tunable_groups):
|
|||
})
|
||||
|
||||
|
||||
def test_tunables_assign_invalid_categorical(tunable_groups):
|
||||
def test_tunables_assign_invalid_categorical(tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Check parameter validation for categorical tunables.
|
||||
"""
|
||||
|
@ -31,7 +34,7 @@ def test_tunables_assign_invalid_categorical(tunable_groups):
|
|||
tunable_groups.assign({"vmSize": "InvalidSize"})
|
||||
|
||||
|
||||
def test_tunables_assign_invalid_range(tunable_groups):
|
||||
def test_tunables_assign_invalid_range(tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Check parameter out-of-range validation for numerical tunables.
|
||||
"""
|
||||
|
@ -39,14 +42,14 @@ def test_tunables_assign_invalid_range(tunable_groups):
|
|||
tunable_groups.assign({"kernel_sched_migration_cost_ns": -2})
|
||||
|
||||
|
||||
def test_tunables_assign_coerce_str(tunable_groups):
|
||||
def test_tunables_assign_coerce_str(tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Check the conversion from strings when assigning to an integer parameter.
|
||||
"""
|
||||
tunable_groups.assign({"kernel_sched_migration_cost_ns": "10000"})
|
||||
|
||||
|
||||
def test_tunables_assign_coerce_str_range_check(tunable_groups):
|
||||
def test_tunables_assign_coerce_str_range_check(tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Check the range when assigning to an integer tunable.
|
||||
"""
|
||||
|
@ -54,7 +57,7 @@ def test_tunables_assign_coerce_str_range_check(tunable_groups):
|
|||
tunable_groups.assign({"kernel_sched_migration_cost_ns": "5500000"})
|
||||
|
||||
|
||||
def test_tunables_assign_coerce_str_invalid(tunable_groups):
|
||||
def test_tunables_assign_coerce_str_invalid(tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Make sure we fail when assigning an invalid string to an integer tunable.
|
||||
"""
|
||||
|
@ -62,23 +65,23 @@ def test_tunables_assign_coerce_str_invalid(tunable_groups):
|
|||
tunable_groups.assign({"kernel_sched_migration_cost_ns": "1.1"})
|
||||
|
||||
|
||||
def test_tunable_assign_str_to_int(tunable_int):
|
||||
def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None:
|
||||
"""
|
||||
Check str to int coercion.
|
||||
"""
|
||||
tunable_int.value = "10"
|
||||
assert tunable_int.value == 10
|
||||
assert tunable_int.value == 10 # type: ignore[comparison-overlap]
|
||||
|
||||
|
||||
def test_tunable_assign_str_to_float(tunable_float):
|
||||
def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None:
|
||||
"""
|
||||
Check str to float coercion.
|
||||
"""
|
||||
tunable_float.value = "0.5"
|
||||
assert tunable_float.value == 0.5
|
||||
assert tunable_float.value == 0.5 # type: ignore[comparison-overlap]
|
||||
|
||||
|
||||
def test_tunable_assign_float_to_int(tunable_int):
|
||||
def test_tunable_assign_float_to_int(tunable_int: Tunable) -> None:
|
||||
"""
|
||||
Check float to int coercion.
|
||||
"""
|
||||
|
@ -86,7 +89,7 @@ def test_tunable_assign_float_to_int(tunable_int):
|
|||
assert tunable_int.value == 10
|
||||
|
||||
|
||||
def test_tunable_assign_float_to_int_fail(tunable_int):
|
||||
def test_tunable_assign_float_to_int_fail(tunable_int: Tunable) -> None:
|
||||
"""
|
||||
Check the invalid float to int coercion.
|
||||
"""
|
||||
|
|
|
@ -6,18 +6,21 @@
|
|||
Unit tests for deep copy of tunable objects and groups.
|
||||
"""
|
||||
|
||||
from mlos_bench.tunables.tunable import Tunable
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
def test_copy_tunable_int(tunable_int):
|
||||
|
||||
def test_copy_tunable_int(tunable_int: Tunable) -> None:
|
||||
"""
|
||||
Check if deep copy works for Tunable object.
|
||||
"""
|
||||
tunable_copy = tunable_int.copy()
|
||||
assert tunable_int == tunable_copy
|
||||
tunable_copy.value += 200
|
||||
tunable_copy.value = tunable_copy.numerical_value + 200
|
||||
assert tunable_int != tunable_copy
|
||||
|
||||
|
||||
def test_copy_tunable_groups(tunable_groups):
|
||||
def test_copy_tunable_groups(tunable_groups: TunableGroups) -> None:
|
||||
"""
|
||||
Check if deep copy works for TunableGroups object.
|
||||
"""
|
||||
|
|
|
@ -7,9 +7,9 @@ Tunable parameter definition.
|
|||
"""
|
||||
import copy
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Dict, Iterable
|
||||
|
||||
from mlos_bench.tunables.tunable import Tunable
|
||||
from mlos_bench.tunables.tunable import Tunable, TunableValue
|
||||
|
||||
|
||||
class CovariantTunableGroup:
|
||||
|
@ -32,8 +32,8 @@ class CovariantTunableGroup:
|
|||
"""
|
||||
self._is_updated = True
|
||||
self._name = name
|
||||
self._cost = config.get("cost", 0)
|
||||
self._tunables = {
|
||||
self._cost = int(config.get("cost", 0))
|
||||
self._tunables: Dict[str, Tunable] = {
|
||||
name: Tunable(name, tunable_config)
|
||||
for (name, tunable_config) in config.get("params", {}).items()
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ class CovariantTunableGroup:
|
|||
"""
|
||||
return self._name
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "CovariantTunableGroup":
|
||||
"""
|
||||
Deep copy of the CovariantTunableGroup object.
|
||||
|
||||
|
@ -62,7 +62,7 @@ class CovariantTunableGroup:
|
|||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""
|
||||
Check if two CovariantTunableGroup objects are equal.
|
||||
|
||||
|
@ -76,12 +76,14 @@ class CovariantTunableGroup:
|
|||
is_equal : bool
|
||||
True if two CovariantTunableGroup objects are equal.
|
||||
"""
|
||||
if not isinstance(other, CovariantTunableGroup):
|
||||
return False
|
||||
return (self._name == other._name and
|
||||
self._cost == other._cost and
|
||||
self._is_updated == other._is_updated and
|
||||
self._tunables == other._tunables)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Clear the update flag. That is, state that running an experiment with the
|
||||
current values of the tunables in this group has no extra cost.
|
||||
|
@ -110,13 +112,13 @@ class CovariantTunableGroup:
|
|||
"""
|
||||
return self._cost if self._is_updated else 0
|
||||
|
||||
def get_names(self) -> List[str]:
|
||||
def get_names(self) -> Iterable[str]:
|
||||
"""
|
||||
Get the names of all tunables in the group.
|
||||
"""
|
||||
return self._tunables.keys()
|
||||
|
||||
def get_values(self) -> Dict[str, Any]:
|
||||
def get_values(self) -> Dict[str, TunableValue]:
|
||||
"""
|
||||
Get current values of all tunables in the group.
|
||||
"""
|
||||
|
@ -151,10 +153,10 @@ class CovariantTunableGroup:
|
|||
"""
|
||||
return self._tunables[name]
|
||||
|
||||
def __getitem__(self, name: str):
|
||||
def __getitem__(self, name: str) -> TunableValue:
|
||||
return self.get_tunable(name).value
|
||||
|
||||
def __setitem__(self, name: str, value):
|
||||
def __setitem__(self, name: str, value: TunableValue) -> TunableValue:
|
||||
self._is_updated = True
|
||||
self._tunables[name].value = value
|
||||
return value
|
||||
|
|
|
@ -9,11 +9,32 @@ import copy
|
|||
import collections
|
||||
import logging
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
from typing import List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
"""A tunable parameter value type alias."""
|
||||
TunableValue = Union[int, float, str]
|
||||
|
||||
|
||||
class TunableDict(TypedDict, total=False):
|
||||
"""
|
||||
A typed dict for tunable parameters.
|
||||
|
||||
Mostly used for mypy type checking.
|
||||
|
||||
These are the types expected to be received from the json config.
|
||||
"""
|
||||
|
||||
type: str
|
||||
description: Optional[str]
|
||||
default: TunableValue
|
||||
values: Optional[List[str]]
|
||||
range: Optional[Union[Sequence[int], Sequence[float]]]
|
||||
special: Optional[Union[List[int], List[str]]]
|
||||
|
||||
|
||||
class Tunable: # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
A tunable parameter definition and its current value.
|
||||
|
@ -25,7 +46,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
"categorical": str,
|
||||
}
|
||||
|
||||
def __init__(self, name: str, config: dict):
|
||||
def __init__(self, name: str, config: TunableDict):
|
||||
"""
|
||||
Create an instance of a new tunable parameter.
|
||||
|
||||
|
@ -39,19 +60,24 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
self._name = name
|
||||
self._type = config["type"] # required
|
||||
self._description = config.get("description")
|
||||
self._default = config.get("default")
|
||||
self._default = config["default"]
|
||||
self._values = config.get("values")
|
||||
self._range = config.get("range")
|
||||
self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None
|
||||
config_range = config.get("range")
|
||||
if config_range is not None:
|
||||
assert len(config_range) == 2, f"Invalid range: {config_range}"
|
||||
config_range = (config_range[0], config_range[1])
|
||||
self._range = config_range
|
||||
self._special = config.get("special")
|
||||
self._current_value = self._default
|
||||
if self._type == "categorical":
|
||||
if self.is_categorical:
|
||||
if not (self._values and isinstance(self._values, collections.abc.Iterable)):
|
||||
raise ValueError("Must specify values for the categorical type")
|
||||
if self._range is not None:
|
||||
raise ValueError("Range must be None for the categorical type")
|
||||
if self._special is not None:
|
||||
raise ValueError("Special values must be None for the categorical type")
|
||||
elif self._type in {"int", "float"}:
|
||||
elif self.is_numerical:
|
||||
if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]:
|
||||
raise ValueError(f"Invalid range: {self._range}")
|
||||
else:
|
||||
|
@ -68,7 +94,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
"""
|
||||
return f"{self._name}={self._current_value}"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""
|
||||
Check if two Tunable objects are equal.
|
||||
|
||||
|
@ -83,13 +109,15 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
True if the Tunables correspond to the same parameter and have the same value and type.
|
||||
NOTE: ranges and special values are not currently considered in the comparison.
|
||||
"""
|
||||
return (
|
||||
if not isinstance(other, Tunable):
|
||||
return False
|
||||
return bool(
|
||||
self._name == other._name and
|
||||
self._type == other._type and
|
||||
self._current_value == other._current_value
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "Tunable":
|
||||
"""
|
||||
Deep copy of the Tunable object.
|
||||
|
||||
|
@ -101,14 +129,14 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
return copy.deepcopy(self)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
def value(self) -> TunableValue:
|
||||
"""
|
||||
Get the current value of the tunable.
|
||||
"""
|
||||
return self._current_value
|
||||
|
||||
@value.setter
|
||||
def value(self, value):
|
||||
def value(self, value: TunableValue) -> TunableValue:
|
||||
"""
|
||||
Set the current value of the tunable.
|
||||
"""
|
||||
|
@ -135,13 +163,13 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
self._current_value = coerced_value
|
||||
return self._current_value
|
||||
|
||||
def is_valid(self, value) -> bool:
|
||||
def is_valid(self, value: TunableValue) -> bool:
|
||||
"""
|
||||
Check if the value can be assigned to the tunable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : Any
|
||||
value : Union[int, float, str]
|
||||
Value to validate.
|
||||
|
||||
Returns
|
||||
|
@ -149,13 +177,36 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
is_valid : bool
|
||||
True if the value is valid, False otherwise.
|
||||
"""
|
||||
if self._type == "categorical":
|
||||
if self.is_categorical and self._values:
|
||||
return value in self._values
|
||||
elif self._type in {"int", "float"} and self._range:
|
||||
return (self._range[0] <= value <= self._range[1]) or value == self._default
|
||||
elif self.is_numerical and self._range:
|
||||
assert isinstance(value, (int, float))
|
||||
return bool(self._range[0] <= value <= self._range[1]) or value == self._default
|
||||
else:
|
||||
raise ValueError(f"Invalid parameter type: {self._type}")
|
||||
|
||||
@property
|
||||
def categorical_value(self) -> str:
|
||||
"""
|
||||
Get the current value of the tunable as a number.
|
||||
"""
|
||||
if self.is_categorical:
|
||||
return str(self._current_value)
|
||||
else:
|
||||
raise ValueError("Cannot get categorical values for a numerical tunable.")
|
||||
|
||||
@property
|
||||
def numerical_value(self) -> Union[int, float]:
|
||||
"""
|
||||
Get the current value of the tunable as a number.
|
||||
"""
|
||||
if self._type == "int":
|
||||
return int(self._current_value)
|
||||
elif self._type == "float":
|
||||
return float(self._current_value)
|
||||
else:
|
||||
raise ValueError("Cannot get numerical value for a categorical tunable.")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
|
@ -176,7 +227,31 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
return self._type
|
||||
|
||||
@property
|
||||
def range(self) -> Tuple[Any, Any]:
|
||||
def is_categorical(self) -> bool:
|
||||
"""
|
||||
Check if the tunable is categorical.
|
||||
|
||||
Returns
|
||||
-------
|
||||
is_categorical : bool
|
||||
True if the tunable is categorical, False otherwise.
|
||||
"""
|
||||
return self._type == "categorical"
|
||||
|
||||
@property
|
||||
def is_numerical(self) -> bool:
|
||||
"""
|
||||
Check if the tunable is an integer or float.
|
||||
|
||||
Returns
|
||||
-------
|
||||
is_int : bool
|
||||
True if the tunable is an integer or float, False otherwise.
|
||||
"""
|
||||
return self._type in {"int", "float"}
|
||||
|
||||
@property
|
||||
def range(self) -> Union[Tuple[int, int], Tuple[float, float]]:
|
||||
"""
|
||||
Get the range of the tunable if it is numerical, None otherwise.
|
||||
|
||||
|
@ -186,6 +261,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
A 2-tuple of numbers that represents the range of the tunable.
|
||||
Numbers can be int or float, depending on the type of the tunable.
|
||||
"""
|
||||
assert self.is_numerical
|
||||
assert self._range is not None
|
||||
return self._range
|
||||
|
||||
@property
|
||||
|
@ -199,4 +276,6 @@ class Tunable: # pylint: disable=too-many-instance-attributes
|
|||
values : List[str]
|
||||
List of all possible values of a categorical tunable.
|
||||
"""
|
||||
assert self.is_categorical
|
||||
assert self._values is not None
|
||||
return self._values
|
||||
|
|
|
@ -7,9 +7,9 @@ TunableGroups definition.
|
|||
"""
|
||||
import copy
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple
|
||||
|
||||
from mlos_bench.tunables.tunable import Tunable
|
||||
from mlos_bench.tunables.tunable import Tunable, TunableValue
|
||||
from mlos_bench.tunables.covariant_group import CovariantTunableGroup
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ class TunableGroups:
|
|||
A collection of covariant groups of tunable parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
def __init__(self, config: Optional[dict] = None):
|
||||
"""
|
||||
Create a new group of tunable parameters.
|
||||
|
||||
|
@ -27,12 +27,14 @@ class TunableGroups:
|
|||
config : dict
|
||||
Python dict of serialized representation of the covariant tunable groups.
|
||||
"""
|
||||
self._index = {} # Index (Tunable id -> CovariantTunableGroup)
|
||||
self._tunable_groups = {}
|
||||
for (name, group_config) in (config or {}).items():
|
||||
if config is None:
|
||||
config = {}
|
||||
self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup)
|
||||
self._tunable_groups: Dict[str, CovariantTunableGroup] = {}
|
||||
for (name, group_config) in config.items():
|
||||
self._add_group(CovariantTunableGroup(name, group_config))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""
|
||||
Check if two TunableGroups are equal.
|
||||
|
||||
|
@ -46,9 +48,11 @@ class TunableGroups:
|
|||
is_equal : bool
|
||||
True if two TunableGroups are equal.
|
||||
"""
|
||||
return self._tunable_groups == other._tunable_groups
|
||||
if not isinstance(other, TunableGroups):
|
||||
return False
|
||||
return bool(self._tunable_groups == other._tunable_groups)
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "TunableGroups":
|
||||
"""
|
||||
Deep copy of the TunableGroups object.
|
||||
|
||||
|
@ -60,7 +64,7 @@ class TunableGroups:
|
|||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def _add_group(self, group: CovariantTunableGroup):
|
||||
def _add_group(self, group: CovariantTunableGroup) -> None:
|
||||
"""
|
||||
Add a CovariantTunableGroup to the current collection.
|
||||
|
||||
|
@ -71,7 +75,7 @@ class TunableGroups:
|
|||
self._tunable_groups[group.name] = group
|
||||
self._index.update(dict.fromkeys(group.get_names(), group))
|
||||
|
||||
def update(self, tunables):
|
||||
def update(self, tunables: "TunableGroups") -> "TunableGroups":
|
||||
"""
|
||||
Merge the two collections of covariant tunable groups.
|
||||
|
||||
|
@ -104,20 +108,20 @@ class TunableGroups:
|
|||
for (group_name, group) in self._tunable_groups.items()
|
||||
for tunable in group._tunables.values()) + " }"
|
||||
|
||||
def __getitem__(self, name: str):
|
||||
def __getitem__(self, name: str) -> TunableValue:
|
||||
"""
|
||||
Get the current value of a single tunable parameter.
|
||||
"""
|
||||
return self._index[name][name]
|
||||
|
||||
def __setitem__(self, name: str, value):
|
||||
def __setitem__(self, name: str, value: TunableValue) -> None:
|
||||
"""
|
||||
Update the current value of a single tunable parameter.
|
||||
"""
|
||||
# Use double index to make sure we set the is_updated flag of the group
|
||||
self._index[name][name] = value
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]:
|
||||
"""
|
||||
An iterator over all tunables in the group.
|
||||
|
||||
|
@ -147,7 +151,7 @@ class TunableGroups:
|
|||
group = self._index[name]
|
||||
return (group.get_tunable(name), group)
|
||||
|
||||
def get_names(self) -> List[str]:
|
||||
def get_names(self) -> Iterable[str]:
|
||||
"""
|
||||
Get the names of all covariance groups in the collection.
|
||||
|
||||
|
@ -158,7 +162,7 @@ class TunableGroups:
|
|||
"""
|
||||
return self._tunable_groups.keys()
|
||||
|
||||
def subgroup(self, group_names: List[str]):
|
||||
def subgroup(self, group_names: Iterable[str]) -> "TunableGroups":
|
||||
"""
|
||||
Select the covariance groups from the current set and create a new
|
||||
TunableGroups object that consists of those covariance groups.
|
||||
|
@ -179,8 +183,8 @@ class TunableGroups:
|
|||
tunables._add_group(self._tunable_groups[name])
|
||||
return tunables
|
||||
|
||||
def get_param_values(self, group_names: List[str] = None,
|
||||
into_params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def get_param_values(self, group_names: Optional[Iterable[str]] = None,
|
||||
into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]:
|
||||
"""
|
||||
Get the current values of the tunables that belong to the specified covariance groups.
|
||||
|
||||
|
@ -205,7 +209,7 @@ class TunableGroups:
|
|||
into_params.update(self._tunable_groups[name].get_values())
|
||||
return into_params
|
||||
|
||||
def is_updated(self, group_names: List[str] = None) -> bool:
|
||||
def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool:
|
||||
"""
|
||||
Check if any of the given covariant tunable groups has been updated.
|
||||
|
||||
|
@ -222,7 +226,7 @@ class TunableGroups:
|
|||
return any(self._tunable_groups[name].is_updated()
|
||||
for name in (group_names or self.get_names()))
|
||||
|
||||
def reset(self, group_names: List[str] = None):
|
||||
def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
|
||||
"""
|
||||
Clear the update flag of given covariant groups.
|
||||
|
||||
|
@ -240,14 +244,14 @@ class TunableGroups:
|
|||
self._tunable_groups[name].reset()
|
||||
return self
|
||||
|
||||
def assign(self, param_values: Dict[str, Any]):
|
||||
def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
|
||||
"""
|
||||
In-place update the values of the tunables from the dictionary
|
||||
of (key, value) pairs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
param_values : Dict[str, Any]
|
||||
param_values : Mapping[str, TunableValue]
|
||||
Dictionary mapping Tunable parameter names to new values.
|
||||
|
||||
Returns
|
||||
|
|
|
@ -11,12 +11,13 @@ Various helper functions for mlos_bench.
|
|||
import json
|
||||
import logging
|
||||
import importlib
|
||||
from typing import Any, Tuple, Dict, Iterable
|
||||
|
||||
from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Type, TypeVar, TYPE_CHECKING
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, Dict[str, Any]]:
|
||||
def prepare_class_load(config: dict, global_config: Optional[dict] = None) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Extract the class instantiation parameters from the configuration.
|
||||
|
||||
|
@ -35,7 +36,10 @@ def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, D
|
|||
class_name = config["class"]
|
||||
class_config = config.setdefault("config", {})
|
||||
|
||||
for key in set(class_config).intersection(global_config or {}):
|
||||
if global_config is None:
|
||||
global_config = {}
|
||||
|
||||
for key in set(class_config).intersection(global_config):
|
||||
class_config[key] = global_config[key]
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
|
@ -45,7 +49,17 @@ def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, D
|
|||
return (class_name, class_config)
|
||||
|
||||
|
||||
def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs):
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.environment.base_environment import Environment
|
||||
from mlos_bench.service.base_service import Service
|
||||
from mlos_bench.optimizer.base_optimizer import Optimizer
|
||||
|
||||
# T is a generic with a constraint of the three base classes.
|
||||
T = TypeVar('T', "Environment", "Service", "Optimizer")
|
||||
|
||||
|
||||
# FIXME: Technically this should return a type "class_name" derived from "base_class".
|
||||
def instantiate_from_config(base_class: Type[T], class_name: str, *args: Any, **kwargs: Any) -> T:
|
||||
"""
|
||||
Factory method for a new class instantiated from config.
|
||||
|
||||
|
@ -65,7 +79,7 @@ def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs):
|
|||
|
||||
Returns
|
||||
-------
|
||||
inst : Any
|
||||
inst : Union[Environment, Service, Optimizer]
|
||||
An instance of the `class_name` class.
|
||||
"""
|
||||
# We need to import mlos_bench to make the factory methods work.
|
||||
|
@ -78,10 +92,12 @@ def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs):
|
|||
_LOG.info("Instantiating: %s :: %s", class_name, impl)
|
||||
|
||||
assert issubclass(impl, base_class)
|
||||
return impl(*args, **kwargs)
|
||||
ret: T = impl(*args, **kwargs)
|
||||
assert isinstance(ret, base_class)
|
||||
return ret
|
||||
|
||||
|
||||
def check_required_params(config: Dict[str, Any], required_params: Iterable[str]):
|
||||
def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None:
|
||||
"""
|
||||
Check if all required parameters are present in the configuration.
|
||||
Raise ValueError if any of the parameters are missing.
|
||||
|
|
|
@ -34,13 +34,16 @@ extra_requires = {
|
|||
|
||||
# construct special 'full' extra that adds requirements for all built-in
|
||||
# backend integrations and additional extra features.
|
||||
extra_requires['full'] = list(set(chain(extra_requires.values())))
|
||||
extra_requires['full'] = list(set(chain(extra_requires.values()))) # type: ignore[assignment]
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
setup(
|
||||
name='mlos-bench',
|
||||
version=_VERSION,
|
||||
packages=find_packages(),
|
||||
package_data={
|
||||
'mlos_bench': ['py.typed'],
|
||||
},
|
||||
install_requires=[
|
||||
'mlos-core==' + _VERSION,
|
||||
'requests',
|
||||
|
|
|
@ -7,7 +7,7 @@ Basic initializer module for the mlos_core space adapters.
|
|||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, TypeVar, Union
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
import ConfigSpace
|
||||
|
||||
|
|
|
@ -22,8 +22,6 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
|
|||
aimed at improving the sample-efficiency of the underlying optimizer.
|
||||
"""
|
||||
|
||||
# pylint: disable=consider-alternative-union-syntax,too-many-arguments
|
||||
|
||||
DEFAULT_NUM_LOW_DIMS = 16
|
||||
"""Default number of dimensions in the low-dimensional search space, generated by HeSBO projection"""
|
||||
|
||||
|
@ -40,7 +38,7 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
|
|||
special_param_values: Optional[dict] = None,
|
||||
max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM,
|
||||
use_approximate_reverse_mapping: bool = False,
|
||||
) -> None:
|
||||
) -> None: # pylint: disable=too-many-arguments
|
||||
"""Create a space adapter that employs LlamaTune's techniques.
|
||||
|
||||
Parameters
|
||||
|
|
10
setup.cfg
10
setup.cfg
|
@ -54,8 +54,12 @@ disallow_untyped_defs = True
|
|||
disallow_incomplete_defs = True
|
||||
strict = True
|
||||
allow_any_generics = True
|
||||
hide_error_codes = False
|
||||
# regex of files to skip type checking
|
||||
exclude = /_pytest/|/build/|doc/|_version.py|setup.py
|
||||
# _version.py and setup.py look like duplicates when run from the root of the repo even though they're part of different packages.
|
||||
# There's not much in them so we just skip them.
|
||||
# We also skip several vendor files that currently throw errors.
|
||||
exclude = mlos_(core|bench)/(_version|setup).py|doc/|/build/|-packages/_pytest/
|
||||
|
||||
# https://github.com/automl/ConfigSpace/issues/293
|
||||
[mypy-ConfigSpace.*]
|
||||
|
@ -65,6 +69,10 @@ ignore_missing_imports = True
|
|||
[mypy-emukit.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# https://github.com/dpranke/pyjson5/issues/65
|
||||
[mypy-json5]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# https://github.com/matplotlib/matplotlib/issues/25634
|
||||
[mypy-matplotlib.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
Загрузка…
Ссылка в новой задаче