From 9fe180a9a5d5857e178201fa4de5a74e92ac9378 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 19 Apr 2023 13:28:38 -0500 Subject: [PATCH] mypy static type checking for mlos_bench (#306) Adds static type checking via mypy for mlos_bench as well. Builds on #305, #307 - [x] Add `Protocol`s for different `Service` types so `Environment`s can check that the appropriate `Service` mix-ins have been setup. - [x] Config Loader - [x] Local Exec - [x] Remote Exec - [x] Remote Fileshare - [x] VM Ops - [ ] Future PR: Split VM Ops to VM/Host and OS Ops (for local SSH support) --- .vscode/settings.json | 5 +- Makefile | 5 +- doc/Dockerfile | 2 +- doc/source/overview.rst | 6 + .../applications/generate_redis_config.py | 2 +- .../applications/process_redis_results.py | 2 +- .../config/linux-boot/generate_grub_config.py | 2 +- .../generate_kernel_config_script.py | 2 +- .../environment/base_environment.py | 58 +++++--- .../mlos_bench/environment/composite_env.py | 26 ++-- .../mlos_bench/environment/local/local_env.py | 39 +++--- .../environment/local/local_env_fileshare.py | 33 +++-- mlos_bench/mlos_bench/environment/mock_env.py | 19 ++- .../mlos_bench/environment/remote/os_env.py | 55 +++++++- .../environment/remote/remote_env.py | 31 +++-- .../mlos_bench/environment/remote/vm_env.py | 51 ++++++- mlos_bench/mlos_bench/environment/status.py | 14 +- mlos_bench/mlos_bench/launcher.py | 23 ++-- .../mlos_bench/optimizer/base_optimizer.py | 33 ++--- .../optimizer/convert_configspace.py | 4 +- .../optimizer/mlos_core_optimizer.py | 16 +-- .../mlos_bench/optimizer/mock_optimizer.py | 29 ++-- mlos_bench/mlos_bench/py.typed | 0 mlos_bench/mlos_bench/run_bench.py | 2 +- mlos_bench/mlos_bench/run_opt.py | 18 ++- mlos_bench/mlos_bench/service/__init__.py | 5 +- .../mlos_bench/service/base_fileshare.py | 15 ++- mlos_bench/mlos_bench/service/base_service.py | 15 ++- .../mlos_bench/service/config_persistence.py | 61 +++++---- .../mlos_bench/service/local/local_exec.py | 47 ++++--- .../service/remote/azure/azure_fileshare.py | 13 +- .../service/remote/azure/azure_services.py | 50 +++---- mlos_bench/mlos_bench/service/types/README.md | 7 + .../mlos_bench/service/types/__init__.py | 22 +++ .../service/types/config_loader_type.py | 111 ++++++++++++++++ .../service/types/fileshare_type.py | 47 +++++++ .../service/types/local_exec_type.py | 68 ++++++++++ .../service/types/remote_exec_type.py | 59 +++++++++ .../service/types/vm_provisioner_type.py | 125 ++++++++++++++++++ mlos_bench/mlos_bench/tests/__init__.py | 8 ++ mlos_bench/mlos_bench/tests/conftest.py | 4 +- .../mlos_bench/tests/environment/__init__.py | 8 ++ .../tests/environment/mock_env_test.py | 15 ++- .../mlos_bench/tests/optimizer/__init__.py | 8 ++ .../tests/optimizer/llamatune_opt_test.py | 6 +- .../tests/optimizer/mock_opt_test.py | 14 +- .../tests/optimizer/opt_bulk_register_test.py | 24 ++-- .../optimizer/toy_optimization_loop_test.py | 20 +-- .../mlos_bench/tests/service/__init__.py | 8 ++ .../tests/service/config_persistence_test.py | 8 +- .../tests/service/local/__init__.py | 8 ++ .../service/local/local_exec_python_test.py | 11 +- .../tests/service/local/local_exec_test.py | 15 ++- .../tests/service/remote/__init__.py | 8 ++ .../tests/service/remote/azure/__init__.py | 8 ++ .../remote/azure/azure_fileshare_test.py | 113 +++++++++------- .../remote/azure/azure_services_test.py | 27 ++-- .../tests/service/remote/azure/conftest.py | 8 +- .../mlos_bench/tests/tunables/__init__.py | 8 ++ .../tunables/tunable_to_configspace_test.py | 16 +-- .../tests/tunables/tunables_assign_test.py | 27 ++-- .../tests/tunables/tunables_copy_test.py | 9 +- .../mlos_bench/tunables/covariant_group.py | 24 ++-- mlos_bench/mlos_bench/tunables/tunable.py | 113 +++++++++++++--- .../mlos_bench/tunables/tunable_groups.py | 48 ++++--- mlos_bench/mlos_bench/util.py | 30 ++++- mlos_bench/setup.py | 5 +- .../mlos_core/spaces/adapters/__init__.py | 2 +- .../mlos_core/spaces/adapters/llamatune.py | 4 +- setup.cfg | 10 +- 70 files changed, 1291 insertions(+), 448 deletions(-) create mode 100644 mlos_bench/mlos_bench/py.typed create mode 100644 mlos_bench/mlos_bench/service/types/README.md create mode 100644 mlos_bench/mlos_bench/service/types/__init__.py create mode 100644 mlos_bench/mlos_bench/service/types/config_loader_type.py create mode 100644 mlos_bench/mlos_bench/service/types/fileshare_type.py create mode 100644 mlos_bench/mlos_bench/service/types/local_exec_type.py create mode 100644 mlos_bench/mlos_bench/service/types/remote_exec_type.py create mode 100644 mlos_bench/mlos_bench/service/types/vm_provisioner_type.py create mode 100644 mlos_bench/mlos_bench/tests/__init__.py create mode 100644 mlos_bench/mlos_bench/tests/environment/__init__.py create mode 100644 mlos_bench/mlos_bench/tests/optimizer/__init__.py create mode 100644 mlos_bench/mlos_bench/tests/service/__init__.py create mode 100644 mlos_bench/mlos_bench/tests/service/local/__init__.py create mode 100644 mlos_bench/mlos_bench/tests/service/remote/__init__.py create mode 100644 mlos_bench/mlos_bench/tests/service/remote/azure/__init__.py create mode 100644 mlos_bench/mlos_bench/tests/tunables/__init__.py diff --git a/.vscode/settings.json b/.vscode/settings.json index c88e5ac0b4..106ebc0235 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,7 +1,10 @@ // vim: set ft=jsonc: { "makefile.extensionOutputFolder": "./.vscode", - "python.defaultInterpreterPath": "${env:HOME}${env:USERPROFILE}/.conda/envs/mlos_core/bin/python", + // Note: this only works in WSL/Linux currently. + "python.defaultInterpreterPath": "${env:HOME}/.conda/envs/mlos_core/bin/python", + // For Windows it should be this instead: + //"python.defaultInterpreterPath": "${env:USERPROFILE}/.conda/envs/mlos_core/python.exe", "python.testing.pytestEnabled": true, "python.linting.enabled": true, "python.linting.pylintEnabled": true, diff --git a/Makefile b/Makefile index ffa3b991b9..232ea6b268 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ MLOS_BENCH_PYTHON_FILES := $(shell find ./mlos_bench/ -type f -name '*.py' 2>/de DOCKER := $(shell which docker) # Make sure the build directory exists. -MKDIR_BUILD := $(shell mkdir -p build) +MKDIR_BUILD := $(shell test -d build || mkdir build) # Allow overriding the default verbosity of conda for CI jobs. #CONDA_INFO_LEVEL ?= -q @@ -113,7 +113,7 @@ build/pylint.%.${CONDA_ENV_NAME}.build-stamp: build/conda-env.${CONDA_ENV_NAME}. touch $@ .PHONY: mypy -mypy: conda-env build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp # TODO: build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp +mypy: conda-env build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp build/mypy.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES) build/mypy.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES) @@ -339,6 +339,7 @@ build/check-doc.build-stamp: doc/build/html/index.html doc/build/html/htmlcov/in -e 'Problems with "include" directive path:' \ -e 'duplicate object description' \ -e "document isn't included in any toctree" \ + -e "more than one target found for cross-reference" \ -e "toctree contains reference to nonexisting document 'auto_examples/index'" \ -e "failed to import function 'create' from module '(SpaceAdapter|Optimizer)Factory'" \ -e "No module named '(SpaceAdapter|Optimizer)Factory'" \ diff --git a/doc/Dockerfile b/doc/Dockerfile index 4a1f79a89e..2cfb621c11 100644 --- a/doc/Dockerfile +++ b/doc/Dockerfile @@ -1,6 +1,6 @@ FROM nginx:latest RUN apt-get update && apt-get install -y --no-install-recommends linklint curl -# Our ngxinx config overrides the default listening port. +# Our nginx config overrides the default listening port. ARG NGINX_PORT=81 ENV NGINX_PORT=${NGINX_PORT} EXPOSE ${NGINX_PORT} diff --git a/doc/source/overview.rst b/doc/source/overview.rst index 1fcfa741ce..61e7d6297d 100644 --- a/doc/source/overview.rst +++ b/doc/source/overview.rst @@ -156,6 +156,12 @@ Service Mix-ins Service FileShareService + +.. currentmodule:: mlos_bench.service.config_persistence +.. autosummary:: + :toctree: generated/ + :template: class.rst + ConfigPersistenceService Local Services diff --git a/mlos_bench/config/applications/generate_redis_config.py b/mlos_bench/config/applications/generate_redis_config.py index 82aef10772..c1850f5e03 100644 --- a/mlos_bench/config/applications/generate_redis_config.py +++ b/mlos_bench/config/applications/generate_redis_config.py @@ -13,7 +13,7 @@ import json import argparse -def _main(fname_input: str, fname_output: str): +def _main(fname_input: str, fname_output: str) -> None: with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: for (key, val) in json.load(fh_tunables).items(): diff --git a/mlos_bench/config/applications/process_redis_results.py b/mlos_bench/config/applications/process_redis_results.py index 0e9e43d3be..537f8c3501 100644 --- a/mlos_bench/config/applications/process_redis_results.py +++ b/mlos_bench/config/applications/process_redis_results.py @@ -12,7 +12,7 @@ import argparse import pandas as pd -def _main(input_file: str, output_file: str): +def _main(input_file: str, output_file: str) -> None: """ Re-shape Redis benchmark CSV results from wide to long. """ diff --git a/mlos_bench/config/linux-boot/generate_grub_config.py b/mlos_bench/config/linux-boot/generate_grub_config.py index 51e5f3acd2..d03e4f5771 100644 --- a/mlos_bench/config/linux-boot/generate_grub_config.py +++ b/mlos_bench/config/linux-boot/generate_grub_config.py @@ -13,7 +13,7 @@ import json import argparse -def _main(fname_input: str, fname_output: str): +def _main(fname_input: str, fname_output: str) -> None: with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: for (key, val) in json.load(fh_tunables).items(): diff --git a/mlos_bench/config/linux-setup/generate_kernel_config_script.py b/mlos_bench/config/linux-setup/generate_kernel_config_script.py index a18423e712..3b90e8ec1a 100644 --- a/mlos_bench/config/linux-setup/generate_kernel_config_script.py +++ b/mlos_bench/config/linux-setup/generate_kernel_config_script.py @@ -13,7 +13,7 @@ import json import argparse -def _main(fname_input: str, fname_output: str): +def _main(fname_input: str, fname_output: str) -> None: with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: for (key, val) in json.load(fh_tunables).items(): diff --git a/mlos_bench/mlos_bench/environment/base_environment.py b/mlos_bench/mlos_bench/environment/base_environment.py index bb506a89ba..593bfa48f2 100644 --- a/mlos_bench/mlos_bench/environment/base_environment.py +++ b/mlos_bench/mlos_bench/environment/base_environment.py @@ -9,9 +9,12 @@ A hierarchy of benchmark environments. import abc import json import logging -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple from mlos_bench.environment.status import Status +from mlos_bench.service.base_service import Service +from mlos_bench.service.types.config_loader_type import SupportsConfigLoading +from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import instantiate_from_config @@ -19,13 +22,21 @@ _LOG = logging.getLogger(__name__) class Environment(metaclass=abc.ABCMeta): + # pylint: disable=too-many-instance-attributes """ An abstract base of all benchmark environments. """ @classmethod - def new(cls, env_name, class_name, config, global_config=None, tunables=None, service=None): - # pylint: disable=too-many-arguments + def new(cls, + *, + env_name: str, + class_name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -58,8 +69,13 @@ class Environment(metaclass=abc.ABCMeta): return instantiate_from_config(cls, class_name, env_name, config, global_config, tunables, service) - def __init__(self, name, config, global_config=None, tunables=None, service=None): - # pylint: disable=too-many-arguments + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment with a given config. @@ -84,10 +100,17 @@ class Environment(metaclass=abc.ABCMeta): self.config = config self._service = service self._is_ready = False - self._params = {} + self._params: Dict[str, TunableValue] = {} + + self._config_loader_service: SupportsConfigLoading + if self._service is not None and isinstance(self._service, SupportsConfigLoading): + self._config_loader_service = self._service + + if global_config is None: + global_config = {} self._const_args = config.get("const_args", {}) - for key in set(self._const_args).intersection(global_config or {}): + for key in set(self._const_args).intersection(global_config): self._const_args[key] = global_config[key] for key in config.get("required_args", []): @@ -110,13 +133,13 @@ class Environment(metaclass=abc.ABCMeta): _LOG.debug("Config for: %s\n%s", name, json.dumps(self.config, indent=2)) - def __str__(self): + def __str__(self) -> str: return self.name - def __repr__(self): + def __repr__(self) -> str: return f"Env: {self.__class__} :: '{self.name}'" - def _combine_tunables(self, tunables): + def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ Plug tunable values into the base config. If the tunable group is unknown, ignore it (it might belong to another environment). This method should @@ -130,14 +153,14 @@ class Environment(metaclass=abc.ABCMeta): Returns ------- - params : dict + params : Dict[str, Union[int, float, str]] Free-format dictionary that contains the new environment configuration. """ return tunables.get_param_values( - group_names=self._tunable_params.get_names(), + group_names=list(self._tunable_params.get_names()), into_params=self._const_args.copy()) - def tunable_params(self): + def tunable_params(self) -> TunableGroups: """ Get the configuration space of the given environment. @@ -148,7 +171,7 @@ class Environment(metaclass=abc.ABCMeta): """ return self._tunable_params - def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Set up a new benchmark environment, if necessary. This method must be idempotent, i.e., calling it several times in a row should be @@ -170,15 +193,18 @@ class Environment(metaclass=abc.ABCMeta): _LOG.info("Setup %s :: %s", self, tunables) assert isinstance(tunables, TunableGroups) + if global_config is None: + global_config = {} + self._params = self._combine_tunables(tunables) - for key in set(self._params).intersection(global_config or {}): + for key in set(self._params).intersection(global_config): self._params[key] = global_config[key] if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2)) return True - def teardown(self): + def teardown(self) -> None: """ Tear down the benchmark environment. This method must be idempotent, i.e., calling it several times in a row should be equivalent to a diff --git a/mlos_bench/mlos_bench/environment/composite_env.py b/mlos_bench/mlos_bench/environment/composite_env.py index 520be16f46..a2b8115910 100644 --- a/mlos_bench/mlos_bench/environment/composite_env.py +++ b/mlos_bench/mlos_bench/environment/composite_env.py @@ -7,7 +7,7 @@ Composite benchmark environment. """ import logging -from typing import Optional, Tuple +from typing import List, Optional, Tuple from mlos_bench.service.base_service import Service from mlos_bench.environment.status import Status @@ -22,9 +22,13 @@ class CompositeEnv(Environment): Composite benchmark environment. """ - def __init__(self, name: str, config: dict, global_config: dict = None, - tunables: TunableGroups = None, service: Service = None): - # pylint: disable=too-many-arguments + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment with a given config. @@ -44,23 +48,23 @@ class CompositeEnv(Environment): An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name, config, global_config, tunables, service) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) - self._children = [] + self._children: List[Environment] = [] for child_config_file in config.get("include_children", []): - for env in self._service.load_environment_list( + for env in self._config_loader_service.load_environment_list( child_config_file, global_config, tunables, self._service): self._add_child(env) for child_config in config.get("children", []): - self._add_child(self._service.build_environment( + self._add_child(self._config_loader_service.build_environment( child_config, global_config, tunables, self._service)) if not self._children: raise ValueError("At least one child environment must be present") - def _add_child(self, env: Environment): + def _add_child(self, env: Environment) -> None: """ Add a new child environment to the composite environment. This method is called from the constructor only. @@ -68,7 +72,7 @@ class CompositeEnv(Environment): self._children.append(env) self._tunable_params.update(env.tunable_params()) - def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Set up the children environments. @@ -92,7 +96,7 @@ class CompositeEnv(Environment): ) return self._is_ready - def teardown(self): + def teardown(self) -> None: """ Tear down the children environments. This method is idempotent, i.e., calling it several times is equivalent to a single call. diff --git a/mlos_bench/mlos_bench/environment/local/local_env.py b/mlos_bench/mlos_bench/environment/local/local_env.py index 388bed39b9..f3ee90bbdc 100644 --- a/mlos_bench/mlos_bench/environment/local/local_env.py +++ b/mlos_bench/mlos_bench/environment/local/local_env.py @@ -17,23 +17,25 @@ import pandas from mlos_bench.environment.status import Status from mlos_bench.environment.base_environment import Environment from mlos_bench.service.base_service import Service +from mlos_bench.service.types.local_exec_type import SupportsLocalExec from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) class LocalEnv(Environment): + # pylint: disable=too-many-instance-attributes """ Scheduler-side Environment that runs scripts locally. """ def __init__(self, + *, name: str, config: dict, global_config: Optional[dict] = None, tunables: Optional[TunableGroups] = None, service: Optional[Service] = None): - # pylint: disable=too-many-arguments """ Create a new environment for local execution. @@ -56,15 +58,19 @@ class LocalEnv(Environment): An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name, config, global_config, tunables, service) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ + "LocalEnv requires a service that supports local execution" + self._local_exec_service: SupportsLocalExec = self._service self._temp_dir = self.config.get("temp_dir") self._script_setup = self.config.get("setup") self._script_run = self.config.get("run") self._script_teardown = self.config.get("teardown") - self._dump_params_file = self.config.get("dump_params_file") - self._read_results_file = self.config.get("read_results_file") + self._dump_params_file: Optional[str] = self.config.get("dump_params_file") + self._read_results_file: Optional[str] = self.config.get("read_results_file") if self._script_setup is None and \ self._script_run is None and \ @@ -77,7 +83,7 @@ class LocalEnv(Environment): if self._script_run is None and self._read_results_file is not None: raise ValueError("'run' must be present if 'read_results_file' is specified") - def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if the environment is ready and set up the application and benchmarks, if necessary. @@ -105,7 +111,7 @@ class LocalEnv(Environment): self._is_ready = True return True - with self._service.temp_dir_context(self._temp_dir) as temp_dir: + with self._local_exec_service.temp_dir_context(self._temp_dir) as temp_dir: _LOG.info("Set up the environment locally: %s at %s", self, temp_dir) @@ -116,7 +122,7 @@ class LocalEnv(Environment): json.dump(tunables.get_param_values(self._tunable_params.get_names()), fh_tunables) - (return_code, _stdout, stderr) = self._service.local_exec( + (return_code, _stdout, stderr) = self._local_exec_service.local_exec( self._script_setup, env=self._params, cwd=temp_dir) if return_code == 0: @@ -125,7 +131,7 @@ class LocalEnv(Environment): _LOG.warning("ERROR: Local setup returns with code %d stderr:\n%s", return_code, stderr) - self._is_ready = return_code == 0 + self._is_ready = bool(return_code == 0) return self._is_ready @@ -145,10 +151,10 @@ class LocalEnv(Environment): if not (status.is_ready and self._script_run): return result - with self._service.temp_dir_context(self._temp_dir) as temp_dir: + with self._local_exec_service.temp_dir_context(self._temp_dir) as temp_dir: _LOG.info("Run script locally on: %s at %s", self, temp_dir) - (return_code, _stdout, stderr) = self._service.local_exec( + (return_code, _stdout, stderr) = self._local_exec_service.local_exec( self._script_run, env=self._params, cwd=temp_dir) if return_code != 0: @@ -156,23 +162,24 @@ class LocalEnv(Environment): return_code, stderr) return (Status.FAILED, None) - data = pandas.read_csv(self._service.resolve_path( + assert self._read_results_file is not None + data = pandas.read_csv(self._config_loader_service.resolve_path( self._read_results_file, extra_paths=[temp_dir])) _LOG.debug("Read data:\n%s", data) if len(data) != 1: _LOG.warning("Local run has %d results - returning the last one", len(data)) - data = data.iloc[-1].to_dict() - _LOG.info("Local run complete: %s ::\n%s", self, data) - return (Status.SUCCEEDED, data) + data_dict = data.iloc[-1].to_dict() + _LOG.info("Local run complete: %s ::\n%s", self, data_dict) + return (Status.SUCCEEDED, data_dict) - def teardown(self): + def teardown(self) -> None: """ Clean up the local environment. """ if self._script_teardown: _LOG.info("Local teardown: %s", self) - (status, _) = self._service.local_exec(self._script_teardown, env=self._params) + (status, _, _) = self._local_exec_service.local_exec(self._script_teardown, env=self._params) _LOG.info("Local teardown complete: %s :: %s", self, status) super().teardown() diff --git a/mlos_bench/mlos_bench/environment/local/local_env_fileshare.py b/mlos_bench/mlos_bench/environment/local/local_env_fileshare.py index b5fbe75ae2..da14a5a259 100644 --- a/mlos_bench/mlos_bench/environment/local/local_env_fileshare.py +++ b/mlos_bench/mlos_bench/environment/local/local_env_fileshare.py @@ -10,11 +10,14 @@ and upload/download data to the shared storage. import logging from string import Template -from typing import Optional, Dict, List, Tuple, Any +from typing import List, Generator, Iterable, Mapping, Optional, Tuple from mlos_bench.service.base_service import Service +from mlos_bench.service.types.local_exec_type import SupportsLocalExec +from mlos_bench.service.types.fileshare_type import SupportsFileShareOps from mlos_bench.environment.status import Status from mlos_bench.environment.local.local_env import LocalEnv +from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) @@ -27,12 +30,12 @@ class LocalFileShareEnv(LocalEnv): """ def __init__(self, + *, name: str, config: dict, global_config: Optional[dict] = None, tunables: Optional[TunableGroups] = None, service: Optional[Service] = None): - # pylint: disable=too-many-arguments """ Create a new application environment with a given config. @@ -56,7 +59,16 @@ class LocalFileShareEnv(LocalEnv): An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name, config, global_config, tunables, service) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ + "LocalEnv requires a service that supports local execution" + self._local_exec_service: SupportsLocalExec = self._service + + assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \ + "LocalEnv requires a service that supports file upload/download operations" + self._file_share_service: SupportsFileShareOps = self._service + self._upload = self._template_from_to("upload") self._download = self._template_from_to("download") @@ -71,7 +83,8 @@ class LocalFileShareEnv(LocalEnv): ] @staticmethod - def _expand(from_to: List[Tuple[Template, Template]], params: Dict[str, Any]): + def _expand(from_to: Iterable[Tuple[Template, Template]], + params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]: """ Substitute $var parameters in from/to path templates. Return a generator of (str, str) pairs of paths. @@ -81,7 +94,7 @@ class LocalFileShareEnv(LocalEnv): for (path_from, path_to) in from_to ) - def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Run setup scripts locally and upload the scripts and data to the shared storage. @@ -102,14 +115,14 @@ class LocalFileShareEnv(LocalEnv): True if operation is successful, false otherwise. """ prev_temp_dir = self._temp_dir - with self._service.temp_dir_context(self._temp_dir) as self._temp_dir: + with self._local_exec_service.temp_dir_context(self._temp_dir) as self._temp_dir: # Override _temp_dir so that setup and upload both use the same path. self._is_ready = super().setup(tunables, global_config) if self._is_ready: params = self._params.copy() params["PWD"] = self._temp_dir for (path_from, path_to) in self._expand(self._upload, params): - self._service.upload(self._service.resolve_path( + self._file_share_service.upload(self._config_loader_service.resolve_path( path_from, extra_paths=[self._temp_dir]), path_to) self._temp_dir = prev_temp_dir return self._is_ready @@ -128,13 +141,13 @@ class LocalFileShareEnv(LocalEnv): be in the `score` field. """ prev_temp_dir = self._temp_dir - with self._service.temp_dir_context(self._temp_dir) as self._temp_dir: + with self._local_exec_service.temp_dir_context(self._temp_dir) as self._temp_dir: # Override _temp_dir so that download and run both use the same path. params = self._params.copy() params["PWD"] = self._temp_dir for (path_from, path_to) in self._expand(self._download, params): - self._service.download( - path_from, self._service.resolve_path( + self._file_share_service.download( + path_from, self._config_loader_service.resolve_path( path_to, extra_paths=[self._temp_dir])) result = super().run() self._temp_dir = prev_temp_dir diff --git a/mlos_bench/mlos_bench/environment/mock_env.py b/mlos_bench/mlos_bench/environment/mock_env.py index 04822546d1..b9f40fc53e 100644 --- a/mlos_bench/mlos_bench/environment/mock_env.py +++ b/mlos_bench/mlos_bench/environment/mock_env.py @@ -12,9 +12,9 @@ from typing import Optional, Tuple import numpy +from mlos_bench.service.base_service import Service from mlos_bench.environment.status import Status from mlos_bench.environment.base_environment import Environment -from mlos_bench.service.base_service import Service from mlos_bench.tunables import Tunable, TunableGroups _LOG = logging.getLogger(__name__) @@ -28,12 +28,12 @@ class MockEnv(Environment): _NOISE_VAR = 0.2 # Variance of the Gaussian noise added to the benchmark value. def __init__(self, + *, name: str, config: dict, global_config: Optional[dict] = None, tunables: Optional[TunableGroups] = None, service: Optional[Service] = None): - # pylint: disable=too-many-arguments """ Create a new environment that produces mock benchmark data. @@ -52,7 +52,7 @@ class MockEnv(Environment): service: Service An optional service object. Not used by this class. """ - super().__init__(name, config, global_config, tunables, service) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) seed = self.config.get("seed") self._random = random.Random(seed) if seed is not None else None self._range = self.config.get("range") @@ -95,15 +95,14 @@ class MockEnv(Environment): That is, map current value to the [0, 1] range. """ val = None - if tunable.type == "categorical": - val = (tunable.categorical_values.index(tunable.value) / + if tunable.is_categorical: + val = (tunable.categorical_values.index(tunable.categorical_value) / float(len(tunable.categorical_values) - 1)) - elif tunable.type in {"int", "float"}: - if not tunable.range: - raise ValueError("Tunable must have a range: " + tunable.name) - val = ((tunable.value - tunable.range[0]) / + elif tunable.is_numerical: + val = ((tunable.numerical_value - tunable.range[0]) / float(tunable.range[1] - tunable.range[0])) else: raise ValueError("Invalid parameter type: " + tunable.type) # Explicitly clip the value in case of numerical errors. - return numpy.clip(val, 0, 1) + ret: float = numpy.clip(val, 0, 1) + return ret diff --git a/mlos_bench/mlos_bench/environment/remote/os_env.py b/mlos_bench/mlos_bench/environment/remote/os_env.py index 89921ae3be..dc9a649181 100644 --- a/mlos_bench/mlos_bench/environment/remote/os_env.py +++ b/mlos_bench/mlos_bench/environment/remote/os_env.py @@ -6,9 +6,14 @@ OS-level remote Environment on Azure. """ +from typing import Optional + import logging -from mlos_bench.environment import Environment, Status +from mlos_bench.environment.base_environment import Environment +from mlos_bench.environment.status import Status +from mlos_bench.service.base_service import Service +from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) @@ -19,7 +24,43 @@ class OSEnv(Environment): OS Level Environment for a host. """ - def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool: + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): + """ + Create a new environment for remote execution. + + Parameters + ---------- + name: str + Human-readable name of the environment. + config : dict + Free-format dictionary that contains the benchmark environment + configuration. Each config must have at least the "tunable_params" + and the "const_args" sections. + `RemoteEnv` must also have at least some of the following parameters: + {setup, run, teardown, wait_boot} + global_config : dict + Free-format dictionary of global parameters (e.g., security credentials) + to be mixed in into the "const_args" section of the local config. + tunables : TunableGroups + A collection of tunable parameters for *all* environments. + service: Service + An optional service object (e.g., providing methods to + deploy or reboot a VM, etc.). + """ + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + + # TODO: Refactor this as "host" and "os" operations to accommodate SSH service. + assert self._service is not None and isinstance(self._service, SupportsVMOps), \ + "RemoteEnv requires a service that supports host operations" + self._host_service: SupportsVMOps = self._service + + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if the host is up and running; boot it, if necessary. @@ -43,21 +84,21 @@ class OSEnv(Environment): if not super().setup(tunables, global_config): return False - (status, params) = self._service.vm_start(self._params) + (status, params) = self._host_service.vm_start(self._params) if status.is_pending: - (status, _) = self._service.wait_vm_operation(params) + (status, _) = self._host_service.wait_vm_operation(params) self._is_ready = status in {Status.SUCCEEDED, Status.READY} return self._is_ready - def teardown(self): + def teardown(self) -> None: """ Clean up and shut down the host without deprovisioning it. """ _LOG.info("OS tear down: %s", self) - (status, params) = self._service.vm_stop() + (status, params) = self._host_service.vm_stop() if status.is_pending: - (status, _) = self._service.wait_vm_operation(params) + (status, _) = self._host_service.wait_vm_operation(params) super().teardown() _LOG.debug("Final status of OS stopping: %s :: %s", self, status) diff --git a/mlos_bench/mlos_bench/environment/remote/remote_env.py b/mlos_bench/mlos_bench/environment/remote/remote_env.py index 404218aba8..abaf65fe31 100644 --- a/mlos_bench/mlos_bench/environment/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environment/remote/remote_env.py @@ -7,11 +7,13 @@ Remotely executed benchmark/script environment. """ import logging -from typing import Optional, Tuple, List +from typing import Iterable, Optional, Tuple from mlos_bench.environment.status import Status from mlos_bench.environment.base_environment import Environment from mlos_bench.service.base_service import Service +from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec +from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) @@ -23,12 +25,12 @@ class RemoteEnv(Environment): """ def __init__(self, + *, name: str, config: dict, global_config: Optional[dict] = None, tunables: Optional[TunableGroups] = None, service: Optional[Service] = None): - # pylint: disable=too-many-arguments """ Create a new environment for remote execution. @@ -51,7 +53,16 @@ class RemoteEnv(Environment): An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name, config, global_config, tunables, service) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \ + "RemoteEnv requires a service that supports remote execution operations" + self._remote_exec_service: SupportsRemoteExec = self._service + + # TODO: Refactor this as "host" and "os" operations to accommodate SSH service. + assert self._service is not None and isinstance(self._service, SupportsVMOps), \ + "RemoteEnv requires a service that supports host operations" + self._host_service: SupportsVMOps = self._service self._wait_boot = self.config.get("wait_boot", False) self._script_setup = self.config.get("setup") @@ -65,7 +76,7 @@ class RemoteEnv(Environment): raise ValueError("At least one of {setup, run, teardown}" + " must be present or wait_boot set to True.") - def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if the environment is ready and set up the application and benchmarks on a remote host. @@ -89,9 +100,9 @@ class RemoteEnv(Environment): if self._wait_boot: _LOG.info("Wait for the remote environment to start: %s", self) - (status, params) = self._service.vm_start(self._params) + (status, params) = self._host_service.vm_start(self._params) if status.is_pending: - (status, _) = self._service.wait_vm_operation(params) + (status, _) = self._host_service.wait_vm_operation(params) if not status.is_succeeded: return False @@ -130,7 +141,7 @@ class RemoteEnv(Environment): _LOG.info("Remote run complete: %s :: %s", self, result) return result - def teardown(self): + def teardown(self) -> None: """ Clean up and shut down the remote environment. """ @@ -140,7 +151,7 @@ class RemoteEnv(Environment): _LOG.info("Remote teardown complete: %s :: %s", self, status) super().teardown() - def _remote_exec(self, script: List[str]) -> Tuple[Status, Optional[dict]]: + def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, Optional[dict]]: """ Run a script on the remote host. @@ -156,10 +167,10 @@ class RemoteEnv(Environment): Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} """ _LOG.debug("Submit script: %s", self) - (status, output) = self._service.remote_exec(script, self._params) + (status, output) = self._remote_exec_service.remote_exec(script, self._params) _LOG.debug("Script submitted: %s %s :: %s", self, status, output) if status in {Status.PENDING, Status.SUCCEEDED}: - (status, output) = self._service.get_remote_exec_results(output) + (status, output) = self._remote_exec_service.get_remote_exec_results(output) # TODO: extract the results from `output`. _LOG.debug("Status: %s :: %s", status, output) return (status, output) diff --git a/mlos_bench/mlos_bench/environment/remote/vm_env.py b/mlos_bench/mlos_bench/environment/remote/vm_env.py index db69b4ccbe..dd44ca5113 100644 --- a/mlos_bench/mlos_bench/environment/remote/vm_env.py +++ b/mlos_bench/mlos_bench/environment/remote/vm_env.py @@ -6,9 +6,13 @@ "Remote" VM Environment. """ +from typing import Optional + import logging -from mlos_bench.environment import Environment +from mlos_bench.environment.base_environment import Environment +from mlos_bench.service.base_service import Service +from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) @@ -19,7 +23,40 @@ class VMEnv(Environment): "Remote" VM environment. """ - def setup(self, tunables: TunableGroups, global_config: dict = None) -> bool: + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): + """ + Create a new environment for VM operations. + + Parameters + ---------- + name: str + Human-readable name of the environment. + config : dict + Free-format dictionary that contains the benchmark environment + configuration. Each config must have at least the "tunable_params" + and the "const_args" sections. + global_config : dict + Free-format dictionary of global parameters (e.g., security credentials) + to be mixed in into the "const_args" section of the local config. + tunables : TunableGroups + A collection of tunable parameters for *all* environments. + service: Service + An optional service object (e.g., providing methods to + deploy or reboot a VM, etc.). + """ + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsVMOps), \ + "VMEnv requires a service that supports VM operations" + self._vm_service: SupportsVMOps = self._service + + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if VM is ready. (Re)provision and start it, if necessary. @@ -43,21 +80,21 @@ class VMEnv(Environment): if not super().setup(tunables, global_config): return False - (status, params) = self._service.vm_provision(self._params) + (status, params) = self._vm_service.vm_provision(self._params) if status.is_pending: - (status, _) = self._service.wait_vm_deployment(True, params) + (status, _) = self._vm_service.wait_vm_deployment(True, params) self._is_ready = status.is_succeeded return self._is_ready - def teardown(self): + def teardown(self) -> None: """ Shut down the VM and release it. """ _LOG.info("VM tear down: %s", self) - (status, params) = self._service.vm_deprovision() + (status, params) = self._vm_service.vm_deprovision() if status.is_pending: - (status, _) = self._service.wait_vm_deployment(False, params) + (status, _) = self._vm_service.wait_vm_deployment(False, params) super().teardown() _LOG.debug("Final status of VM deprovisioning: %s :: %s", self, status) diff --git a/mlos_bench/mlos_bench/environment/status.py b/mlos_bench/mlos_bench/environment/status.py index 34ea454ae2..f1cce4e890 100644 --- a/mlos_bench/mlos_bench/environment/status.py +++ b/mlos_bench/mlos_bench/environment/status.py @@ -24,7 +24,7 @@ class Status(enum.Enum): TIMED_OUT = 7 @property - def is_good(self): + def is_good(self) -> bool: """ Check if the status of the benchmark/environment is good. """ @@ -36,42 +36,42 @@ class Status(enum.Enum): } @property - def is_pending(self): + def is_pending(self) -> bool: """ Check if the status of the benchmark/environment is PENDING. """ return self == Status.PENDING @property - def is_ready(self): + def is_ready(self) -> bool: """ Check if the status of the benchmark/environment is READY. """ return self == Status.READY @property - def is_succeeded(self): + def is_succeeded(self) -> bool: """ Check if the status of the benchmark/environment is SUCCEEDED. """ return self == Status.SUCCEEDED @property - def is_failed(self): + def is_failed(self) -> bool: """ Check if the status of the benchmark/environment is FAILED. """ return self == Status.FAILED @property - def is_canceled(self): + def is_canceled(self) -> bool: """ Check if the status of the benchmark/environment is CANCELED. """ return self == Status.CANCELED @property - def is_timed_out(self): + def is_timed_out(self) -> bool: """ Check if the status of the benchmark/environment is TIMEDOUT. """ diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index 4a691afff5..203dafec3e 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -9,10 +9,11 @@ Helper functions to launch the benchmark and the optimizer from the command line import logging import argparse -from typing import List, Dict, Any +from typing import Any, Dict, Iterable -from mlos_bench.environment import Environment -from mlos_bench.service import LocalExecService, ConfigPersistenceService +from mlos_bench.environment.base_environment import Environment +from mlos_bench.service.local.local_exec import LocalExecService +from mlos_bench.service.config_persistence import ConfigPersistenceService _LOG_LEVEL = logging.INFO _LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s' @@ -30,9 +31,9 @@ class Launcher: _LOG.info("Launch: %s", description) - self._config_loader = None - self._env_config_file = None - self._global_config = {} + self._config_loader: ConfigPersistenceService + self._env_config_file: str + self._global_config: Dict[str, Any] = {} self._parser = argparse.ArgumentParser(description=description) self._parser.add_argument( @@ -89,7 +90,9 @@ class Launcher: if args.globals is not None: for config_file in args.globals: - self._global_config.update(self._config_loader.load_config(config_file)) + conf = self._config_loader.load_config(config_file) + assert isinstance(conf, dict) + self._global_config.update(conf) self._global_config.update(Launcher._try_parse_extra_args(args_rest)) if args.config_path: @@ -98,7 +101,7 @@ class Launcher: return args @staticmethod - def _try_parse_extra_args(cmdline: List[str]) -> Dict[str, str]: + def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, str]: """ Helper function to parse global key/value pairs from the command line. """ @@ -132,7 +135,9 @@ class Launcher: Load JSON config file. Use path relative to `config_path` if required. """ assert self._config_loader is not None, "Call after invoking .parse_args()" - return self._config_loader.load_config(json_file_name) + conf = self._config_loader.load_config(json_file_name) + assert isinstance(conf, dict) + return conf def load_env(self) -> Environment: """ diff --git a/mlos_bench/mlos_bench/optimizer/base_optimizer.py b/mlos_bench/mlos_bench/optimizer/base_optimizer.py index 323c80322b..0b6895087b 100644 --- a/mlos_bench/mlos_bench/optimizer/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizer/base_optimizer.py @@ -8,7 +8,7 @@ and mlos_core optimizers. """ import logging -from typing import Optional, Tuple, List, Union +from typing import Dict, Optional, Sequence, Tuple, Union from abc import ABCMeta, abstractmethod from mlos_bench.environment.status import Status @@ -24,7 +24,7 @@ class Optimizer(metaclass=ABCMeta): """ @staticmethod - def load(tunables: TunableGroups, config: dict, global_config: dict = None): + def load(tunables: TunableGroups, config: dict, global_config: Optional[dict] = None) -> "Optimizer": """ Instantiate the Optimizer shim from the configuration. @@ -48,7 +48,7 @@ class Optimizer(metaclass=ABCMeta): return opt @classmethod - def new(cls, class_name: str, tunables: TunableGroups, config: dict): + def new(cls, class_name: str, tunables: TunableGroups, config: dict) -> "Optimizer": """ Factory method for a new optimizer with a given config. @@ -89,11 +89,13 @@ class Optimizer(metaclass=ABCMeta): self._tunables = tunables self._iter = 1 self._max_iter = int(self._config.pop('max_iterations', 10)) - self._opt_target = self._config.pop('maximize', None) - if self._opt_target is None: + self._opt_target: str + _opt_target = self._config.pop('maximize', None) + if _opt_target is None: self._opt_target = self._config.pop('minimize', 'score') self._opt_sign = 1 else: + self._opt_target = _opt_target if 'minimize' in self._config: raise ValueError("Cannot specify both 'maximize' and 'minimize'.") self._opt_sign = -1 @@ -110,18 +112,18 @@ class Optimizer(metaclass=ABCMeta): return self._opt_target @abstractmethod - def bulk_register(self, configs: List[dict], scores: List[float], - status: Optional[List[Status]] = None) -> bool: + def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]], + status: Optional[Sequence[Status]] = None) -> bool: """ Pre-load the optimizer with the bulk data from previous experiments. Parameters ---------- - configs : List[dict] + configs : Sequence[dict] Records of tunable values from other experiments. - scores : List[float] + scores : Sequence[float] Benchmark results from experiments that correspond to `configs`. - status : Optional[List[float]] + status : Optional[Sequence[float]] Status of the experiments that correspond to `configs`. Returns @@ -152,7 +154,7 @@ class Optimizer(metaclass=ABCMeta): @abstractmethod def register(self, tunables: TunableGroups, status: Status, - score: Union[float, dict] = None) -> float: + score: Optional[Union[float, Dict[str, float]]] = None) -> Optional[float]: """ Register the observation for the given configuration. @@ -163,7 +165,7 @@ class Optimizer(metaclass=ABCMeta): Usually it's the same config that the `.suggest()` method returned. status : Status Final status of the experiment (e.g., SUCCEEDED or FAILED). - score : Union[float, dict] + score : Union[float, Dict[str, float]] A scalar or a dict with the final benchmark results. None if the experiment was not successful. @@ -178,7 +180,7 @@ class Optimizer(metaclass=ABCMeta): raise ValueError("Status and score must be consistent.") return self._get_score(status, score) - def _get_score(self, status: Status, score: Union[float, dict]) -> float: + def _get_score(self, status: Status, score: Optional[Union[float, Dict[str, float]]]) -> Optional[float]: """ Extract a scalar benchmark score from the dataframe. Change the sign if we are maximizing. @@ -187,7 +189,7 @@ class Optimizer(metaclass=ABCMeta): ---------- status : Status Final status of the experiment (e.g., SUCCEEDED or FAILED). - score : Union[float, dict] + score : Union[float, Dict[str, float]] A scalar or a dict with the final benchmark results. None if the experiment was not successful. @@ -198,6 +200,7 @@ class Optimizer(metaclass=ABCMeta): """ if not status.is_succeeded: return None + assert score is not None if isinstance(score, dict): score = score[self._opt_target] return float(score) * self._opt_sign @@ -210,7 +213,7 @@ class Optimizer(metaclass=ABCMeta): return self._iter <= self._max_iter @abstractmethod - def get_best_observation(self) -> Tuple[float, TunableGroups]: + def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]: """ Get the best observation so far. diff --git a/mlos_bench/mlos_bench/optimizer/convert_configspace.py b/mlos_bench/mlos_bench/optimizer/convert_configspace.py index 0b921df7fb..c9801e99e3 100644 --- a/mlos_bench/mlos_bench/optimizer/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizer/convert_configspace.py @@ -8,6 +8,8 @@ Functions to convert TunableGroups to ConfigSpace for use with the mlos_core opt import logging +from typing import Optional + from ConfigSpace.hyperparameters import Hyperparameter from ConfigSpace import UniformIntegerHyperparameter from ConfigSpace import UniformFloatHyperparameter @@ -21,7 +23,7 @@ _LOG = logging.getLogger(__name__) def _tunable_to_hyperparameter( - tunable: Tunable, group_name: str = None, cost: int = 0) -> Hyperparameter: + tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> Hyperparameter: """ Convert a single Tunable to an equivalent ConfigSpace Hyperparameter object. diff --git a/mlos_bench/mlos_bench/optimizer/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizer/mlos_core_optimizer.py index d7997ff8f6..c44f01ad2a 100644 --- a/mlos_bench/mlos_bench/optimizer/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizer/mlos_core_optimizer.py @@ -7,7 +7,7 @@ A wrapper for mlos_core optimizers for mlos_bench. """ import logging -from typing import Optional, Tuple, List, Union +from typing import Optional, Sequence, Tuple, Union import pandas as pd @@ -47,17 +47,17 @@ class MlosCoreOptimizer(Optimizer): space_adapter_type=space_adapter_type, space_adapter_kwargs=space_adapter_config) - def bulk_register(self, configs: List[dict], scores: List[float], - status: Optional[List[Status]] = None) -> bool: + def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]], + status: Optional[Sequence[Status]] = None) -> bool: if not super().bulk_register(configs, scores, status): return False tunables_names = list(self._tunables.get_param_values().keys()) df_configs = pd.DataFrame(configs)[tunables_names] df_scores = pd.Series(scores, dtype=float) * self._opt_sign if status is not None: - df_status_succ = pd.Series(status) == Status.SUCCEEDED - df_configs = df_configs[df_status_succ] - df_scores = df_scores[df_status_succ] + df_status_ok = pd.Series(status) == Status.SUCCEEDED + df_configs = df_configs[df_status_ok] + df_scores = df_scores[df_status_ok] self._opt.register(df_configs, df_scores) if _LOG.isEnabledFor(logging.DEBUG): (score, _) = self.get_best_observation() @@ -70,7 +70,7 @@ class MlosCoreOptimizer(Optimizer): return self._tunables.copy().assign(df_config.loc[0].to_dict()) def register(self, tunables: TunableGroups, status: Status, - score: Union[float, dict] = None) -> float: + score: Optional[Union[float, dict]] = None) -> Optional[float]: score = super().register(tunables, status, score) # TODO: mlos_core currently does not support registration of failed trials: if status.is_succeeded: @@ -81,7 +81,7 @@ class MlosCoreOptimizer(Optimizer): self._iter += 1 return score - def get_best_observation(self) -> Tuple[float, TunableGroups]: + def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]: df_config = self._opt.get_best_observation() if len(df_config) == 0: return (None, None) diff --git a/mlos_bench/mlos_bench/optimizer/mock_optimizer.py b/mlos_bench/mlos_bench/optimizer/mock_optimizer.py index f979141544..2d605c80fd 100644 --- a/mlos_bench/mlos_bench/optimizer/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizer/mock_optimizer.py @@ -8,9 +8,11 @@ Mock optimizer for mlos_bench. import random import logging -from typing import Optional, Tuple, List, Union + +from typing import Callable, Dict, Optional, Sequence, Tuple, Union from mlos_bench.environment.status import Status +from mlos_bench.tunables.tunable import Tunable, TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.optimizer.base_optimizer import Optimizer @@ -26,16 +28,16 @@ class MockOptimizer(Optimizer): def __init__(self, tunables: TunableGroups, config: dict): super().__init__(tunables, config) rnd = random.Random(config.get("seed", 42)) - self._random = { + self._random: Dict[str, Callable[[Tunable], TunableValue]] = { "categorical": lambda tunable: rnd.choice(tunable.categorical_values), "float": lambda tunable: rnd.uniform(*tunable.range), "int": lambda tunable: rnd.randint(*tunable.range), } - self._best_config = None - self._best_score = None + self._best_config: Optional[TunableGroups] = None + self._best_score: Optional[float] = None - def bulk_register(self, configs: List[dict], scores: List[float], - status: Optional[List[Status]] = None) -> bool: + def bulk_register(self, configs: Sequence[dict], scores: Sequence[Optional[float]], + status: Optional[Sequence[Status]] = None) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: @@ -60,15 +62,18 @@ class MockOptimizer(Optimizer): return tunables def register(self, tunables: TunableGroups, status: Status, - score: Union[float, dict] = None) -> float: - score = super().register(tunables, status, score) - if status.is_succeeded and (self._best_score is None or score < self._best_score): - self._best_score = score + score: Optional[Union[float, dict]] = None) -> Optional[float]: + registered_score = super().register(tunables, status, score) + if status.is_succeeded and ( + self._best_score is None or (registered_score is not None and registered_score < self._best_score) + ): + self._best_score = registered_score self._best_config = tunables.copy() self._iter += 1 - return score + return registered_score - def get_best_observation(self) -> Tuple[float, TunableGroups]: + def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]: if self._best_score is None: return (None, None) + assert self._best_config is not None return (self._best_score * self._opt_sign, self._best_config) diff --git a/mlos_bench/mlos_bench/py.typed b/mlos_bench/mlos_bench/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mlos_bench/mlos_bench/run_bench.py b/mlos_bench/mlos_bench/run_bench.py index 48abe1b6d3..5be039f452 100755 --- a/mlos_bench/mlos_bench/run_bench.py +++ b/mlos_bench/mlos_bench/run_bench.py @@ -17,7 +17,7 @@ from mlos_bench.launcher import Launcher _LOG = logging.getLogger(__name__) -def _main(): +def _main() -> None: launcher = Launcher("mlos_bench run_bench") diff --git a/mlos_bench/mlos_bench/run_opt.py b/mlos_bench/mlos_bench/run_opt.py index 3ad7c5097f..70f877bc12 100755 --- a/mlos_bench/mlos_bench/run_opt.py +++ b/mlos_bench/mlos_bench/run_opt.py @@ -9,16 +9,20 @@ OS Autotune main optimization loop. See `--help` output for details. """ +from typing import Tuple, Union + import logging from mlos_bench.launcher import Launcher -from mlos_bench.optimizer import Optimizer -from mlos_bench.environment import Status, Environment +from mlos_bench.optimizer.base_optimizer import Optimizer +from mlos_bench.environment.base_environment import Environment +from mlos_bench.environment.status import Status +from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) -def _main(): +def _main() -> None: launcher = Launcher("mlos_bench run_opt") @@ -40,7 +44,7 @@ def _main(): _LOG.info("Final result: %s", result) -def _optimize(env: Environment, opt: Optimizer, no_teardown: bool): +def _optimize(env: Environment, opt: Optimizer, no_teardown: bool) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]: """ Main optimization loop. """ @@ -57,7 +61,11 @@ def _optimize(env: Environment, opt: Optimizer, no_teardown: bool): continue (status, output) = env.run() # Block and wait for the final result - value = output['score'] if status.is_succeeded else None + if status.is_succeeded: + assert output is not None + value = output['score'] + else: + value = None _LOG.info("Result: %s = %s :: %s", tunables, status, value) opt.register(tunables, status, value) diff --git a/mlos_bench/mlos_bench/service/__init__.py b/mlos_bench/mlos_bench/service/__init__.py index 367e8e56bc..14e687b4d1 100644 --- a/mlos_bench/mlos_bench/service/__init__.py +++ b/mlos_bench/mlos_bench/service/__init__.py @@ -8,14 +8,11 @@ Services for implementing Environments for mlos_bench. from mlos_bench.service.base_service import Service from mlos_bench.service.base_fileshare import FileShareService - from mlos_bench.service.local.local_exec import LocalExecService -from mlos_bench.service.config_persistence import ConfigPersistenceService __all__ = [ 'Service', - 'LocalExecService', 'FileShareService', - 'ConfigPersistenceService', + 'LocalExecService', ] diff --git a/mlos_bench/mlos_bench/service/base_fileshare.py b/mlos_bench/mlos_bench/service/base_fileshare.py index 2964c34db8..86f02c171d 100644 --- a/mlos_bench/mlos_bench/service/base_fileshare.py +++ b/mlos_bench/mlos_bench/service/base_fileshare.py @@ -11,11 +11,12 @@ import logging from abc import ABCMeta, abstractmethod from mlos_bench.service.base_service import Service +from mlos_bench.service.types.fileshare_type import SupportsFileShareOps _LOG = logging.getLogger(__name__) -class FileShareService(Service, metaclass=ABCMeta): +class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta): """ An abstract base of all file shares. """ @@ -41,7 +42,7 @@ class FileShareService(Service, metaclass=ABCMeta): ]) @abstractmethod - def download(self, remote_path: str, local_path: str, recursive: bool = True): + def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None: """ Downloads contents from a remote share path to a local path. @@ -56,11 +57,11 @@ class FileShareService(Service, metaclass=ABCMeta): If False, ignore the subdirectories; if True (the default), download the entire directory tree. """ - _LOG.info("Download from File Share %srecursively: %s -> %s", - "" if recursive else "non-", remote_path, local_path) + _LOG.info("Download from File Share %s recursively: %s -> %s", + "" if recursive else "non", remote_path, local_path) @abstractmethod - def upload(self, local_path: str, remote_path: str, recursive: bool = True): + def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None: """ Uploads contents from a local path to remote share path. @@ -74,5 +75,5 @@ class FileShareService(Service, metaclass=ABCMeta): If False, ignore the subdirectories; if True (the default), upload the entire directory tree. """ - _LOG.info("Upload to File Share %srecursively: %s -> %s", - "" if recursive else "non-", local_path, remote_path) + _LOG.info("Upload to File Share %s recursively: %s -> %s", + "" if recursive else "non", local_path, remote_path) diff --git a/mlos_bench/mlos_bench/service/base_service.py b/mlos_bench/mlos_bench/service/base_service.py index 87580fe36f..e8e795f417 100644 --- a/mlos_bench/mlos_bench/service/base_service.py +++ b/mlos_bench/mlos_bench/service/base_service.py @@ -9,8 +9,9 @@ Base class for the service mix-ins. import json import logging -from typing import Callable, Dict +from typing import Callable, Dict, List, Optional, Union +from mlos_bench.service.types.config_loader_type import SupportsConfigLoading from mlos_bench.util import instantiate_from_config _LOG = logging.getLogger(__name__) @@ -22,7 +23,7 @@ class Service: """ @classmethod - def new(cls, class_name: str, config: dict, parent): + def new(cls, class_name: str, config: dict, parent: Optional["Service"]) -> "Service": """ Factory method for a new service with a given config. @@ -46,7 +47,7 @@ class Service: """ return instantiate_from_config(cls, class_name, config, parent) - def __init__(self, config: dict = None, parent=None): + def __init__(self, config: Optional[dict] = None, parent: Optional["Service"] = None): """ Create a new service with a given config. @@ -61,11 +62,15 @@ class Service: """ self.config = config or {} self._parent = parent - self._services = {} + self._services: Dict[str, Callable] = {} if parent: self.register(parent.export()) + self._config_loader_service: SupportsConfigLoading + if parent and isinstance(parent, SupportsConfigLoading): + self._config_loader_service = parent + if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Service: %s Config:\n%s", self.__class__.__name__, json.dumps(self.config, indent=2)) @@ -73,7 +78,7 @@ class Service: self.__class__.__name__, [] if parent is None else list(parent._services.keys())) - def register(self, services): + def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None: """ Register new mix-in services. diff --git a/mlos_bench/mlos_bench/service/config_persistence.py b/mlos_bench/mlos_bench/service/config_persistence.py index b3e8613c59..6616db08d4 100644 --- a/mlos_bench/mlos_bench/service/config_persistence.py +++ b/mlos_bench/mlos_bench/service/config_persistence.py @@ -12,24 +12,25 @@ import os import json # For logging only import logging -from typing import Optional, List +from typing import List, Iterable, Optional, Union import json5 # To read configs with comments and other JSON5 syntax features from mlos_bench.util import prepare_class_load from mlos_bench.environment.base_environment import Environment from mlos_bench.service.base_service import Service +from mlos_bench.service.types.config_loader_type import SupportsConfigLoading from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) -class ConfigPersistenceService(Service): +class ConfigPersistenceService(Service, SupportsConfigLoading): """ Collection of methods to deserialize the Environment, Service, and TunableGroups objects. """ - def __init__(self, config: dict = None, parent: Service = None): + def __init__(self, config: Optional[dict] = None, parent: Optional[Service] = None): """ Create a new instance of config persistence service. @@ -58,7 +59,7 @@ class ConfigPersistenceService(Service): ]) def resolve_path(self, file_path: str, - extra_paths: Optional[List[str]] = None) -> str: + extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -67,7 +68,7 @@ class ConfigPersistenceService(Service): ---------- file_path : str Path to the input config file. - extra_paths : List[str] + extra_paths : Iterable[str] Additional directories to prepend to the list of search paths. Returns @@ -86,7 +87,7 @@ class ConfigPersistenceService(Service): _LOG.debug("Path not resolved: %s", file_path) return file_path - def load_config(self, json_file_name: str) -> dict: + def load_config(self, json_file_name: str) -> Union[dict, List[dict]]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -99,18 +100,18 @@ class ConfigPersistenceService(Service): Returns ------- - config : dict + config : Union[dict, List[dict]] Free-format dictionary that contains the configuration. """ json_file_name = self.resolve_path(json_file_name) _LOG.info("Load config: %s", json_file_name) with open(json_file_name, mode='r', encoding='utf-8') as fh_json: - return json5.load(fh_json) + return json5.load(fh_json) # type: ignore[no-any-return] def build_environment(self, config: dict, - global_config: dict = None, - tunables: TunableGroups = None, - service: Service = None) -> Environment: + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None) -> Environment: """ Factory method for a new environment with a given config. @@ -146,15 +147,17 @@ class ConfigPersistenceService(Service): tunables = self.load_tunables(env_tunables_path, tunables) _LOG.debug("Creating env: %s :: %s", env_name, env_class) - env = Environment.new(env_name, env_class, env_config, global_config, tunables, service) + env = Environment.new(env_name=env_name, class_name=env_class, + config=env_config, global_config=global_config, + tunables=tunables, service=service) _LOG.info("Created env: %s :: %s", env_name, env) return env @classmethod def _build_standalone_service(cls, config: dict, - global_config: dict = None, - parent: Service = None) -> Service: + global_config: Optional[dict] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -180,9 +183,9 @@ class ConfigPersistenceService(Service): return service @classmethod - def _build_composite_service(cls, config_list: List[dict], - global_config: dict = None, - parent: Service = None) -> Service: + def _build_composite_service(cls, config_list: Iterable[dict], + global_config: Optional[dict] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -218,8 +221,8 @@ class ConfigPersistenceService(Service): return service @classmethod - def build_service(cls, config: List[dict], global_config: dict = None, - parent: Service = None) -> Service: + def build_service(cls, config: Union[dict, List[dict]], global_config: Optional[dict] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -252,7 +255,7 @@ class ConfigPersistenceService(Service): return cls._build_composite_service(config, global_config, parent) @staticmethod - def build_tunables(config: dict, parent: TunableGroups = None) -> TunableGroups: + def build_tunables(config: dict, parent: Optional[TunableGroups] = None) -> TunableGroups: """ Create a new collection of tunable parameters. @@ -280,8 +283,8 @@ class ConfigPersistenceService(Service): groups.update(TunableGroups(config)) return groups - def load_environment(self, json_file_name: str, global_config: dict = None, - tunables: TunableGroups = None, service: Service = None) -> Environment: + def load_environment(self, json_file_name: str, global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, service: Optional[Service] = None) -> Environment: """ Load and build new environment from the config file. @@ -302,11 +305,12 @@ class ConfigPersistenceService(Service): A new benchmarking environment. """ config = self.load_config(json_file_name) + assert isinstance(config, dict) return self.build_environment(config, global_config, tunables, service) def load_environment_list( - self, json_file_name: str, global_config: dict = None, - tunables: TunableGroups = None, service: Service = None) -> List[Environment]: + self, json_file_name: str, global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, service: Optional[Service] = None) -> List[Environment]: """ Load and build a list of environments from the config file. @@ -335,8 +339,8 @@ class ConfigPersistenceService(Service): for config in config_list ] - def load_services(self, json_file_names: List[str], - global_config: dict = None, parent: Service = None) -> Service: + def load_services(self, json_file_names: Iterable[str], + global_config: Optional[dict] = None, parent: Optional[Service] = None) -> Service: """ Read the configuration files and bundle all service methods from those configs into a single Service object. @@ -363,8 +367,8 @@ class ConfigPersistenceService(Service): ConfigPersistenceService.build_service(config, global_config, service).export()) return service - def load_tunables(self, json_file_names: List[str], - parent: TunableGroups = None) -> TunableGroups: + def load_tunables(self, json_file_names: Iterable[str], + parent: Optional[TunableGroups] = None) -> TunableGroups: """ Load a collection of tunable parameters from JSON files. @@ -386,5 +390,6 @@ class ConfigPersistenceService(Service): groups.update(parent) for fname in json_file_names: config = self.load_config(fname) + assert isinstance(config, dict) groups.update(TunableGroups(config)) return groups diff --git a/mlos_bench/mlos_bench/service/local/local_exec.py b/mlos_bench/mlos_bench/service/local/local_exec.py index e2772bcdb4..eb58b89f94 100644 --- a/mlos_bench/mlos_bench/service/local/local_exec.py +++ b/mlos_bench/mlos_bench/service/local/local_exec.py @@ -15,21 +15,25 @@ import shlex import subprocess import logging -from typing import Optional, Tuple, List, Dict +from typing import Dict, Iterable, Mapping, Optional, Tuple, Union, TYPE_CHECKING from mlos_bench.service.base_service import Service +from mlos_bench.service.types.local_exec_type import SupportsLocalExec + +if TYPE_CHECKING: + from mlos_bench.tunables.tunable import TunableValue _LOG = logging.getLogger(__name__) -class LocalExecService(Service): +class LocalExecService(Service, SupportsLocalExec): """ Collection of methods to run scripts and commands in an external process on the node acting as the scheduler. Can be useful for data processing due to reduced dependency management complications vs the target environment. """ - def __init__(self, config: dict = None, parent: Service = None): + def __init__(self, config: Optional[dict] = None, parent: Optional[Service] = None): """ Create a new instance of a service to run scripts locally. @@ -45,7 +49,7 @@ class LocalExecService(Service): self._temp_dir = self.config.get("temp_dir") self.register([self.temp_dir_context, self.local_exec]) - def temp_dir_context(self, path: Optional[str] = None): + def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: """ Create a temp directory or use the provided path. @@ -63,8 +67,8 @@ class LocalExecService(Service): return tempfile.TemporaryDirectory() return contextlib.nullcontext(path or self._temp_dir) - def local_exec(self, script_lines: List[str], - env: Optional[Dict[str, str]] = None, + def local_exec(self, script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, cwd: Optional[str] = None, return_on_error: bool = False) -> Tuple[int, str, str]: """ @@ -72,10 +76,10 @@ class LocalExecService(Service): Parameters ---------- - script_lines : List[str] + script_lines : Iterable[str] Lines of the script to run locally. Treat every line as a separate command to run. - env : Dict[str, str] + env : Mapping[str, Union[int, float, str]] Environment variables (optional). cwd : str Work directory to run the script at. @@ -110,7 +114,8 @@ class LocalExecService(Service): return (return_code, stdout, stderr) def _local_exec_script(self, script_line: str, - env: Dict[str, str], cwd: str) -> Tuple[int, str, str]: + env_params: Optional[Mapping[str, "TunableValue"]], + cwd: str) -> Tuple[int, str, str]: """ Execute the script from `script_path` in a local process. @@ -118,9 +123,9 @@ class LocalExecService(Service): ---------- script_line : str Line of the script to tun in the local process. - args : List[str] + args : Iterable[str] Command line arguments for the script. - env : Dict[str, str] + env_params : Mapping[str, Union[int, float, str]] Environment variables. cwd : str Work directory to run the script at. @@ -131,7 +136,7 @@ class LocalExecService(Service): A 3-tuple of return code, stdout, and stderr of the script process. """ cmd = shlex.split(script_line) - script_path = self._parent.resolve_path(cmd[0]) + script_path = self._config_loader_service.resolve_path(cmd[0]) if os.path.exists(script_path): script_path = os.path.abspath(script_path) @@ -139,14 +144,16 @@ class LocalExecService(Service): if script_path.strip().lower().endswith(".py"): cmd = [sys.executable] + cmd - if env: - env = {key: str(val) for (key, val) in env.items()} - if sys.platform == 'win32': - # A hack to run Python on Windows with env variables set: - env_copy = os.environ.copy() - env_copy["PYTHONPATH"] = "" - env_copy.update(env) - env = env_copy + env: Dict[str, str] = {} + if env_params: + env = {key: str(val) for (key, val) in env_params.items()} + + if sys.platform == 'win32': + # A hack to run Python on Windows with env variables set: + env_copy = os.environ.copy() + env_copy["PYTHONPATH"] = "" + env_copy.update(env) + env = env_copy _LOG.info("Run: %s", cmd) diff --git a/mlos_bench/mlos_bench/service/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/service/remote/azure/azure_fileshare.py index f48e4d1b32..2012723519 100644 --- a/mlos_bench/mlos_bench/service/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/service/remote/azure/azure_fileshare.py @@ -13,7 +13,8 @@ from typing import Set from azure.storage.fileshare import ShareClient -from mlos_bench.service import Service, FileShareService +from mlos_bench.service.base_service import Service +from mlos_bench.service.base_fileshare import FileShareService from mlos_bench.util import check_required_params _LOG = logging.getLogger(__name__) @@ -57,7 +58,7 @@ class AzureFileShareService(FileShareService): credential=config["storageAccountKey"], ) - def download(self, remote_path: str, local_path: str, recursive: bool = True): + def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None: super().download(remote_path, local_path, recursive) dir_client = self._share_client.get_directory_client(remote_path) if dir_client.exists(): @@ -75,13 +76,13 @@ class AzureFileShareService(FileShareService): file_client = self._share_client.get_file_client(remote_path) data = file_client.download_file() with open(local_path, "wb") as output_file: - data.readinto(output_file) + data.readinto(output_file) # type: ignore[no-untyped-call] - def upload(self, local_path: str, remote_path: str, recursive: bool = True): + def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None: super().upload(local_path, remote_path, recursive) self._upload(local_path, remote_path, recursive, set()) - def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]): + def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None: """ Upload contents from a local path to an Azure file share. This method is called from `.upload()` above. We need it to avoid exposing @@ -124,7 +125,7 @@ class AzureFileShareService(FileShareService): with open(local_path, "rb") as file_data: file_client.upload_file(file_data) - def _remote_makedirs(self, remote_path: str): + def _remote_makedirs(self, remote_path: str) -> None: """ Create remote directories for the entire path. Succeeds even some or all directories along the path already exist. diff --git a/mlos_bench/mlos_bench/service/remote/azure/azure_services.py b/mlos_bench/mlos_bench/service/remote/azure/azure_services.py index ee594a7a75..758d238238 100644 --- a/mlos_bench/mlos_bench/service/remote/azure/azure_services.py +++ b/mlos_bench/mlos_bench/service/remote/azure/azure_services.py @@ -10,18 +10,20 @@ import json import time import logging -from typing import Any, Tuple, List, Dict, Callable +from typing import Callable, Iterable, Tuple import requests -from mlos_bench.environment import Status -from mlos_bench.service import Service +from mlos_bench.environment.status import Status +from mlos_bench.service.base_service import Service +from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec +from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps from mlos_bench.util import check_required_params _LOG = logging.getLogger(__name__) -class AzureVMService(Service): # pylint: disable=too-many-instance-attributes +class AzureVMService(Service, SupportsVMOps, SupportsRemoteExec): # pylint: disable=too-many-instance-attributes """ Helper methods to manage VMs on Azure. """ @@ -134,7 +136,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes self.vm_start, self.vm_stop, self.vm_deprovision, - self.vm_reboot, + self.vm_restart, self.remote_exec, self.get_remote_exec_results ]) @@ -144,7 +146,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes self._poll_timeout = float(config.get("pollTimeout", AzureVMService._POLL_TIMEOUT)) self._request_timeout = float(config.get("requestTimeout", AzureVMService._REQUEST_TIMEOUT)) - self._deploy_template = self._parent.load_config(config['deployTemplatePath']) + self._deploy_template = self._config_loader_service.load_config(config['deployTemplatePath']) self._url_deploy = AzureVMService._URL_DEPLOY.format( subscription=config["subscription"], @@ -188,7 +190,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes } @staticmethod - def _extract_arm_parameters(json_data): + def _extract_arm_parameters(json_data: dict) -> dict: """ Extract parameters from the ARM Template REST response JSON. @@ -203,7 +205,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes if val.get("value") is not None } - def _azure_vm_post_helper(self, url: str) -> Tuple[Status, Dict]: + def _azure_vm_post_helper(self, url: str) -> Tuple[Status, dict]: """ General pattern for performing an action on an Azure VM via its REST API. @@ -237,7 +239,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes elif "Location" in response.headers: result["asyncResultsUrl"] = response.headers.get("Location") if "Retry-After" in response.headers: - result["pollInterval"] = float(response.headers["Retry-After"]) + result["pollInterval"] = str(float(response.headers["Retry-After"])) return (Status.PENDING, result) else: @@ -245,7 +247,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes # _LOG.error("Bad Request:\n%s", response.request.body) return (Status.FAILED, {}) - def check_vm_operation_status(self, params: Dict) -> Tuple[Status, Dict]: + def check_vm_operation_status(self, params: dict) -> Tuple[Status, dict]: """ Checks the status of a pending operation on an Azure VM. @@ -285,7 +287,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes _LOG.error("Response: %s :: %s", response, response.text) return Status.FAILED, {} - def wait_vm_deployment(self, is_setup: bool, params: Dict[str, Any]) -> Tuple[Status, Dict]: + def wait_vm_deployment(self, is_setup: bool, params: dict) -> Tuple[Status, dict]: """ Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. @@ -308,7 +310,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes "provision" if is_setup else "deprovision") return self._wait_while(self._check_deployment, Status.PENDING, params) - def wait_vm_operation(self, params: Dict[str, Any]) -> Tuple[Status, Dict]: + def wait_vm_operation(self, params: dict) -> Tuple[Status, dict]: """ Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. @@ -330,8 +332,8 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes _LOG.info("Wait for operation on VM %s", self.config["vmName"]) return self._wait_while(self.check_vm_operation_status, Status.RUNNING, params) - def _wait_while(self, func: Callable[[Dict[str, Any]], Tuple[Status, Dict]], - loop_status: Status, params: Dict[str, Any]) -> Tuple[Status, Dict]: + def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], + loop_status: Status, params: dict) -> Tuple[Status, dict]: """ Invoke `func` periodically while the status is equal to `loop_status`. Return TIMED_OUT when timing out. @@ -377,7 +379,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes _LOG.warning("Request timed out: %s", params) return (Status.TIMED_OUT, {}) - def _check_deployment(self, _params: Dict) -> Tuple[Status, Dict]: + def _check_deployment(self, _params: dict) -> Tuple[Status, dict]: """ Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. @@ -409,7 +411,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes _LOG.error("Response: %s :: %s", response, response.text) return (Status.FAILED, {}) - def vm_provision(self, params: Dict) -> Tuple[Status, Dict]: + def vm_provision(self, params: dict) -> Tuple[Status, dict]: """ Check if Azure VM is ready. Deploy a new VM, if necessary. @@ -464,7 +466,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes # _LOG.error("Bad Request:\n%s", response.request.body) return (Status.FAILED, {}) - def vm_start(self, params: Dict) -> Tuple[Status, Dict]: + def vm_start(self, params: dict) -> Tuple[Status, dict]: """ Start the VM on Azure. @@ -482,7 +484,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes _LOG.info("Start VM: %s :: %s", self.config["vmName"], params) return self._azure_vm_post_helper(self._url_start) - def vm_stop(self) -> Tuple[Status, Dict]: + def vm_stop(self) -> Tuple[Status, dict]: """ Stops the VM on Azure by initiating a graceful shutdown. @@ -495,7 +497,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes _LOG.info("Stop VM: %s", self.config["vmName"]) return self._azure_vm_post_helper(self._url_stop) - def vm_deprovision(self) -> Tuple[Status, Dict]: + def vm_deprovision(self) -> Tuple[Status, dict]: """ Deallocates the VM on Azure by shutting it down then releasing the compute resources. @@ -508,7 +510,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes _LOG.info("Deprovision VM: %s", self.config["vmName"]) return self._azure_vm_post_helper(self._url_stop) - def vm_reboot(self) -> Tuple[Status, Dict]: + def vm_restart(self) -> Tuple[Status, dict]: """ Reboot the VM on Azure by initiating a graceful shutdown. @@ -521,13 +523,13 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes _LOG.info("Reboot VM: %s", self.config["vmName"]) return self._azure_vm_post_helper(self._url_reboot) - def remote_exec(self, script: List[str], params: Dict[str, Any]) -> Tuple[Status, Dict]: + def remote_exec(self, script: Iterable[str], params: dict) -> Tuple[Status, dict]: """ Run a command on Azure VM. Parameters ---------- - script : List[str] + script : Iterable[str] A list of lines to execute as a script on a remote VM. params : dict Flat dictionary of (key, value) pairs of parameters. @@ -546,7 +548,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes json_req = { "commandId": "RunShellScript", - "script": script, + "script": list(script), "parameters": [{"name": key, "value": val} for (key, val) in params.items()] } @@ -576,7 +578,7 @@ class AzureVMService(Service): # pylint: disable=too-many-instance-attributes # _LOG.error("Bad Request:\n%s", response.request.body) return (Status.FAILED, {}) - def get_remote_exec_results(self, params: Dict) -> Tuple[Status, Dict]: + def get_remote_exec_results(self, params: dict) -> Tuple[Status, dict]: """ Get the results of the asynchronously running command. diff --git a/mlos_bench/mlos_bench/service/types/README.md b/mlos_bench/mlos_bench/service/types/README.md new file mode 100644 index 0000000000..d5e992ef77 --- /dev/null +++ b/mlos_bench/mlos_bench/service/types/README.md @@ -0,0 +1,7 @@ +# Service Type Interfaces + +Service loading in `mlos_bench` uses a [mix-in](https://en.wikipedia.org/wiki/Mixin#In_Python) approach to combine the functionality of multiple classes specified at runtime through config files into a single class. + +This can make type checking in the `Environments` that use those `Services` a little tricky, both for developers and checking tools. + +To address this we define `@runtime_checkable` decorated [`Protocols`](https://peps.python.org/pep-0544/) ("interfaces" in other languages) to declare the expected behavior of the `Services` that are loaded at runtime in the `Environments` that use them. diff --git a/mlos_bench/mlos_bench/service/types/__init__.py b/mlos_bench/mlos_bench/service/types/__init__.py new file mode 100644 index 0000000000..0006a7d2d4 --- /dev/null +++ b/mlos_bench/mlos_bench/service/types/__init__.py @@ -0,0 +1,22 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Service types for implementing declaring Service behavior for Environments to use in mlos_bench. +""" + +from mlos_bench.service.types.config_loader_type import SupportsConfigLoading +from mlos_bench.service.types.fileshare_type import SupportsFileShareOps +from mlos_bench.service.types.vm_provisioner_type import SupportsVMOps +from mlos_bench.service.types.local_exec_type import SupportsLocalExec +from mlos_bench.service.types.remote_exec_type import SupportsRemoteExec + + +__all__ = [ + 'SupportsConfigLoading', + 'SupportsFileShareOps', + 'SupportsVMOps', + 'SupportsLocalExec', + 'SupportsRemoteExec', +] diff --git a/mlos_bench/mlos_bench/service/types/config_loader_type.py b/mlos_bench/mlos_bench/service/types/config_loader_type.py new file mode 100644 index 0000000000..8fa4d81c58 --- /dev/null +++ b/mlos_bench/mlos_bench/service/types/config_loader_type.py @@ -0,0 +1,111 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Protocol interface for helper functions to lookup and load configs. +""" + +from typing import List, Iterable, Optional, Union, Protocol, runtime_checkable, TYPE_CHECKING + + +# Avoid's circular import issues. +if TYPE_CHECKING: + from mlos_bench.tunables.tunable_groups import TunableGroups + from mlos_bench.service.base_service import Service + from mlos_bench.environment.base_environment import Environment + + +@runtime_checkable +class SupportsConfigLoading(Protocol): + """ + Protocol interface for helper functions to lookup and load configs. + """ + + def resolve_path(self, file_path: str, + extra_paths: Optional[Iterable[str]] = None) -> str: + """ + Prepend the suitable `_config_path` to `path` if the latter is not absolute. + If `_config_path` is `None` or `path` is absolute, return `path` as is. + + Parameters + ---------- + file_path : str + Path to the input config file. + extra_paths : Iterable[str] + Additional directories to prepend to the list of search paths. + + Returns + ------- + path : str + An actual path to the config or script. + """ + + def load_config(self, json_file_name: str) -> Union[dict, List[dict]]: + """ + Load JSON config file. Search for a file relative to `_config_path` + if the input path is not absolute. + This method is exported to be used as a service. + + Parameters + ---------- + json_file_name : str + Path to the input config file. + + Returns + ------- + config : Union[dict, List[dict]] + Free-format dictionary that contains the configuration. + """ + + def build_environment(self, config: dict, + global_config: Optional[dict] = None, + tunables: Optional["TunableGroups"] = None, + service: Optional["Service"] = None) -> "Environment": + """ + Factory method for a new environment with a given config. + + Parameters + ---------- + config : dict + A dictionary with three mandatory fields: + "name": Human-readable string describing the environment; + "class": FQN of a Python class to instantiate; + "config": Free-format dictionary to pass to the constructor. + global_config : dict + Global parameters to add to the environment config. + tunables : TunableGroups + A collection of groups of tunable parameters for all environments. + service: Service + An optional service object (e.g., providing methods to + deploy or reboot a VM, etc.). + + Returns + ------- + env : Environment + An instance of the `Environment` class initialized with `config`. + """ + + def load_environment_list( + self, json_file_name: str, global_config: Optional[dict] = None, + tunables: Optional["TunableGroups"] = None, service: Optional["Service"] = None) -> List["Environment"]: + """ + Load and build a list of environments from the config file. + + Parameters + ---------- + json_file_name : str + The environment JSON configuration file. + Can contain either one environment or a list of environments. + global_config : dict + Global parameters to add to the environment config. + tunables : TunableGroups + An optional collection of tunables to add to the environment. + service : Service + An optional reference of the parent service to mix in. + + Returns + ------- + env : List[Environment] + A list of new benchmarking environments. + """ diff --git a/mlos_bench/mlos_bench/service/types/fileshare_type.py b/mlos_bench/mlos_bench/service/types/fileshare_type.py new file mode 100644 index 0000000000..1a5aa9c00f --- /dev/null +++ b/mlos_bench/mlos_bench/service/types/fileshare_type.py @@ -0,0 +1,47 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Protocol interface for file share operations. +""" + +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class SupportsFileShareOps(Protocol): + """ + Protocol interface for file share operations. + """ + + def download(self, remote_path: str, local_path: str, recursive: bool = True) -> None: + """ + Downloads contents from a remote share path to a local path. + + Parameters + ---------- + remote_path : str + Path to download from the remote file share, a file if recursive=False + or a directory if recursive=True. + local_path : str + Path to store the downloaded content to. + recursive : bool + If False, ignore the subdirectories; + if True (the default), download the entire directory tree. + """ + + def upload(self, local_path: str, remote_path: str, recursive: bool = True) -> None: + """ + Uploads contents from a local path to remote share path. + + Parameters + ---------- + local_path : str + Path to the local directory to upload contents from. + remote_path : str + Path in the remote file share to store the uploaded content to. + recursive : bool + If False, ignore the subdirectories; + if True (the default), upload the entire directory tree. + """ diff --git a/mlos_bench/mlos_bench/service/types/local_exec_type.py b/mlos_bench/mlos_bench/service/types/local_exec_type.py new file mode 100644 index 0000000000..5dbd47d52d --- /dev/null +++ b/mlos_bench/mlos_bench/service/types/local_exec_type.py @@ -0,0 +1,68 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Protocol interface for Service types that provide helper functions to run +scripts and commands locally on the scheduler side. +""" + +from typing import Iterable, Mapping, Optional, Tuple, Union, Protocol, runtime_checkable + +import tempfile +import contextlib + +from mlos_bench.tunables.tunable import TunableValue + + +@runtime_checkable +class SupportsLocalExec(Protocol): + """ + Protocol interface for a collection of methods to run scripts and commands + in an external process on the node acting as the scheduler. Can be useful + for data processing due to reduced dependency management complications vs + the target environment. + Used in LocalEnv and provided by LocalExecService. + """ + + def local_exec(self, script_lines: Iterable[str], + env: Optional[Mapping[str, TunableValue]] = None, + cwd: Optional[str] = None, + return_on_error: bool = False) -> Tuple[int, str, str]: + """ + Execute the script lines from `script_lines` in a local process. + + Parameters + ---------- + script_lines : Iterable[str] + Lines of the script to run locally. + Treat every line as a separate command to run. + env : Mapping[str, Union[int, float, str]] + Environment variables (optional). + cwd : str + Work directory to run the script at. + If omitted, use `temp_dir` or create a temporary dir. + return_on_error : bool + If True, stop running script lines on first non-zero return code. + The default is False. + + Returns + ------- + (return_code, stdout, stderr) : (int, str, str) + A 3-tuple of return code, stdout, and stderr of the script process. + """ + + def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: + """ + Create a temp directory or use the provided path. + + Parameters + ---------- + path : str + A path to the temporary directory. Create a new one if None. + + Returns + ------- + temp_dir_context : TemporaryDirectory + Temporary directory context to use in the `with` clause. + """ diff --git a/mlos_bench/mlos_bench/service/types/remote_exec_type.py b/mlos_bench/mlos_bench/service/types/remote_exec_type.py new file mode 100644 index 0000000000..279f2e6d91 --- /dev/null +++ b/mlos_bench/mlos_bench/service/types/remote_exec_type.py @@ -0,0 +1,59 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Protocol interface for Service types that provide helper functions to run +scripts on a remote host OS. +""" + +from typing import Iterable, Tuple, Protocol, runtime_checkable + + +from mlos_bench.environment.status import Status + + +@runtime_checkable +class SupportsRemoteExec(Protocol): + """ + Protocol interface for Service types that provide helper functions to run + scripts on a remote host OS. + """ + + def remote_exec(self, script: Iterable[str], params: dict) -> Tuple[Status, dict]: + """ + Run a command on remote host OS. + + Parameters + ---------- + script : Iterable[str] + A list of lines to execute as a script on a remote VM. + params : dict + Flat dictionary of (key, value) pairs of parameters. + They usually come from `const_args` and `tunable_params` + properties of the Environment. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + + def get_remote_exec_results(self, params: dict) -> Tuple[Status, dict]: + """ + Get the results of the asynchronously running command. + + Parameters + ---------- + params : dict + Flat dictionary of (key, value) pairs of tunable parameters. + Must have the "asyncResultsUrl" key to get the results. + If the key is not present, return Status.PENDING. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} + """ diff --git a/mlos_bench/mlos_bench/service/types/vm_provisioner_type.py b/mlos_bench/mlos_bench/service/types/vm_provisioner_type.py new file mode 100644 index 0000000000..20958f1340 --- /dev/null +++ b/mlos_bench/mlos_bench/service/types/vm_provisioner_type.py @@ -0,0 +1,125 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Protocol interface for VM provisioning operations. +""" + +from typing import Tuple, Protocol, runtime_checkable + +from mlos_bench.environment.status import Status + + +@runtime_checkable +class SupportsVMOps(Protocol): + """ + Protocol interface for VM provisioning operations. + """ + + def vm_provision(self, params: dict) -> Tuple[Status, dict]: + """ + Check if VM is ready. Deploy a new VM, if necessary. + + Parameters + ---------- + params : dict + Flat dictionary of (key, value) pairs of tunable parameters. + VMEnv tunables are variable parameters that, together with the + VMEnv configuration, are sufficient to provision a VM. + + Returns + ------- + result : (Status, dict={}) + A pair of Status and result. The result is always {}. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + + def wait_vm_deployment(self, is_setup: bool, params: dict) -> Tuple[Status, dict]: + """ + Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. + Return TIMED_OUT when timing out. + + Parameters + ---------- + is_setup : bool + If True, wait for VM being deployed; otherwise, wait for successful deprovisioning. + params : dict + Flat dictionary of (key, value) pairs of tunable parameters. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} + Result is info on the operation runtime if SUCCEEDED, otherwise {}. + """ + + def vm_start(self, params: dict) -> Tuple[Status, dict]: + """ + Start a VM. + + Parameters + ---------- + params : dict + Flat dictionary of (key, value) pairs of tunable parameters. + + Returns + ------- + result : (Status, dict={}) + A pair of Status and result. The result is always {}. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + + def vm_stop(self) -> Tuple[Status, dict]: + """ + Stops the VM by initiating a graceful shutdown. + + Returns + ------- + result : (Status, dict={}) + A pair of Status and result. The result is always {}. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + + def vm_restart(self) -> Tuple[Status, dict]: + """ + Restarts the VM by initiating a graceful shutdown. + + Returns + ------- + result : (Status, dict={}) + A pair of Status and result. The result is always {}. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + + def vm_deprovision(self) -> Tuple[Status, dict]: + """ + Deallocates the VM by shutting it down then releasing the compute resources. + + Returns + ------- + result : (Status, dict={}) + A pair of Status and result. The result is always {}. + Status is one of {PENDING, SUCCEEDED, FAILED} + """ + + def wait_vm_operation(self, params: dict) -> Tuple[Status, dict]: + """ + Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED. + Return TIMED_OUT when timing out. + + Parameters + ---------- + params: dict + Flat dictionary of (key, value) pairs of tunable parameters. + Must have the "asyncResultsUrl" key to get the results. + If the key is not present, return Status.PENDING. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} + Result is info on the operation runtime if SUCCEEDED, otherwise {}. + """ diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py new file mode 100644 index 0000000000..dc483dc4d6 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench. +Used to make mypy happy about multiple conftest.py modules. +""" diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index ad7b9a83dc..f99a56648a 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -72,7 +72,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv: Test fixture for MockEnv. """ return MockEnv( - "Test Env", + name="Test Env", config={ "tunable_groups": ["provision", "boot", "kernel"], "seed": 13, @@ -88,7 +88,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv: Test fixture for MockEnv. """ return MockEnv( - "Test Env No Noise", + name="Test Env No Noise", config={ "tunable_groups": ["provision", "boot", "kernel"], "range": [60, 120] diff --git a/mlos_bench/mlos_bench/tests/environment/__init__.py b/mlos_bench/mlos_bench/tests/environment/__init__.py new file mode 100644 index 0000000000..5460a816cb --- /dev/null +++ b/mlos_bench/mlos_bench/tests/environment/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.environment. +Used to make mypy happy about multiple conftest.py modules. +""" diff --git a/mlos_bench/mlos_bench/tests/environment/mock_env_test.py b/mlos_bench/mlos_bench/tests/environment/mock_env_test.py index 97b98a138d..4cf26068ee 100644 --- a/mlos_bench/mlos_bench/tests/environment/mock_env_test.py +++ b/mlos_bench/mlos_bench/tests/environment/mock_env_test.py @@ -8,25 +8,27 @@ Unit tests for mock benchmark environment. import pytest -from mlos_bench.environment import MockEnv +from mlos_bench.environment.mock_env import MockEnv from mlos_bench.tunables.tunable_groups import TunableGroups -def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups): +def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> None: """ Check the default values of the mock environment. """ assert mock_env.setup(tunable_groups) (status, data) = mock_env.run() assert status.is_succeeded + assert data is not None assert data["score"] == pytest.approx(78.45, 0.01) # Second time, results should differ because of the noise. (status, data) = mock_env.run() assert status.is_succeeded + assert data is not None assert data["score"] == pytest.approx(98.21, 0.01) -def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups): +def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups) -> None: """ Check the default values of the mock environment. """ @@ -35,6 +37,7 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr # Noise-free results should be the same every time. (status, data) = mock_env_no_noise.run() assert status.is_succeeded + assert data is not None assert data["score"] == pytest.approx(80.11, 0.01) @@ -51,7 +54,7 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr }, 79.1), ]) def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, - tunable_values: dict, expected_score: float): + tunable_values: dict, expected_score: float) -> None: """ Check the benchmark values of the mock environment after the assignment. """ @@ -59,6 +62,7 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, assert mock_env.setup(tunable_groups) (status, data) = mock_env.run() assert status.is_succeeded + assert data is not None assert data["score"] == pytest.approx(expected_score, 0.01) @@ -76,7 +80,7 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, ]) def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups, - tunable_values: dict, expected_score: float): + tunable_values: dict, expected_score: float) -> None: """ Check the benchmark values of the noiseless mock environment after the assignment. """ @@ -86,4 +90,5 @@ def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv, # Noise-free environment should produce the same results every time. (status, data) = mock_env_no_noise.run() assert status.is_succeeded + assert data is not None assert data["score"] == pytest.approx(expected_score, 0.01) diff --git a/mlos_bench/mlos_bench/tests/optimizer/__init__.py b/mlos_bench/mlos_bench/tests/optimizer/__init__.py new file mode 100644 index 0000000000..b36043747b --- /dev/null +++ b/mlos_bench/mlos_bench/tests/optimizer/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.optimizer. +Used to make mypy happy about multiple conftest.py modules. +""" diff --git a/mlos_bench/mlos_bench/tests/optimizer/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizer/llamatune_opt_test.py index 46ec7b7391..57236b2170 100644 --- a/mlos_bench/mlos_bench/tests/optimizer/llamatune_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizer/llamatune_opt_test.py @@ -8,9 +8,9 @@ Unit tests for mock mlos_bench optimizer. import pytest -from mlos_bench.environment import Status +from mlos_bench.environment.status import Status from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.optimizer import MlosCoreOptimizer +from mlos_bench.optimizer.mlos_core_optimizer import MlosCoreOptimizer # pylint: disable=redefined-outer-name @@ -41,7 +41,7 @@ def mock_scores() -> list: return [88.88, 66.66, 99.99] -def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list): +def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None: """ Make sure that llamatune+emukit optimizer initializes and works correctly. """ diff --git a/mlos_bench/mlos_bench/tests/optimizer/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizer/mock_opt_test.py index fe2c5f5d8f..c3168395a7 100644 --- a/mlos_bench/mlos_bench/tests/optimizer/mock_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizer/mock_opt_test.py @@ -8,8 +8,8 @@ Unit tests for mock mlos_bench optimizer. import pytest -from mlos_bench.environment import Status -from mlos_bench.optimizer import MockOptimizer +from mlos_bench.environment.status import Status +from mlos_bench.optimizer.mock_optimizer import MockOptimizer # pylint: disable=redefined-outer-name @@ -40,7 +40,7 @@ def mock_configurations() -> list: def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: """ - Run several iterations of the oiptimizer and return the best score. + Run several iterations of the optimizer and return the best score. """ for (tunable_values, score) in mock_configurations: assert mock_opt.not_converged() @@ -49,10 +49,12 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: mock_opt.register(tunables, Status.SUCCEEDED, score) (score, _tunables) = mock_opt.get_best_observation() + assert score is not None + assert isinstance(score, float) return score -def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list): +def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> None: """ Make sure that mock optimizer produces consistent suggestions. """ @@ -60,7 +62,7 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list): assert score == pytest.approx(66.66, 0.01) -def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list): +def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list) -> None: """ Check the maximization mode of the mock optimizer. """ @@ -68,7 +70,7 @@ def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: li assert score == pytest.approx(99.99, 0.01) -def test_mock_optimizer_register_fail(mock_opt: MockOptimizer): +def test_mock_optimizer_register_fail(mock_opt: MockOptimizer) -> None: """ Check the input acceptance conditions for Optimizer.register(). """ diff --git a/mlos_bench/mlos_bench/tests/optimizer/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizer/opt_bulk_register_test.py index 61990bc9c2..d9b4f6d203 100644 --- a/mlos_bench/mlos_bench/tests/optimizer/opt_bulk_register_test.py +++ b/mlos_bench/mlos_bench/tests/optimizer/opt_bulk_register_test.py @@ -10,7 +10,7 @@ from typing import Optional, List import pytest -from mlos_bench.environment import Status +from mlos_bench.environment.status import Status from mlos_bench.optimizer import Optimizer, MockOptimizer, MlosCoreOptimizer # pylint: disable=redefined-outer-name @@ -46,7 +46,7 @@ def mock_configs() -> List[dict]: @pytest.fixture -def mock_scores() -> List[float]: +def mock_scores() -> List[Optional[float]]: """ Mock benchmark results from earlier experiments. """ @@ -62,13 +62,14 @@ def mock_status() -> List[Status]: def _test_opt_update_min(opt: Optimizer, configs: List[dict], - scores: List[float], status: Optional[List[Status]] = None): + scores: List[float], status: Optional[List[Status]] = None) -> None: """ Test the bulk update of the optimizer on the minimization problem. """ opt.bulk_register(configs, scores, status) (score, tunables) = opt.get_best_observation() assert score == pytest.approx(66.66, 0.01) + assert tunables is not None assert tunables.get_param_values() == { "vmSize": "Standard_B4ms", "rootfs": "ext4", @@ -77,13 +78,14 @@ def _test_opt_update_min(opt: Optimizer, configs: List[dict], def _test_opt_update_max(opt: Optimizer, configs: List[dict], - scores: List[float], status: Optional[List[Status]] = None): + scores: List[float], status: Optional[List[Status]] = None) -> None: """ - Test the bulk update of the optimizer on the maximiation prtoblem. + Test the bulk update of the optimizer on the maximization problem. """ opt.bulk_register(configs, scores, status) (score, tunables) = opt.get_best_observation() assert score == pytest.approx(99.99, 0.01) + assert tunables is not None assert tunables.get_param_values() == { "vmSize": "Standard_B2s", "rootfs": "xfs", @@ -92,7 +94,7 @@ def _test_opt_update_max(opt: Optimizer, configs: List[dict], def test_update_mock_min(mock_opt: MockOptimizer, mock_configs: List[dict], - mock_scores: List[float], mock_status: List[Status]): + mock_scores: List[float], mock_status: List[Status]) -> None: """ Test the bulk update of the mock optimizer on the minimization problem. """ @@ -100,7 +102,7 @@ def test_update_mock_min(mock_opt: MockOptimizer, mock_configs: List[dict], def test_update_mock_max(mock_opt_max: MockOptimizer, mock_configs: List[dict], - mock_scores: List[float], mock_status: List[Status]): + mock_scores: List[float], mock_status: List[Status]) -> None: """ Test the bulk update of the mock optimizer on the maximization problem. """ @@ -108,7 +110,7 @@ def test_update_mock_max(mock_opt_max: MockOptimizer, mock_configs: List[dict], def test_update_emukit(emukit_opt: MlosCoreOptimizer, mock_configs: List[dict], - mock_scores: List[float], mock_status: List[Status]): + mock_scores: List[float], mock_status: List[Status]) -> None: """ Test the bulk update of the EmuKit optimizer. """ @@ -116,7 +118,7 @@ def test_update_emukit(emukit_opt: MlosCoreOptimizer, mock_configs: List[dict], def test_update_emukit_max(emukit_opt_max: MlosCoreOptimizer, mock_configs: List[dict], - mock_scores: List[float], mock_status: List[Status]): + mock_scores: List[float], mock_status: List[Status]) -> None: """ Test the bulk update of the EmuKit optimizer on the maximization problem. """ @@ -124,7 +126,7 @@ def test_update_emukit_max(emukit_opt_max: MlosCoreOptimizer, mock_configs: List def test_update_scikit_gp(scikit_gp_opt: MlosCoreOptimizer, mock_configs: List[dict], - mock_scores: List[float], mock_status: List[Status]): + mock_scores: List[float], mock_status: List[Status]) -> None: """ Test the bulk update of the scikit-optimize GP optimizer. """ @@ -132,7 +134,7 @@ def test_update_scikit_gp(scikit_gp_opt: MlosCoreOptimizer, mock_configs: List[d def test_update_scikit_et(scikit_et_opt: MlosCoreOptimizer, mock_configs: List[dict], - mock_scores: List[float], mock_status: List[Status]): + mock_scores: List[float], mock_status: List[Status]) -> None: """ Test the bulk update of the scikit-optimize ET optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/optimizer/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizer/toy_optimization_loop_test.py index 3386c34c62..a5e4d455f9 100644 --- a/mlos_bench/mlos_bench/tests/optimizer/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizer/toy_optimization_loop_test.py @@ -10,7 +10,8 @@ from typing import Tuple import pytest -from mlos_bench.environment import Environment, MockEnv +from mlos_bench.environment.base_environment import Environment +from mlos_bench.environment.mock_env import MockEnv from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.optimizer import Optimizer, MockOptimizer, MlosCoreOptimizer @@ -27,17 +28,20 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: assert env.setup(tunables) (status, output) = env.run() - score = output['score'] assert status.is_succeeded + assert output is not None + score = output['score'] assert 60 <= score <= 120 opt.register(tunables, status, score) - return opt.get_best_observation() + (best_score, best_tunables) = opt.get_best_observation() + assert isinstance(best_score, float) and isinstance(best_tunables, TunableGroups) + return (best_score, best_tunables) def test_mock_optimization_loop(mock_env_no_noise: MockEnv, - mock_opt: MockOptimizer): + mock_opt: MockOptimizer) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -51,7 +55,7 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv, def test_scikit_gp_optimization_loop(mock_env_no_noise: MockEnv, - scikit_gp_opt: MlosCoreOptimizer): + scikit_gp_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and Scikit GP optimizer. """ @@ -65,7 +69,7 @@ def test_scikit_gp_optimization_loop(mock_env_no_noise: MockEnv, def test_scikit_et_optimization_loop(mock_env_no_noise: MockEnv, - scikit_et_opt: MlosCoreOptimizer): + scikit_et_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and Scikit ET optimizer. """ @@ -79,7 +83,7 @@ def test_scikit_et_optimization_loop(mock_env_no_noise: MockEnv, def test_emukit_optimization_loop(mock_env_no_noise: MockEnv, - emukit_opt: MlosCoreOptimizer): + emukit_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and EmuKit optimizer. """ @@ -89,7 +93,7 @@ def test_emukit_optimization_loop(mock_env_no_noise: MockEnv, def test_emukit_optimization_loop_max(mock_env_no_noise: MockEnv, - emukit_opt_max: MlosCoreOptimizer): + emukit_opt_max: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and EmuKit optimizer in maximization mode. diff --git a/mlos_bench/mlos_bench/tests/service/__init__.py b/mlos_bench/mlos_bench/tests/service/__init__.py new file mode 100644 index 0000000000..e5a7a5bd7c --- /dev/null +++ b/mlos_bench/mlos_bench/tests/service/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.service. +Used to make mypy happy about multiple conftest.py modules. +""" diff --git a/mlos_bench/mlos_bench/tests/service/config_persistence_test.py b/mlos_bench/mlos_bench/tests/service/config_persistence_test.py index 833b4fd3d3..c63fe9a47c 100644 --- a/mlos_bench/mlos_bench/tests/service/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/service/config_persistence_test.py @@ -9,7 +9,7 @@ Unit tests for configuration persistence service. import os import pytest -from mlos_bench.service import ConfigPersistenceService +from mlos_bench.service.config_persistence import ConfigPersistenceService # pylint: disable=redefined-outer-name @@ -27,7 +27,7 @@ def config_persistence_service() -> ConfigPersistenceService: }) -def test_resolve_path(config_persistence_service: ConfigPersistenceService): +def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> None: """ Check if we can actually find a file somewhere in `config_path`. """ @@ -37,7 +37,7 @@ def test_resolve_path(config_persistence_service: ConfigPersistenceService): assert os.path.exists(path) -def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService): +def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None: """ Check if non-existent file resolves without using `config_path`. """ @@ -47,7 +47,7 @@ def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) assert path == file_path -def test_load_config(config_persistence_service: ConfigPersistenceService): +def test_load_config(config_persistence_service: ConfigPersistenceService) -> None: """ Check if we can successfully load a config file located relative to `config_path`. """ diff --git a/mlos_bench/mlos_bench/tests/service/local/__init__.py b/mlos_bench/mlos_bench/tests/service/local/__init__.py new file mode 100644 index 0000000000..7e4517055e --- /dev/null +++ b/mlos_bench/mlos_bench/tests/service/local/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.service.local. +Used to make mypy happy about multiple conftest.py modules. +""" diff --git a/mlos_bench/mlos_bench/tests/service/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/service/local/local_exec_python_test.py index 6eedb98643..0d1a663dae 100644 --- a/mlos_bench/mlos_bench/tests/service/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/service/local/local_exec_python_test.py @@ -5,12 +5,17 @@ """ Unit tests for LocalExecService to run Python scripts locally. """ + +from typing import Dict + import os import json import pytest -from mlos_bench.service import LocalExecService, ConfigPersistenceService +from mlos_bench.tunables.tunable import TunableValue +from mlos_bench.service.local.local_exec import LocalExecService +from mlos_bench.service.config_persistence import ConfigPersistenceService # pylint: disable=redefined-outer-name @@ -25,7 +30,7 @@ def local_exec_service() -> LocalExecService: })) -def test_run_python_script(local_exec_service: LocalExecService): +def test_run_python_script(local_exec_service: LocalExecService) -> None: """ Run a Python script using a local_exec service. """ @@ -33,7 +38,7 @@ def test_run_python_script(local_exec_service: LocalExecService): output_file = "./config-kernel.sh" # Tunable parameters to save in JSON - params = { + params: Dict[str, TunableValue] = { "sched_migration_cost_ns": 40000, "sched_granularity_ns": 800000 } diff --git a/mlos_bench/mlos_bench/tests/service/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/service/local/local_exec_test.py index f3e8362b3b..1fd790c3db 100644 --- a/mlos_bench/mlos_bench/tests/service/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/service/local/local_exec_test.py @@ -11,7 +11,8 @@ import sys import pytest import pandas -from mlos_bench.service import LocalExecService, ConfigPersistenceService +from mlos_bench.service.local.local_exec import LocalExecService +from mlos_bench.service.config_persistence import ConfigPersistenceService # pylint: disable=redefined-outer-name # -- Ignore pylint complaints about pytest references to @@ -26,7 +27,7 @@ def local_exec_service() -> LocalExecService: return LocalExecService(parent=ConfigPersistenceService()) -def test_run_script(local_exec_service: LocalExecService): +def test_run_script(local_exec_service: LocalExecService) -> None: """ Run a script locally and check the results. """ @@ -37,7 +38,7 @@ def test_run_script(local_exec_service: LocalExecService): assert stderr.strip() == "" -def test_run_script_multiline(local_exec_service: LocalExecService): +def test_run_script_multiline(local_exec_service: LocalExecService) -> None: """ Run a multiline script locally and check the results. """ @@ -51,7 +52,7 @@ def test_run_script_multiline(local_exec_service: LocalExecService): assert stderr.strip() == "" -def test_run_script_multiline_env(local_exec_service: LocalExecService): +def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None: """ Run a multiline script locally and pass the environment variables to it. """ @@ -68,7 +69,7 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService): assert stderr.strip() == "" -def test_run_script_read_csv(local_exec_service: LocalExecService): +def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: """ Run a script locally and read the resulting CSV file. """ @@ -89,7 +90,7 @@ def test_run_script_read_csv(local_exec_service: LocalExecService): assert all(data.col2 == [222, 444]) -def test_run_script_write_read_txt(local_exec_service: LocalExecService): +def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None: """ Write data a temp location and run a script that updates it there. """ @@ -112,7 +113,7 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService): assert fh_input.read().split() == ["hello", "world", "test"] -def test_run_script_fail(local_exec_service: LocalExecService): +def test_run_script_fail(local_exec_service: LocalExecService) -> None: """ Try to run a non-existent command. """ diff --git a/mlos_bench/mlos_bench/tests/service/remote/__init__.py b/mlos_bench/mlos_bench/tests/service/remote/__init__.py new file mode 100644 index 0000000000..3dc9a35cbf --- /dev/null +++ b/mlos_bench/mlos_bench/tests/service/remote/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.service.remote. +Used to make mypy happy about multiple conftest.py modules. +""" diff --git a/mlos_bench/mlos_bench/tests/service/remote/azure/__init__.py b/mlos_bench/mlos_bench/tests/service/remote/azure/__init__.py new file mode 100644 index 0000000000..0a615d8e31 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/service/remote/azure/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.service.remote.azure. +Used to make mypy happy about multiple conftest.py modules. +""" diff --git a/mlos_bench/mlos_bench/tests/service/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/service/remote/azure/azure_fileshare_test.py index 28de1b0a56..0575ddfe27 100644 --- a/mlos_bench/mlos_bench/tests/service/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/service/remote/azure/azure_fileshare_test.py @@ -7,7 +7,9 @@ Tests for mlos_bench.service.remote.azure.azure_fileshare """ import os -from unittest.mock import Mock, patch, call +from unittest.mock import MagicMock, Mock, patch, call + +from mlos_bench.service.remote.azure.azure_fileshare import AzureFileShareService # pylint: disable=missing-function-docstring # pylint: disable=too-many-arguments @@ -16,31 +18,31 @@ from unittest.mock import Mock, patch, call @patch("mlos_bench.service.remote.azure.azure_fileshare.open") @patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs") -def test_download_file(mock_makedirs, mock_open, azure_fileshare): +def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - # pylint: disable=protected-access - mock_share_client = azure_fileshare._share_client - mock_share_client.get_directory_client.return_value = Mock( - exists=Mock(return_value=False) - ) + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \ + patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client: + mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False)) - azure_fileshare.download(remote_path, local_path) + azure_fileshare.download(remote_path, local_path) - mock_share_client.get_file_client.assert_called_with(remote_path) - mock_makedirs.assert_called_with( - local_folder, - exist_ok=True, - ) - open_path, open_mode = mock_open.call_args.args - assert os.path.abspath(local_path) == os.path.abspath(open_path) - assert open_mode == "wb" + mock_get_file_client.assert_called_with(remote_path) + + mock_makedirs.assert_called_with( + local_folder, + exist_ok=True, + ) + open_path, open_mode = mock_open.call_args.args + assert os.path.abspath(local_path) == os.path.abspath(open_path) + assert open_mode == "wb" -def make_dir_client_returns(remote_folder: str): +def make_dir_client_returns(remote_folder: str) -> dict: return { remote_folder: Mock( exists=Mock(return_value=True), @@ -66,20 +68,24 @@ def make_dir_client_returns(remote_folder: str): @patch("mlos_bench.service.remote.azure.azure_fileshare.open") @patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_non_recursive(mock_makedirs, mock_open, azure_fileshare): +def test_download_folder_non_recursive(mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - # pylint: disable=protected-access - mock_share_client = azure_fileshare._share_client - mock_share_client.get_directory_client.side_effect = lambda x: dir_client_returns[x] + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ + patch.object(mock_share_client, "get_file_client") as mock_get_file_client: - azure_fileshare.download(remote_folder, local_folder, recursive=False) + mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] - mock_share_client.get_file_client.assert_called_with( + azure_fileshare.download(remote_folder, local_folder, recursive=False) + + mock_get_file_client.assert_called_with( f"{remote_folder}/a_file_1.csv", ) - mock_share_client.get_directory_client.assert_has_calls([ + mock_get_directory_client.assert_has_calls([ call(remote_folder), call(f"{remote_folder}/a_file_1.csv"), ], any_order=True) @@ -87,21 +93,22 @@ def test_download_folder_non_recursive(mock_makedirs, mock_open, azure_fileshare @patch("mlos_bench.service.remote.azure.azure_fileshare.open") @patch("mlos_bench.service.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_recursive(mock_makedirs, mock_open, azure_fileshare): +def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - # pylint: disable=protected-access - mock_share_client = azure_fileshare._share_client - mock_share_client.get_directory_client.side_effect = lambda x: dir_client_returns[x] + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ + patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] - azure_fileshare.download(remote_folder, local_folder, recursive=True) + azure_fileshare.download(remote_folder, local_folder, recursive=True) - mock_share_client.get_file_client.assert_has_calls([ + mock_get_file_client.assert_has_calls([ call(f"{remote_folder}/a_file_1.csv"), call(f"{remote_folder}/a_folder/a_file_2.csv"), ], any_order=True) - mock_share_client.get_directory_client.assert_has_calls([ + mock_get_directory_client.assert_has_calls([ call(remote_folder), call(f"{remote_folder}/a_file_1.csv"), call(f"{remote_folder}/a_folder"), @@ -111,19 +118,19 @@ def test_download_folder_recursive(mock_makedirs, mock_open, azure_fileshare): @patch("mlos_bench.service.remote.azure.azure_fileshare.open") @patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir") -def test_upload_file(mock_isdir, mock_open, azure_fileshare): +def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - # pylint: disable=protected-access - mock_share_client = azure_fileshare._share_client + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access mock_isdir.return_value = False - azure_fileshare.upload(local_path, remote_path) + with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + azure_fileshare.upload(local_path, remote_path) - mock_share_client.get_file_client.assert_called_with(remote_path) + mock_get_file_client.assert_called_with(remote_path) open_path, open_mode = mock_open.call_args.args assert os.path.abspath(local_path) == os.path.abspath(open_path) assert open_mode == "rb" @@ -136,11 +143,11 @@ class MyDirEntry: self.name = name self.is_a_dir = is_a_dir - def is_dir(self): + def is_dir(self) -> bool: return self.is_a_dir -def make_scandir_returns(local_folder: str): +def make_scandir_returns(local_folder: str) -> dict: return { local_folder: [ MyDirEntry("a_folder", True), @@ -152,7 +159,7 @@ def make_scandir_returns(local_folder: str): } -def make_isdir_returns(local_folder: str): +def make_isdir_returns(local_folder: str) -> dict: return { local_folder: True, f"{local_folder}/a_file_1.csv": False, @@ -161,7 +168,7 @@ def make_isdir_returns(local_folder: str): } -def process_paths(input_path): +def process_paths(input_path: str) -> str: skip_prefix = os.getcwd() # Remove prefix from os.path.abspath if there if input_path == os.path.abspath(input_path): @@ -175,37 +182,43 @@ def process_paths(input_path): @patch("mlos_bench.service.remote.azure.azure_fileshare.open") @patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.service.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_non_recursive(mock_scandir, mock_isdir, mock_open, azure_fileshare): +def test_upload_directory_non_recursive(mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - # pylint: disable=protected-access - mock_share_client = azure_fileshare._share_client + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access - azure_fileshare.upload(local_folder, remote_folder, recursive=False) + with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + azure_fileshare.upload(local_folder, remote_folder, recursive=False) - mock_share_client.get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv") + mock_get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv") @patch("mlos_bench.service.remote.azure.azure_fileshare.open") @patch("mlos_bench.service.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.service.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_recursive(mock_scandir, mock_isdir, mock_open, azure_fileshare): +def test_upload_directory_recursive(mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - # pylint: disable=protected-access - mock_share_client = azure_fileshare._share_client + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access - azure_fileshare.upload(local_folder, remote_folder, recursive=True) + with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + azure_fileshare.upload(local_folder, remote_folder, recursive=True) - mock_share_client.get_file_client.assert_has_calls([ + mock_get_file_client.assert_has_calls([ call(f"{remote_folder}/a_file_1.csv"), call(f"{remote_folder}/a_folder/a_file_2.csv"), ], any_order=True) diff --git a/mlos_bench/mlos_bench/tests/service/remote/azure/azure_services_test.py b/mlos_bench/mlos_bench/tests/service/remote/azure/azure_services_test.py index 3882a196af..e81c868477 100644 --- a/mlos_bench/mlos_bench/tests/service/remote/azure/azure_services_test.py +++ b/mlos_bench/mlos_bench/tests/service/remote/azure/azure_services_test.py @@ -10,7 +10,9 @@ from unittest.mock import MagicMock, patch import pytest -from mlos_bench.environment import Status +from mlos_bench.environment.status import Status + +from mlos_bench.service.remote.azure.azure_services import AzureVMService # pylint: disable=missing-function-docstring # pylint: disable=too-many-arguments @@ -21,7 +23,7 @@ from mlos_bench.environment import Status ("vm_start", True), ("vm_stop", False), ("vm_deprovision", False), - ("vm_reboot", False), + ("vm_restart", False), ]) @pytest.mark.parametrize( ("http_status_code", "operation_status"), [ @@ -31,8 +33,8 @@ from mlos_bench.environment import Status (404, Status.FAILED), ]) @patch("mlos_bench.service.remote.azure.azure_services.requests") -def test_vm_operation_status(mock_requests, azure_vm_service, operation_name, - accepts_params, http_status_code, operation_status): +def test_vm_operation_status(mock_requests: MagicMock, azure_vm_service: AzureVMService, operation_name: str, + accepts_params: bool, http_status_code: int, operation_status: Status) -> None: mock_response = MagicMock() mock_response.status_code = http_status_code @@ -49,7 +51,7 @@ def test_vm_operation_status(mock_requests, azure_vm_service, operation_name, @patch("mlos_bench.service.remote.azure.azure_services.time.sleep") @patch("mlos_bench.service.remote.azure.azure_services.requests") -def test_wait_vm_operation_ready(mock_requests, mock_sleep, azure_vm_service): +def test_wait_vm_operation_ready(mock_requests: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService) -> None: # Mock response header async_url = "DUMMY_ASYNC_URL" @@ -73,7 +75,7 @@ def test_wait_vm_operation_ready(mock_requests, mock_sleep, azure_vm_service): @patch("mlos_bench.service.remote.azure.azure_services.requests") -def test_wait_vm_operation_timeout(mock_requests, azure_vm_service): +def test_wait_vm_operation_timeout(mock_requests: MagicMock, azure_vm_service: AzureVMService) -> None: # Mock response header params = { @@ -99,8 +101,8 @@ def test_wait_vm_operation_timeout(mock_requests, azure_vm_service): (404, Status.FAILED), ]) @patch("mlos_bench.service.remote.azure.azure_services.requests") -def test_remote_exec_status(mock_requests, azure_vm_service, http_status_code, operation_status): - +def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service: AzureVMService, + http_status_code: int, operation_status: Status) -> None: script = ["command_1", "command_2"] mock_response = MagicMock() @@ -113,7 +115,7 @@ def test_remote_exec_status(mock_requests, azure_vm_service, http_status_code, o @patch("mlos_bench.service.remote.azure.azure_services.requests") -def test_remote_exec_headers_output(mock_requests, azure_vm_service): +def test_remote_exec_headers_output(mock_requests: MagicMock, azure_vm_service: AzureVMService) -> None: async_url_key = "asyncResultsUrl" async_url_value = "DUMMY_ASYNC_URL" @@ -158,14 +160,15 @@ def test_remote_exec_headers_output(mock_requests, azure_vm_service): (Status.PENDING, {}, {}), (Status.FAILED, {}, {}), ]) -def test_get_remote_exec_results(azure_vm_service, operation_status: Status, - wait_output: dict, results_output: dict): +def test_get_remote_exec_results(azure_vm_service: AzureVMService, operation_status: Status, + wait_output: dict, results_output: dict) -> None: params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"} mock_wait_vm_operation = MagicMock() mock_wait_vm_operation.return_value = (operation_status, wait_output) - azure_vm_service.wait_vm_operation = mock_wait_vm_operation + # azure_vm_service.wait_vm_operation = mock_wait_vm_operation + setattr(azure_vm_service, "wait_vm_operation", mock_wait_vm_operation) status, cmd_output = azure_vm_service.get_remote_exec_results(params) diff --git a/mlos_bench/mlos_bench/tests/service/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/service/remote/azure/conftest.py index 39970fe517..3dfb333cb6 100644 --- a/mlos_bench/mlos_bench/tests/service/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/service/remote/azure/conftest.py @@ -9,14 +9,14 @@ Configuration test fixtures for azure_services in mlos_bench. from unittest.mock import patch import pytest -from mlos_bench.service import ConfigPersistenceService +from mlos_bench.service.config_persistence import ConfigPersistenceService from mlos_bench.service.remote.azure import AzureVMService, AzureFileShareService # pylint: disable=redefined-outer-name @pytest.fixture -def config_persistence_service(): +def config_persistence_service() -> ConfigPersistenceService: """ Test fixture for ConfigPersistenceService. """ @@ -28,7 +28,7 @@ def config_persistence_service(): @pytest.fixture -def azure_vm_service(config_persistence_service: ConfigPersistenceService): +def azure_vm_service(config_persistence_service: ConfigPersistenceService) -> AzureVMService: """ Creates a dummy Azure VM service for tests that require it. """ @@ -45,7 +45,7 @@ def azure_vm_service(config_persistence_service: ConfigPersistenceService): @pytest.fixture -def azure_fileshare(config_persistence_service: ConfigPersistenceService): +def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService: """ Creates a dummy AzureFileShareService for tests that require it. """ diff --git a/mlos_bench/mlos_bench/tests/tunables/__init__.py b/mlos_bench/mlos_bench/tests/tunables/__init__.py new file mode 100644 index 0000000000..83c046e575 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/tunables/__init__.py @@ -0,0 +1,8 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Tests for mlos_bench.tunables. +Used to make mypy happy about multiple conftest.py modules. +""" diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 5544b9a341..7cd8d650cf 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -46,7 +46,7 @@ def configuration_space() -> ConfigurationSpace: def _cmp_tunable_hyperparameter_categorical( - tunable: Tunable, cs_param: CategoricalHyperparameter): + tunable: Tunable, cs_param: CategoricalHyperparameter) -> None: """ Check if categorical Tunable and ConfigSpace Hyperparameter actually match. """ @@ -56,7 +56,7 @@ def _cmp_tunable_hyperparameter_categorical( def _cmp_tunable_hyperparameter_int( - tunable: Tunable, cs_param: UniformIntegerHyperparameter): + tunable: Tunable, cs_param: UniformIntegerHyperparameter) -> None: """ Check if integer Tunable and ConfigSpace Hyperparameter actually match. """ @@ -66,7 +66,7 @@ def _cmp_tunable_hyperparameter_int( def _cmp_tunable_hyperparameter_float( - tunable: Tunable, cs_param: UniformFloatHyperparameter): + tunable: Tunable, cs_param: UniformFloatHyperparameter) -> None: """ Check if float Tunable and ConfigSpace Hyperparameter actually match. """ @@ -75,7 +75,7 @@ def _cmp_tunable_hyperparameter_float( assert cs_param.default_value == tunable.value -def test_tunable_to_hyperparameter_categorical(tunable_categorical): +def test_tunable_to_hyperparameter_categorical(tunable_categorical: Tunable) -> None: """ Check the conversion of Tunable to CategoricalHyperparameter. """ @@ -83,7 +83,7 @@ def test_tunable_to_hyperparameter_categorical(tunable_categorical): _cmp_tunable_hyperparameter_categorical(tunable_categorical, cs_param) -def test_tunable_to_hyperparameter_int(tunable_int): +def test_tunable_to_hyperparameter_int(tunable_int: Tunable) -> None: """ Check the conversion of Tunable to UniformIntegerHyperparameter. """ @@ -91,7 +91,7 @@ def test_tunable_to_hyperparameter_int(tunable_int): _cmp_tunable_hyperparameter_int(tunable_int, cs_param) -def test_tunable_to_hyperparameter_float(tunable_float): +def test_tunable_to_hyperparameter_float(tunable_float: Tunable) -> None: """ Check the conversion of Tunable to UniformFloatHyperparameter. """ @@ -106,7 +106,7 @@ _CMP_FUNC = { } -def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups): +def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> None: """ Check the conversion of TunableGroups to ConfigurationSpace. Make sure that the corresponding Tunable and Hyperparameter objects match. @@ -119,7 +119,7 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups): def test_tunable_groups_to_configspace( - tunable_groups: TunableGroups, configuration_space: ConfigurationSpace): + tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None: """ Check the conversion of the entire TunableGroups collection to a single ConfigurationSpace object. diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py index 3713588d4d..9177b8d4cb 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py @@ -8,8 +8,11 @@ Unit tests for assigning values to the individual parameters within tunable grou import pytest +from mlos_bench.tunables.tunable import Tunable +from mlos_bench.tunables.tunable_groups import TunableGroups -def test_tunables_assign_unknown_param(tunable_groups): + +def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None: """ Make sure that bulk assignment fails for parameters that don't exist in the TunableGroups object. @@ -23,7 +26,7 @@ def test_tunables_assign_unknown_param(tunable_groups): }) -def test_tunables_assign_invalid_categorical(tunable_groups): +def test_tunables_assign_invalid_categorical(tunable_groups: TunableGroups) -> None: """ Check parameter validation for categorical tunables. """ @@ -31,7 +34,7 @@ def test_tunables_assign_invalid_categorical(tunable_groups): tunable_groups.assign({"vmSize": "InvalidSize"}) -def test_tunables_assign_invalid_range(tunable_groups): +def test_tunables_assign_invalid_range(tunable_groups: TunableGroups) -> None: """ Check parameter out-of-range validation for numerical tunables. """ @@ -39,14 +42,14 @@ def test_tunables_assign_invalid_range(tunable_groups): tunable_groups.assign({"kernel_sched_migration_cost_ns": -2}) -def test_tunables_assign_coerce_str(tunable_groups): +def test_tunables_assign_coerce_str(tunable_groups: TunableGroups) -> None: """ Check the conversion from strings when assigning to an integer parameter. """ tunable_groups.assign({"kernel_sched_migration_cost_ns": "10000"}) -def test_tunables_assign_coerce_str_range_check(tunable_groups): +def test_tunables_assign_coerce_str_range_check(tunable_groups: TunableGroups) -> None: """ Check the range when assigning to an integer tunable. """ @@ -54,7 +57,7 @@ def test_tunables_assign_coerce_str_range_check(tunable_groups): tunable_groups.assign({"kernel_sched_migration_cost_ns": "5500000"}) -def test_tunables_assign_coerce_str_invalid(tunable_groups): +def test_tunables_assign_coerce_str_invalid(tunable_groups: TunableGroups) -> None: """ Make sure we fail when assigning an invalid string to an integer tunable. """ @@ -62,23 +65,23 @@ def test_tunables_assign_coerce_str_invalid(tunable_groups): tunable_groups.assign({"kernel_sched_migration_cost_ns": "1.1"}) -def test_tunable_assign_str_to_int(tunable_int): +def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None: """ Check str to int coercion. """ tunable_int.value = "10" - assert tunable_int.value == 10 + assert tunable_int.value == 10 # type: ignore[comparison-overlap] -def test_tunable_assign_str_to_float(tunable_float): +def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None: """ Check str to float coercion. """ tunable_float.value = "0.5" - assert tunable_float.value == 0.5 + assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] -def test_tunable_assign_float_to_int(tunable_int): +def test_tunable_assign_float_to_int(tunable_int: Tunable) -> None: """ Check float to int coercion. """ @@ -86,7 +89,7 @@ def test_tunable_assign_float_to_int(tunable_int): assert tunable_int.value == 10 -def test_tunable_assign_float_to_int_fail(tunable_int): +def test_tunable_assign_float_to_int_fail(tunable_int: Tunable) -> None: """ Check the invalid float to int coercion. """ diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py index 3d1ee5b5df..440060ebe1 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py @@ -6,18 +6,21 @@ Unit tests for deep copy of tunable objects and groups. """ +from mlos_bench.tunables.tunable import Tunable +from mlos_bench.tunables.tunable_groups import TunableGroups -def test_copy_tunable_int(tunable_int): + +def test_copy_tunable_int(tunable_int: Tunable) -> None: """ Check if deep copy works for Tunable object. """ tunable_copy = tunable_int.copy() assert tunable_int == tunable_copy - tunable_copy.value += 200 + tunable_copy.value = tunable_copy.numerical_value + 200 assert tunable_int != tunable_copy -def test_copy_tunable_groups(tunable_groups): +def test_copy_tunable_groups(tunable_groups: TunableGroups) -> None: """ Check if deep copy works for TunableGroups object. """ diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index b8349662ab..b01adf0e25 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -7,9 +7,9 @@ Tunable parameter definition. """ import copy -from typing import Any, Dict, List +from typing import Dict, Iterable -from mlos_bench.tunables.tunable import Tunable +from mlos_bench.tunables.tunable import Tunable, TunableValue class CovariantTunableGroup: @@ -32,8 +32,8 @@ class CovariantTunableGroup: """ self._is_updated = True self._name = name - self._cost = config.get("cost", 0) - self._tunables = { + self._cost = int(config.get("cost", 0)) + self._tunables: Dict[str, Tunable] = { name: Tunable(name, tunable_config) for (name, tunable_config) in config.get("params", {}).items() } @@ -50,7 +50,7 @@ class CovariantTunableGroup: """ return self._name - def copy(self): + def copy(self) -> "CovariantTunableGroup": """ Deep copy of the CovariantTunableGroup object. @@ -62,7 +62,7 @@ class CovariantTunableGroup: """ return copy.deepcopy(self) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: """ Check if two CovariantTunableGroup objects are equal. @@ -76,12 +76,14 @@ class CovariantTunableGroup: is_equal : bool True if two CovariantTunableGroup objects are equal. """ + if not isinstance(other, CovariantTunableGroup): + return False return (self._name == other._name and self._cost == other._cost and self._is_updated == other._is_updated and self._tunables == other._tunables) - def reset(self): + def reset(self) -> None: """ Clear the update flag. That is, state that running an experiment with the current values of the tunables in this group has no extra cost. @@ -110,13 +112,13 @@ class CovariantTunableGroup: """ return self._cost if self._is_updated else 0 - def get_names(self) -> List[str]: + def get_names(self) -> Iterable[str]: """ Get the names of all tunables in the group. """ return self._tunables.keys() - def get_values(self) -> Dict[str, Any]: + def get_values(self) -> Dict[str, TunableValue]: """ Get current values of all tunables in the group. """ @@ -151,10 +153,10 @@ class CovariantTunableGroup: """ return self._tunables[name] - def __getitem__(self, name: str): + def __getitem__(self, name: str) -> TunableValue: return self.get_tunable(name).value - def __setitem__(self, name: str, value): + def __setitem__(self, name: str, value: TunableValue) -> TunableValue: self._is_updated = True self._tunables[name].value = value return value diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 3adf0b84a3..227278bc33 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -9,11 +9,32 @@ import copy import collections import logging -from typing import Any, List, Tuple +from typing import List, Optional, Sequence, Tuple, TypedDict, Union _LOG = logging.getLogger(__name__) +"""A tunable parameter value type alias.""" +TunableValue = Union[int, float, str] + + +class TunableDict(TypedDict, total=False): + """ + A typed dict for tunable parameters. + + Mostly used for mypy type checking. + + These are the types expected to be received from the json config. + """ + + type: str + description: Optional[str] + default: TunableValue + values: Optional[List[str]] + range: Optional[Union[Sequence[int], Sequence[float]]] + special: Optional[Union[List[int], List[str]]] + + class Tunable: # pylint: disable=too-many-instance-attributes """ A tunable parameter definition and its current value. @@ -25,7 +46,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes "categorical": str, } - def __init__(self, name: str, config: dict): + def __init__(self, name: str, config: TunableDict): """ Create an instance of a new tunable parameter. @@ -39,19 +60,24 @@ class Tunable: # pylint: disable=too-many-instance-attributes self._name = name self._type = config["type"] # required self._description = config.get("description") - self._default = config.get("default") + self._default = config["default"] self._values = config.get("values") - self._range = config.get("range") + self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None + config_range = config.get("range") + if config_range is not None: + assert len(config_range) == 2, f"Invalid range: {config_range}" + config_range = (config_range[0], config_range[1]) + self._range = config_range self._special = config.get("special") self._current_value = self._default - if self._type == "categorical": + if self.is_categorical: if not (self._values and isinstance(self._values, collections.abc.Iterable)): raise ValueError("Must specify values for the categorical type") if self._range is not None: raise ValueError("Range must be None for the categorical type") if self._special is not None: raise ValueError("Special values must be None for the categorical type") - elif self._type in {"int", "float"}: + elif self.is_numerical: if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]: raise ValueError(f"Invalid range: {self._range}") else: @@ -68,7 +94,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes """ return f"{self._name}={self._current_value}" - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: """ Check if two Tunable objects are equal. @@ -83,13 +109,15 @@ class Tunable: # pylint: disable=too-many-instance-attributes True if the Tunables correspond to the same parameter and have the same value and type. NOTE: ranges and special values are not currently considered in the comparison. """ - return ( + if not isinstance(other, Tunable): + return False + return bool( self._name == other._name and self._type == other._type and self._current_value == other._current_value ) - def copy(self): + def copy(self) -> "Tunable": """ Deep copy of the Tunable object. @@ -101,14 +129,14 @@ class Tunable: # pylint: disable=too-many-instance-attributes return copy.deepcopy(self) @property - def value(self): + def value(self) -> TunableValue: """ Get the current value of the tunable. """ return self._current_value @value.setter - def value(self, value): + def value(self, value: TunableValue) -> TunableValue: """ Set the current value of the tunable. """ @@ -135,13 +163,13 @@ class Tunable: # pylint: disable=too-many-instance-attributes self._current_value = coerced_value return self._current_value - def is_valid(self, value) -> bool: + def is_valid(self, value: TunableValue) -> bool: """ Check if the value can be assigned to the tunable. Parameters ---------- - value : Any + value : Union[int, float, str] Value to validate. Returns @@ -149,13 +177,36 @@ class Tunable: # pylint: disable=too-many-instance-attributes is_valid : bool True if the value is valid, False otherwise. """ - if self._type == "categorical": + if self.is_categorical and self._values: return value in self._values - elif self._type in {"int", "float"} and self._range: - return (self._range[0] <= value <= self._range[1]) or value == self._default + elif self.is_numerical and self._range: + assert isinstance(value, (int, float)) + return bool(self._range[0] <= value <= self._range[1]) or value == self._default else: raise ValueError(f"Invalid parameter type: {self._type}") + @property + def categorical_value(self) -> str: + """ + Get the current value of the tunable as a number. + """ + if self.is_categorical: + return str(self._current_value) + else: + raise ValueError("Cannot get categorical values for a numerical tunable.") + + @property + def numerical_value(self) -> Union[int, float]: + """ + Get the current value of the tunable as a number. + """ + if self._type == "int": + return int(self._current_value) + elif self._type == "float": + return float(self._current_value) + else: + raise ValueError("Cannot get numerical value for a categorical tunable.") + @property def name(self) -> str: """ @@ -176,7 +227,31 @@ class Tunable: # pylint: disable=too-many-instance-attributes return self._type @property - def range(self) -> Tuple[Any, Any]: + def is_categorical(self) -> bool: + """ + Check if the tunable is categorical. + + Returns + ------- + is_categorical : bool + True if the tunable is categorical, False otherwise. + """ + return self._type == "categorical" + + @property + def is_numerical(self) -> bool: + """ + Check if the tunable is an integer or float. + + Returns + ------- + is_int : bool + True if the tunable is an integer or float, False otherwise. + """ + return self._type in {"int", "float"} + + @property + def range(self) -> Union[Tuple[int, int], Tuple[float, float]]: """ Get the range of the tunable if it is numerical, None otherwise. @@ -186,6 +261,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes A 2-tuple of numbers that represents the range of the tunable. Numbers can be int or float, depending on the type of the tunable. """ + assert self.is_numerical + assert self._range is not None return self._range @property @@ -199,4 +276,6 @@ class Tunable: # pylint: disable=too-many-instance-attributes values : List[str] List of all possible values of a categorical tunable. """ + assert self.is_categorical + assert self._values is not None return self._values diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index 1e7cd6175d..1993f23a16 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -7,9 +7,9 @@ TunableGroups definition. """ import copy -from typing import Any, Dict, List, Tuple +from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple -from mlos_bench.tunables.tunable import Tunable +from mlos_bench.tunables.tunable import Tunable, TunableValue from mlos_bench.tunables.covariant_group import CovariantTunableGroup @@ -18,7 +18,7 @@ class TunableGroups: A collection of covariant groups of tunable parameters. """ - def __init__(self, config: dict = None): + def __init__(self, config: Optional[dict] = None): """ Create a new group of tunable parameters. @@ -27,12 +27,14 @@ class TunableGroups: config : dict Python dict of serialized representation of the covariant tunable groups. """ - self._index = {} # Index (Tunable id -> CovariantTunableGroup) - self._tunable_groups = {} - for (name, group_config) in (config or {}).items(): + if config is None: + config = {} + self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup) + self._tunable_groups: Dict[str, CovariantTunableGroup] = {} + for (name, group_config) in config.items(): self._add_group(CovariantTunableGroup(name, group_config)) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: """ Check if two TunableGroups are equal. @@ -46,9 +48,11 @@ class TunableGroups: is_equal : bool True if two TunableGroups are equal. """ - return self._tunable_groups == other._tunable_groups + if not isinstance(other, TunableGroups): + return False + return bool(self._tunable_groups == other._tunable_groups) - def copy(self): + def copy(self) -> "TunableGroups": """ Deep copy of the TunableGroups object. @@ -60,7 +64,7 @@ class TunableGroups: """ return copy.deepcopy(self) - def _add_group(self, group: CovariantTunableGroup): + def _add_group(self, group: CovariantTunableGroup) -> None: """ Add a CovariantTunableGroup to the current collection. @@ -71,7 +75,7 @@ class TunableGroups: self._tunable_groups[group.name] = group self._index.update(dict.fromkeys(group.get_names(), group)) - def update(self, tunables): + def update(self, tunables: "TunableGroups") -> "TunableGroups": """ Merge the two collections of covariant tunable groups. @@ -104,20 +108,20 @@ class TunableGroups: for (group_name, group) in self._tunable_groups.items() for tunable in group._tunables.values()) + " }" - def __getitem__(self, name: str): + def __getitem__(self, name: str) -> TunableValue: """ Get the current value of a single tunable parameter. """ return self._index[name][name] - def __setitem__(self, name: str, value): + def __setitem__(self, name: str, value: TunableValue) -> None: """ Update the current value of a single tunable parameter. """ # Use double index to make sure we set the is_updated flag of the group self._index[name][name] = value - def __iter__(self): + def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]: """ An iterator over all tunables in the group. @@ -147,7 +151,7 @@ class TunableGroups: group = self._index[name] return (group.get_tunable(name), group) - def get_names(self) -> List[str]: + def get_names(self) -> Iterable[str]: """ Get the names of all covariance groups in the collection. @@ -158,7 +162,7 @@ class TunableGroups: """ return self._tunable_groups.keys() - def subgroup(self, group_names: List[str]): + def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": """ Select the covariance groups from the current set and create a new TunableGroups object that consists of those covariance groups. @@ -179,8 +183,8 @@ class TunableGroups: tunables._add_group(self._tunable_groups[name]) return tunables - def get_param_values(self, group_names: List[str] = None, - into_params: Dict[str, Any] = None) -> Dict[str, Any]: + def get_param_values(self, group_names: Optional[Iterable[str]] = None, + into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]: """ Get the current values of the tunables that belong to the specified covariance groups. @@ -205,7 +209,7 @@ class TunableGroups: into_params.update(self._tunable_groups[name].get_values()) return into_params - def is_updated(self, group_names: List[str] = None) -> bool: + def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: """ Check if any of the given covariant tunable groups has been updated. @@ -222,7 +226,7 @@ class TunableGroups: return any(self._tunable_groups[name].is_updated() for name in (group_names or self.get_names())) - def reset(self, group_names: List[str] = None): + def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": """ Clear the update flag of given covariant groups. @@ -240,14 +244,14 @@ class TunableGroups: self._tunable_groups[name].reset() return self - def assign(self, param_values: Dict[str, Any]): + def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups": """ In-place update the values of the tunables from the dictionary of (key, value) pairs. Parameters ---------- - param_values : Dict[str, Any] + param_values : Mapping[str, TunableValue] Dictionary mapping Tunable parameter names to new values. Returns diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index d425899475..58705bf46e 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -11,12 +11,13 @@ Various helper functions for mlos_bench. import json import logging import importlib -from typing import Any, Tuple, Dict, Iterable + +from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Type, TypeVar, TYPE_CHECKING _LOG = logging.getLogger(__name__) -def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, Dict[str, Any]]: +def prepare_class_load(config: dict, global_config: Optional[dict] = None) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. @@ -35,7 +36,10 @@ def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, D class_name = config["class"] class_config = config.setdefault("config", {}) - for key in set(class_config).intersection(global_config or {}): + if global_config is None: + global_config = {} + + for key in set(class_config).intersection(global_config): class_config[key] = global_config[key] if _LOG.isEnabledFor(logging.DEBUG): @@ -45,7 +49,17 @@ def prepare_class_load(config: dict, global_config: dict = None) -> Tuple[str, D return (class_name, class_config) -def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs): +if TYPE_CHECKING: + from mlos_bench.environment.base_environment import Environment + from mlos_bench.service.base_service import Service + from mlos_bench.optimizer.base_optimizer import Optimizer + +# T is a generic with a constraint of the three base classes. +T = TypeVar('T', "Environment", "Service", "Optimizer") + + +# FIXME: Technically this should return a type "class_name" derived from "base_class". +def instantiate_from_config(base_class: Type[T], class_name: str, *args: Any, **kwargs: Any) -> T: """ Factory method for a new class instantiated from config. @@ -65,7 +79,7 @@ def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs): Returns ------- - inst : Any + inst : Union[Environment, Service, Optimizer] An instance of the `class_name` class. """ # We need to import mlos_bench to make the factory methods work. @@ -78,10 +92,12 @@ def instantiate_from_config(base_class: type, class_name: str, *args, **kwargs): _LOG.info("Instantiating: %s :: %s", class_name, impl) assert issubclass(impl, base_class) - return impl(*args, **kwargs) + ret: T = impl(*args, **kwargs) + assert isinstance(ret, base_class) + return ret -def check_required_params(config: Dict[str, Any], required_params: Iterable[str]): +def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: """ Check if all required parameters are present in the configuration. Raise ValueError if any of the parameters are missing. diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index f620c24626..85ead3969a 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -34,13 +34,16 @@ extra_requires = { # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(extra_requires.values()))) +extra_requires['full'] = list(set(chain(extra_requires.values()))) # type: ignore[assignment] # pylint: disable=duplicate-code setup( name='mlos-bench', version=_VERSION, packages=find_packages(), + package_data={ + 'mlos_bench': ['py.typed'], + }, install_requires=[ 'mlos-core==' + _VERSION, 'requests', diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 89032d7c64..b0b8e31b47 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -7,7 +7,7 @@ Basic initializer module for the mlos_core space adapters. """ from enum import Enum -from typing import Optional, TypeVar, Union +from typing import Optional, TypeVar import ConfigSpace diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index a7c60a9b11..a4e313f0a3 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -22,8 +22,6 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance- aimed at improving the sample-efficiency of the underlying optimizer. """ - # pylint: disable=consider-alternative-union-syntax,too-many-arguments - DEFAULT_NUM_LOW_DIMS = 16 """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection""" @@ -40,7 +38,7 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance- special_param_values: Optional[dict] = None, max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, use_approximate_reverse_mapping: bool = False, - ) -> None: + ) -> None: # pylint: disable=too-many-arguments """Create a space adapter that employs LlamaTune's techniques. Parameters diff --git a/setup.cfg b/setup.cfg index c2ef3d3d5e..39d22126a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,8 +54,12 @@ disallow_untyped_defs = True disallow_incomplete_defs = True strict = True allow_any_generics = True +hide_error_codes = False # regex of files to skip type checking -exclude = /_pytest/|/build/|doc/|_version.py|setup.py +# _version.py and setup.py look like duplicates when run from the root of the repo even though they're part of different packages. +# There's not much in them so we just skip them. +# We also skip several vendor files that currently throw errors. +exclude = mlos_(core|bench)/(_version|setup).py|doc/|/build/|-packages/_pytest/ # https://github.com/automl/ConfigSpace/issues/293 [mypy-ConfigSpace.*] @@ -65,6 +69,10 @@ ignore_missing_imports = True [mypy-emukit.*] ignore_missing_imports = True +# https://github.com/dpranke/pyjson5/issues/65 +[mypy-json5] +ignore_missing_imports = True + # https://github.com/matplotlib/matplotlib/issues/25634 [mypy-matplotlib.*] ignore_missing_imports = True