mypy static type checking for mlos_bench (#306)

Adds static type checking via mypy for mlos_bench as well.

Builds on #305, #307

- [x] Add `Protocol`s for different `Service` types so `Environment`s can check that the appropriate `Service` mix-ins have been setup.
  - [x] Config Loader
  - [x] Local Exec
  - [x] Remote Exec
  - [x] Remote Fileshare
  - [x] VM Ops
    - [ ] Future PR: Split VM Ops to VM/Host and OS Ops (for local SSH support)
This commit is contained in:
Brian Kroth 2023-04-19 13:28:38 -05:00 коммит произвёл GitHub
Родитель 10e4fa5b91
Коммит 9fe180a9a5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
70 изменённых файлов: 1291 добавлений и 448 удалений

5
.vscode/settings.json поставляемый
Просмотреть файл

@ -1,7 +1,10 @@
// vim: set ft=jsonc:
{
"makefile.extensionOutputFolder": "./.vscode",
"python.defaultInterpreterPath": "${env:HOME}${env:USERPROFILE}/.conda/envs/mlos_core/bin/python",
// Note: this only works in WSL/Linux currently.
"python.defaultInterpreterPath": "${env:HOME}/.conda/envs/mlos_core/bin/python",
// For Windows it should be this instead:
//"python.defaultInterpreterPath": "${env:USERPROFILE}/.conda/envs/mlos_core/python.exe",
"python.testing.pytestEnabled": true,
"python.linting.enabled": true,
"python.linting.pylintEnabled": true,

Просмотреть файл

@ -12,7 +12,7 @@ MLOS_BENCH_PYTHON_FILES := $(shell find ./mlos_bench/ -type f -name '*.py' 2>/de
DOCKER := $(shell which docker)
# Make sure the build directory exists.
MKDIR_BUILD := $(shell mkdir -p build)
MKDIR_BUILD := $(shell test -d build || mkdir build)
# Allow overriding the default verbosity of conda for CI jobs.
#CONDA_INFO_LEVEL ?= -q
@ -113,7 +113,7 @@ build/pylint.%.${CONDA_ENV_NAME}.build-stamp: build/conda-env.${CONDA_ENV_NAME}.
touch $@
.PHONY: mypy
mypy: conda-env build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp # TODO: build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp
mypy: conda-env build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp
build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES)
build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES)
@ -339,6 +339,7 @@ build/check-doc.build-stamp: doc/build/html/index.html doc/build/html/htmlcov/in
-e 'Problems with "include" directive path:' \
-e 'duplicate object description' \
-e "document isn't included in any toctree" \
-e "more than one target found for cross-reference" \
-e "toctree contains reference to nonexisting document 'auto_examples/index'" \
-e "failed to import function 'create' from module '(SpaceAdapter|Optimizer)Factory'" \
-e "No module named '(SpaceAdapter|Optimizer)Factory'" \

Просмотреть файл

@ -1,6 +1,6 @@
FROM nginx:latest
RUN apt-get update && apt-get install -y --no-install-recommends linklint curl
# Our ngxinx config overrides the default listening port.
# Our nginx config overrides the default listening port.
ARG NGINX_PORT=81
ENV NGINX_PORT=${NGINX_PORT}
EXPOSE ${NGINX_PORT}

Просмотреть файл

@ -156,6 +156,12 @@ Service Mix-ins
Service
FileShareService
.. currentmodule:: mlos_bench.service.config_persistence
.. autosummary::
:toctree: generated/
:template: class.rst
ConfigPersistenceService
Local Services

Просмотреть файл

@ -13,7 +13,7 @@ import json
import argparse
def _main(fname_input: str, fname_output: str):
def _main(fname_input: str, fname_output: str) -> None:
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
for (key, val) in json.load(fh_tunables).items():

Просмотреть файл

@ -12,7 +12,7 @@ import argparse
import pandas as pd
def _main(input_file: str, output_file: str):
def _main(input_file: str, output_file: str) -> None:
"""
Re-shape Redis benchmark CSV results from wide to long.
"""

Просмотреть файл

@ -13,7 +13,7 @@ import json
import argparse
def _main(fname_input: str, fname_output: str):
def _main(fname_input: str, fname_output: str) -> None:
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
for (key, val) in json.load(fh_tunables).items():

Просмотреть файл

@ -13,7 +13,7 @@ import json
import argparse
def _main(fname_input: str, fname_output: str):
def _main(fname_input: str, fname_output: str) -> None:
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
for (key, val) in json.load(fh_tunables).items():

Просмотреть файл

@ -9,9 +9,12 @@ A hierarchy of benchmark environments.
import abc
import json
import logging
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple
from mlos_bench.environment.status import Status
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.config_loader_type import SupportsConfigLoading
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import instantiate_from_config
@ -19,13 +22,21 @@ _LOG = logging.getLogger(__name__)
class Environment(metaclass=abc.ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
An abstract base of all benchmark environments.
"""
@classmethod
def new(cls, env_name, class_name, config, global_config=None, tunables=None, service=None):
# pylint: disable=too-many-arguments
def new(cls,
*,
env_name: str,
class_name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
) -> "Environment":
"""
Factory method for a new environment with a given config.
@ -58,8 +69,13 @@ class Environment(metaclass=abc.ABCMeta):
return instantiate_from_config(cls, class_name, env_name, config,
global_config, tunables, service)
def __init__(self, name, config, global_config=None, tunables=None, service=None):
# pylint: disable=too-many-arguments
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
"""
Create a new environment with a given config.
@ -84,10 +100,17 @@ class Environment(metaclass=abc.ABCMeta):
self.config = config
self._service = service
self._is_ready = False
self._params = {}
self._params: Dict[str, TunableValue] = {}
self._config_loader_service: SupportsConfigLoading
if self._service is not None and isinstance(self._service, SupportsConfigLoading):
self._config_loader_service = self._service
if global_config is None:
global_config = {}
self._const_args = config.get("const_args", {})
for key in set(self._const_args).intersection(global_config or {}):
for key in set(self._const_args).intersection(global_config):
self._const_args[key] = global_config[key]
for key in config.get("required_args", []):
@ -110,13 +133,13 @@ class Environment(metaclass=abc.ABCMeta):
_LOG.debug("Config for: %s\n%s",
name, json.dumps(self.config, indent=2))
def __str__(self):
def __str__(self) -> str:
return self.name
def __repr__(self):
def __repr__(self) -> str:
return f"Env: {self.__class__} :: '{self.name}'"
def _combine_tunables(self, tunables):
def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]:
"""
Plug tunable values into the base config. If the tunable group is unknown,
ignore it (it might belong to another environment). This method should
@ -130,14 +153,14 @@ class Environment(metaclass=abc.ABCMeta):
Returns
-------
params : dict
params : Dict[str, Union[int, float, str]]
Free-format dictionary that contains the new environment configuration.
"""
return tunables.get_param_values(
group_names=self._tunable_params.get_names(),
group_names=list(self._tunable_params.get_names()),
into_params=self._const_args.copy())
def tunable_params(self):
def tunable_params(self) -> TunableGroups:
"""
Get the configuration space of the given environment.
@ -148,7 +171,7 @@ class Environment(metaclass=abc.ABCMeta):
"""
return self._tunable_params
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Set up a new benchmark environment, if necessary. This method must be
idempotent, i.e., calling it several times in a row should be
@ -170,15 +193,18 @@ class Environment(metaclass=abc.ABCMeta):
_LOG.info("Setup %s :: %s", self, tunables)
assert isinstance(tunables, TunableGroups)
if global_config is None:
global_config = {}
self._params = self._combine_tunables(tunables)
for key in set(self._params).intersection(global_config or {}):
for key in set(self._params).intersection(global_config):
self._params[key] = global_config[key]
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2))
return True
def teardown(self):
def teardown(self) -> None:
"""
Tear down the benchmark environment. This method must be idempotent,
i.e., calling it several times in a row should be equivalent to a

Просмотреть файл

@ -7,7 +7,7 @@ Composite benchmark environment.
"""
import logging
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from mlos_bench.service.base_service import Service
from mlos_bench.environment.status import Status
@ -22,9 +22,13 @@ class CompositeEnv(Environment):
Composite benchmark environment.
"""
def __init__(self, name: str, config: dict, global_config: dict = None,
tunables: TunableGroups = None, service: Service = None):
# pylint: disable=too-many-arguments
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
"""
Create a new environment with a given config.
@ -44,23 +48,23 @@ class CompositeEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name, config, global_config, tunables, service)
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
self._children = []
self._children: List[Environment] = []
for child_config_file in config.get("include_children", []):
for env in self._service.load_environment_list(
for env in self._config_loader_service.load_environment_list(
child_config_file, global_config, tunables, self._service):
self._add_child(env)
for child_config in config.get("children", []):
self._add_child(self._service.build_environment(
self._add_child(self._config_loader_service.build_environment(
child_config, global_config, tunables, self._service))
if not self._children:
raise ValueError("At least one child environment must be present")
def _add_child(self, env: Environment):
def _add_child(self, env: Environment) -> None:
"""
Add a new child environment to the composite environment.
This method is called from the constructor only.
@ -68,7 +72,7 @@ class CompositeEnv(Environment):
self._children.append(env)
self._tunable_params.update(env.tunable_params())
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Set up the children environments.
@ -92,7 +96,7 @@ class CompositeEnv(Environment):
)
return self._is_ready
def teardown(self):
def teardown(self) -> None:
"""
Tear down the children environments. This method is idempotent,
i.e., calling it several times is equivalent to a single call.

Просмотреть файл

@ -17,23 +17,25 @@ import pandas
from mlos_bench.environment.status import Status
from mlos_bench.environment.base_environment import Environment
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.local_exec_type import SupportsLocalExec
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
class LocalEnv(Environment):
# pylint: disable=too-many-instance-attributes
"""
Scheduler-side Environment that runs scripts locally.
"""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
# pylint: disable=too-many-arguments
"""
Create a new environment for local execution.
@ -56,15 +58,19 @@ class LocalEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name, config, global_config, tunables, service)
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
"LocalEnv requires a service that supports local execution"
self._local_exec_service: SupportsLocalExec = self._service
self._temp_dir = self.config.get("temp_dir")
self._script_setup = self.config.get("setup")
self._script_run = self.config.get("run")
self._script_teardown = self.config.get("teardown")
self._dump_params_file = self.config.get("dump_params_file")
self._read_results_file = self.config.get("read_results_file")
self._dump_params_file: Optional[str] = self.config.get("dump_params_file")
self._read_results_file: Optional[str] = self.config.get("read_results_file")
if self._script_setup is None and \
self._script_run is None and \
@ -77,7 +83,7 @@ class LocalEnv(Environment):
if self._script_run is None and self._read_results_file is not None:
raise ValueError("'run' must be present if 'read_results_file' is specified")
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Check if the environment is ready and set up the application
and benchmarks, if necessary.
@ -105,7 +111,7 @@ class LocalEnv(Environment):
self._is_ready = True
return True
with self._service.temp_dir_context(self._temp_dir) as temp_dir:
with self._local_exec_service.temp_dir_context(self._temp_dir) as temp_dir:
_LOG.info("Set up the environment locally: %s at %s", self, temp_dir)
@ -116,7 +122,7 @@ class LocalEnv(Environment):
json.dump(tunables.get_param_values(self._tunable_params.get_names()),
fh_tunables)
(return_code, _stdout, stderr) = self._service.local_exec(
(return_code, _stdout, stderr) = self._local_exec_service.local_exec(
self._script_setup, env=self._params, cwd=temp_dir)
if return_code == 0:
@ -125,7 +131,7 @@ class LocalEnv(Environment):
_LOG.warning("ERROR: Local setup returns with code %d stderr:\n%s",
return_code, stderr)
self._is_ready = return_code == 0
self._is_ready = bool(return_code == 0)
return self._is_ready
@ -145,10 +151,10 @@ class LocalEnv(Environment):
if not (status.is_ready and self._script_run):
return result
with self._service.temp_dir_context(self._temp_dir) as temp_dir:
with self._local_exec_service.temp_dir_context(self._temp_dir) as temp_dir:
_LOG.info("Run script locally on: %s at %s", self, temp_dir)
(return_code, _stdout, stderr) = self._service.local_exec(
(return_code, _stdout, stderr) = self._local_exec_service.local_exec(
self._script_run, env=self._params, cwd=temp_dir)
if return_code != 0:
@ -156,23 +162,24 @@ class LocalEnv(Environment):
return_code, stderr)
return (Status.FAILED, None)
data = pandas.read_csv(self._service.resolve_path(
assert self._read_results_file is not None
data = pandas.read_csv(self._config_loader_service.resolve_path(
self._read_results_file, extra_paths=[temp_dir]))
_LOG.debug("Read data:\n%s", data)
if len(data) != 1:
_LOG.warning("Local run has %d results - returning the last one", len(data))
data = data.iloc[-1].to_dict()
_LOG.info("Local run complete: %s ::\n%s", self, data)
return (Status.SUCCEEDED, data)
data_dict = data.iloc[-1].to_dict()
_LOG.info("Local run complete: %s ::\n%s", self, data_dict)
return (Status.SUCCEEDED, data_dict)
def teardown(self):
def teardown(self) -> None:
"""
Clean up the local environment.
"""
if self._script_teardown:
_LOG.info("Local teardown: %s", self)
(status, _) = self._service.local_exec(self._script_teardown, env=self._params)
(status, _, _) = self._local_exec_service.local_exec(self._script_teardown, env=self._params)
_LOG.info("Local teardown complete: %s :: %s", self, status)
super().teardown()

Просмотреть файл

@ -10,11 +10,14 @@ and upload/download data to the shared storage.
import logging
from string import Template
from typing import Optional, Dict, List, Tuple, Any
from typing import List, Generator, Iterable, Mapping, Optional, Tuple
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.local_exec_type import SupportsLocalExec
from mlos_bench.service.types.fileshare_type import SupportsFileShareOps
from mlos_bench.environment.status import Status
from mlos_bench.environment.local.local_env import LocalEnv
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
@ -27,12 +30,12 @@ class LocalFileShareEnv(LocalEnv):
"""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
# pylint: disable=too-many-arguments
"""
Create a new application environment with a given config.
@ -56,7 +59,16 @@ class LocalFileShareEnv(LocalEnv):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name, config, global_config, tunables, service)
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
"LocalEnv requires a service that supports local execution"
self._local_exec_service: SupportsLocalExec = self._service
assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \
"LocalEnv requires a service that supports file upload/download operations"
self._file_share_service: SupportsFileShareOps = self._service
self._upload = self._template_from_to("upload")
self._download = self._template_from_to("download")
@ -71,7 +83,8 @@ class LocalFileShareEnv(LocalEnv):
]
@staticmethod
def _expand(from_to: List[Tuple[Template, Template]], params: Dict[str, Any]):
def _expand(from_to: Iterable[Tuple[Template, Template]],
params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]:
"""
Substitute $var parameters in from/to path templates.
Return a generator of (str, str) pairs of paths.
@ -81,7 +94,7 @@ class LocalFileShareEnv(LocalEnv):
for (path_from, path_to) in from_to
)
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Run setup scripts locally and upload the scripts and data to the shared storage.
@ -102,14 +115,14 @@ class LocalFileShareEnv(LocalEnv):
True if operation is successful, false otherwise.
"""
prev_temp_dir = self._temp_dir
with self._service.temp_dir_context(self._temp_dir) as self._temp_dir:
with self._local_exec_service.temp_dir_context(self._temp_dir) as self._temp_dir:
# Override _temp_dir so that setup and upload both use the same path.
self._is_ready = super().setup(tunables, global_config)
if self._is_ready:
params = self._params.copy()
params["PWD"] = self._temp_dir
for (path_from, path_to) in self._expand(self._upload, params):
self._service.upload(self._service.resolve_path(
self._file_share_service.upload(self._config_loader_service.resolve_path(
path_from, extra_paths=[self._temp_dir]), path_to)
self._temp_dir = prev_temp_dir
return self._is_ready
@ -128,13 +141,13 @@ class LocalFileShareEnv(LocalEnv):
be in the `score` field.
"""
prev_temp_dir = self._temp_dir
with self._service.temp_dir_context(self._temp_dir) as self._temp_dir:
with self._local_exec_service.temp_dir_context(self._temp_dir) as self._temp_dir:
# Override _temp_dir so that download and run both use the same path.
params = self._params.copy()
params["PWD"] = self._temp_dir
for (path_from, path_to) in self._expand(self._download, params):
self._service.download(
path_from, self._service.resolve_path(
self._file_share_service.download(
path_from, self._config_loader_service.resolve_path(
path_to, extra_paths=[self._temp_dir]))
result = super().run()
self._temp_dir = prev_temp_dir

Просмотреть файл

@ -12,9 +12,9 @@ from typing import Optional, Tuple
import numpy
from mlos_bench.service.base_service import Service
from mlos_bench.environment.status import Status
from mlos_bench.environment.base_environment import Environment
from mlos_bench.service.base_service import Service
from mlos_bench.tunables import Tunable, TunableGroups
_LOG = logging.getLogger(__name__)
@ -28,12 +28,12 @@ class MockEnv(Environment):
_NOISE_VAR = 0.2 # Variance of the Gaussian noise added to the benchmark value.
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
# pylint: disable=too-many-arguments
"""
Create a new environment that produces mock benchmark data.
@ -52,7 +52,7 @@ class MockEnv(Environment):
service: Service
An optional service object. Not used by this class.
"""
super().__init__(name, config, global_config, tunables, service)
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
seed = self.config.get("seed")
self._random = random.Random(seed) if seed is not None else None
self._range = self.config.get("range")
@ -95,15 +95,14 @@ class MockEnv(Environment):
That is, map current value to the [0, 1] range.
"""
val = None
if tunable.type == "categorical":
val = (tunable.categorical_values.index(tunable.value) /
if tunable.is_categorical:
val = (tunable.categorical_values.index(tunable.categorical_value) /
float(len(tunable.categorical_values) - 1))
elif tunable.type in {"int", "float"}:
if not tunable.range:
raise ValueError("Tunable must have a range: " + tunable.name)
val = ((tunable.value - tunable.range[0]) /
elif tunable.is_numerical:
val = ((tunable.numerical_value - tunable.range[0]) /
float(tunable.range[1] - tunable.range[0]))
else:
raise ValueError("Invalid parameter type: " + tunable.type)
# Explicitly clip the value in case of numerical errors.
return numpy.clip(val, 0, 1)
ret: float = numpy.clip(val, 0, 1)
return ret

Просмотреть файл

@ -6,9 +6,14 @@
OS-level remote Environment on Azure.
"""
from typing import Optional
import logging
from mlos_bench.environment import Environment, Status
from mlos_bench.environment.base_environment import Environment
from mlos_bench.environment.status import Status
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
@ -19,7 +24,43 @@ class OSEnv(Environment):
OS Level Environment for a host.
"""
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
"""
Create a new environment for remote execution.
Parameters
----------
name: str
Human-readable name of the environment.
config : dict
Free-format dictionary that contains the benchmark environment
configuration. Each config must have at least the "tunable_params"
and the "const_args" sections.
`RemoteEnv` must also have at least some of the following parameters:
{setup, run, teardown, wait_boot}
global_config : dict
Free-format dictionary of global parameters (e.g., security credentials)
to be mixed in into the "const_args" section of the local config.
tunables : TunableGroups
A collection of tunable parameters for *all* environments.
service: Service
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
# TODO: Refactor this as "host" and "os" operations to accommodate SSH service.
assert self._service is not None and isinstance(self._service, SupportsVMOps), \
"RemoteEnv requires a service that supports host operations"
self._host_service: SupportsVMOps = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Check if the host is up and running; boot it, if necessary.
@ -43,21 +84,21 @@ class OSEnv(Environment):
if not super().setup(tunables, global_config):
return False
(status, params) = self._service.vm_start(self._params)
(status, params) = self._host_service.vm_start(self._params)
if status.is_pending:
(status, _) = self._service.wait_vm_operation(params)
(status, _) = self._host_service.wait_vm_operation(params)
self._is_ready = status in {Status.SUCCEEDED, Status.READY}
return self._is_ready
def teardown(self):
def teardown(self) -> None:
"""
Clean up and shut down the host without deprovisioning it.
"""
_LOG.info("OS tear down: %s", self)
(status, params) = self._service.vm_stop()
(status, params) = self._host_service.vm_stop()
if status.is_pending:
(status, _) = self._service.wait_vm_operation(params)
(status, _) = self._host_service.wait_vm_operation(params)
super().teardown()
_LOG.debug("Final status of OS stopping: %s :: %s", self, status)

Просмотреть файл

@ -7,11 +7,13 @@ Remotely executed benchmark/script environment.
"""
import logging
from typing import Optional, Tuple, List
from typing import Iterable, Optional, Tuple
from mlos_bench.environment.status import Status
from mlos_bench.environment.base_environment import Environment
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
@ -23,12 +25,12 @@ class RemoteEnv(Environment):
"""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
# pylint: disable=too-many-arguments
"""
Create a new environment for remote execution.
@ -51,7 +53,16 @@ class RemoteEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name, config, global_config, tunables, service)
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \
"RemoteEnv requires a service that supports remote execution operations"
self._remote_exec_service: SupportsRemoteExec = self._service
# TODO: Refactor this as "host" and "os" operations to accommodate SSH service.
assert self._service is not None and isinstance(self._service, SupportsVMOps), \
"RemoteEnv requires a service that supports host operations"
self._host_service: SupportsVMOps = self._service
self._wait_boot = self.config.get("wait_boot", False)
self._script_setup = self.config.get("setup")
@ -65,7 +76,7 @@ class RemoteEnv(Environment):
raise ValueError("At least one of {setup, run, teardown}" +
" must be present or wait_boot set to True.")
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Check if the environment is ready and set up the application
and benchmarks on a remote host.
@ -89,9 +100,9 @@ class RemoteEnv(Environment):
if self._wait_boot:
_LOG.info("Wait for the remote environment to start: %s", self)
(status, params) = self._service.vm_start(self._params)
(status, params) = self._host_service.vm_start(self._params)
if status.is_pending:
(status, _) = self._service.wait_vm_operation(params)
(status, _) = self._host_service.wait_vm_operation(params)
if not status.is_succeeded:
return False
@ -130,7 +141,7 @@ class RemoteEnv(Environment):
_LOG.info("Remote run complete: %s :: %s", self, result)
return result
def teardown(self):
def teardown(self) -> None:
"""
Clean up and shut down the remote environment.
"""
@ -140,7 +151,7 @@ class RemoteEnv(Environment):
_LOG.info("Remote teardown complete: %s :: %s", self, status)
super().teardown()
def _remote_exec(self, script: List[str]) -> Tuple[Status, Optional[dict]]:
def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, Optional[dict]]:
"""
Run a script on the remote host.
@ -156,10 +167,10 @@ class RemoteEnv(Environment):
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
"""
_LOG.debug("Submit script: %s", self)
(status, output) = self._service.remote_exec(script, self._params)
(status, output) = self._remote_exec_service.remote_exec(script, self._params)
_LOG.debug("Script submitted: %s %s :: %s", self, status, output)
if status in {Status.PENDING, Status.SUCCEEDED}:
(status, output) = self._service.get_remote_exec_results(output)
(status, output) = self._remote_exec_service.get_remote_exec_results(output)
# TODO: extract the results from `output`.
_LOG.debug("Status: %s :: %s", status, output)
return (status, output)

Просмотреть файл

@ -6,9 +6,13 @@
"Remote" VM Environment.
"""
from typing import Optional
import logging
from mlos_bench.environment import Environment
from mlos_bench.environment.base_environment import Environment
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
@ -19,7 +23,40 @@ class VMEnv(Environment):
"Remote" VM environment.
"""
def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool:
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
"""
Create a new environment for VM operations.
Parameters
----------
name: str
Human-readable name of the environment.
config : dict
Free-format dictionary that contains the benchmark environment
configuration. Each config must have at least the "tunable_params"
and the "const_args" sections.
global_config : dict
Free-format dictionary of global parameters (e.g., security credentials)
to be mixed in into the "const_args" section of the local config.
tunables : TunableGroups
A collection of tunable parameters for *all* environments.
service: Service
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
assert self._service is not None and isinstance(self._service, SupportsVMOps), \
"VMEnv requires a service that supports VM operations"
self._vm_service: SupportsVMOps = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Check if VM is ready. (Re)provision and start it, if necessary.
@ -43,21 +80,21 @@ class VMEnv(Environment):
if not super().setup(tunables, global_config):
return False
(status, params) = self._service.vm_provision(self._params)
(status, params) = self._vm_service.vm_provision(self._params)
if status.is_pending:
(status, _) = self._service.wait_vm_deployment(True, params)
(status, _) = self._vm_service.wait_vm_deployment(True, params)
self._is_ready = status.is_succeeded
return self._is_ready
def teardown(self):
def teardown(self) -> None:
"""
Shut down the VM and release it.
"""
_LOG.info("VM tear down: %s", self)
(status, params) = self._service.vm_deprovision()
(status, params) = self._vm_service.vm_deprovision()
if status.is_pending:
(status, _) = self._service.wait_vm_deployment(False, params)
(status, _) = self._vm_service.wait_vm_deployment(False, params)
super().teardown()
_LOG.debug("Final status of VM deprovisioning: %s :: %s", self, status)

Просмотреть файл

@ -24,7 +24,7 @@ class Status(enum.Enum):
TIMED_OUT = 7
@property
def is_good(self):
def is_good(self) -> bool:
"""
Check if the status of the benchmark/environment is good.
"""
@ -36,42 +36,42 @@ class Status(enum.Enum):
}
@property
def is_pending(self):
def is_pending(self) -> bool:
"""
Check if the status of the benchmark/environment is PENDING.
"""
return self == Status.PENDING
@property
def is_ready(self):
def is_ready(self) -> bool:
"""
Check if the status of the benchmark/environment is READY.
"""
return self == Status.READY
@property
def is_succeeded(self):
def is_succeeded(self) -> bool:
"""
Check if the status of the benchmark/environment is SUCCEEDED.
"""
return self == Status.SUCCEEDED
@property
def is_failed(self):
def is_failed(self) -> bool:
"""
Check if the status of the benchmark/environment is FAILED.
"""
return self == Status.FAILED
@property
def is_canceled(self):
def is_canceled(self) -> bool:
"""
Check if the status of the benchmark/environment is CANCELED.
"""
return self == Status.CANCELED
@property
def is_timed_out(self):
def is_timed_out(self) -> bool:
"""
Check if the status of the benchmark/environment is TIMEDOUT.
"""

Просмотреть файл

@ -9,10 +9,11 @@ Helper functions to launch the benchmark and the optimizer from the command line
import logging
import argparse
from typing import List, Dict, Any
from typing import Any, Dict, Iterable
from mlos_bench.environment import Environment
from mlos_bench.service import LocalExecService, ConfigPersistenceService
from mlos_bench.environment.base_environment import Environment
from mlos_bench.service.local.local_exec import LocalExecService
from mlos_bench.service.config_persistence import ConfigPersistenceService
_LOG_LEVEL = logging.INFO
_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s'
@ -30,9 +31,9 @@ class Launcher:
_LOG.info("Launch: %s", description)
self._config_loader = None
self._env_config_file = None
self._global_config = {}
self._config_loader: ConfigPersistenceService
self._env_config_file: str
self._global_config: Dict[str, Any] = {}
self._parser = argparse.ArgumentParser(description=description)
self._parser.add_argument(
@ -89,7 +90,9 @@ class Launcher:
if args.globals is not None:
for config_file in args.globals:
self._global_config.update(self._config_loader.load_config(config_file))
conf = self._config_loader.load_config(config_file)
assert isinstance(conf, dict)
self._global_config.update(conf)
self._global_config.update(Launcher._try_parse_extra_args(args_rest))
if args.config_path:
@ -98,7 +101,7 @@ class Launcher:
return args
@staticmethod
def _try_parse_extra_args(cmdline: List[str]) -> Dict[str, str]:
def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, str]:
"""
Helper function to parse global key/value pairs from the command line.
"""
@ -132,7 +135,9 @@ class Launcher:
Load JSON config file. Use path relative to `config_path` if required.
"""
assert self._config_loader is not None, "Call after invoking .parse_args()"
return self._config_loader.load_config(json_file_name)
conf = self._config_loader.load_config(json_file_name)
assert isinstance(conf, dict)
return conf
def load_env(self) -> Environment:
"""

Просмотреть файл

@ -8,7 +8,7 @@ and mlos_core optimizers.
"""
import logging
from typing import Optional, Tuple, List, Union
from typing import Dict, Optional, Sequence, Tuple, Union
from abc import ABCMeta, abstractmethod
from mlos_bench.environment.status import Status
@ -24,7 +24,7 @@ class Optimizer(metaclass=ABCMeta):
"""
@staticmethod
def load(tunables: TunableGroups, config: dict, global_config: dict = None):
def load(tunables: TunableGroups, config: dict, global_config: Optional[dict] = None) -> "Optimizer":
"""
Instantiate the Optimizer shim from the configuration.
@ -48,7 +48,7 @@ class Optimizer(metaclass=ABCMeta):
return opt
@classmethod
def new(cls, class_name: str, tunables: TunableGroups, config: dict):
def new(cls, class_name: str, tunables: TunableGroups, config: dict) -> "Optimizer":
"""
Factory method for a new optimizer with a given config.
@ -89,11 +89,13 @@ class Optimizer(metaclass=ABCMeta):
self._tunables = tunables
self._iter = 1
self._max_iter = int(self._config.pop('max_iterations', 10))
self._opt_target = self._config.pop('maximize', None)
if self._opt_target is None:
self._opt_target: str
_opt_target = self._config.pop('maximize', None)
if _opt_target is None:
self._opt_target = self._config.pop('minimize', 'score')
self._opt_sign = 1
else:
self._opt_target = _opt_target
if 'minimize' in self._config:
raise ValueError("Cannot specify both 'maximize' and 'minimize'.")
self._opt_sign = -1
@ -110,18 +112,18 @@ class Optimizer(metaclass=ABCMeta):
return self._opt_target
@abstractmethod
def bulk_register(self, configs: List[dict], scores: List[float],
status: Optional[List[Status]] = None) -> bool:
def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]],
status: Optional[Sequence[Status]] = None) -> bool:
"""
Pre-load the optimizer with the bulk data from previous experiments.
Parameters
----------
configs : List[dict]
configs : Sequence[dict]
Records of tunable values from other experiments.
scores : List[float]
scores : Sequence[float]
Benchmark results from experiments that correspond to `configs`.
status : Optional[List[float]]
status : Optional[Sequence[float]]
Status of the experiments that correspond to `configs`.
Returns
@ -152,7 +154,7 @@ class Optimizer(metaclass=ABCMeta):
@abstractmethod
def register(self, tunables: TunableGroups, status: Status,
score: Union[float, dict] = None) -> float:
score: Optional[Union[float, Dict[str, float]]] = None) -> Optional[float]:
"""
Register the observation for the given configuration.
@ -163,7 +165,7 @@ class Optimizer(metaclass=ABCMeta):
Usually it's the same config that the `.suggest()` method returned.
status : Status
Final status of the experiment (e.g., SUCCEEDED or FAILED).
score : Union[float, dict]
score : Union[float, Dict[str, float]]
A scalar or a dict with the final benchmark results.
None if the experiment was not successful.
@ -178,7 +180,7 @@ class Optimizer(metaclass=ABCMeta):
raise ValueError("Status and score must be consistent.")
return self._get_score(status, score)
def _get_score(self, status: Status, score: Union[float, dict]) -> float:
def _get_score(self, status: Status, score: Optional[Union[float, Dict[str, float]]]) -> Optional[float]:
"""
Extract a scalar benchmark score from the dataframe.
Change the sign if we are maximizing.
@ -187,7 +189,7 @@ class Optimizer(metaclass=ABCMeta):
----------
status : Status
Final status of the experiment (e.g., SUCCEEDED or FAILED).
score : Union[float, dict]
score : Union[float, Dict[str, float]]
A scalar or a dict with the final benchmark results.
None if the experiment was not successful.
@ -198,6 +200,7 @@ class Optimizer(metaclass=ABCMeta):
"""
if not status.is_succeeded:
return None
assert score is not None
if isinstance(score, dict):
score = score[self._opt_target]
return float(score) * self._opt_sign
@ -210,7 +213,7 @@ class Optimizer(metaclass=ABCMeta):
return self._iter <= self._max_iter
@abstractmethod
def get_best_observation(self) -> Tuple[float, TunableGroups]:
def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
"""
Get the best observation so far.

Просмотреть файл

@ -8,6 +8,8 @@ Functions to convert TunableGroups to ConfigSpace for use with the mlos_core opt
import logging
from typing import Optional
from ConfigSpace.hyperparameters import Hyperparameter
from ConfigSpace import UniformIntegerHyperparameter
from ConfigSpace import UniformFloatHyperparameter
@ -21,7 +23,7 @@ _LOG = logging.getLogger(__name__)
def _tunable_to_hyperparameter(
tunable: Tunable, group_name: str = None, cost: int = 0) -> Hyperparameter:
tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> Hyperparameter:
"""
Convert a single Tunable to an equivalent ConfigSpace Hyperparameter object.

Просмотреть файл

@ -7,7 +7,7 @@ A wrapper for mlos_core optimizers for mlos_bench.
"""
import logging
from typing import Optional, Tuple, List, Union
from typing import Optional, Sequence, Tuple, Union
import pandas as pd
@ -47,17 +47,17 @@ class MlosCoreOptimizer(Optimizer):
space_adapter_type=space_adapter_type,
space_adapter_kwargs=space_adapter_config)
def bulk_register(self, configs: List[dict], scores: List[float],
status: Optional[List[Status]] = None) -> bool:
def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]],
status: Optional[Sequence[Status]] = None) -> bool:
if not super().bulk_register(configs, scores, status):
return False
tunables_names = list(self._tunables.get_param_values().keys())
df_configs = pd.DataFrame(configs)[tunables_names]
df_scores = pd.Series(scores, dtype=float) * self._opt_sign
if status is not None:
df_status_succ = pd.Series(status) == Status.SUCCEEDED
df_configs = df_configs[df_status_succ]
df_scores = df_scores[df_status_succ]
df_status_ok = pd.Series(status) == Status.SUCCEEDED
df_configs = df_configs[df_status_ok]
df_scores = df_scores[df_status_ok]
self._opt.register(df_configs, df_scores)
if _LOG.isEnabledFor(logging.DEBUG):
(score, _) = self.get_best_observation()
@ -70,7 +70,7 @@ class MlosCoreOptimizer(Optimizer):
return self._tunables.copy().assign(df_config.loc[0].to_dict())
def register(self, tunables: TunableGroups, status: Status,
score: Union[float, dict] = None) -> float:
score: Optional[Union[float, dict]] = None) -> Optional[float]:
score = super().register(tunables, status, score)
# TODO: mlos_core currently does not support registration of failed trials:
if status.is_succeeded:
@ -81,7 +81,7 @@ class MlosCoreOptimizer(Optimizer):
self._iter += 1
return score
def get_best_observation(self) -> Tuple[float, TunableGroups]:
def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
df_config = self._opt.get_best_observation()
if len(df_config) == 0:
return (None, None)

Просмотреть файл

@ -8,9 +8,11 @@ Mock optimizer for mlos_bench.
import random
import logging
from typing import Optional, Tuple, List, Union
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
from mlos_bench.environment.status import Status
from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizer.base_optimizer import Optimizer
@ -26,16 +28,16 @@ class MockOptimizer(Optimizer):
def __init__(self, tunables: TunableGroups, config: dict):
super().__init__(tunables, config)
rnd = random.Random(config.get("seed", 42))
self._random = {
self._random: Dict[str, Callable[[Tunable], TunableValue]] = {
"categorical": lambda tunable: rnd.choice(tunable.categorical_values),
"float": lambda tunable: rnd.uniform(*tunable.range),
"int": lambda tunable: rnd.randint(*tunable.range),
}
self._best_config = None
self._best_score = None
self._best_config: Optional[TunableGroups] = None
self._best_score: Optional[float] = None
def bulk_register(self, configs: List[dict], scores: List[float],
status: Optional[List[Status]] = None) -> bool:
def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]],
status: Optional[Sequence[Status]] = None) -> bool:
if not super().bulk_register(configs, scores, status):
return False
if status is None:
@ -60,15 +62,18 @@ class MockOptimizer(Optimizer):
return tunables
def register(self, tunables: TunableGroups, status: Status,
score: Union[float, dict] = None) -> float:
score = super().register(tunables, status, score)
if status.is_succeeded and (self._best_score is None or score < self._best_score):
self._best_score = score
score: Optional[Union[float, dict]] = None) -> Optional[float]:
registered_score = super().register(tunables, status, score)
if status.is_succeeded and (
self._best_score is None or (registered_score is not None and registered_score < self._best_score)
):
self._best_score = registered_score
self._best_config = tunables.copy()
self._iter += 1
return score
return registered_score
def get_best_observation(self) -> Tuple[float, TunableGroups]:
def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
if self._best_score is None:
return (None, None)
assert self._best_config is not None
return (self._best_score * self._opt_sign, self._best_config)

Просмотреть файл

Просмотреть файл

@ -17,7 +17,7 @@ from mlos_bench.launcher import Launcher
_LOG = logging.getLogger(__name__)
def _main():
def _main() -> None:
launcher = Launcher("mlos_bench run_bench")

Просмотреть файл

@ -9,16 +9,20 @@ OS Autotune main optimization loop.
See `--help` output for details.
"""
from typing import Tuple, Union
import logging
from mlos_bench.launcher import Launcher
from mlos_bench.optimizer import Optimizer
from mlos_bench.environment import Status, Environment
from mlos_bench.optimizer.base_optimizer import Optimizer
from mlos_bench.environment.base_environment import Environment
from mlos_bench.environment.status import Status
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
def _main():
def _main() -> None:
launcher = Launcher("mlos_bench run_opt")
@ -40,7 +44,7 @@ def _main():
_LOG.info("Final result: %s", result)
def _optimize(env: Environment, opt: Optimizer, no_teardown: bool):
def _optimize(env: Environment, opt: Optimizer, no_teardown: bool) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
"""
Main optimization loop.
"""
@ -57,7 +61,11 @@ def _optimize(env: Environment, opt: Optimizer, no_teardown: bool):
continue
(status, output) = env.run() # Block and wait for the final result
value = output['score'] if status.is_succeeded else None
if status.is_succeeded:
assert output is not None
value = output['score']
else:
value = None
_LOG.info("Result: %s = %s :: %s", tunables, status, value)
opt.register(tunables, status, value)

Просмотреть файл

@ -8,14 +8,11 @@ Services for implementing Environments for mlos_bench.
from mlos_bench.service.base_service import Service
from mlos_bench.service.base_fileshare import FileShareService
from mlos_bench.service.local.local_exec import LocalExecService
from mlos_bench.service.config_persistence import ConfigPersistenceService
__all__ = [
'Service',
'LocalExecService',
'FileShareService',
'ConfigPersistenceService',
'LocalExecService',
]

Просмотреть файл

@ -11,11 +11,12 @@ import logging
from abc import ABCMeta, abstractmethod
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.fileshare_type import SupportsFileShareOps
_LOG = logging.getLogger(__name__)
class FileShareService(Service, metaclass=ABCMeta):
class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
"""
An abstract base of all file shares.
"""
@ -41,7 +42,7 @@ class FileShareService(Service, metaclass=ABCMeta):
])
@abstractmethod
def download(self, remote_path: str, local_path: str, recursive: bool = True):
def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None:
"""
Downloads contents from a remote share path to a local path.
@ -57,10 +58,10 @@ class FileShareService(Service, metaclass=ABCMeta):
if True (the default), download the entire directory tree.
"""
_LOG.info("Download from File Share %s recursively: %s -> %s",
"" if recursive else "non-", remote_path, local_path)
"" if recursive else "non", remote_path, local_path)
@abstractmethod
def upload(self, local_path: str, remote_path: str, recursive: bool = True):
def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None:
"""
Uploads contents from a local path to remote share path.
@ -75,4 +76,4 @@ class FileShareService(Service, metaclass=ABCMeta):
if True (the default), upload the entire directory tree.
"""
_LOG.info("Upload to File Share %s recursively: %s -> %s",
"" if recursive else "non-", local_path, remote_path)
"" if recursive else "non", local_path, remote_path)

Просмотреть файл

@ -9,8 +9,9 @@ Base class for the service mix-ins.
import json
import logging
from typing import Callable, Dict
from typing import Callable, Dict, List, Optional, Union
from mlos_bench.service.types.config_loader_type import SupportsConfigLoading
from mlos_bench.util import instantiate_from_config
_LOG = logging.getLogger(__name__)
@ -22,7 +23,7 @@ class Service:
"""
@classmethod
def new(cls, class_name: str, config: dict, parent):
def new(cls, class_name: str, config: dict, parent: Optional["Service"]) -> "Service":
"""
Factory method for a new service with a given config.
@ -46,7 +47,7 @@ class Service:
"""
return instantiate_from_config(cls, class_name, config, parent)
def __init__(self, config: dict = None, parent=None):
def __init__(self, config: Optional[dict] = None, parent: Optional["Service"] = None):
"""
Create a new service with a given config.
@ -61,11 +62,15 @@ class Service:
"""
self.config = config or {}
self._parent = parent
self._services = {}
self._services: Dict[str, Callable] = {}
if parent:
self.register(parent.export())
self._config_loader_service: SupportsConfigLoading
if parent and isinstance(parent, SupportsConfigLoading):
self._config_loader_service = parent
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Service: %s Config:\n%s",
self.__class__.__name__, json.dumps(self.config, indent=2))
@ -73,7 +78,7 @@ class Service:
self.__class__.__name__,
[] if parent is None else list(parent._services.keys()))
def register(self, services):
def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None:
"""
Register new mix-in services.

Просмотреть файл

@ -12,24 +12,25 @@ import os
import json # For logging only
import logging
from typing import Optional, List
from typing import List, Iterable, Optional, Union
import json5 # To read configs with comments and other JSON5 syntax features
from mlos_bench.util import prepare_class_load
from mlos_bench.environment.base_environment import Environment
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.config_loader_type import SupportsConfigLoading
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
class ConfigPersistenceService(Service):
class ConfigPersistenceService(Service, SupportsConfigLoading):
"""
Collection of methods to deserialize the Environment, Service, and TunableGroups objects.
"""
def __init__(self, config: dict = None, parent: Service = None):
def __init__(self, config: Optional[dict] = None, parent: Optional[Service] = None):
"""
Create a new instance of config persistence service.
@ -58,7 +59,7 @@ class ConfigPersistenceService(Service):
])
def resolve_path(self, file_path: str,
extra_paths: Optional[List[str]] = None) -> str:
extra_paths: Optional[Iterable[str]] = None) -> str:
"""
Prepend the suitable `_config_path` to `path` if the latter is not absolute.
If `_config_path` is `None` or `path` is absolute, return `path` as is.
@ -67,7 +68,7 @@ class ConfigPersistenceService(Service):
----------
file_path : str
Path to the input config file.
extra_paths : List[str]
extra_paths : Iterable[str]
Additional directories to prepend to the list of search paths.
Returns
@ -86,7 +87,7 @@ class ConfigPersistenceService(Service):
_LOG.debug("Path not resolved: %s", file_path)
return file_path
def load_config(self, json_file_name: str) -> dict:
def load_config(self, json_file_name: str) -> Union[dict, List[dict]]:
"""
Load JSON config file. Search for a file relative to `_config_path`
if the input path is not absolute.
@ -99,18 +100,18 @@ class ConfigPersistenceService(Service):
Returns
-------
config : dict
config : Union[dict, List[dict]]
Free-format dictionary that contains the configuration.
"""
json_file_name = self.resolve_path(json_file_name)
_LOG.info("Load config: %s", json_file_name)
with open(json_file_name, mode='r', encoding='utf-8') as fh_json:
return json5.load(fh_json)
return json5.load(fh_json) # type: ignore[no-any-return]
def build_environment(self, config: dict,
global_config: dict = None,
tunables: TunableGroups = None,
service: Service = None) -> Environment:
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None) -> Environment:
"""
Factory method for a new environment with a given config.
@ -146,15 +147,17 @@ class ConfigPersistenceService(Service):
tunables = self.load_tunables(env_tunables_path, tunables)
_LOG.debug("Creating env: %s :: %s", env_name, env_class)
env = Environment.new(env_name, env_class, env_config, global_config, tunables, service)
env = Environment.new(env_name=env_name, class_name=env_class,
config=env_config, global_config=global_config,
tunables=tunables, service=service)
_LOG.info("Created env: %s :: %s", env_name, env)
return env
@classmethod
def _build_standalone_service(cls, config: dict,
global_config: dict = None,
parent: Service = None) -> Service:
global_config: Optional[dict] = None,
parent: Optional[Service] = None) -> Service:
"""
Factory method for a new service with a given config.
@ -180,9 +183,9 @@ class ConfigPersistenceService(Service):
return service
@classmethod
def _build_composite_service(cls, config_list: List[dict],
global_config: dict = None,
parent: Service = None) -> Service:
def _build_composite_service(cls, config_list: Iterable[dict],
global_config: Optional[dict] = None,
parent: Optional[Service] = None) -> Service:
"""
Factory method for a new service with a given config.
@ -218,8 +221,8 @@ class ConfigPersistenceService(Service):
return service
@classmethod
def build_service(cls, config: List[dict], global_config: dict = None,
parent: Service = None) -> Service:
def build_service(cls, config: Union[dict, List[dict]], global_config: Optional[dict] = None,
parent: Optional[Service] = None) -> Service:
"""
Factory method for a new service with a given config.
@ -252,7 +255,7 @@ class ConfigPersistenceService(Service):
return cls._build_composite_service(config, global_config, parent)
@staticmethod
def build_tunables(config: dict, parent: TunableGroups = None) -> TunableGroups:
def build_tunables(config: dict, parent: Optional[TunableGroups] = None) -> TunableGroups:
"""
Create a new collection of tunable parameters.
@ -280,8 +283,8 @@ class ConfigPersistenceService(Service):
groups.update(TunableGroups(config))
return groups
def load_environment(self, json_file_name: str, global_config: dict = None,
tunables: TunableGroups = None, service: Service = None) -> Environment:
def load_environment(self, json_file_name: str, global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None, service: Optional[Service] = None) -> Environment:
"""
Load and build new environment from the config file.
@ -302,11 +305,12 @@ class ConfigPersistenceService(Service):
A new benchmarking environment.
"""
config = self.load_config(json_file_name)
assert isinstance(config, dict)
return self.build_environment(config, global_config, tunables, service)
def load_environment_list(
self, json_file_name: str, global_config: dict = None,
tunables: TunableGroups = None, service: Service = None) -> List[Environment]:
self, json_file_name: str, global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None, service: Optional[Service] = None) -> List[Environment]:
"""
Load and build a list of environments from the config file.
@ -335,8 +339,8 @@ class ConfigPersistenceService(Service):
for config in config_list
]
def load_services(self, json_file_names: List[str],
global_config: dict = None, parent: Service = None) -> Service:
def load_services(self, json_file_names: Iterable[str],
global_config: Optional[dict] = None, parent: Optional[Service] = None) -> Service:
"""
Read the configuration files and bundle all service methods
from those configs into a single Service object.
@ -363,8 +367,8 @@ class ConfigPersistenceService(Service):
ConfigPersistenceService.build_service(config, global_config, service).export())
return service
def load_tunables(self, json_file_names: List[str],
parent: TunableGroups = None) -> TunableGroups:
def load_tunables(self, json_file_names: Iterable[str],
parent: Optional[TunableGroups] = None) -> TunableGroups:
"""
Load a collection of tunable parameters from JSON files.
@ -386,5 +390,6 @@ class ConfigPersistenceService(Service):
groups.update(parent)
for fname in json_file_names:
config = self.load_config(fname)
assert isinstance(config, dict)
groups.update(TunableGroups(config))
return groups

Просмотреть файл

@ -15,21 +15,25 @@ import shlex
import subprocess
import logging
from typing import Optional, Tuple, List, Dict
from typing import Dict, Iterable, Mapping, Optional, Tuple, Union, TYPE_CHECKING
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.local_exec_type import SupportsLocalExec
if TYPE_CHECKING:
from mlos_bench.tunables.tunable import TunableValue
_LOG = logging.getLogger(__name__)
class LocalExecService(Service):
class LocalExecService(Service, SupportsLocalExec):
"""
Collection of methods to run scripts and commands in an external process
on the node acting as the scheduler. Can be useful for data processing
due to reduced dependency management complications vs the target environment.
"""
def __init__(self, config: dict = None, parent: Service = None):
def __init__(self, config: Optional[dict] = None, parent: Optional[Service] = None):
"""
Create a new instance of a service to run scripts locally.
@ -45,7 +49,7 @@ class LocalExecService(Service):
self._temp_dir = self.config.get("temp_dir")
self.register([self.temp_dir_context, self.local_exec])
def temp_dir_context(self, path: Optional[str] = None):
def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
"""
Create a temp directory or use the provided path.
@ -63,8 +67,8 @@ class LocalExecService(Service):
return tempfile.TemporaryDirectory()
return contextlib.nullcontext(path or self._temp_dir)
def local_exec(self, script_lines: List[str],
env: Optional[Dict[str, str]] = None,
def local_exec(self, script_lines: Iterable[str],
env: Optional[Mapping[str, "TunableValue"]] = None,
cwd: Optional[str] = None,
return_on_error: bool = False) -> Tuple[int, str, str]:
"""
@ -72,10 +76,10 @@ class LocalExecService(Service):
Parameters
----------
script_lines : List[str]
script_lines : Iterable[str]
Lines of the script to run locally.
Treat every line as a separate command to run.
env : Dict[str, str]
env : Mapping[str, Union[int, float, str]]
Environment variables (optional).
cwd : str
Work directory to run the script at.
@ -110,7 +114,8 @@ class LocalExecService(Service):
return (return_code, stdout, stderr)
def _local_exec_script(self, script_line: str,
env: Dict[str, str], cwd: str) -> Tuple[int, str, str]:
env_params: Optional[Mapping[str, "TunableValue"]],
cwd: str) -> Tuple[int, str, str]:
"""
Execute the script from `script_path` in a local process.
@ -118,9 +123,9 @@ class LocalExecService(Service):
----------
script_line : str
Line of the script to tun in the local process.
args : List[str]
args : Iterable[str]
Command line arguments for the script.
env : Dict[str, str]
env_params : Mapping[str, Union[int, float, str]]
Environment variables.
cwd : str
Work directory to run the script at.
@ -131,7 +136,7 @@ class LocalExecService(Service):
A 3-tuple of return code, stdout, and stderr of the script process.
"""
cmd = shlex.split(script_line)
script_path = self._parent.resolve_path(cmd[0])
script_path = self._config_loader_service.resolve_path(cmd[0])
if os.path.exists(script_path):
script_path = os.path.abspath(script_path)
@ -139,8 +144,10 @@ class LocalExecService(Service):
if script_path.strip().lower().endswith(".py"):
cmd = [sys.executable] + cmd
if env:
env = {key: str(val) for (key, val) in env.items()}
env: Dict[str, str] = {}
if env_params:
env = {key: str(val) for (key, val) in env_params.items()}
if sys.platform == 'win32':
# A hack to run Python on Windows with env variables set:
env_copy = os.environ.copy()

Просмотреть файл

@ -13,7 +13,8 @@ from typing import Set
from azure.storage.fileshare import ShareClient
from mlos_bench.service import Service, FileShareService
from mlos_bench.service.base_service import Service
from mlos_bench.service.base_fileshare import FileShareService
from mlos_bench.util import check_required_params
_LOG = logging.getLogger(__name__)
@ -57,7 +58,7 @@ class AzureFileShareService(FileShareService):
credential=config["storageAccountKey"],
)
def download(self, remote_path: str, local_path: str, recursive: bool = True):
def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None:
super().download(remote_path, local_path, recursive)
dir_client = self._share_client.get_directory_client(remote_path)
if dir_client.exists():
@ -75,13 +76,13 @@ class AzureFileShareService(FileShareService):
file_client = self._share_client.get_file_client(remote_path)
data = file_client.download_file()
with open(local_path, "wb") as output_file:
data.readinto(output_file)
data.readinto(output_file) # type: ignore[no-untyped-call]
def upload(self, local_path: str, remote_path: str, recursive: bool = True):
def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None:
super().upload(local_path, remote_path, recursive)
self._upload(local_path, remote_path, recursive, set())
def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]):
def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None:
"""
Upload contents from a local path to an Azure file share.
This method is called from `.upload()` above. We need it to avoid exposing
@ -124,7 +125,7 @@ class AzureFileShareService(FileShareService):
with open(local_path, "rb") as file_data:
file_client.upload_file(file_data)
def _remote_makedirs(self, remote_path: str):
def _remote_makedirs(self, remote_path: str) -> None:
"""
Create remote directories for the entire path.
Succeeds even some or all directories along the path already exist.

Просмотреть файл

@ -10,18 +10,20 @@ import json
import time
import logging
from typing import Any, Tuple, List, Dict, Callable
from typing import Callable, Iterable, Tuple
import requests
from mlos_bench.environment import Status
from mlos_bench.service import Service
from mlos_bench.environment.status import Status
from mlos_bench.service.base_service import Service
from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
from mlos_bench.util import check_required_params
_LOG = logging.getLogger(__name__)
class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
class AzureVMService(Service, SupportsVMOps, SupportsRemoteExec): # pylint: disable=too-many-instance-attributes
"""
Helper methods to manage VMs on Azure.
"""
@ -134,7 +136,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
self.vm_start,
self.vm_stop,
self.vm_deprovision,
self.vm_reboot,
self.vm_restart,
self.remote_exec,
self.get_remote_exec_results
])
@ -144,7 +146,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
self._poll_timeout = float(config.get("pollTimeout", AzureVMService._POLL_TIMEOUT))
self._request_timeout = float(config.get("requestTimeout", AzureVMService._REQUEST_TIMEOUT))
self._deploy_template = self._parent.load_config(config['deployTemplatePath'])
self._deploy_template = self._config_loader_service.load_config(config['deployTemplatePath'])
self._url_deploy = AzureVMService._URL_DEPLOY.format(
subscription=config["subscription"],
@ -188,7 +190,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
}
@staticmethod
def _extract_arm_parameters(json_data):
def _extract_arm_parameters(json_data: dict) -> dict:
"""
Extract parameters from the ARM Template REST response JSON.
@ -203,7 +205,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
if val.get("value") is not None
}
def _azure_vm_post_helper(self, url: str) -> Tuple[Status, Dict]:
def _azure_vm_post_helper(self, url: str) -> Tuple[Status, dict]:
"""
General pattern for performing an action on an Azure VM via its REST API.
@ -237,7 +239,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
elif "Location" in response.headers:
result["asyncResultsUrl"] = response.headers.get("Location")
if "Retry-After" in response.headers:
result["pollInterval"] = float(response.headers["Retry-After"])
result["pollInterval"] = str(float(response.headers["Retry-After"]))
return (Status.PENDING, result)
else:
@ -245,7 +247,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
# _LOG.error("Bad Request:\n%s", response.request.body)
return (Status.FAILED, {})
def check_vm_operation_status(self, params: Dict) -> Tuple[Status, Dict]:
def check_vm_operation_status(self, params: dict) -> Tuple[Status, dict]:
"""
Checks the status of a pending operation on an Azure VM.
@ -285,7 +287,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
_LOG.error("Response: %s :: %s", response, response.text)
return Status.FAILED, {}
def wait_vm_deployment(self, is_setup: bool, params: Dict[str, Any]) -> Tuple[Status, Dict]:
def wait_vm_deployment(self, is_setup: bool, params: dict) -> Tuple[Status, dict]:
"""
Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED.
Return TIMED_OUT when timing out.
@ -308,7 +310,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
"provision" if is_setup else "deprovision")
return self._wait_while(self._check_deployment, Status.PENDING, params)
def wait_vm_operation(self, params: Dict[str, Any]) -> Tuple[Status, Dict]:
def wait_vm_operation(self, params: dict) -> Tuple[Status, dict]:
"""
Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED.
Return TIMED_OUT when timing out.
@ -330,8 +332,8 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
_LOG.info("Wait for operation on VM %s", self.config["vmName"])
return self._wait_while(self.check_vm_operation_status, Status.RUNNING, params)
def _wait_while(self, func: Callable[[Dict[str, Any]], Tuple[Status, Dict]],
loop_status: Status, params: Dict[str, Any]) -> Tuple[Status, Dict]:
def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]],
loop_status: Status, params: dict) -> Tuple[Status, dict]:
"""
Invoke `func` periodically while the status is equal to `loop_status`.
Return TIMED_OUT when timing out.
@ -377,7 +379,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
_LOG.warning("Request timed out: %s", params)
return (Status.TIMED_OUT, {})
def _check_deployment(self, _params: Dict) -> Tuple[Status, Dict]:
def _check_deployment(self, _params: dict) -> Tuple[Status, dict]:
"""
Check if Azure deployment exists.
Return SUCCEEDED if true, PENDING otherwise.
@ -409,7 +411,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
_LOG.error("Response: %s :: %s", response, response.text)
return (Status.FAILED, {})
def vm_provision(self, params: Dict) -> Tuple[Status, Dict]:
def vm_provision(self, params: dict) -> Tuple[Status, dict]:
"""
Check if Azure VM is ready. Deploy a new VM, if necessary.
@ -464,7 +466,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
# _LOG.error("Bad Request:\n%s", response.request.body)
return (Status.FAILED, {})
def vm_start(self, params: Dict) -> Tuple[Status, Dict]:
def vm_start(self, params: dict) -> Tuple[Status, dict]:
"""
Start the VM on Azure.
@ -482,7 +484,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
_LOG.info("Start VM: %s :: %s", self.config["vmName"], params)
return self._azure_vm_post_helper(self._url_start)
def vm_stop(self) -> Tuple[Status, Dict]:
def vm_stop(self) -> Tuple[Status, dict]:
"""
Stops the VM on Azure by initiating a graceful shutdown.
@ -495,7 +497,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
_LOG.info("Stop VM: %s", self.config["vmName"])
return self._azure_vm_post_helper(self._url_stop)
def vm_deprovision(self) -> Tuple[Status, Dict]:
def vm_deprovision(self) -> Tuple[Status, dict]:
"""
Deallocates the VM on Azure by shutting it down then releasing the compute resources.
@ -508,7 +510,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
_LOG.info("Deprovision VM: %s", self.config["vmName"])
return self._azure_vm_post_helper(self._url_stop)
def vm_reboot(self) -> Tuple[Status, Dict]:
def vm_restart(self) -> Tuple[Status, dict]:
"""
Reboot the VM on Azure by initiating a graceful shutdown.
@ -521,13 +523,13 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
_LOG.info("Reboot VM: %s", self.config["vmName"])
return self._azure_vm_post_helper(self._url_reboot)
def remote_exec(self, script: List[str], params: Dict[str, Any]) -> Tuple[Status, Dict]:
def remote_exec(self, script: Iterable[str], params: dict) -> Tuple[Status, dict]:
"""
Run a command on Azure VM.
Parameters
----------
script : List[str]
script : Iterable[str]
A list of lines to execute as a script on a remote VM.
params : dict
Flat dictionary of (key, value) pairs of parameters.
@ -546,7 +548,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
json_req = {
"commandId": "RunShellScript",
"script": script,
"script": list(script),
"parameters": [{"name": key, "value": val} for (key, val) in params.items()]
}
@ -576,7 +578,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes
# _LOG.error("Bad Request:\n%s", response.request.body)
return (Status.FAILED, {})
def get_remote_exec_results(self, params: Dict) -> Tuple[Status, Dict]:
def get_remote_exec_results(self, params: dict) -> Tuple[Status, dict]:
"""
Get the results of the asynchronously running command.

Просмотреть файл

@ -0,0 +1,7 @@
# Service Type Interfaces
Service loading in `mlos_bench` uses a [mix-in](https://en.wikipedia.org/wiki/Mixin#In_Python) approach to combine the functionality of multiple classes specified at runtime through config files into a single class.
This can make type checking in the `Environments` that use those `Services` a little tricky, both for developers and checking tools.
To address this we define `@runtime_checkable` decorated [`Protocols`](https://peps.python.org/pep-0544/) ("interfaces" in other languages) to declare the expected behavior of the `Services` that are loaded at runtime in the `Environments` that use them.

Просмотреть файл

@ -0,0 +1,22 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Service types for implementing declaring Service behavior for Environments to use in mlos_bench.
"""
from mlos_bench.service.types.config_loader_type import SupportsConfigLoading
from mlos_bench.service.types.fileshare_type import SupportsFileShareOps
from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps
from mlos_bench.service.types.local_exec_type import SupportsLocalExec
from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec
__all__ = [
'SupportsConfigLoading',
'SupportsFileShareOps',
'SupportsVMOps',
'SupportsLocalExec',
'SupportsRemoteExec',
]

Просмотреть файл

@ -0,0 +1,111 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for helper functions to lookup and load configs.
"""
from typing import List, Iterable, Optional, Union, Protocol, runtime_checkable, TYPE_CHECKING
# Avoid's circular import issues.
if TYPE_CHECKING:
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.service.base_service import Service
from mlos_bench.environment.base_environment import Environment
@runtime_checkable
class SupportsConfigLoading(Protocol):
"""
Protocol interface for helper functions to lookup and load configs.
"""
def resolve_path(self, file_path: str,
extra_paths: Optional[Iterable[str]] = None) -> str:
"""
Prepend the suitable `_config_path` to `path` if the latter is not absolute.
If `_config_path` is `None` or `path` is absolute, return `path` as is.
Parameters
----------
file_path : str
Path to the input config file.
extra_paths : Iterable[str]
Additional directories to prepend to the list of search paths.
Returns
-------
path : str
An actual path to the config or script.
"""
def load_config(self, json_file_name: str) -> Union[dict, List[dict]]:
"""
Load JSON config file. Search for a file relative to `_config_path`
if the input path is not absolute.
This method is exported to be used as a service.
Parameters
----------
json_file_name : str
Path to the input config file.
Returns
-------
config : Union[dict, List[dict]]
Free-format dictionary that contains the configuration.
"""
def build_environment(self, config: dict,
global_config: Optional[dict] = None,
tunables: Optional["TunableGroups"] = None,
service: Optional["Service"] = None) -> "Environment":
"""
Factory method for a new environment with a given config.
Parameters
----------
config : dict
A dictionary with three mandatory fields:
"name": Human-readable string describing the environment;
"class": FQN of a Python class to instantiate;
"config": Free-format dictionary to pass to the constructor.
global_config : dict
Global parameters to add to the environment config.
tunables : TunableGroups
A collection of groups of tunable parameters for all environments.
service: Service
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
Returns
-------
env : Environment
An instance of the `Environment` class initialized with `config`.
"""
def load_environment_list(
self, json_file_name: str, global_config: Optional[dict] = None,
tunables: Optional["TunableGroups"] = None, service: Optional["Service"] = None) -> List["Environment"]:
"""
Load and build a list of environments from the config file.
Parameters
----------
json_file_name : str
The environment JSON configuration file.
Can contain either one environment or a list of environments.
global_config : dict
Global parameters to add to the environment config.
tunables : TunableGroups
An optional collection of tunables to add to the environment.
service : Service
An optional reference of the parent service to mix in.
Returns
-------
env : List[Environment]
A list of new benchmarking environments.
"""

Просмотреть файл

@ -0,0 +1,47 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for file share operations.
"""
from typing import Protocol, runtime_checkable
@runtime_checkable
class SupportsFileShareOps(Protocol):
"""
Protocol interface for file share operations.
"""
def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None:
"""
Downloads contents from a remote share path to a local path.
Parameters
----------
remote_path : str
Path to download from the remote file share, a file if recursive=False
or a directory if recursive=True.
local_path : str
Path to store the downloaded content to.
recursive : bool
If False, ignore the subdirectories;
if True (the default), download the entire directory tree.
"""
def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None:
"""
Uploads contents from a local path to remote share path.
Parameters
----------
local_path : str
Path to the local directory to upload contents from.
remote_path : str
Path in the remote file share to store the uploaded content to.
recursive : bool
If False, ignore the subdirectories;
if True (the default), upload the entire directory tree.
"""

Просмотреть файл

@ -0,0 +1,68 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for Service types that provide helper functions to run
scripts and commands locally on the scheduler side.
"""
from typing import Iterable, Mapping, Optional, Tuple, Union, Protocol, runtime_checkable
import tempfile
import contextlib
from mlos_bench.tunables.tunable import TunableValue
@runtime_checkable
class SupportsLocalExec(Protocol):
"""
Protocol interface for a collection of methods to run scripts and commands
in an external process on the node acting as the scheduler. Can be useful
for data processing due to reduced dependency management complications vs
the target environment.
Used in LocalEnv and provided by LocalExecService.
"""
def local_exec(self, script_lines: Iterable[str],
env: Optional[Mapping[str, TunableValue]] = None,
cwd: Optional[str] = None,
return_on_error: bool = False) -> Tuple[int, str, str]:
"""
Execute the script lines from `script_lines` in a local process.
Parameters
----------
script_lines : Iterable[str]
Lines of the script to run locally.
Treat every line as a separate command to run.
env : Mapping[str, Union[int, float, str]]
Environment variables (optional).
cwd : str
Work directory to run the script at.
If omitted, use `temp_dir` or create a temporary dir.
return_on_error : bool
If True, stop running script lines on first non-zero return code.
The default is False.
Returns
-------
(return_code, stdout, stderr) : (int, str, str)
A 3-tuple of return code, stdout, and stderr of the script process.
"""
def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
"""
Create a temp directory or use the provided path.
Parameters
----------
path : str
A path to the temporary directory. Create a new one if None.
Returns
-------
temp_dir_context : TemporaryDirectory
Temporary directory context to use in the `with` clause.
"""

Просмотреть файл

@ -0,0 +1,59 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for Service types that provide helper functions to run
scripts on a remote host OS.
"""
from typing import Iterable, Tuple, Protocol, runtime_checkable
from mlos_bench.environment.status import Status
@runtime_checkable
class SupportsRemoteExec(Protocol):
"""
Protocol interface for Service types that provide helper functions to run
scripts on a remote host OS.
"""
def remote_exec(self, script: Iterable[str], params: dict) -> Tuple[Status, dict]:
"""
Run a command on remote host OS.
Parameters
----------
script : Iterable[str]
A list of lines to execute as a script on a remote VM.
params : dict
Flat dictionary of (key, value) pairs of parameters.
They usually come from `const_args` and `tunable_params`
properties of the Environment.
Returns
-------
result : (Status, dict)
A pair of Status and result.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
def get_remote_exec_results(self, params: dict) -> Tuple[Status, dict]:
"""
Get the results of the asynchronously running command.
Parameters
----------
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
Must have the "asyncResultsUrl" key to get the results.
If the key is not present, return Status.PENDING.
Returns
-------
result : (Status, dict)
A pair of Status and result.
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
"""

Просмотреть файл

@ -0,0 +1,125 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for VM provisioning operations.
"""
from typing import Tuple, Protocol, runtime_checkable
from mlos_bench.environment.status import Status
@runtime_checkable
class SupportsVMOps(Protocol):
"""
Protocol interface for VM provisioning operations.
"""
def vm_provision(self, params: dict) -> Tuple[Status, dict]:
"""
Check if VM is ready. Deploy a new VM, if necessary.
Parameters
----------
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
VMEnv tunables are variable parameters that, together with the
VMEnv configuration, are sufficient to provision a VM.
Returns
-------
result : (Status, dict={})
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
def wait_vm_deployment(self, is_setup: bool, params: dict) -> Tuple[Status, dict]:
"""
Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED.
Return TIMED_OUT when timing out.
Parameters
----------
is_setup : bool
If True, wait for VM being deployed; otherwise, wait for successful deprovisioning.
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
Returns
-------
result : (Status, dict)
A pair of Status and result.
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
def vm_start(self, params: dict) -> Tuple[Status, dict]:
"""
Start a VM.
Parameters
----------
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
Returns
-------
result : (Status, dict={})
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
def vm_stop(self) -> Tuple[Status, dict]:
"""
Stops the VM by initiating a graceful shutdown.
Returns
-------
result : (Status, dict={})
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
def vm_restart(self) -> Tuple[Status, dict]:
"""
Restarts the VM by initiating a graceful shutdown.
Returns
-------
result : (Status, dict={})
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
def vm_deprovision(self) -> Tuple[Status, dict]:
"""
Deallocates the VM by shutting it down then releasing the compute resources.
Returns
-------
result : (Status, dict={})
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
def wait_vm_operation(self, params: dict) -> Tuple[Status, dict]:
"""
Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED.
Return TIMED_OUT when timing out.
Parameters
----------
params: dict
Flat dictionary of (key, value) pairs of tunable parameters.
Must have the "asyncResultsUrl" key to get the results.
If the key is not present, return Status.PENDING.
Returns
-------
result : (Status, dict)
A pair of Status and result.
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""

Просмотреть файл

@ -0,0 +1,8 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tests for mlos_bench.
Used to make mypy happy about multiple conftest.py modules.
"""

Просмотреть файл

@ -72,7 +72,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv:
Test fixture for MockEnv.
"""
return MockEnv(
"Test Env",
name="Test Env",
config={
"tunable_groups": ["provision", "boot", "kernel"],
"seed": 13,
@ -88,7 +88,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv:
Test fixture for MockEnv.
"""
return MockEnv(
"Test Env No Noise",
name="Test Env No Noise",
config={
"tunable_groups": ["provision", "boot", "kernel"],
"range": [60, 120]

Просмотреть файл

@ -0,0 +1,8 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tests for mlos_bench.environment.
Used to make mypy happy about multiple conftest.py modules.
"""

Просмотреть файл

@ -8,25 +8,27 @@ Unit tests for mock benchmark environment.
import pytest
from mlos_bench.environment import MockEnv
from mlos_bench.environment.mock_env import MockEnv
from mlos_bench.tunables.tunable_groups import TunableGroups
def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups):
def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> None:
"""
Check the default values of the mock environment.
"""
assert mock_env.setup(tunable_groups)
(status, data) = mock_env.run()
assert status.is_succeeded
assert data is not None
assert data["score"] == pytest.approx(78.45, 0.01)
# Second time, results should differ because of the noise.
(status, data) = mock_env.run()
assert status.is_succeeded
assert data is not None
assert data["score"] == pytest.approx(98.21, 0.01)
def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups):
def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups) -> None:
"""
Check the default values of the mock environment.
"""
@ -35,6 +37,7 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr
# Noise-free results should be the same every time.
(status, data) = mock_env_no_noise.run()
assert status.is_succeeded
assert data is not None
assert data["score"] == pytest.approx(80.11, 0.01)
@ -51,7 +54,7 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr
}, 79.1),
])
def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups,
tunable_values: dict, expected_score: float):
tunable_values: dict, expected_score: float) -> None:
"""
Check the benchmark values of the mock environment after the assignment.
"""
@ -59,6 +62,7 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups,
assert mock_env.setup(tunable_groups)
(status, data) = mock_env.run()
assert status.is_succeeded
assert data is not None
assert data["score"] == pytest.approx(expected_score, 0.01)
@ -76,7 +80,7 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups,
])
def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv,
tunable_groups: TunableGroups,
tunable_values: dict, expected_score: float):
tunable_values: dict, expected_score: float) -> None:
"""
Check the benchmark values of the noiseless mock environment after the assignment.
"""
@ -86,4 +90,5 @@ def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv,
# Noise-free environment should produce the same results every time.
(status, data) = mock_env_no_noise.run()
assert status.is_succeeded
assert data is not None
assert data["score"] == pytest.approx(expected_score, 0.01)

Просмотреть файл

@ -0,0 +1,8 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tests for mlos_bench.optimizer.
Used to make mypy happy about multiple conftest.py modules.
"""

Просмотреть файл

@ -8,9 +8,9 @@ Unit tests for mock mlos_bench optimizer.
import pytest
from mlos_bench.environment import Status
from mlos_bench.environment.status import Status
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizer import MlosCoreOptimizer
from mlos_bench.optimizer.mlos_core_optimizer import MlosCoreOptimizer
# pylint: disable=redefined-outer-name
@ -41,7 +41,7 @@ def mock_scores() -> list:
return [88.88, 66.66, 99.99]
def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list):
def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None:
"""
Make sure that llamatune+emukit optimizer initializes and works correctly.
"""

Просмотреть файл

@ -8,8 +8,8 @@ Unit tests for mock mlos_bench optimizer.
import pytest
from mlos_bench.environment import Status
from mlos_bench.optimizer import MockOptimizer
from mlos_bench.environment.status import Status
from mlos_bench.optimizer.mock_optimizer import MockOptimizer
# pylint: disable=redefined-outer-name
@ -40,7 +40,7 @@ def mock_configurations() -> list:
def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float:
"""
Run several iterations of the oiptimizer and return the best score.
Run several iterations of the optimizer and return the best score.
"""
for (tunable_values, score) in mock_configurations:
assert mock_opt.not_converged()
@ -49,10 +49,12 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float:
mock_opt.register(tunables, Status.SUCCEEDED, score)
(score, _tunables) = mock_opt.get_best_observation()
assert score is not None
assert isinstance(score, float)
return score
def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list):
def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> None:
"""
Make sure that mock optimizer produces consistent suggestions.
"""
@ -60,7 +62,7 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list):
assert score == pytest.approx(66.66, 0.01)
def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list):
def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list) -> None:
"""
Check the maximization mode of the mock optimizer.
"""
@ -68,7 +70,7 @@ def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: li
assert score == pytest.approx(99.99, 0.01)
def test_mock_optimizer_register_fail(mock_opt: MockOptimizer):
def test_mock_optimizer_register_fail(mock_opt: MockOptimizer) -> None:
"""
Check the input acceptance conditions for Optimizer.register().
"""

Просмотреть файл

@ -10,7 +10,7 @@ from typing import Optional, List
import pytest
from mlos_bench.environment import Status
from mlos_bench.environment.status import Status
from mlos_bench.optimizer import Optimizer, MockOptimizer, MlosCoreOptimizer
# pylint: disable=redefined-outer-name
@ -46,7 +46,7 @@ def mock_configs() -> List[dict]:
@pytest.fixture
def mock_scores() -> List[float]:
def mock_scores() -> List[Optional[float]]:
"""
Mock benchmark results from earlier experiments.
"""
@ -62,13 +62,14 @@ def mock_status() -> List[Status]:
def _test_opt_update_min(opt: Optimizer, configs: List[dict],
scores: List[float], status: Optional[List[Status]] = None):
scores: List[float], status: Optional[List[Status]] = None) -> None:
"""
Test the bulk update of the optimizer on the minimization problem.
"""
opt.bulk_register(configs, scores, status)
(score, tunables) = opt.get_best_observation()
assert score == pytest.approx(66.66, 0.01)
assert tunables is not None
assert tunables.get_param_values() == {
"vmSize": "Standard_B4ms",
"rootfs": "ext4",
@ -77,13 +78,14 @@ def _test_opt_update_min(opt: Optimizer, configs: List[dict],
def _test_opt_update_max(opt: Optimizer, configs: List[dict],
scores: List[float], status: Optional[List[Status]] = None):
scores: List[float], status: Optional[List[Status]] = None) -> None:
"""
Test the bulk update of the optimizer on the maximiation prtoblem.
Test the bulk update of the optimizer on the maximization problem.
"""
opt.bulk_register(configs, scores, status)
(score, tunables) = opt.get_best_observation()
assert score == pytest.approx(99.99, 0.01)
assert tunables is not None
assert tunables.get_param_values() == {
"vmSize": "Standard_B2s",
"rootfs": "xfs",
@ -92,7 +94,7 @@ def _test_opt_update_max(opt: Optimizer, configs: List[dict],
def test_update_mock_min(mock_opt: MockOptimizer, mock_configs: List[dict],
mock_scores: List[float], mock_status: List[Status]):
mock_scores: List[float], mock_status: List[Status]) -> None:
"""
Test the bulk update of the mock optimizer on the minimization problem.
"""
@ -100,7 +102,7 @@ def test_update_mock_min(mock_opt: MockOptimizer, mock_configs: List[dict],
def test_update_mock_max(mock_opt_max: MockOptimizer, mock_configs: List[dict],
mock_scores: List[float], mock_status: List[Status]):
mock_scores: List[float], mock_status: List[Status]) -> None:
"""
Test the bulk update of the mock optimizer on the maximization problem.
"""
@ -108,7 +110,7 @@ def test_update_mock_max(mock_opt_max: MockOptimizer, mock_configs: List[dict],
def test_update_emukit(emukit_opt: MlosCoreOptimizer, mock_configs: List[dict],
mock_scores: List[float], mock_status: List[Status]):
mock_scores: List[float], mock_status: List[Status]) -> None:
"""
Test the bulk update of the EmuKit optimizer.
"""
@ -116,7 +118,7 @@ def test_update_emukit(emukit_opt: MlosCoreOptimizer, mock_configs: List[dict],
def test_update_emukit_max(emukit_opt_max: MlosCoreOptimizer, mock_configs: List[dict],
mock_scores: List[float], mock_status: List[Status]):
mock_scores: List[float], mock_status: List[Status]) -> None:
"""
Test the bulk update of the EmuKit optimizer on the maximization problem.
"""
@ -124,7 +126,7 @@ def test_update_emukit_max(emukit_opt_max: MlosCoreOptimizer, mock_configs: List
def test_update_scikit_gp(scikit_gp_opt: MlosCoreOptimizer, mock_configs: List[dict],
mock_scores: List[float], mock_status: List[Status]):
mock_scores: List[float], mock_status: List[Status]) -> None:
"""
Test the bulk update of the scikit-optimize GP optimizer.
"""
@ -132,7 +134,7 @@ def test_update_scikit_gp(scikit_gp_opt: MlosCoreOptimizer, mock_configs: List[d
def test_update_scikit_et(scikit_et_opt: MlosCoreOptimizer, mock_configs: List[dict],
mock_scores: List[float], mock_status: List[Status]):
mock_scores: List[float], mock_status: List[Status]) -> None:
"""
Test the bulk update of the scikit-optimize ET optimizer.
"""

Просмотреть файл

@ -10,7 +10,8 @@ from typing import Tuple
import pytest
from mlos_bench.environment import Environment, MockEnv
from mlos_bench.environment.base_environment import Environment
from mlos_bench.environment.mock_env import MockEnv
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizer import Optimizer, MockOptimizer, MlosCoreOptimizer
@ -27,17 +28,20 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]:
assert env.setup(tunables)
(status, output) = env.run()
score = output['score']
assert status.is_succeeded
assert output is not None
score = output['score']
assert 60 <= score <= 120
opt.register(tunables, status, score)
return opt.get_best_observation()
(best_score, best_tunables) = opt.get_best_observation()
assert isinstance(best_score, float) and isinstance(best_tunables, TunableGroups)
return (best_score, best_tunables)
def test_mock_optimization_loop(mock_env_no_noise: MockEnv,
mock_opt: MockOptimizer):
mock_opt: MockOptimizer) -> None:
"""
Toy optimization loop with mock environment and optimizer.
"""
@ -51,7 +55,7 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv,
def test_scikit_gp_optimization_loop(mock_env_no_noise: MockEnv,
scikit_gp_opt: MlosCoreOptimizer):
scikit_gp_opt: MlosCoreOptimizer) -> None:
"""
Toy optimization loop with mock environment and Scikit GP optimizer.
"""
@ -65,7 +69,7 @@ def test_scikit_gp_optimization_loop(mock_env_no_noise: MockEnv,
def test_scikit_et_optimization_loop(mock_env_no_noise: MockEnv,
scikit_et_opt: MlosCoreOptimizer):
scikit_et_opt: MlosCoreOptimizer) -> None:
"""
Toy optimization loop with mock environment and Scikit ET optimizer.
"""
@ -79,7 +83,7 @@ def test_scikit_et_optimization_loop(mock_env_no_noise: MockEnv,
def test_emukit_optimization_loop(mock_env_no_noise: MockEnv,
emukit_opt: MlosCoreOptimizer):
emukit_opt: MlosCoreOptimizer) -> None:
"""
Toy optimization loop with mock environment and EmuKit optimizer.
"""
@ -89,7 +93,7 @@ def test_emukit_optimization_loop(mock_env_no_noise: MockEnv,
def test_emukit_optimization_loop_max(mock_env_no_noise: MockEnv,
emukit_opt_max: MlosCoreOptimizer):
emukit_opt_max: MlosCoreOptimizer) -> None:
"""
Toy optimization loop with mock environment and EmuKit optimizer
in maximization mode.

Просмотреть файл

@ -0,0 +1,8 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tests for mlos_bench.service.
Used to make mypy happy about multiple conftest.py modules.
"""

Просмотреть файл

@ -9,7 +9,7 @@ Unit tests for configuration persistence service.
import os
import pytest
from mlos_bench.service import ConfigPersistenceService
from mlos_bench.service.config_persistence import ConfigPersistenceService
# pylint: disable=redefined-outer-name
@ -27,7 +27,7 @@ def config_persistence_service() -> ConfigPersistenceService:
})
def test_resolve_path(config_persistence_service: ConfigPersistenceService):
def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> None:
"""
Check if we can actually find a file somewhere in `config_path`.
"""
@ -37,7 +37,7 @@ def test_resolve_path(config_persistence_service: ConfigPersistenceService):
assert os.path.exists(path)
def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService):
def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None:
"""
Check if non-existent file resolves without using `config_path`.
"""
@ -47,7 +47,7 @@ def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService)
assert path == file_path
def test_load_config(config_persistence_service: ConfigPersistenceService):
def test_load_config(config_persistence_service: ConfigPersistenceService) -> None:
"""
Check if we can successfully load a config file located relative to `config_path`.
"""

Просмотреть файл

@ -0,0 +1,8 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tests for mlos_bench.service.local.
Used to make mypy happy about multiple conftest.py modules.
"""

Просмотреть файл

@ -5,12 +5,17 @@
"""
Unit tests for LocalExecService to run Python scripts locally.
"""
from typing import Dict
import os
import json
import pytest
from mlos_bench.service import LocalExecService, ConfigPersistenceService
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.service.local.local_exec import LocalExecService
from mlos_bench.service.config_persistence import ConfigPersistenceService
# pylint: disable=redefined-outer-name
@ -25,7 +30,7 @@ def local_exec_service() -> LocalExecService:
}))
def test_run_python_script(local_exec_service: LocalExecService):
def test_run_python_script(local_exec_service: LocalExecService) -> None:
"""
Run a Python script using a local_exec service.
"""
@ -33,7 +38,7 @@ def test_run_python_script(local_exec_service: LocalExecService):
output_file = "./config-kernel.sh"
# Tunable parameters to save in JSON
params = {
params: Dict[str, TunableValue] = {
"sched_migration_cost_ns": 40000,
"sched_granularity_ns": 800000
}

Просмотреть файл

@ -11,7 +11,8 @@ import sys
import pytest
import pandas
from mlos_bench.service import LocalExecService, ConfigPersistenceService
from mlos_bench.service.local.local_exec import LocalExecService
from mlos_bench.service.config_persistence import ConfigPersistenceService
# pylint: disable=redefined-outer-name
# -- Ignore pylint complaints about pytest references to
@ -26,7 +27,7 @@ def local_exec_service() -> LocalExecService:
return LocalExecService(parent=ConfigPersistenceService())
def test_run_script(local_exec_service: LocalExecService):
def test_run_script(local_exec_service: LocalExecService) -> None:
"""
Run a script locally and check the results.
"""
@ -37,7 +38,7 @@ def test_run_script(local_exec_service: LocalExecService):
assert stderr.strip() == ""
def test_run_script_multiline(local_exec_service: LocalExecService):
def test_run_script_multiline(local_exec_service: LocalExecService) -> None:
"""
Run a multiline script locally and check the results.
"""
@ -51,7 +52,7 @@ def test_run_script_multiline(local_exec_service: LocalExecService):
assert stderr.strip() == ""
def test_run_script_multiline_env(local_exec_service: LocalExecService):
def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None:
"""
Run a multiline script locally and pass the environment variables to it.
"""
@ -68,7 +69,7 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService):
assert stderr.strip() == ""
def test_run_script_read_csv(local_exec_service: LocalExecService):
def test_run_script_read_csv(local_exec_service: LocalExecService) -> None:
"""
Run a script locally and read the resulting CSV file.
"""
@ -89,7 +90,7 @@ def test_run_script_read_csv(local_exec_service: LocalExecService):
assert all(data.col2 == [222, 444])
def test_run_script_write_read_txt(local_exec_service: LocalExecService):
def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None:
"""
Write data a temp location and run a script that updates it there.
"""
@ -112,7 +113,7 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService):
assert fh_input.read().split() == ["hello", "world", "test"]
def test_run_script_fail(local_exec_service: LocalExecService):
def test_run_script_fail(local_exec_service: LocalExecService) -> None:
"""
Try to run a non-existent command.
"""

Просмотреть файл

@ -0,0 +1,8 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tests for mlos_bench.service.remote.
Used to make mypy happy about multiple conftest.py modules.
"""

Просмотреть файл

@ -0,0 +1,8 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tests for mlos_bench.service.remote.azure.
Used to make mypy happy about multiple conftest.py modules.
"""

Просмотреть файл

@ -7,7 +7,9 @@ Tests for mlos_bench.service.remote.azure.azure_fileshare
"""
import os
from unittest.mock import Mock, patch, call
from unittest.mock import MagicMock, Mock, patch, call
from mlos_bench.service.remote.azure.azure_fileshare import AzureFileShareService
# pylint: disable=missing-function-docstring
# pylint: disable=too-many-arguments
@ -16,21 +18,21 @@ from unittest.mock import Mock, patch, call
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs")
def test_download_file(mock_makedirs, mock_open, azure_fileshare):
def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
filename = "test.csv"
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
remote_path = f"{remote_folder}/{filename}"
local_path = f"{local_folder}/{filename}"
# pylint: disable=protected-access
mock_share_client = azure_fileshare._share_client
mock_share_client.get_directory_client.return_value = Mock(
exists=Mock(return_value=False)
)
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \
patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client:
mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False))
azure_fileshare.download(remote_path, local_path)
mock_share_client.get_file_client.assert_called_with(remote_path)
mock_get_file_client.assert_called_with(remote_path)
mock_makedirs.assert_called_with(
local_folder,
exist_ok=True,
@ -40,7 +42,7 @@ def test_download_file(mock_makedirs, mock_open, azure_fileshare):
assert open_mode == "wb"
def make_dir_client_returns(remote_folder: str):
def make_dir_client_returns(remote_folder: str) -> dict:
return {
remote_folder: Mock(
exists=Mock(return_value=True),
@ -66,20 +68,24 @@ def make_dir_client_returns(remote_folder: str):
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs")
def test_download_folder_non_recursive(mock_makedirs, mock_open, azure_fileshare):
def test_download_folder_non_recursive(mock_makedirs: MagicMock,
mock_open: MagicMock,
azure_fileshare: AzureFileShareService) -> None:
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
dir_client_returns = make_dir_client_returns(remote_folder)
# pylint: disable=protected-access
mock_share_client = azure_fileshare._share_client
mock_share_client.get_directory_client.side_effect = lambda x: dir_client_returns[x]
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \
patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]
azure_fileshare.download(remote_folder, local_folder, recursive=False)
mock_share_client.get_file_client.assert_called_with(
mock_get_file_client.assert_called_with(
f"{remote_folder}/a_file_1.csv",
)
mock_share_client.get_directory_client.assert_has_calls([
mock_get_directory_client.assert_has_calls([
call(remote_folder),
call(f"{remote_folder}/a_file_1.csv"),
], any_order=True)
@ -87,21 +93,22 @@ def test_download_folder_non_recursive(mock_makedirs, mock_open, azure_fileshare
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs")
def test_download_folder_recursive(mock_makedirs, mock_open, azure_fileshare):
def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
dir_client_returns = make_dir_client_returns(remote_folder)
# pylint: disable=protected-access
mock_share_client = azure_fileshare._share_client
mock_share_client.get_directory_client.side_effect = lambda x: dir_client_returns[x]
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \
patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]
azure_fileshare.download(remote_folder, local_folder, recursive=True)
mock_share_client.get_file_client.assert_has_calls([
mock_get_file_client.assert_has_calls([
call(f"{remote_folder}/a_file_1.csv"),
call(f"{remote_folder}/a_folder/a_file_2.csv"),
], any_order=True)
mock_share_client.get_directory_client.assert_has_calls([
mock_get_directory_client.assert_has_calls([
call(remote_folder),
call(f"{remote_folder}/a_file_1.csv"),
call(f"{remote_folder}/a_folder"),
@ -111,19 +118,19 @@ def test_download_folder_recursive(mock_makedirs, mock_open, azure_fileshare):
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir")
def test_upload_file(mock_isdir, mock_open, azure_fileshare):
def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
filename = "test.csv"
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
remote_path = f"{remote_folder}/{filename}"
local_path = f"{local_folder}/{filename}"
# pylint: disable=protected-access
mock_share_client = azure_fileshare._share_client
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
mock_isdir.return_value = False
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
azure_fileshare.upload(local_path, remote_path)
mock_share_client.get_file_client.assert_called_with(remote_path)
mock_get_file_client.assert_called_with(remote_path)
open_path, open_mode = mock_open.call_args.args
assert os.path.abspath(local_path) == os.path.abspath(open_path)
assert open_mode == "rb"
@ -136,11 +143,11 @@ class MyDirEntry:
self.name = name
self.is_a_dir = is_a_dir
def is_dir(self):
def is_dir(self) -> bool:
return self.is_a_dir
def make_scandir_returns(local_folder: str):
def make_scandir_returns(local_folder: str) -> dict:
return {
local_folder: [
MyDirEntry("a_folder", True),
@ -152,7 +159,7 @@ def make_scandir_returns(local_folder: str):
}
def make_isdir_returns(local_folder: str):
def make_isdir_returns(local_folder: str) -> dict:
return {
local_folder: True,
f"{local_folder}/a_file_1.csv": False,
@ -161,7 +168,7 @@ def make_isdir_returns(local_folder: str):
}
def process_paths(input_path):
def process_paths(input_path: str) -> str:
skip_prefix = os.getcwd()
# Remove prefix from os.path.abspath if there
if input_path == os.path.abspath(input_path):
@ -175,37 +182,43 @@ def process_paths(input_path):
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir")
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.scandir")
def test_upload_directory_non_recursive(mock_scandir, mock_isdir, mock_open, azure_fileshare):
def test_upload_directory_non_recursive(mock_scandir: MagicMock,
mock_isdir: MagicMock,
mock_open: MagicMock,
azure_fileshare: AzureFileShareService) -> None:
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
scandir_returns = make_scandir_returns(local_folder)
isdir_returns = make_isdir_returns(local_folder)
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
# pylint: disable=protected-access
mock_share_client = azure_fileshare._share_client
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
azure_fileshare.upload(local_folder, remote_folder, recursive=False)
mock_share_client.get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv")
mock_get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv")
@patch("mlos_bench.service.remote.azure.azure_fileshare.open")
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir")
@patch("mlos_bench.service.remote.azure.azure_fileshare.os.scandir")
def test_upload_directory_recursive(mock_scandir, mock_isdir, mock_open, azure_fileshare):
def test_upload_directory_recursive(mock_scandir: MagicMock,
mock_isdir: MagicMock,
mock_open: MagicMock,
azure_fileshare: AzureFileShareService) -> None:
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
scandir_returns = make_scandir_returns(local_folder)
isdir_returns = make_isdir_returns(local_folder)
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
# pylint: disable=protected-access
mock_share_client = azure_fileshare._share_client
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
azure_fileshare.upload(local_folder, remote_folder, recursive=True)
mock_share_client.get_file_client.assert_has_calls([
mock_get_file_client.assert_has_calls([
call(f"{remote_folder}/a_file_1.csv"),
call(f"{remote_folder}/a_folder/a_file_2.csv"),
], any_order=True)

Просмотреть файл

@ -10,7 +10,9 @@ from unittest.mock import MagicMock, patch
import pytest
from mlos_bench.environment import Status
from mlos_bench.environment.status import Status
from mlos_bench.service.remote.azure.azure_services import AzureVMService
# pylint: disable=missing-function-docstring
# pylint: disable=too-many-arguments
@ -21,7 +23,7 @@ from mlos_bench.environment import Status
("vm_start", True),
("vm_stop", False),
("vm_deprovision", False),
("vm_reboot", False),
("vm_restart", False),
])
@pytest.mark.parametrize(
("http_status_code", "operation_status"), [
@ -31,8 +33,8 @@ from mlos_bench.environment import Status
(404, Status.FAILED),
])
@patch("mlos_bench.service.remote.azure.azure_services.requests")
def test_vm_operation_status(mock_requests, azure_vm_service, operation_name,
accepts_params, http_status_code, operation_status):
def test_vm_operation_status(mock_requests: MagicMock, azure_vm_service: AzureVMService, operation_name: str,
accepts_params: bool, http_status_code: int, operation_status: Status) -> None:
mock_response = MagicMock()
mock_response.status_code = http_status_code
@ -49,7 +51,7 @@ def test_vm_operation_status(mock_requests, azure_vm_service, operation_name,
@patch("mlos_bench.service.remote.azure.azure_services.time.sleep")
@patch("mlos_bench.service.remote.azure.azure_services.requests")
def test_wait_vm_operation_ready(mock_requests, mock_sleep, azure_vm_service):
def test_wait_vm_operation_ready(mock_requests: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService) -> None:
# Mock response header
async_url = "DUMMY_ASYNC_URL"
@ -73,7 +75,7 @@ def test_wait_vm_operation_ready(mock_requests, mock_sleep, azure_vm_service):
@patch("mlos_bench.service.remote.azure.azure_services.requests")
def test_wait_vm_operation_timeout(mock_requests, azure_vm_service):
def test_wait_vm_operation_timeout(mock_requests: MagicMock, azure_vm_service: AzureVMService) -> None:
# Mock response header
params = {
@ -99,8 +101,8 @@ def test_wait_vm_operation_timeout(mock_requests, azure_vm_service):
(404, Status.FAILED),
])
@patch("mlos_bench.service.remote.azure.azure_services.requests")
def test_remote_exec_status(mock_requests, azure_vm_service, http_status_code, operation_status):
def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service: AzureVMService,
http_status_code: int, operation_status: Status) -> None:
script = ["command_1", "command_2"]
mock_response = MagicMock()
@ -113,7 +115,7 @@ def test_remote_exec_status(mock_requests, azure_vm_service, http_status_code, o
@patch("mlos_bench.service.remote.azure.azure_services.requests")
def test_remote_exec_headers_output(mock_requests, azure_vm_service):
def test_remote_exec_headers_output(mock_requests: MagicMock, azure_vm_service: AzureVMService) -> None:
async_url_key = "asyncResultsUrl"
async_url_value = "DUMMY_ASYNC_URL"
@ -158,14 +160,15 @@ def test_remote_exec_headers_output(mock_requests, azure_vm_service):
(Status.PENDING, {}, {}),
(Status.FAILED, {}, {}),
])
def test_get_remote_exec_results(azure_vm_service, operation_status: Status,
wait_output: dict, results_output: dict):
def test_get_remote_exec_results(azure_vm_service: AzureVMService, operation_status: Status,
wait_output: dict, results_output: dict) -> None:
params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"}
mock_wait_vm_operation = MagicMock()
mock_wait_vm_operation.return_value = (operation_status, wait_output)
azure_vm_service.wait_vm_operation = mock_wait_vm_operation
# azure_vm_service.wait_vm_operation = mock_wait_vm_operation
setattr(azure_vm_service, "wait_vm_operation", mock_wait_vm_operation)
status, cmd_output = azure_vm_service.get_remote_exec_results(params)

Просмотреть файл

@ -9,14 +9,14 @@ Configuration test fixtures for azure_services in mlos_bench.
from unittest.mock import patch
import pytest
from mlos_bench.service import ConfigPersistenceService
from mlos_bench.service.config_persistence import ConfigPersistenceService
from mlos_bench.service.remote.azure import AzureVMService, AzureFileShareService
# pylint: disable=redefined-outer-name
@pytest.fixture
def config_persistence_service():
def config_persistence_service() -> ConfigPersistenceService:
"""
Test fixture for ConfigPersistenceService.
"""
@ -28,7 +28,7 @@ def config_persistence_service():
@pytest.fixture
def azure_vm_service(config_persistence_service: ConfigPersistenceService):
def azure_vm_service(config_persistence_service: ConfigPersistenceService) -> AzureVMService:
"""
Creates a dummy Azure VM service for tests that require it.
"""
@ -45,7 +45,7 @@ def azure_vm_service(config_persistence_service: ConfigPersistenceService):
@pytest.fixture
def azure_fileshare(config_persistence_service: ConfigPersistenceService):
def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService:
"""
Creates a dummy AzureFileShareService for tests that require it.
"""

Просмотреть файл

@ -0,0 +1,8 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tests for mlos_bench.tunables.
Used to make mypy happy about multiple conftest.py modules.
"""

Просмотреть файл

@ -46,7 +46,7 @@ def configuration_space() -> ConfigurationSpace:
def _cmp_tunable_hyperparameter_categorical(
tunable: Tunable, cs_param: CategoricalHyperparameter):
tunable: Tunable, cs_param: CategoricalHyperparameter) -> None:
"""
Check if categorical Tunable and ConfigSpace Hyperparameter actually match.
"""
@ -56,7 +56,7 @@ def _cmp_tunable_hyperparameter_categorical(
def _cmp_tunable_hyperparameter_int(
tunable: Tunable, cs_param: UniformIntegerHyperparameter):
tunable: Tunable, cs_param: UniformIntegerHyperparameter) -> None:
"""
Check if integer Tunable and ConfigSpace Hyperparameter actually match.
"""
@ -66,7 +66,7 @@ def _cmp_tunable_hyperparameter_int(
def _cmp_tunable_hyperparameter_float(
tunable: Tunable, cs_param: UniformFloatHyperparameter):
tunable: Tunable, cs_param: UniformFloatHyperparameter) -> None:
"""
Check if float Tunable and ConfigSpace Hyperparameter actually match.
"""
@ -75,7 +75,7 @@ def _cmp_tunable_hyperparameter_float(
assert cs_param.default_value == tunable.value
def test_tunable_to_hyperparameter_categorical(tunable_categorical):
def test_tunable_to_hyperparameter_categorical(tunable_categorical: Tunable) -> None:
"""
Check the conversion of Tunable to CategoricalHyperparameter.
"""
@ -83,7 +83,7 @@ def test_tunable_to_hyperparameter_categorical(tunable_categorical):
_cmp_tunable_hyperparameter_categorical(tunable_categorical, cs_param)
def test_tunable_to_hyperparameter_int(tunable_int):
def test_tunable_to_hyperparameter_int(tunable_int: Tunable) -> None:
"""
Check the conversion of Tunable to UniformIntegerHyperparameter.
"""
@ -91,7 +91,7 @@ def test_tunable_to_hyperparameter_int(tunable_int):
_cmp_tunable_hyperparameter_int(tunable_int, cs_param)
def test_tunable_to_hyperparameter_float(tunable_float):
def test_tunable_to_hyperparameter_float(tunable_float: Tunable) -> None:
"""
Check the conversion of Tunable to UniformFloatHyperparameter.
"""
@ -106,7 +106,7 @@ _CMP_FUNC = {
}
def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups):
def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> None:
"""
Check the conversion of TunableGroups to ConfigurationSpace.
Make sure that the corresponding Tunable and Hyperparameter objects match.
@ -119,7 +119,7 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups):
def test_tunable_groups_to_configspace(
tunable_groups: TunableGroups, configuration_space: ConfigurationSpace):
tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None:
"""
Check the conversion of the entire TunableGroups collection
to a single ConfigurationSpace object.

Просмотреть файл

@ -8,8 +8,11 @@ Unit tests for assigning values to the individual parameters within tunable grou
import pytest
from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable_groups import TunableGroups
def test_tunables_assign_unknown_param(tunable_groups):
def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None:
"""
Make sure that bulk assignment fails for parameters
that don't exist in the TunableGroups object.
@ -23,7 +26,7 @@ def test_tunables_assign_unknown_param(tunable_groups):
})
def test_tunables_assign_invalid_categorical(tunable_groups):
def test_tunables_assign_invalid_categorical(tunable_groups: TunableGroups) -> None:
"""
Check parameter validation for categorical tunables.
"""
@ -31,7 +34,7 @@ def test_tunables_assign_invalid_categorical(tunable_groups):
tunable_groups.assign({"vmSize": "InvalidSize"})
def test_tunables_assign_invalid_range(tunable_groups):
def test_tunables_assign_invalid_range(tunable_groups: TunableGroups) -> None:
"""
Check parameter out-of-range validation for numerical tunables.
"""
@ -39,14 +42,14 @@ def test_tunables_assign_invalid_range(tunable_groups):
tunable_groups.assign({"kernel_sched_migration_cost_ns": -2})
def test_tunables_assign_coerce_str(tunable_groups):
def test_tunables_assign_coerce_str(tunable_groups: TunableGroups) -> None:
"""
Check the conversion from strings when assigning to an integer parameter.
"""
tunable_groups.assign({"kernel_sched_migration_cost_ns": "10000"})
def test_tunables_assign_coerce_str_range_check(tunable_groups):
def test_tunables_assign_coerce_str_range_check(tunable_groups: TunableGroups) -> None:
"""
Check the range when assigning to an integer tunable.
"""
@ -54,7 +57,7 @@ def test_tunables_assign_coerce_str_range_check(tunable_groups):
tunable_groups.assign({"kernel_sched_migration_cost_ns": "5500000"})
def test_tunables_assign_coerce_str_invalid(tunable_groups):
def test_tunables_assign_coerce_str_invalid(tunable_groups: TunableGroups) -> None:
"""
Make sure we fail when assigning an invalid string to an integer tunable.
"""
@ -62,23 +65,23 @@ def test_tunables_assign_coerce_str_invalid(tunable_groups):
tunable_groups.assign({"kernel_sched_migration_cost_ns": "1.1"})
def test_tunable_assign_str_to_int(tunable_int):
def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None:
"""
Check str to int coercion.
"""
tunable_int.value = "10"
assert tunable_int.value == 10
assert tunable_int.value == 10 # type: ignore[comparison-overlap]
def test_tunable_assign_str_to_float(tunable_float):
def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None:
"""
Check str to float coercion.
"""
tunable_float.value = "0.5"
assert tunable_float.value == 0.5
assert tunable_float.value == 0.5 # type: ignore[comparison-overlap]
def test_tunable_assign_float_to_int(tunable_int):
def test_tunable_assign_float_to_int(tunable_int: Tunable) -> None:
"""
Check float to int coercion.
"""
@ -86,7 +89,7 @@ def test_tunable_assign_float_to_int(tunable_int):
assert tunable_int.value == 10
def test_tunable_assign_float_to_int_fail(tunable_int):
def test_tunable_assign_float_to_int_fail(tunable_int: Tunable) -> None:
"""
Check the invalid float to int coercion.
"""

Просмотреть файл

@ -6,18 +6,21 @@
Unit tests for deep copy of tunable objects and groups.
"""
from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable_groups import TunableGroups
def test_copy_tunable_int(tunable_int):
def test_copy_tunable_int(tunable_int: Tunable) -> None:
"""
Check if deep copy works for Tunable object.
"""
tunable_copy = tunable_int.copy()
assert tunable_int == tunable_copy
tunable_copy.value += 200
tunable_copy.value = tunable_copy.numerical_value + 200
assert tunable_int != tunable_copy
def test_copy_tunable_groups(tunable_groups):
def test_copy_tunable_groups(tunable_groups: TunableGroups) -> None:
"""
Check if deep copy works for TunableGroups object.
"""

Просмотреть файл

@ -7,9 +7,9 @@ Tunable parameter definition.
"""
import copy
from typing import Any, Dict, List
from typing import Dict, Iterable
from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable import Tunable, TunableValue
class CovariantTunableGroup:
@ -32,8 +32,8 @@ class CovariantTunableGroup:
"""
self._is_updated = True
self._name = name
self._cost = config.get("cost", 0)
self._tunables = {
self._cost = int(config.get("cost", 0))
self._tunables: Dict[str, Tunable] = {
name: Tunable(name, tunable_config)
for (name, tunable_config) in config.get("params", {}).items()
}
@ -50,7 +50,7 @@ class CovariantTunableGroup:
"""
return self._name
def copy(self):
def copy(self) -> "CovariantTunableGroup":
"""
Deep copy of the CovariantTunableGroup object.
@ -62,7 +62,7 @@ class CovariantTunableGroup:
"""
return copy.deepcopy(self)
def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""
Check if two CovariantTunableGroup objects are equal.
@ -76,12 +76,14 @@ class CovariantTunableGroup:
is_equal : bool
True if two CovariantTunableGroup objects are equal.
"""
if not isinstance(other, CovariantTunableGroup):
return False
return (self._name == other._name and
self._cost == other._cost and
self._is_updated == other._is_updated and
self._tunables == other._tunables)
def reset(self):
def reset(self) -> None:
"""
Clear the update flag. That is, state that running an experiment with the
current values of the tunables in this group has no extra cost.
@ -110,13 +112,13 @@ class CovariantTunableGroup:
"""
return self._cost if self._is_updated else 0
def get_names(self) -> List[str]:
def get_names(self) -> Iterable[str]:
"""
Get the names of all tunables in the group.
"""
return self._tunables.keys()
def get_values(self) -> Dict[str, Any]:
def get_values(self) -> Dict[str, TunableValue]:
"""
Get current values of all tunables in the group.
"""
@ -151,10 +153,10 @@ class CovariantTunableGroup:
"""
return self._tunables[name]
def __getitem__(self, name: str):
def __getitem__(self, name: str) -> TunableValue:
return self.get_tunable(name).value
def __setitem__(self, name: str, value):
def __setitem__(self, name: str, value: TunableValue) -> TunableValue:
self._is_updated = True
self._tunables[name].value = value
return value

Просмотреть файл

@ -9,11 +9,32 @@ import copy
import collections
import logging
from typing import Any, List, Tuple
from typing import List, Optional, Sequence, Tuple, TypedDict, Union
_LOG = logging.getLogger(__name__)
"""A tunable parameter value type alias."""
TunableValue = Union[int, float, str]
class TunableDict(TypedDict, total=False):
"""
A typed dict for tunable parameters.
Mostly used for mypy type checking.
These are the types expected to be received from the json config.
"""
type: str
description: Optional[str]
default: TunableValue
values: Optional[List[str]]
range: Optional[Union[Sequence[int], Sequence[float]]]
special: Optional[Union[List[int], List[str]]]
class Tunable: # pylint: disable=too-many-instance-attributes
"""
A tunable parameter definition and its current value.
@ -25,7 +46,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes
"categorical": str,
}
def __init__(self, name: str, config: dict):
def __init__(self, name: str, config: TunableDict):
"""
Create an instance of a new tunable parameter.
@ -39,19 +60,24 @@ class Tunable: # pylint: disable=too-many-instance-attributes
self._name = name
self._type = config["type"] # required
self._description = config.get("description")
self._default = config.get("default")
self._default = config["default"]
self._values = config.get("values")
self._range = config.get("range")
self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None
config_range = config.get("range")
if config_range is not None:
assert len(config_range) == 2, f"Invalid range: {config_range}"
config_range = (config_range[0], config_range[1])
self._range = config_range
self._special = config.get("special")
self._current_value = self._default
if self._type == "categorical":
if self.is_categorical:
if not (self._values and isinstance(self._values, collections.abc.Iterable)):
raise ValueError("Must specify values for the categorical type")
if self._range is not None:
raise ValueError("Range must be None for the categorical type")
if self._special is not None:
raise ValueError("Special values must be None for the categorical type")
elif self._type in {"int", "float"}:
elif self.is_numerical:
if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]:
raise ValueError(f"Invalid range: {self._range}")
else:
@ -68,7 +94,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes
"""
return f"{self._name}={self._current_value}"
def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""
Check if two Tunable objects are equal.
@ -83,13 +109,15 @@ class Tunable: # pylint: disable=too-many-instance-attributes
True if the Tunables correspond to the same parameter and have the same value and type.
NOTE: ranges and special values are not currently considered in the comparison.
"""
return (
if not isinstance(other, Tunable):
return False
return bool(
self._name == other._name and
self._type == other._type and
self._current_value == other._current_value
)
def copy(self):
def copy(self) -> "Tunable":
"""
Deep copy of the Tunable object.
@ -101,14 +129,14 @@ class Tunable: # pylint: disable=too-many-instance-attributes
return copy.deepcopy(self)
@property
def value(self):
def value(self) -> TunableValue:
"""
Get the current value of the tunable.
"""
return self._current_value
@value.setter
def value(self, value):
def value(self, value: TunableValue) -> TunableValue:
"""
Set the current value of the tunable.
"""
@ -135,13 +163,13 @@ class Tunable: # pylint: disable=too-many-instance-attributes
self._current_value = coerced_value
return self._current_value
def is_valid(self, value) -> bool:
def is_valid(self, value: TunableValue) -> bool:
"""
Check if the value can be assigned to the tunable.
Parameters
----------
value : Any
value : Union[int, float, str]
Value to validate.
Returns
@ -149,13 +177,36 @@ class Tunable: # pylint: disable=too-many-instance-attributes
is_valid : bool
True if the value is valid, False otherwise.
"""
if self._type == "categorical":
if self.is_categorical and self._values:
return value in self._values
elif self._type in {"int", "float"} and self._range:
return (self._range[0] <= value <= self._range[1]) or value == self._default
elif self.is_numerical and self._range:
assert isinstance(value, (int, float))
return bool(self._range[0] <= value <= self._range[1]) or value == self._default
else:
raise ValueError(f"Invalid parameter type: {self._type}")
@property
def categorical_value(self) -> str:
"""
Get the current value of the tunable as a number.
"""
if self.is_categorical:
return str(self._current_value)
else:
raise ValueError("Cannot get categorical values for a numerical tunable.")
@property
def numerical_value(self) -> Union[int, float]:
"""
Get the current value of the tunable as a number.
"""
if self._type == "int":
return int(self._current_value)
elif self._type == "float":
return float(self._current_value)
else:
raise ValueError("Cannot get numerical value for a categorical tunable.")
@property
def name(self) -> str:
"""
@ -176,7 +227,31 @@ class Tunable: # pylint: disable=too-many-instance-attributes
return self._type
@property
def range(self) -> Tuple[Any, Any]:
def is_categorical(self) -> bool:
"""
Check if the tunable is categorical.
Returns
-------
is_categorical : bool
True if the tunable is categorical, False otherwise.
"""
return self._type == "categorical"
@property
def is_numerical(self) -> bool:
"""
Check if the tunable is an integer or float.
Returns
-------
is_int : bool
True if the tunable is an integer or float, False otherwise.
"""
return self._type in {"int", "float"}
@property
def range(self) -> Union[Tuple[int, int], Tuple[float, float]]:
"""
Get the range of the tunable if it is numerical, None otherwise.
@ -186,6 +261,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes
A 2-tuple of numbers that represents the range of the tunable.
Numbers can be int or float, depending on the type of the tunable.
"""
assert self.is_numerical
assert self._range is not None
return self._range
@property
@ -199,4 +276,6 @@ class Tunable: # pylint: disable=too-many-instance-attributes
values : List[str]
List of all possible values of a categorical tunable.
"""
assert self.is_categorical
assert self._values is not None
return self._values

Просмотреть файл

@ -7,9 +7,9 @@ TunableGroups definition.
"""
import copy
from typing import Any, Dict, List, Tuple
from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple
from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.covariant_group import CovariantTunableGroup
@ -18,7 +18,7 @@ class TunableGroups:
A collection of covariant groups of tunable parameters.
"""
def __init__(self, config: dict = None):
def __init__(self, config: Optional[dict] = None):
"""
Create a new group of tunable parameters.
@ -27,12 +27,14 @@ class TunableGroups:
config : dict
Python dict of serialized representation of the covariant tunable groups.
"""
self._index = {} # Index (Tunable id -> CovariantTunableGroup)
self._tunable_groups = {}
for (name, group_config) in (config or {}).items():
if config is None:
config = {}
self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup)
self._tunable_groups: Dict[str, CovariantTunableGroup] = {}
for (name, group_config) in config.items():
self._add_group(CovariantTunableGroup(name, group_config))
def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""
Check if two TunableGroups are equal.
@ -46,9 +48,11 @@ class TunableGroups:
is_equal : bool
True if two TunableGroups are equal.
"""
return self._tunable_groups == other._tunable_groups
if not isinstance(other, TunableGroups):
return False
return bool(self._tunable_groups == other._tunable_groups)
def copy(self):
def copy(self) -> "TunableGroups":
"""
Deep copy of the TunableGroups object.
@ -60,7 +64,7 @@ class TunableGroups:
"""
return copy.deepcopy(self)
def _add_group(self, group: CovariantTunableGroup):
def _add_group(self, group: CovariantTunableGroup) -> None:
"""
Add a CovariantTunableGroup to the current collection.
@ -71,7 +75,7 @@ class TunableGroups:
self._tunable_groups[group.name] = group
self._index.update(dict.fromkeys(group.get_names(), group))
def update(self, tunables):
def update(self, tunables: "TunableGroups") -> "TunableGroups":
"""
Merge the two collections of covariant tunable groups.
@ -104,20 +108,20 @@ class TunableGroups:
for (group_name, group) in self._tunable_groups.items()
for tunable in group._tunables.values()) + " }"
def __getitem__(self, name: str):
def __getitem__(self, name: str) -> TunableValue:
"""
Get the current value of a single tunable parameter.
"""
return self._index[name][name]
def __setitem__(self, name: str, value):
def __setitem__(self, name: str, value: TunableValue) -> None:
"""
Update the current value of a single tunable parameter.
"""
# Use double index to make sure we set the is_updated flag of the group
self._index[name][name] = value
def __iter__(self):
def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]:
"""
An iterator over all tunables in the group.
@ -147,7 +151,7 @@ class TunableGroups:
group = self._index[name]
return (group.get_tunable(name), group)
def get_names(self) -> List[str]:
def get_names(self) -> Iterable[str]:
"""
Get the names of all covariance groups in the collection.
@ -158,7 +162,7 @@ class TunableGroups:
"""
return self._tunable_groups.keys()
def subgroup(self, group_names: List[str]):
def subgroup(self, group_names: Iterable[str]) -> "TunableGroups":
"""
Select the covariance groups from the current set and create a new
TunableGroups object that consists of those covariance groups.
@ -179,8 +183,8 @@ class TunableGroups:
tunables._add_group(self._tunable_groups[name])
return tunables
def get_param_values(self, group_names: List[str] = None,
into_params: Dict[str, Any] = None) -> Dict[str, Any]:
def get_param_values(self, group_names: Optional[Iterable[str]] = None,
into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]:
"""
Get the current values of the tunables that belong to the specified covariance groups.
@ -205,7 +209,7 @@ class TunableGroups:
into_params.update(self._tunable_groups[name].get_values())
return into_params
def is_updated(self, group_names: List[str] = None) -> bool:
def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool:
"""
Check if any of the given covariant tunable groups has been updated.
@ -222,7 +226,7 @@ class TunableGroups:
return any(self._tunable_groups[name].is_updated()
for name in (group_names or self.get_names()))
def reset(self, group_names: List[str] = None):
def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
"""
Clear the update flag of given covariant groups.
@ -240,14 +244,14 @@ class TunableGroups:
self._tunable_groups[name].reset()
return self
def assign(self, param_values: Dict[str, Any]):
def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
"""
In-place update the values of the tunables from the dictionary
of (key, value) pairs.
Parameters
----------
param_values : Dict[str, Any]
param_values : Mapping[str, TunableValue]
Dictionary mapping Tunable parameter names to new values.
Returns

Просмотреть файл

@ -11,12 +11,13 @@ Various helper functions for mlos_bench.
import json
import logging
import importlib
from typing import Any, Tuple, Dict, Iterable
from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Type, TypeVar, TYPE_CHECKING
_LOG = logging.getLogger(__name__)
def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, Dict[str, Any]]:
def prepare_class_load(config: dict, global_config: Optional[dict] = None) -> Tuple[str, Dict[str, Any]]:
"""
Extract the class instantiation parameters from the configuration.
@ -35,7 +36,10 @@ def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, D
class_name = config["class"]
class_config = config.setdefault("config", {})
for key in set(class_config).intersection(global_config or {}):
if global_config is None:
global_config = {}
for key in set(class_config).intersection(global_config):
class_config[key] = global_config[key]
if _LOG.isEnabledFor(logging.DEBUG):
@ -45,7 +49,17 @@ def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, D
return (class_name, class_config)
def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs):
if TYPE_CHECKING:
from mlos_bench.environment.base_environment import Environment
from mlos_bench.service.base_service import Service
from mlos_bench.optimizer.base_optimizer import Optimizer
# T is a generic with a constraint of the three base classes.
T = TypeVar('T', "Environment", "Service", "Optimizer")
# FIXME: Technically this should return a type "class_name" derived from "base_class".
def instantiate_from_config(base_class: Type[T], class_name: str, *args: Any, **kwargs: Any) -> T:
"""
Factory method for a new class instantiated from config.
@ -65,7 +79,7 @@ def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs):
Returns
-------
inst : Any
inst : Union[Environment, Service, Optimizer]
An instance of the `class_name` class.
"""
# We need to import mlos_bench to make the factory methods work.
@ -78,10 +92,12 @@ def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs):
_LOG.info("Instantiating: %s :: %s", class_name, impl)
assert issubclass(impl, base_class)
return impl(*args, **kwargs)
ret: T = impl(*args, **kwargs)
assert isinstance(ret, base_class)
return ret
def check_required_params(config: Dict[str, Any], required_params: Iterable[str]):
def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None:
"""
Check if all required parameters are present in the configuration.
Raise ValueError if any of the parameters are missing.

Просмотреть файл

@ -34,13 +34,16 @@ extra_requires = {
# construct special 'full' extra that adds requirements for all built-in
# backend integrations and additional extra features.
extra_requires['full'] = list(set(chain(extra_requires.values())))
extra_requires['full'] = list(set(chain(extra_requires.values()))) # type: ignore[assignment]
# pylint: disable=duplicate-code
setup(
name='mlos-bench',
version=_VERSION,
packages=find_packages(),
package_data={
'mlos_bench': ['py.typed'],
},
install_requires=[
'mlos-core==' + _VERSION,
'requests',

Просмотреть файл

@ -7,7 +7,7 @@ Basic initializer module for the mlos_core space adapters.
"""
from enum import Enum
from typing import Optional, TypeVar, Union
from typing import Optional, TypeVar
import ConfigSpace

Просмотреть файл

@ -22,8 +22,6 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
aimed at improving the sample-efficiency of the underlying optimizer.
"""
# pylint: disable=consider-alternative-union-syntax,too-many-arguments
DEFAULT_NUM_LOW_DIMS = 16
"""Default number of dimensions in the low-dimensional search space, generated by HeSBO projection"""
@ -40,7 +38,7 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
special_param_values: Optional[dict] = None,
max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM,
use_approximate_reverse_mapping: bool = False,
) -> None:
) -> None: # pylint: disable=too-many-arguments
"""Create a space adapter that employs LlamaTune's techniques.
Parameters

Просмотреть файл

@ -54,8 +54,12 @@ disallow_untyped_defs = True
disallow_incomplete_defs = True
strict = True
allow_any_generics = True
hide_error_codes = False
# regex of files to skip type checking
exclude = /_pytest/|/build/|doc/|_version.py|setup.py
# _version.py and setup.py look like duplicates when run from the root of the repo even though they're part of different packages.
# There's not much in them so we just skip them.
# We also skip several vendor files that currently throw errors.
exclude = mlos_(core|bench)/(_version|setup).py|doc/|/build/|-packages/_pytest/
# https://github.com/automl/ConfigSpace/issues/293
[mypy-ConfigSpace.*]
@ -65,6 +69,10 @@ ignore_missing_imports = True
[mypy-emukit.*]
ignore_missing_imports = True
# https://github.com/dpranke/pyjson5/issues/65
[mypy-json5]
ignore_missing_imports = True
# https://github.com/matplotlib/matplotlib/issues/25634
[mypy-matplotlib.*]
ignore_missing_imports = True