diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 76a699a4d3..1f56be9d1a 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -45,7 +45,6 @@
// Adjust the python interpreter path to point to the conda environment
"python.defaultInterpreterPath": "/opt/conda/envs/mlos/bin/python",
"python.testing.pytestPath": "/opt/conda/envs/mlos/bin/pytest",
- "python.formatting.autopep8Path": "/opt/conda/envs/mlos/bin/autopep8",
"python.linting.pylintPath": "/opt/conda/envs/mlos/bin/pylint",
"pylint.path": [
"/opt/conda/envs/mlos/bin/pylint"
@@ -71,9 +70,8 @@
"lextudio.restructuredtext",
"matangover.mypy",
"ms-azuretools.vscode-docker",
- // TODO: Enable additional formatter extensions:
- //"ms-python.black-formatter",
- //"ms-python.isort",
+ "ms-python.black-formatter",
+ "ms-python.isort",
"ms-python.pylint",
"ms-python.python",
"ms-python.vscode-pylance",
diff --git a/.editorconfig b/.editorconfig
index e984d47595..7e753174de 100644
--- a/.editorconfig
+++ b/.editorconfig
@@ -12,6 +12,9 @@ charset = utf-8
# Note: this is not currently supported by all editors or their editorconfig plugins.
max_line_length = 132
+[*.py]
+max_line_length = 99
+
# Makefiles need tab indentation
[{Makefile,*.mk}]
indent_style = tab
diff --git a/.gitignore b/.gitignore
index 157dba7a4d..471d653344 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,6 @@
+# Ignore git directory (ripgrep)
+.git/
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/.pylintrc b/.pylintrc
index 6b308d1966..c6c512ecb7 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -35,7 +35,7 @@ load-plugins=
[FORMAT]
# Maximum number of characters on a single line.
-max-line-length=132
+max-line-length=99
[MESSAGE CONTROL]
disable=
@@ -48,5 +48,5 @@ disable=
missing-raises-doc
[STRING]
-#check-quote-consistency=yes
+check-quote-consistency=yes
check-str-concat-over-line-jumps=yes
diff --git a/.vscode/extensions.json b/.vscode/extensions.json
index 327bf5c51c..76dce33d5a 100644
--- a/.vscode/extensions.json
+++ b/.vscode/extensions.json
@@ -14,9 +14,8 @@
"lextudio.restructuredtext",
"matangover.mypy",
"ms-azuretools.vscode-docker",
- // TODO: Enable additional formatter extensions:
- //"ms-python.black-formatter",
- //"ms-python.isort",
+ "ms-python.black-formatter",
+ "ms-python.isort",
"ms-python.pylint",
"ms-python.python",
"ms-python.vscode-pylance",
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 1e8eb58adb..6b9729290f 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -125,14 +125,10 @@
],
"esbonio.sphinx.confDir": "${workspaceFolder}/doc/source",
"esbonio.sphinx.buildDir": "${workspaceFolder}/doc/build/",
- "autopep8.args": [
- "--experimental"
- ],
"[python]": {
- // TODO: Enable black formatter
- //"editor.defaultFormatter": "ms-python.black-formatter",
- //"editor.formatOnSave": true,
- //"editor.formatOnSaveMode": "modifications"
+ "editor.defaultFormatter": "ms-python.black-formatter",
+ "editor.formatOnSave": true,
+ "editor.formatOnSaveMode": "modifications"
},
// See Also .vscode/launch.json for environment variable args to pytest during debug sessions.
// For the rest, see setup.cfg
diff --git a/Makefile b/Makefile
index 128b3dc849..62ea9a3359 100644
--- a/Makefile
+++ b/Makefile
@@ -34,7 +34,7 @@ conda-env: build/conda-env.${CONDA_ENV_NAME}.build-stamp
MLOS_CORE_CONF_FILES := mlos_core/pyproject.toml mlos_core/setup.py mlos_core/MANIFEST.in
MLOS_BENCH_CONF_FILES := mlos_bench/pyproject.toml mlos_bench/setup.py mlos_bench/MANIFEST.in
MLOS_VIZ_CONF_FILES := mlos_viz/pyproject.toml mlos_viz/setup.py mlos_viz/MANIFEST.in
-MLOS_GLOBAL_CONF_FILES := setup.cfg # pyproject.toml
+MLOS_GLOBAL_CONF_FILES := setup.cfg pyproject.toml
MLOS_PKGS := mlos_core mlos_bench mlos_viz
MLOS_PKG_CONF_FILES := $(MLOS_CORE_CONF_FILES) $(MLOS_BENCH_CONF_FILES) $(MLOS_VIZ_CONF_FILES) $(MLOS_GLOBAL_CONF_FILES)
@@ -69,9 +69,9 @@ ifneq (,$(filter format,$(MAKECMDGOALS)))
endif
build/format.${CONDA_ENV_NAME}.build-stamp: build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
-# TODO: enable isort and black formatters
-#build/format.${CONDA_ENV_NAME}.build-stamp: build/isort.${CONDA_ENV_NAME}.build-stamp
-#build/format.${CONDA_ENV_NAME}.build-stamp: build/black.${CONDA_ENV_NAME}.build-stamp
+build/format.${CONDA_ENV_NAME}.build-stamp: build/isort.${CONDA_ENV_NAME}.build-stamp
+build/format.${CONDA_ENV_NAME}.build-stamp: build/black.${CONDA_ENV_NAME}.build-stamp
+build/format.${CONDA_ENV_NAME}.build-stamp: build/docformatter.${CONDA_ENV_NAME}.build-stamp
build/format.${CONDA_ENV_NAME}.build-stamp:
touch $@
@@ -111,8 +111,8 @@ build/isort.${CONDA_ENV_NAME}.build-stamp:
# NOTE: when using pattern rules (involving %) we can only add one line of
# prerequisities, so we use this pattern to compose the list as variables.
-# Both isort and licenseheaders alter files, so only run one at a time, by
-# making licenseheaders an order-only prerequisite.
+# black, licenseheaders, isort, and docformatter all alter files, so only run
+# one at a time, by adding prerequisites, but only as necessary.
ISORT_COMMON_PREREQS :=
ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
ISORT_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
@@ -126,7 +126,7 @@ build/isort.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
build/isort.%.${CONDA_ENV_NAME}.build-stamp: $(ISORT_COMMON_PREREQS)
# Reformat python file imports with isort.
- conda run -n ${CONDA_ENV_NAME} isort --verbose --only-modified --atomic -j0 $(filter %.py,$?)
+ conda run -n ${CONDA_ENV_NAME} isort --verbose --only-modified --atomic -j0 $(filter %.py,$+)
touch $@
.PHONY: black
@@ -142,8 +142,8 @@ build/black.${CONDA_ENV_NAME}.build-stamp: build/black.mlos_viz.${CONDA_ENV_NAME
build/black.${CONDA_ENV_NAME}.build-stamp:
touch $@
-# Both black, licenseheaders, and isort all alter files, so only run one at a time, by
-# making licenseheaders and isort an order-only prerequisite.
+# black, licenseheaders, isort, and docformatter all alter files, so only run
+# one at a time, by adding prerequisites, but only as necessary.
BLACK_COMMON_PREREQS :=
ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
BLACK_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
@@ -160,13 +160,52 @@ build/black.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
build/black.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_COMMON_PREREQS)
# Reformat python files with black.
- conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$?)
+ conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$+)
touch $@
+.PHONY: docformatter
+docformatter: build/docformatter.${CONDA_ENV_NAME}.build-stamp
+
+ifneq (,$(filter docformatter,$(MAKECMDGOALS)))
+ FORMAT_PREREQS += build/docformatter.${CONDA_ENV_NAME}.build-stamp
+endif
+
+build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp
+build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp
+build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp
+build/docformatter.${CONDA_ENV_NAME}.build-stamp:
+ touch $@
+
+# black, licenseheaders, isort, and docformatter all alter files, so only run
+# one at a time, by adding prerequisites, but only as necessary.
+DOCFORMATTER_COMMON_PREREQS :=
+ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
+DOCFORMATTER_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
+endif
+ifneq (,$(filter format isort,$(MAKECMDGOALS)))
+DOCFORMATTER_COMMON_PREREQS += build/isort.${CONDA_ENV_NAME}.build-stamp
+endif
+ifneq (,$(filter format black,$(MAKECMDGOALS)))
+DOCFORMATTER_COMMON_PREREQS += build/black.${CONDA_ENV_NAME}.build-stamp
+endif
+DOCFORMATTER_COMMON_PREREQS += build/conda-env.${CONDA_ENV_NAME}.build-stamp
+DOCFORMATTER_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
+
+build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES)
+build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES)
+build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
+
+# docformatter returns non-zero when it changes anything so instead we ignore that
+# return code and just have it recheck itself immediately
+build/docformatter.%.${CONDA_ENV_NAME}.build-stamp: $(DOCFORMATTER_COMMON_PREREQS)
+ # Reformat python file docstrings with docformatter.
+ conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$+) || true
+ conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$+)
+ touch $@
+
+
.PHONY: check
-check: pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
-# TODO: Enable isort and black checks
-#check: isort-check black-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
+check: isort-check black-check docformatter-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
.PHONY: black-check
black-check: build/black-check.mlos_core.${CONDA_ENV_NAME}.build-stamp
@@ -185,7 +224,27 @@ BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
build/black-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS)
# Check for import sort order.
# Note: if this fails use "make format" or "make black" to fix it.
- conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$?)
+ conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$+)
+ touch $@
+
+.PHONY: docformatter-check
+docformatter-check: build/docformatter-check.mlos_core.${CONDA_ENV_NAME}.build-stamp
+docformatter-check: build/docformatter-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp
+docformatter-check: build/docformatter-check.mlos_viz.${CONDA_ENV_NAME}.build-stamp
+
+# Make sure docformatter format rules run before docformatter-check rules.
+build/docformatter-check.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES)
+build/docformatter-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES)
+build/docformatter-check.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
+
+BLACK_CHECK_COMMON_PREREQS := build/conda-env.${CONDA_ENV_NAME}.build-stamp
+BLACK_CHECK_COMMON_PREREQS += $(FORMAT_PREREQS)
+BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
+
+build/docformatter-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS)
+ # Check for import sort order.
+ # Note: if this fails use "make format" or "make docformatter" to fix it.
+ conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$+)
touch $@
.PHONY: isort-check
@@ -204,7 +263,7 @@ ISORT_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
build/isort-check.%.${CONDA_ENV_NAME}.build-stamp: $(ISORT_CHECK_COMMON_PREREQS)
# Note: if this fails use "make format" or "make isort" to fix it.
- conda run -n ${CONDA_ENV_NAME} isort --only-modified --check --diff -j0 $(filter %.py,$?)
+ conda run -n ${CONDA_ENV_NAME} isort --only-modified --check --diff -j0 $(filter %.py,$+)
touch $@
.PHONY: pycodestyle
@@ -723,7 +782,12 @@ clean-doc:
.PHONY: clean-format
clean-format:
- # TODO: add black and isort rules
+ rm -f build/black.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/black.mlos_*.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/docformatter.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/docformatter.mlos_*.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/isort.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/isort.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
rm -f build/licenseheaders-prereqs.${CONDA_ENV_NAME}.build-stamp
@@ -733,6 +797,13 @@ clean-check:
rm -f build/pylint.${CONDA_ENV_NAME}.build-stamp
rm -f build/pylint.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/mypy.mlos_*.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/black-check.build-stamp
+ rm -f build/black-check.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/black-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/docformatter-check.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/docformatter-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/isort-check.${CONDA_ENV_NAME}.build-stamp
+ rm -f build/isort-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/pycodestyle.build-stamp
rm -f build/pycodestyle.${CONDA_ENV_NAME}.build-stamp
rm -f build/pycodestyle.mlos_*.${CONDA_ENV_NAME}.build-stamp
diff --git a/conda-envs/mlos-3.10.yml b/conda-envs/mlos-3.10.yml
index 64dca68dda..b76d48e5b5 100644
--- a/conda-envs/mlos-3.10.yml
+++ b/conda-envs/mlos-3.10.yml
@@ -25,10 +25,10 @@ dependencies:
# See comments in mlos.yml.
#- gcc_linux-64
- pip:
- - autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
+ - docformatter
- licenseheaders
- mypy
- pandas-stubs
diff --git a/conda-envs/mlos-3.11.yml b/conda-envs/mlos-3.11.yml
index f6add2a586..64ee6fd58a 100644
--- a/conda-envs/mlos-3.11.yml
+++ b/conda-envs/mlos-3.11.yml
@@ -25,10 +25,10 @@ dependencies:
# See comments in mlos.yml.
#- gcc_linux-64
- pip:
- - autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
+ - docformatter
- licenseheaders
- mypy
- pandas-stubs
diff --git a/conda-envs/mlos-3.8.yml b/conda-envs/mlos-3.8.yml
index 2eb0b25cef..b1e14c7402 100644
--- a/conda-envs/mlos-3.8.yml
+++ b/conda-envs/mlos-3.8.yml
@@ -25,10 +25,10 @@ dependencies:
# See comments in mlos.yml.
#- gcc_linux-64
- pip:
- - autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
+ - docformatter
- licenseheaders
- mypy
- pandas-stubs
diff --git a/conda-envs/mlos-3.9.yml b/conda-envs/mlos-3.9.yml
index f35aadb5e3..edccdab405 100644
--- a/conda-envs/mlos-3.9.yml
+++ b/conda-envs/mlos-3.9.yml
@@ -25,10 +25,10 @@ dependencies:
# See comments in mlos.yml.
#- gcc_linux-64
- pip:
- - autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
+ - docformatter
- licenseheaders
- mypy
- pandas-stubs
diff --git a/conda-envs/mlos-windows.yml b/conda-envs/mlos-windows.yml
index b2c5467458..107b6fb2cf 100644
--- a/conda-envs/mlos-windows.yml
+++ b/conda-envs/mlos-windows.yml
@@ -28,10 +28,10 @@ dependencies:
# This also requires a more recent vs2015_runtime from conda-forge.
- pyrfr>=0.9.0
- pip:
- - autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
+ - docformatter
- licenseheaders
- mypy
- pandas-stubs
diff --git a/conda-envs/mlos.yml b/conda-envs/mlos.yml
index a257197761..5cd35fdbba 100644
--- a/conda-envs/mlos.yml
+++ b/conda-envs/mlos.yml
@@ -24,10 +24,10 @@ dependencies:
# FIXME: https://github.com/microsoft/MLOS/issues/727
- python<3.12
- pip:
- - autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
+ - docformatter
- licenseheaders
- mypy
- pandas-stubs
diff --git a/conftest.py b/conftest.py
index e22395f82f..7985ef8239 100644
--- a/conftest.py
+++ b/conftest.py
@@ -32,14 +32,23 @@ def pytest_configure(config: pytest.Config) -> None:
Add some additional (global) configuration steps for pytest.
"""
# Workaround some issues loading emukit in certain environments.
- if os.environ.get('DISPLAY', None):
+ if os.environ.get("DISPLAY", None):
try:
- import matplotlib # pylint: disable=import-outside-toplevel
- matplotlib.rcParams['backend'] = 'agg'
- if is_master(config) or dict(getattr(config, 'workerinput', {}))['workerid'] == 'gw0':
+ import matplotlib # pylint: disable=import-outside-toplevel
+
+ matplotlib.rcParams["backend"] = "agg"
+ if is_master(config) or dict(getattr(config, "workerinput", {}))["workerid"] == "gw0":
# Only warn once.
- warn(UserWarning('DISPLAY environment variable is set, which can cause problems in some setups (e.g. WSL). '
- + f'Adjusting matplotlib backend to "{matplotlib.rcParams["backend"]}" to compensate.'))
+ warn(
+ UserWarning(
+ (
+ "DISPLAY environment variable is set, "
+ "which can cause problems in some setups (e.g. WSL). "
+ f"Adjusting matplotlib backend to '{matplotlib.rcParams['backend']}' "
+ "to compensate."
+ )
+ )
+ )
except ImportError:
pass
diff --git a/doc/source/conf.py b/doc/source/conf.py
index 3e25d9b082..a06436ba1f 100644
--- a/doc/source/conf.py
+++ b/doc/source/conf.py
@@ -87,10 +87,11 @@ autosummary_generate = True
numpydoc_class_members_toctree = False
autodoc_default_options = {
- 'members': True,
- 'undoc-members': True,
- # Don't generate documentation for some (non-private) functions that are more for internal implementation use.
- 'exclude-members': 'mlos_bench.util.check_required_params'
+ "members": True,
+ "undoc-members": True,
+ # Don't generate documentation for some (non-private) functions that are more
+ # for internal implementation use.
+ "exclude-members": "mlos_bench.util.check_required_params",
}
# Generate the plots for the gallery
diff --git a/mlos_bench/mlos_bench/__init__.py b/mlos_bench/mlos_bench/__init__.py
index 1fed310b78..db8c235041 100644
--- a/mlos_bench/mlos_bench/__init__.py
+++ b/mlos_bench/mlos_bench/__init__.py
@@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-mlos_bench is a framework to help automate benchmarking and and
-OS/application parameter autotuning.
+"""mlos_bench is a framework to help automate benchmarking and and OS/application
+parameter autotuning.
"""
diff --git a/mlos_bench/mlos_bench/config/__init__.py b/mlos_bench/mlos_bench/config/__init__.py
index 590e3d50d0..b78386118c 100644
--- a/mlos_bench/mlos_bench/config/__init__.py
+++ b/mlos_bench/mlos_bench/config/__init__.py
@@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-mlos_bench.config
-"""
+"""mlos_bench.config."""
diff --git a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py
index 679d0d4ceb..43baeb1cf8 100644
--- a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py
+++ b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py
@@ -3,41 +3,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Script for post-processing FIO results for mlos_bench.
-"""
+"""Script for post-processing FIO results for mlos_bench."""
import argparse
import itertools
import json
-
from typing import Any, Iterator, Tuple
import pandas
def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]:
- """
- Flatten every dict in the hierarchy and rename the keys with the dict path.
- """
+ """Flatten every dict in the hierarchy and rename the keys with the dict path."""
if isinstance(data, dict):
- for (key, val) in data.items():
+ for key, val in data.items():
yield from _flat_dict(val, f"{path}.{key}")
else:
yield (path, data)
def _main(input_file: str, output_file: str, prefix: str) -> None:
- """
- Convert FIO read data from JSON to tall CSV.
- """
- with open(input_file, mode='r', encoding='utf-8') as fh_input:
+ """Convert FIO read data from JSON to tall CSV."""
+ with open(input_file, mode="r", encoding="utf-8") as fh_input:
json_data = json.load(fh_input)
- data = list(itertools.chain(
- _flat_dict(json_data["jobs"][0], prefix),
- _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util")
- ))
+ data = list(
+ itertools.chain(
+ _flat_dict(json_data["jobs"][0], prefix),
+ _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util"),
+ )
+ )
tall_df = pandas.DataFrame(data, columns=["metric", "value"])
tall_df.to_csv(output_file, index=False)
@@ -50,12 +45,18 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Post-process FIO benchmark results.")
parser.add_argument(
- "input", help="FIO benchmark results in JSON format (downloaded from a remote VM).")
+ "input",
+ help="FIO benchmark results in JSON format (downloaded from a remote VM).",
+ )
parser.add_argument(
- "output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).")
+ "output",
+ help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).",
+ )
parser.add_argument(
- "--prefix", default="fio",
- help="Prefix of the metric IDs (default 'fio')")
+ "--prefix",
+ default="fio",
+ help="Prefix of the metric IDs (default 'fio')",
+ )
args = parser.parse_args()
_main(args.input, args.output, args.prefix)
diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py
index c1850f5e03..d41f20d2a9 100644
--- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py
+++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py
@@ -9,22 +9,24 @@ Helper script to generate Redis config from tunable parameters JSON.
Run: `./generate_redis_config.py ./input-params.json ./output-redis.cfg`
"""
-import json
import argparse
+import json
def _main(fname_input: str, fname_output: str) -> None:
- with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
- open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
- for (key, val) in json.load(fh_tunables).items():
- line = f'{key} {val}'
+ with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open(
+ fname_output, "wt", encoding="utf-8", newline=""
+ ) as fh_config:
+ for key, val in json.load(fh_tunables).items():
+ line = f"{key} {val}"
fh_config.write(line + "\n")
print(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
- description="generate Redis config from tunable parameters JSON.")
+ description="generate Redis config from tunable parameters JSON."
+ )
parser.add_argument("input", help="JSON file with tunable parameters.")
parser.add_argument("output", help="Output Redis config file.")
args = parser.parse_args()
diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py
index e33c717953..8b979e5014 100644
--- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py
+++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py
@@ -3,9 +3,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Script for post-processing redis-benchmark results.
-"""
+"""Script for post-processing redis-benchmark results."""
import argparse
@@ -13,26 +11,25 @@ import pandas as pd
def _main(input_file: str, output_file: str) -> None:
- """
- Re-shape Redis benchmark CSV results from wide to long.
- """
+ """Re-shape Redis benchmark CSV results from wide to long."""
df_wide = pd.read_csv(input_file)
# Format the results from wide to long
# The target is columns of metric and value to act as key-value pairs.
df_long = (
- df_wide
- .melt(id_vars=["test"])
+ df_wide.melt(id_vars=["test"])
.assign(metric=lambda df: df["test"] + "_" + df["variable"])
.drop(columns=["test", "variable"])
.loc[:, ["metric", "value"]]
)
# Add a default `score` metric to the end of the dataframe.
- df_long = pd.concat([
- df_long,
- pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]})
- ])
+ df_long = pd.concat(
+ [
+ df_long,
+ pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}),
+ ]
+ )
df_long.to_csv(output_file, index=False)
print(f"Converted: {input_file} -> {output_file}")
@@ -41,8 +38,13 @@ def _main(input_file: str, output_file: str) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.")
- parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).")
- parser.add_argument("output", help="Converted Redis benchmark data" +
- " (to be consumed by OS Autotune framework).")
+ parser.add_argument(
+ "input",
+ help="Redis benchmark results (downloaded from a remote VM).",
+ )
+ parser.add_argument(
+ "output",
+ help="Converted Redis benchmark data (to be consumed by OS Autotune framework).",
+ )
args = parser.parse_args()
_main(args.input, args.output)
diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py
index 41bd162459..9b75f04008 100755
--- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py
+++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py
@@ -6,16 +6,18 @@
"""
Python script to parse through JSON and create new config file.
-This script will be run in the SCHEDULER.
-NEW_CFG will need to be copied over to the VM (/etc/default/grub.d).
+This script will be run in the SCHEDULER. NEW_CFG will need to be copied over to the VM
+(/etc/default/grub.d).
"""
import json
JSON_CONFIG_FILE = "config-boot-time.json"
NEW_CFG = "zz-mlos-boot-params.cfg"
-with open(JSON_CONFIG_FILE, 'r', encoding='UTF-8') as fh_json, \
- open(NEW_CFG, 'w', encoding='UTF-8') as fh_config:
+with open(JSON_CONFIG_FILE, "r", encoding="UTF-8") as fh_json, open(
+ NEW_CFG, "w", encoding="UTF-8"
+) as fh_config:
for key, val in json.load(fh_json).items():
- fh_config.write('GRUB_CMDLINE_LINUX_DEFAULT="$'
- f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n')
+ fh_config.write(
+ 'GRUB_CMDLINE_LINUX_DEFAULT="$' f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n'
+ )
diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py
index d03e4f5771..9f130e5c0e 100755
--- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py
+++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py
@@ -9,14 +9,15 @@ Helper script to generate GRUB config from tunable parameters JSON.
Run: `./generate_grub_config.py ./input-boot-params.json ./output-grub.cfg`
"""
-import json
import argparse
+import json
def _main(fname_input: str, fname_output: str) -> None:
- with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
- open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
- for (key, val) in json.load(fh_tunables).items():
+ with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open(
+ fname_output, "wt", encoding="utf-8", newline=""
+ ) as fh_config:
+ for key, val in json.load(fh_tunables).items():
line = f'GRUB_CMDLINE_LINUX_DEFAULT="${{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"'
fh_config.write(line + "\n")
print(line)
@@ -24,7 +25,8 @@ def _main(fname_input: str, fname_output: str) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
- description="Generate GRUB config from tunable parameters JSON.")
+ description="Generate GRUB config from tunable parameters JSON."
+ )
parser.add_argument("input", help="JSON file with tunable parameters.")
parser.add_argument("output", help="Output shell script to configure GRUB.")
args = parser.parse_args()
diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py
index e6d8039729..a4e5e5ccb6 100755
--- a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py
+++ b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py
@@ -6,11 +6,14 @@
"""
Helper script to generate a script to update kernel parameters from tunables JSON.
-Run: `./generate_kernel_config_script.py ./kernel-params.json ./kernel-params-meta.json ./config-kernel.sh`
+Run:
+ ./generate_kernel_config_script.py \
+ ./kernel-params.json ./kernel-params-meta.json \
+ ./config-kernel.sh
"""
-import json
import argparse
+import json
def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
@@ -22,7 +25,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
tunables_meta = json.load(fh_meta)
with open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
- for (key, val) in tunables_data.items():
+ for key, val in tunables_data.items():
meta = tunables_meta.get(key, {})
name_prefix = meta.get("name_prefix", "")
line = f'echo "{val}" > {name_prefix}{key}'
@@ -33,7 +36,8 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
- description="generate a script to update kernel parameters from tunables JSON.")
+ description="generate a script to update kernel parameters from tunables JSON."
+ )
parser.add_argument("input", help="JSON file with tunable parameters.")
parser.add_argument("meta", help="JSON file with tunable parameters metadata.")
diff --git a/mlos_bench/mlos_bench/config/schemas/__init__.py b/mlos_bench/mlos_bench/config/schemas/__init__.py
index 73daf81c3b..d4987add63 100644
--- a/mlos_bench/mlos_bench/config/schemas/__init__.py
+++ b/mlos_bench/mlos_bench/config/schemas/__init__.py
@@ -2,14 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A module for managing config schemas and their validation.
-"""
-
-from mlos_bench.config.schemas.config_schemas import ConfigSchema, CONFIG_SCHEMA_DIR
+"""A module for managing config schemas and their validation."""
+from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema
__all__ = [
- 'ConfigSchema',
- 'CONFIG_SCHEMA_DIR',
+ "ConfigSchema",
+ "CONFIG_SCHEMA_DIR",
]
diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py
index 9c4a066be5..b7ce402b5d 100644
--- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py
+++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py
@@ -2,18 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A simple class for describing where to find different config schemas and validating configs against them.
+"""A simple class for describing where to find different config schemas and validating
+configs against them.
"""
+import json # schema files are pure json - no comments
import logging
from enum import Enum
-from os import path, walk, environ
+from os import environ, path, walk
from typing import Dict, Iterator, Mapping
-import json # schema files are pure json - no comments
import jsonschema
-
from referencing import Registry, Resource
from referencing.jsonschema import DRAFT202012
@@ -28,16 +27,21 @@ CONFIG_SCHEMA_DIR = path_join(path.dirname(__file__), abs_path=True)
# It is used in `ConfigSchema.validate()` method below.
# NOTE: this may cause pytest to fail if it's expecting exceptions
# to be raised for invalid configs.
-_VALIDATION_ENV_FLAG = 'MLOS_BENCH_SKIP_SCHEMA_VALIDATION'
-_SKIP_VALIDATION = (environ.get(_VALIDATION_ENV_FLAG, 'false').lower()
- in {'true', 'y', 'yes', 'on', '1'})
+_VALIDATION_ENV_FLAG = "MLOS_BENCH_SKIP_SCHEMA_VALIDATION"
+_SKIP_VALIDATION = environ.get(_VALIDATION_ENV_FLAG, "false").lower() in {
+ "true",
+ "y",
+ "yes",
+ "on",
+ "1",
+}
# Note: we separate out the SchemaStore from a class method on ConfigSchema
# because of issues with mypy/pylint and non-Enum-member class members.
class SchemaStore(Mapping):
- """
- A simple class for storing schemas and subschemas for the validator to reference.
+ """A simple class for storing schemas and subschemas for the validator to
+ reference.
"""
# A class member mapping of schema id to schema object.
@@ -58,7 +62,9 @@ class SchemaStore(Mapping):
@classmethod
def _load_schemas(cls) -> None:
- """Loads all schemas and subschemas into the schema store for the validator to reference."""
+ """Loads all schemas and subschemas into the schema store for the validator to
+ reference.
+ """
if cls._SCHEMA_STORE:
return
for root, _, files in walk(CONFIG_SCHEMA_DIR):
@@ -78,13 +84,17 @@ class SchemaStore(Mapping):
@classmethod
def _load_registry(cls) -> None:
- """Also store them in a Registry object for referencing by recent versions of jsonschema."""
+ """Also store them in a Registry object for referencing by recent versions of
+ jsonschema.
+ """
if not cls._SCHEMA_STORE:
cls._load_schemas()
- cls._REGISTRY = Registry().with_resources([
- (url, Resource.from_contents(schema, default_specification=DRAFT202012))
- for url, schema in cls._SCHEMA_STORE.items()
- ])
+ cls._REGISTRY = Registry().with_resources(
+ [
+ (url, Resource.from_contents(schema, default_specification=DRAFT202012))
+ for url, schema in cls._SCHEMA_STORE.items()
+ ]
+ )
@property
def registry(self) -> Registry:
@@ -98,9 +108,7 @@ SCHEMA_STORE = SchemaStore()
class ConfigSchema(Enum):
- """
- An enum to help describe schema types and help validate configs against them.
- """
+ """An enum to help describe schema types and help validate configs against them."""
CLI = path_join(CONFIG_SCHEMA_DIR, "cli/cli-schema.json")
GLOBALS = path_join(CONFIG_SCHEMA_DIR, "cli/globals-schema.json")
diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py
index 4ccef7817b..e209f12bed 100644
--- a/mlos_bench/mlos_bench/dict_templater.py
+++ b/mlos_bench/mlos_bench/dict_templater.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Simple class to help with nested dictionary $var templating.
-"""
+"""Simple class to help with nested dictionary $var templating."""
from copy import deepcopy
from string import Template
@@ -13,10 +11,8 @@ from typing import Any, Dict, Optional
from mlos_bench.os_environ import environ
-class DictTemplater: # pylint: disable=too-few-public-methods
- """
- Simple class to help with nested dictionary $var templating.
- """
+class DictTemplater: # pylint: disable=too-few-public-methods
+ """Simple class to help with nested dictionary $var templating."""
def __init__(self, source_dict: Dict[str, Any]):
"""
@@ -32,9 +28,12 @@ class DictTemplater: # pylint: disable=too-few-public-methods
# The source/target dictionary to expand.
self._dict: Dict[str, Any] = {}
- def expand_vars(self, *,
- extra_source_dict: Optional[Dict[str, Any]] = None,
- use_os_env: bool = False) -> Dict[str, Any]:
+ def expand_vars(
+ self,
+ *,
+ extra_source_dict: Optional[Dict[str, Any]] = None,
+ use_os_env: bool = False,
+ ) -> Dict[str, Any]:
"""
Expand the template variables in the destination dictionary.
@@ -55,10 +54,13 @@ class DictTemplater: # pylint: disable=too-few-public-methods
assert isinstance(self._dict, dict)
return self._dict
- def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any:
- """
- Recursively expand $var strings in the currently operating dictionary.
- """
+ def _expand_vars(
+ self,
+ value: Any,
+ extra_source_dict: Optional[Dict[str, Any]],
+ use_os_env: bool,
+ ) -> Any:
+ """Recursively expand $var strings in the currently operating dictionary."""
if isinstance(value, str):
# First try to expand all $vars internally.
value = Template(value).safe_substitute(self._dict)
@@ -71,7 +73,7 @@ class DictTemplater: # pylint: disable=too-few-public-methods
elif isinstance(value, dict):
# Note: we use a loop instead of dict comprehension in order to
# allow secondary expansion of subsequent values immediately.
- for (key, val) in value.items():
+ for key, val in value.items():
value[key] = self._expand_vars(val, extra_source_dict, use_os_env)
elif isinstance(value, list):
value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value]
diff --git a/mlos_bench/mlos_bench/environments/__init__.py b/mlos_bench/mlos_bench/environments/__init__.py
index 9ed5480908..ff649af50e 100644
--- a/mlos_bench/mlos_bench/environments/__init__.py
+++ b/mlos_bench/mlos_bench/environments/__init__.py
@@ -2,26 +2,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tunable Environments for mlos_bench.
-"""
+"""Tunable Environments for mlos_bench."""
-from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
-
-from mlos_bench.environments.mock_env import MockEnv
-from mlos_bench.environments.remote.remote_env import RemoteEnv
+from mlos_bench.environments.composite_env import CompositeEnv
from mlos_bench.environments.local.local_env import LocalEnv
from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv
-from mlos_bench.environments.composite_env import CompositeEnv
+from mlos_bench.environments.mock_env import MockEnv
+from mlos_bench.environments.remote.remote_env import RemoteEnv
+from mlos_bench.environments.status import Status
__all__ = [
- 'Status',
-
- 'Environment',
- 'MockEnv',
- 'RemoteEnv',
- 'LocalEnv',
- 'LocalFileShareEnv',
- 'CompositeEnv',
+ "Status",
+ "Environment",
+ "MockEnv",
+ "RemoteEnv",
+ "LocalEnv",
+ "LocalFileShareEnv",
+ "CompositeEnv",
]
diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py
index 508d78589b..2c51a5ee8d 100644
--- a/mlos_bench/mlos_bench/environments/base_environment.py
+++ b/mlos_bench/mlos_bench/environments/base_environment.py
@@ -2,19 +2,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A hierarchy of benchmark environments.
-"""
+"""A hierarchy of benchmark environments."""
import abc
import json
import logging
from datetime import datetime
from types import TracebackType
-from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union
-from typing_extensions import Literal
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
from pytz import UTC
+from typing_extensions import Literal
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.dict_templater import DictTemplater
@@ -32,20 +41,19 @@ _LOG = logging.getLogger(__name__)
class Environment(metaclass=abc.ABCMeta):
# pylint: disable=too-many-instance-attributes
- """
- An abstract base of all benchmark environments.
- """
+ """An abstract base of all benchmark environments."""
@classmethod
- def new(cls,
- *,
- env_name: str,
- class_name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None,
- ) -> "Environment":
+ def new(
+ cls,
+ *,
+ env_name: str,
+ class_name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ) -> "Environment":
"""
Factory method for a new environment with a given config.
@@ -83,16 +91,18 @@ class Environment(metaclass=abc.ABCMeta):
config=config,
global_config=global_config,
tunables=tunables,
- service=service
+ service=service,
)
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment with a given config.
@@ -123,24 +133,33 @@ class Environment(metaclass=abc.ABCMeta):
self._const_args: Dict[str, TunableValue] = config.get("const_args", {})
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Environment: '%s' Service: %s", name,
- self._service.pprint() if self._service else None)
+ _LOG.debug(
+ "Environment: '%s' Service: %s",
+ name,
+ self._service.pprint() if self._service else None,
+ )
if tunables is None:
- _LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name)
+ _LOG.warning(
+ (
+ "No tunables provided for %s. "
+ "Tunable inheritance across composite environments may be broken."
+ ),
+ name,
+ )
tunables = TunableGroups()
groups = self._expand_groups(
config.get("tunable_params", []),
- (global_config or {}).get("tunable_params_map", {}))
+ (global_config or {}).get("tunable_params_map", {}),
+ )
_LOG.debug("Tunable groups for: '%s' :: %s", name, groups)
self._tunable_params = tunables.subgroup(groups)
# If a parameter comes from the tunables, do not require it in the const_args or globals
- req_args = (
- set(config.get("required_args", [])) -
- set(self._tunable_params.get_param_values().keys())
+ req_args = set(config.get("required_args", [])) - set(
+ self._tunable_params.get_param_values().keys()
)
merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args)
self._const_args = self._expand_vars(self._const_args, global_config or {})
@@ -149,14 +168,12 @@ class Environment(metaclass=abc.ABCMeta):
_LOG.debug("Parameters for '%s' :: %s", name, self._params)
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Config for: '%s'\n%s",
- name, json.dumps(self.config, indent=2))
+ _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2))
def _validate_json_config(self, config: dict, name: str) -> None:
- """
- Reconstructs a basic json config that this class might have been
- instantiated from in order to validate configs provided outside the
- file loading mechanism.
+ """Reconstructs a basic json config that this class might have been instantiated
+ from in order to validate configs provided outside the file loading
+ mechanism.
"""
json_config: dict = {
"class": self.__class__.__module__ + "." + self.__class__.__name__,
@@ -168,8 +185,10 @@ class Environment(metaclass=abc.ABCMeta):
ConfigSchema.ENVIRONMENT.validate(json_config)
@staticmethod
- def _expand_groups(groups: Iterable[str],
- groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]:
+ def _expand_groups(
+ groups: Iterable[str],
+ groups_exp: Dict[str, Union[str, Sequence[str]]],
+ ) -> List[str]:
"""
Expand `$tunable_group` into actual names of the tunable groups.
@@ -191,7 +210,12 @@ class Environment(metaclass=abc.ABCMeta):
if grp[:1] == "$":
tunable_group_name = grp[1:]
if tunable_group_name not in groups_exp:
- raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}")
+ raise KeyError(
+ (
+ f"Expected tunable group name ${tunable_group_name} "
+ "undefined in {groups_exp}"
+ )
+ )
add_groups = groups_exp[tunable_group_name]
res += [add_groups] if isinstance(add_groups, str) else add_groups
else:
@@ -199,10 +223,11 @@ class Environment(metaclass=abc.ABCMeta):
return res
@staticmethod
- def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict:
- """
- Expand `$var` into actual values of the variables.
- """
+ def _expand_vars(
+ params: Dict[str, TunableValue],
+ global_config: Dict[str, TunableValue],
+ ) -> dict:
+ """Expand `$var` into actual values of the variables."""
return DictTemplater(params).expand_vars(extra_source_dict=global_config)
@property
@@ -210,10 +235,8 @@ class Environment(metaclass=abc.ABCMeta):
assert self._service is not None
return self._service.config_loader_service
- def __enter__(self) -> 'Environment':
- """
- Enter the environment's benchmarking context.
- """
+ def __enter__(self) -> "Environment":
+ """Enter the environment's benchmarking context."""
_LOG.debug("Environment START :: %s", self)
assert not self._in_context
if self._service:
@@ -221,12 +244,13 @@ class Environment(metaclass=abc.ABCMeta):
self._in_context = True
return self
- def __exit__(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
- """
- Exit the context of the benchmarking environment.
- """
+ def __exit__(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
+ """Exit the context of the benchmarking environment."""
ex_throw = None
if ex_val is None:
_LOG.debug("Environment END :: %s", self)
@@ -256,8 +280,8 @@ class Environment(metaclass=abc.ABCMeta):
def pprint(self, indent: int = 4, level: int = 0) -> str:
"""
- Pretty-print the environment configuration.
- For composite environments, print all children environments as well.
+ Pretty-print the environment configuration. For composite environments, print
+ all children environments as well.
Parameters
----------
@@ -277,8 +301,8 @@ class Environment(metaclass=abc.ABCMeta):
def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]:
"""
Plug tunable values into the base config. If the tunable group is unknown,
- ignore it (it might belong to another environment). This method should
- never mutate the original config or the tunables.
+ ignore it (it might belong to another environment). This method should never
+ mutate the original config or the tunables.
Parameters
----------
@@ -293,7 +317,8 @@ class Environment(metaclass=abc.ABCMeta):
"""
return tunables.get_param_values(
group_names=list(self._tunable_params.get_covariant_group_names()),
- into_params=self._const_args.copy())
+ into_params=self._const_args.copy(),
+ )
@property
def tunable_params(self) -> TunableGroups:
@@ -310,21 +335,23 @@ class Environment(metaclass=abc.ABCMeta):
@property
def parameters(self) -> Dict[str, TunableValue]:
"""
- Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`).
- Note that before `.setup()` is called, all tunables will be set to None.
+ Key/value pairs of all environment parameters (i.e., `const_args` and
+ `tunable_params`). Note that before `.setup()` is called, all tunables will be
+ set to None.
Returns
-------
parameters : Dict[str, TunableValue]
- Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`).
+ Key/value pairs of all environment parameters
+ (i.e., `const_args` and `tunable_params`).
"""
return self._params
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Set up a new benchmark environment, if necessary. This method must be
- idempotent, i.e., calling it several times in a row should be
- equivalent to a single call.
+ idempotent, i.e., calling it several times in a row should be equivalent to a
+ single call.
Parameters
----------
@@ -353,10 +380,15 @@ class Environment(metaclass=abc.ABCMeta):
# (Derived classes still have to check `self._tunable_params.is_updated()`).
is_updated = self._tunable_params.is_updated()
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, {
- name: self._tunable_params.is_updated([name])
- for name in self._tunable_params.get_covariant_group_names()
- })
+ _LOG.debug(
+ "Env '%s': Tunable groups reset = %s :: %s",
+ self,
+ is_updated,
+ {
+ name: self._tunable_params.is_updated([name])
+ for name in self._tunable_params.get_covariant_group_names()
+ },
+ )
else:
_LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated)
@@ -371,9 +403,10 @@ class Environment(metaclass=abc.ABCMeta):
def teardown(self) -> None:
"""
- Tear down the benchmark environment. This method must be idempotent,
- i.e., calling it several times in a row should be equivalent to a
- single call.
+ Tear down the benchmark environment.
+
+ This method must be idempotent, i.e., calling it several times in a row should
+ be equivalent to a single call.
"""
_LOG.info("Teardown %s", self)
# Make sure we create a context before invoking setup/run/status/teardown
diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py
index 06b4f431be..e37d5273eb 100644
--- a/mlos_bench/mlos_bench/environments/composite_env.py
+++ b/mlos_bench/mlos_bench/environments/composite_env.py
@@ -2,20 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Composite benchmark environment.
-"""
+"""Composite benchmark environment."""
import logging
from datetime import datetime
-
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type
+
from typing_extensions import Literal
-from mlos_bench.services.base_service import Service
-from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
+from mlos_bench.environments.status import Status
+from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
@@ -23,17 +21,17 @@ _LOG = logging.getLogger(__name__)
class CompositeEnv(Environment):
- """
- Composite benchmark environment.
- """
+ """Composite benchmark environment."""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment with a given config.
@@ -53,8 +51,13 @@ class CompositeEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
- super().__init__(name=name, config=config, global_config=global_config,
- tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
# By default, the Environment includes only the tunables explicitly specified
# in the "tunable_params" section of the config. `CompositeEnv`, however, must
@@ -70,17 +73,27 @@ class CompositeEnv(Environment):
# each CompositeEnv gets a copy of the original global config and adjusts it with
# the `const_args` specific to it.
global_config = (global_config or {}).copy()
- for (key, val) in self._const_args.items():
+ for key, val in self._const_args.items():
global_config.setdefault(key, val)
for child_config_file in config.get("include_children", []):
for env in self._config_loader_service.load_environment_list(
- child_config_file, tunables, global_config, self._const_args, self._service):
+ child_config_file,
+ tunables,
+ global_config,
+ self._const_args,
+ self._service,
+ ):
self._add_child(env, tunables)
for child_config in config.get("children", []):
env = self._config_loader_service.build_environment(
- child_config, tunables, global_config, self._const_args, self._service)
+ child_config,
+ tunables,
+ global_config,
+ self._const_args,
+ self._service,
+ )
self._add_child(env, tunables)
_LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params)
@@ -92,9 +105,12 @@ class CompositeEnv(Environment):
self._child_contexts = [env.__enter__() for env in self._children]
return super().__enter__()
- def __exit__(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
+ def __exit__(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
ex_throw = None
for env in reversed(self._children):
try:
@@ -111,9 +127,7 @@ class CompositeEnv(Environment):
@property
def children(self) -> List[Environment]:
- """
- Return the list of child environments.
- """
+ """Return the list of child environments."""
return self._children
def pprint(self, indent: int = 4, level: int = 0) -> str:
@@ -132,12 +146,16 @@ class CompositeEnv(Environment):
pretty : str
Pretty-printed environment configuration.
"""
- return super().pprint(indent, level) + '\n' + '\n'.join(
- child.pprint(indent, level + 1) for child in self._children)
+ return (
+ super().pprint(indent, level)
+ + "\n"
+ + "\n".join(child.pprint(indent, level + 1) for child in self._children)
+ )
def _add_child(self, env: Environment, tunables: TunableGroups) -> None:
"""
Add a new child environment to the composite environment.
+
This method is called from the constructor only.
"""
_LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params)
@@ -165,14 +183,16 @@ class CompositeEnv(Environment):
"""
assert self._in_context
self._is_ready = super().setup(tunables, global_config) and all(
- env_context.setup(tunables, global_config) for env_context in self._child_contexts)
+ env_context.setup(tunables, global_config) for env_context in self._child_contexts
+ )
return self._is_ready
def teardown(self) -> None:
"""
- Tear down the children environments. This method is idempotent,
- i.e., calling it several times is equivalent to a single call.
- The environments are being torn down in the reverse order.
+ Tear down the children environments.
+
+ This method is idempotent, i.e., calling it several times is equivalent to a
+ single call. The environments are being torn down in the reverse order.
"""
assert self._in_context
for env_context in reversed(self._child_contexts):
@@ -181,9 +201,9 @@ class CompositeEnv(Environment):
def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
"""
- Submit a new experiment to the environment.
- Return the result of the *last* child environment if successful,
- or the status of the last failed environment otherwise.
+ Submit a new experiment to the environment. Return the result of the *last*
+ child environment if successful, or the status of the last failed environment
+ otherwise.
Returns
-------
@@ -238,5 +258,6 @@ class CompositeEnv(Environment):
final_status = final_status or status
_LOG.info("Final status: %s :: %s", self, final_status)
- # Return the status and the timestamp of the last child environment or the first failed child environment.
+ # Return the status and the timestamp of the last child environment or the
+ # first failed child environment.
return (final_status, timestamp, joint_telemetry)
diff --git a/mlos_bench/mlos_bench/environments/local/__init__.py b/mlos_bench/mlos_bench/environments/local/__init__.py
index 0cdd8349b4..7de10647f1 100644
--- a/mlos_bench/mlos_bench/environments/local/__init__.py
+++ b/mlos_bench/mlos_bench/environments/local/__init__.py
@@ -2,14 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Local Environments for mlos_bench.
-"""
+"""Local Environments for mlos_bench."""
from mlos_bench.environments.local.local_env import LocalEnv
from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv
__all__ = [
- 'LocalEnv',
- 'LocalFileShareEnv',
+ "LocalEnv",
+ "LocalFileShareEnv",
]
diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py
index 01f0337c1f..c3e81bb94d 100644
--- a/mlos_bench/mlos_bench/environments/local/local_env.py
+++ b/mlos_bench/mlos_bench/environments/local/local_env.py
@@ -2,27 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Scheduler-side benchmark environment to run scripts locally.
-"""
+"""Scheduler-side benchmark environment to run scripts locally."""
import json
import logging
import sys
-
+from contextlib import nullcontext
from datetime import datetime
from tempfile import TemporaryDirectory
-from contextlib import nullcontext
-
from types import TracebackType
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union
-from typing_extensions import Literal
import pandas
+from typing_extensions import Literal
-from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.script_env import ScriptEnv
+from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
from mlos_bench.tunables.tunable import TunableValue
@@ -34,17 +30,17 @@ _LOG = logging.getLogger(__name__)
class LocalEnv(ScriptEnv):
# pylint: disable=too-many-instance-attributes
- """
- Scheduler-side Environment that runs scripts locally.
- """
+ """Scheduler-side Environment that runs scripts locally."""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment for local execution.
@@ -67,11 +63,17 @@ class LocalEnv(ScriptEnv):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
- super().__init__(name=name, config=config, global_config=global_config,
- tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
- assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
- "LocalEnv requires a service that supports local execution"
+ assert self._service is not None and isinstance(
+ self._service, SupportsLocalExec
+ ), "LocalEnv requires a service that supports local execution"
self._local_exec_service: SupportsLocalExec = self._service
self._temp_dir: Optional[str] = None
@@ -85,16 +87,19 @@ class LocalEnv(ScriptEnv):
def __enter__(self) -> Environment:
assert self._temp_dir is None and self._temp_dir_context is None
- self._temp_dir_context = self._local_exec_service.temp_dir_context(self.config.get("temp_dir"))
+ self._temp_dir_context = self._local_exec_service.temp_dir_context(
+ self.config.get("temp_dir"),
+ )
self._temp_dir = self._temp_dir_context.__enter__()
return super().__enter__()
- def __exit__(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
- """
- Exit the context of the benchmarking environment.
- """
+ def __exit__(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
+ """Exit the context of the benchmarking environment."""
assert not (self._temp_dir is None or self._temp_dir_context is None)
self._temp_dir_context.__exit__(ex_type, ex_val, ex_tb)
self._temp_dir = None
@@ -103,8 +108,8 @@ class LocalEnv(ScriptEnv):
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
- Check if the environment is ready and set up the application
- and benchmarks, if necessary.
+ Check if the environment is ready and set up the application and benchmarks, if
+ necessary.
Parameters
----------
@@ -139,10 +144,14 @@ class LocalEnv(ScriptEnv):
fname = path_join(self._temp_dir, self._dump_meta_file)
_LOG.debug("Dump tunables metadata to file: %s", fname)
with open(fname, "wt", encoding="utf-8") as fh_meta:
- json.dump({
- tunable.name: tunable.meta
- for (tunable, _group) in self._tunable_params if tunable.meta
- }, fh_meta)
+ json.dump(
+ {
+ tunable.name: tunable.meta
+ for (tunable, _group) in self._tunable_params
+ if tunable.meta
+ },
+ fh_meta,
+ )
if self._script_setup:
(return_code, _output) = self._local_exec(self._script_setup, self._temp_dir)
@@ -182,18 +191,28 @@ class LocalEnv(ScriptEnv):
_LOG.debug("Not reading the data at: %s", self)
return (Status.SUCCEEDED, timestamp, stdout_data)
- data = self._normalize_columns(pandas.read_csv(
- self._config_loader_service.resolve_path(
- self._read_results_file, extra_paths=[self._temp_dir]),
- index_col=False,
- ))
+ data = self._normalize_columns(
+ pandas.read_csv(
+ self._config_loader_service.resolve_path(
+ self._read_results_file,
+ extra_paths=[self._temp_dir],
+ ),
+ index_col=False,
+ )
+ )
_LOG.debug("Read data:\n%s", data)
if list(data.columns) == ["metric", "value"]:
- _LOG.info("Local results have (metric,value) header and %d rows: assume long format", len(data))
+ _LOG.info(
+ "Local results have (metric,value) header and %d rows: assume long format",
+ len(data),
+ )
data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list())
# Try to convert string metrics to numbers.
- data = data.apply(pandas.to_numeric, errors='coerce').fillna(data) # type: ignore[assignment] # (false positive)
+ data = data.apply( # type: ignore[assignment] # (false positive)
+ pandas.to_numeric,
+ errors="coerce",
+ ).fillna(data)
elif len(data) == 1:
_LOG.info("Local results have 1 row: assume wide format")
else:
@@ -205,14 +224,12 @@ class LocalEnv(ScriptEnv):
@staticmethod
def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame:
- """
- Strip trailing spaces from column names (Windows only).
- """
+ """Strip trailing spaces from column names (Windows only)."""
# Windows cmd interpretation of > redirect symbols can leave trailing spaces in
# the final column, which leads to misnamed columns.
# For now, we simply strip trailing spaces from column names to account for that.
- if sys.platform == 'win32':
- data.rename(str.rstrip, axis='columns', inplace=True)
+ if sys.platform == "win32":
+ data.rename(str.rstrip, axis="columns", inplace=True)
return data
def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
@@ -224,24 +241,24 @@ class LocalEnv(ScriptEnv):
assert self._temp_dir is not None
try:
fname = self._config_loader_service.resolve_path(
- self._read_telemetry_file, extra_paths=[self._temp_dir])
+ self._read_telemetry_file,
+ extra_paths=[self._temp_dir],
+ )
# TODO: Use the timestamp of the CSV file as our status timestamp?
# FIXME: We should not be assuming that the only output file type is a CSV.
- data = self._normalize_columns(
- pandas.read_csv(fname, index_col=False))
+ data = self._normalize_columns(pandas.read_csv(fname, index_col=False))
data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
expected_col_names = ["timestamp", "metric", "value"]
if len(data.columns) != len(expected_col_names):
- raise ValueError(f'Telemetry data must have columns {expected_col_names}')
+ raise ValueError(f"Telemetry data must have columns {expected_col_names}")
if list(data.columns) != expected_col_names:
# Assume no header - this is ok for telemetry data.
- data = pandas.read_csv(
- fname, index_col=False, names=expected_col_names)
+ data = pandas.read_csv(fname, index_col=False, names=expected_col_names)
data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
except FileNotFoundError as ex:
@@ -250,15 +267,17 @@ class LocalEnv(ScriptEnv):
_LOG.debug("Read telemetry data:\n%s", data)
col_dtypes: Mapping[int, Type] = {0: datetime}
- return (status, timestamp, [
- (pandas.Timestamp(ts).to_pydatetime(), metric, value)
- for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
- ])
+ return (
+ status,
+ timestamp,
+ [
+ (pandas.Timestamp(ts).to_pydatetime(), metric, value)
+ for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
+ ],
+ )
def teardown(self) -> None:
- """
- Clean up the local environment.
- """
+ """Clean up the local environment."""
if self._script_teardown:
_LOG.info("Local teardown: %s", self)
(return_code, _output) = self._local_exec(self._script_teardown)
@@ -285,7 +304,10 @@ class LocalEnv(ScriptEnv):
env_params = self._get_env_params()
_LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params)
(return_code, stdout, stderr) = self._local_exec_service.local_exec(
- script, env=env_params, cwd=cwd)
+ script,
+ env=env_params,
+ cwd=cwd,
+ )
if return_code != 0:
_LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr)
return (return_code, {"stdout": stdout, "stderr": stderr})
diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py
index 6aea7acfc4..14ba59f3f6 100644
--- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py
+++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py
@@ -2,22 +2,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Scheduler-side Environment to run scripts locally
-and upload/download data to the shared storage.
+"""Scheduler-side Environment to run scripts locally and upload/download data to the
+shared storage.
"""
import logging
-
from datetime import datetime
from string import Template
-from typing import Any, Dict, List, Generator, Iterable, Mapping, Optional, Tuple
+from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple
-from mlos_bench.services.base_service import Service
-from mlos_bench.services.types.local_exec_type import SupportsLocalExec
-from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
-from mlos_bench.environments.status import Status
from mlos_bench.environments.local.local_env import LocalEnv
+from mlos_bench.environments.status import Status
+from mlos_bench.services.base_service import Service
+from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
+from mlos_bench.services.types.local_exec_type import SupportsLocalExec
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
@@ -25,18 +23,19 @@ _LOG = logging.getLogger(__name__)
class LocalFileShareEnv(LocalEnv):
- """
- Scheduler-side Environment that runs scripts locally
- and uploads/downloads data to the shared file storage.
+ """Scheduler-side Environment that runs scripts locally and uploads/downloads data
+ to the shared file storage.
"""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new application environment with a given config.
@@ -60,34 +59,41 @@ class LocalFileShareEnv(LocalEnv):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
- super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
- assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
- "LocalEnv requires a service that supports local execution"
+ assert self._service is not None and isinstance(
+ self._service, SupportsLocalExec
+ ), "LocalEnv requires a service that supports local execution"
self._local_exec_service: SupportsLocalExec = self._service
- assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \
- "LocalEnv requires a service that supports file upload/download operations"
+ assert self._service is not None and isinstance(
+ self._service, SupportsFileShareOps
+ ), "LocalEnv requires a service that supports file upload/download operations"
self._file_share_service: SupportsFileShareOps = self._service
self._upload = self._template_from_to("upload")
self._download = self._template_from_to("download")
def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]:
+ """Convert a list of {"from": "...", "to": "..."} to a list of pairs of
+ string.Template objects so that we can plug in self._params into it later.
"""
- Convert a list of {"from": "...", "to": "..."} to a list of pairs
- of string.Template objects so that we can plug in self._params into it later.
- """
- return [
- (Template(d['from']), Template(d['to']))
- for d in self.config.get(config_key, [])
- ]
+ return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])]
@staticmethod
- def _expand(from_to: Iterable[Tuple[Template, Template]],
- params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]:
+ def _expand(
+ from_to: Iterable[Tuple[Template, Template]],
+ params: Mapping[str, TunableValue],
+ ) -> Generator[Tuple[str, str], None, None]:
"""
Substitute $var parameters in from/to path templates.
+
Return a generator of (str, str) pairs of paths.
"""
return (
@@ -120,9 +126,15 @@ class LocalFileShareEnv(LocalEnv):
assert self._temp_dir is not None
params = self._get_env_params(restrict=False)
params["PWD"] = self._temp_dir
- for (path_from, path_to) in self._expand(self._upload, params):
- self._file_share_service.upload(self._params, self._config_loader_service.resolve_path(
- path_from, extra_paths=[self._temp_dir]), path_to)
+ for path_from, path_to in self._expand(self._upload, params):
+ self._file_share_service.upload(
+ self._params,
+ self._config_loader_service.resolve_path(
+ path_from,
+ extra_paths=[self._temp_dir],
+ ),
+ path_to,
+ )
return self._is_ready
def _download_files(self, ignore_missing: bool = False) -> None:
@@ -138,11 +150,16 @@ class LocalFileShareEnv(LocalEnv):
assert self._temp_dir is not None
params = self._get_env_params(restrict=False)
params["PWD"] = self._temp_dir
- for (path_from, path_to) in self._expand(self._download, params):
+ for path_from, path_to in self._expand(self._download, params):
try:
- self._file_share_service.download(self._params,
- path_from, self._config_loader_service.resolve_path(
- path_to, extra_paths=[self._temp_dir]))
+ self._file_share_service.download(
+ self._params,
+ path_from,
+ self._config_loader_service.resolve_path(
+ path_to,
+ extra_paths=[self._temp_dir],
+ ),
+ )
except FileNotFoundError as ex:
_LOG.warning("Cannot download: %s", path_from)
if not ignore_missing:
@@ -153,8 +170,8 @@ class LocalFileShareEnv(LocalEnv):
def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
"""
- Download benchmark results from the shared storage
- and run post-processing scripts locally.
+ Download benchmark results from the shared storage and run post-processing
+ scripts locally.
Returns
-------
diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py
index d8ffe3e47d..a8888c5e28 100644
--- a/mlos_bench/mlos_bench/environments/mock_env.py
+++ b/mlos_bench/mlos_bench/environments/mock_env.py
@@ -2,40 +2,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Scheduler-side environment to mock the benchmark results.
-"""
+"""Scheduler-side environment to mock the benchmark results."""
-import random
import logging
+import random
from datetime import datetime
from typing import Dict, Optional, Tuple
import numpy
-from mlos_bench.services.base_service import Service
-from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
+from mlos_bench.environments.status import Status
+from mlos_bench.services.base_service import Service
from mlos_bench.tunables import Tunable, TunableGroups, TunableValue
_LOG = logging.getLogger(__name__)
class MockEnv(Environment):
- """
- Scheduler-side environment to mock the benchmark results.
- """
+ """Scheduler-side environment to mock the benchmark results."""
_NOISE_VAR = 0.2
"""Variance of the Gaussian noise added to the benchmark value."""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment that produces mock benchmark data.
@@ -55,8 +53,13 @@ class MockEnv(Environment):
service: Service
An optional service object. Not used by this class.
"""
- super().__init__(name=name, config=config, global_config=global_config,
- tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
seed = int(self.config.get("mock_env_seed", -1))
self._random = random.Random(seed or None) if seed >= 0 else None
self._range = self.config.get("mock_env_range")
@@ -81,9 +84,9 @@ class MockEnv(Environment):
return result
# Simple convex function of all tunable parameters.
- score = numpy.mean(numpy.square([
- self._normalized(tunable) for (tunable, _group) in self._tunable_params
- ]))
+ score = numpy.mean(
+ numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params])
+ )
# Add noise and shift the benchmark value from [0, 1] to a given range.
noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0
@@ -97,15 +100,16 @@ class MockEnv(Environment):
def _normalized(tunable: Tunable) -> float:
"""
Get the NORMALIZED value of a tunable.
+
That is, map current value to the [0, 1] range.
"""
val = None
if tunable.is_categorical:
- val = (tunable.categories.index(tunable.category) /
- float(len(tunable.categories) - 1))
+ val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1)
elif tunable.is_numerical:
- val = ((tunable.numerical_value - tunable.range[0]) /
- float(tunable.range[1] - tunable.range[0]))
+ val = (tunable.numerical_value - tunable.range[0]) / float(
+ tunable.range[1] - tunable.range[0]
+ )
else:
raise ValueError("Invalid parameter type: " + tunable.type)
# Explicitly clip the value in case of numerical errors.
diff --git a/mlos_bench/mlos_bench/environments/remote/__init__.py b/mlos_bench/mlos_bench/environments/remote/__init__.py
index f07575ac86..10608a6980 100644
--- a/mlos_bench/mlos_bench/environments/remote/__init__.py
+++ b/mlos_bench/mlos_bench/environments/remote/__init__.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Remote Tunable Environments for mlos_bench.
-"""
+"""Remote Tunable Environments for mlos_bench."""
from mlos_bench.environments.remote.host_env import HostEnv
from mlos_bench.environments.remote.network_env import NetworkEnv
@@ -14,10 +12,10 @@ from mlos_bench.environments.remote.saas_env import SaaSEnv
from mlos_bench.environments.remote.vm_env import VMEnv
__all__ = [
- 'HostEnv',
- 'NetworkEnv',
- 'OSEnv',
- 'RemoteEnv',
- 'SaaSEnv',
- 'VMEnv',
+ "HostEnv",
+ "NetworkEnv",
+ "OSEnv",
+ "RemoteEnv",
+ "SaaSEnv",
+ "VMEnv",
]
diff --git a/mlos_bench/mlos_bench/environments/remote/host_env.py b/mlos_bench/mlos_bench/environments/remote/host_env.py
index 4b63e47278..c6d1c0145e 100644
--- a/mlos_bench/mlos_bench/environments/remote/host_env.py
+++ b/mlos_bench/mlos_bench/environments/remote/host_env.py
@@ -2,13 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Remote host Environment.
-"""
-
-from typing import Optional
+"""Remote host Environment."""
import logging
+from typing import Optional
from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
@@ -19,17 +16,17 @@ _LOG = logging.getLogger(__name__)
class HostEnv(Environment):
- """
- Remote host environment.
- """
+ """Remote host environment."""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment for host operations.
@@ -50,10 +47,17 @@ class HostEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM/host, etc.).
"""
- super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
- assert self._service is not None and isinstance(self._service, SupportsHostProvisioning), \
- "HostEnv requires a service that supports host provisioning operations"
+ assert self._service is not None and isinstance(
+ self._service, SupportsHostProvisioning
+ ), "HostEnv requires a service that supports host provisioning operations"
self._host_service: SupportsHostProvisioning = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
@@ -88,9 +92,7 @@ class HostEnv(Environment):
return self._is_ready
def teardown(self) -> None:
- """
- Shut down the Host and release it.
- """
+ """Shut down the Host and release it."""
_LOG.info("Host tear down: %s", self)
(status, params) = self._host_service.deprovision_host(self._params)
if status.is_pending():
diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py
index d049beacfd..c87ddd0899 100644
--- a/mlos_bench/mlos_bench/environments/remote/network_env.py
+++ b/mlos_bench/mlos_bench/environments/remote/network_env.py
@@ -2,17 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Network Environment.
-"""
-
-from typing import Optional
+"""Network Environment."""
import logging
+from typing import Optional
from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
-from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
+from mlos_bench.services.types.network_provisioner_type import (
+ SupportsNetworkProvisioning,
+)
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
@@ -26,13 +25,15 @@ class NetworkEnv(Environment):
but no real tuning is expected for it ... yet.
"""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment for network operations.
@@ -53,14 +54,21 @@ class NetworkEnv(Environment):
An optional service object (e.g., providing methods to
deploy a network, etc.).
"""
- super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
# Virtual networks can be used for more than one experiment, so by default
# we don't attempt to deprovision them.
self._deprovision_on_teardown = config.get("deprovision_on_teardown", False)
- assert self._service is not None and isinstance(self._service, SupportsNetworkProvisioning), \
- "NetworkEnv requires a service that supports network provisioning"
+ assert self._service is not None and isinstance(
+ self._service, SupportsNetworkProvisioning
+ ), "NetworkEnv requires a service that supports network provisioning"
self._network_service: SupportsNetworkProvisioning = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
@@ -96,15 +104,16 @@ class NetworkEnv(Environment):
return self._is_ready
def teardown(self) -> None:
- """
- Shut down the Network and releases it.
- """
+ """Shut down the Network and releases it."""
if not self._deprovision_on_teardown:
_LOG.info("Skipping Network deprovision: %s", self)
return
# Else
_LOG.info("Network tear down: %s", self)
- (status, params) = self._network_service.deprovision_network(self._params, ignore_errors=True)
+ (status, params) = self._network_service.deprovision_network(
+ self._params,
+ ignore_errors=True,
+ )
if status.is_pending():
(status, _) = self._network_service.wait_network_deployment(params, is_setup=False)
diff --git a/mlos_bench/mlos_bench/environments/remote/os_env.py b/mlos_bench/mlos_bench/environments/remote/os_env.py
index bb5a0238a4..4328b8f694 100644
--- a/mlos_bench/mlos_bench/environments/remote/os_env.py
+++ b/mlos_bench/mlos_bench/environments/remote/os_env.py
@@ -2,13 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-OS-level remote Environment on Azure.
-"""
-
-from typing import Optional
+"""OS-level remote Environment on Azure."""
import logging
+from typing import Optional
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.status import Status
@@ -21,17 +18,17 @@ _LOG = logging.getLogger(__name__)
class OSEnv(Environment):
- """
- OS Level Environment for a host.
- """
+ """OS Level Environment for a host."""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment for remote execution.
@@ -54,14 +51,22 @@ class OSEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
- super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
- assert self._service is not None and isinstance(self._service, SupportsHostOps), \
- "RemoteEnv requires a service that supports host operations"
+ assert self._service is not None and isinstance(
+ self._service, SupportsHostOps
+ ), "RemoteEnv requires a service that supports host operations"
self._host_service: SupportsHostOps = self._service
- assert self._service is not None and isinstance(self._service, SupportsOSOps), \
- "RemoteEnv requires a service that supports host operations"
+ assert self._service is not None and isinstance(
+ self._service, SupportsOSOps
+ ), "RemoteEnv requires a service that supports host operations"
self._os_service: SupportsOSOps = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
@@ -98,9 +103,7 @@ class OSEnv(Environment):
return self._is_ready
def teardown(self) -> None:
- """
- Clean up and shut down the host without deprovisioning it.
- """
+ """Clean up and shut down the host without deprovisioning it."""
_LOG.info("OS tear down: %s", self)
(status, params) = self._os_service.shutdown(self._params)
if status.is_pending():
diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py
index 0320b02769..c48b84cfdd 100644
--- a/mlos_bench/mlos_bench/environments/remote/remote_env.py
+++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py
@@ -14,11 +14,11 @@ from typing import Dict, Iterable, Optional, Tuple
from pytz import UTC
-from mlos_bench.environments.status import Status
from mlos_bench.environments.script_env import ScriptEnv
+from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
-from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.services.types.host_ops_type import SupportsHostOps
+from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
@@ -32,13 +32,15 @@ class RemoteEnv(ScriptEnv):
e.g. Application Environment
"""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment for remote execution.
@@ -61,24 +63,31 @@ class RemoteEnv(ScriptEnv):
An optional service object (e.g., providing methods to
deploy or reboot a Host, VM, OS, etc.).
"""
- super().__init__(name=name, config=config, global_config=global_config,
- tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
self._wait_boot = self.config.get("wait_boot", False)
- assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \
- "RemoteEnv requires a service that supports remote execution operations"
+ assert self._service is not None and isinstance(
+ self._service, SupportsRemoteExec
+ ), "RemoteEnv requires a service that supports remote execution operations"
self._remote_exec_service: SupportsRemoteExec = self._service
if self._wait_boot:
- assert self._service is not None and isinstance(self._service, SupportsHostOps), \
- "RemoteEnv requires a service that supports host operations"
+ assert self._service is not None and isinstance(
+ self._service, SupportsHostOps
+ ), "RemoteEnv requires a service that supports host operations"
self._host_service: SupportsHostOps = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
- Check if the environment is ready and set up the application
- and benchmarks on a remote host.
+ Check if the environment is ready and set up the application and benchmarks on a
+ remote host.
Parameters
----------
@@ -143,9 +152,7 @@ class RemoteEnv(ScriptEnv):
return (status, timestamp, output)
def teardown(self) -> None:
- """
- Clean up and shut down the remote environment.
- """
+ """Clean up and shut down the remote environment."""
if self._script_teardown:
_LOG.info("Remote teardown: %s", self)
(status, _timestamp, _output) = self._remote_exec(self._script_teardown)
@@ -170,7 +177,10 @@ class RemoteEnv(ScriptEnv):
env_params = self._get_env_params()
_LOG.debug("Submit script: %s with %s", self, env_params)
(status, output) = self._remote_exec_service.remote_exec(
- script, config=self._params, env_params=env_params)
+ script,
+ config=self._params,
+ env_params=env_params,
+ )
_LOG.debug("Script submitted: %s %s :: %s", self, status, output)
if status in {Status.PENDING, Status.SUCCEEDED}:
(status, output) = self._remote_exec_service.get_remote_exec_results(output)
diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py
index 96a91db292..5d1ba9d800 100644
--- a/mlos_bench/mlos_bench/environments/remote/saas_env.py
+++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py
@@ -2,13 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Cloud-based (configurable) SaaS environment.
-"""
-
-from typing import Optional
+"""Cloud-based (configurable) SaaS environment."""
import logging
+from typing import Optional
from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
@@ -20,17 +17,17 @@ _LOG = logging.getLogger(__name__)
class SaaSEnv(Environment):
- """
- Cloud-based (configurable) SaaS environment.
- """
+ """Cloud-based (configurable) SaaS environment."""
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment for (configurable) cloud-based SaaS instance.
@@ -51,15 +48,22 @@ class SaaSEnv(Environment):
An optional service object
(e.g., providing methods to configure the remote service).
"""
- super().__init__(name=name, config=config, global_config=global_config,
- tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
- assert self._service is not None and isinstance(self._service, SupportsHostOps), \
- "RemoteEnv requires a service that supports host operations"
+ assert self._service is not None and isinstance(
+ self._service, SupportsHostOps
+ ), "RemoteEnv requires a service that supports host operations"
self._host_service: SupportsHostOps = self._service
- assert self._service is not None and isinstance(self._service, SupportsRemoteConfig), \
- "SaaSEnv requires a service that supports remote host configuration API"
+ assert self._service is not None and isinstance(
+ self._service, SupportsRemoteConfig
+ ), "SaaSEnv requires a service that supports remote host configuration API"
self._config_service: SupportsRemoteConfig = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
@@ -85,7 +89,9 @@ class SaaSEnv(Environment):
return False
(status, _) = self._config_service.configure(
- self._params, self._tunable_params.get_param_values())
+ self._params,
+ self._tunable_params.get_param_values(),
+ )
if not status.is_succeeded():
return False
@@ -94,7 +100,7 @@ class SaaSEnv(Environment):
return False
# Azure Flex DB instances currently require a VM reboot after reconfiguration.
- if res.get('isConfigPendingRestart') or res.get('isConfigPendingReboot'):
+ if res.get("isConfigPendingRestart") or res.get("isConfigPendingReboot"):
_LOG.info("Restarting: %s", self)
(status, params) = self._host_service.restart_host(self._params)
if status.is_pending():
diff --git a/mlos_bench/mlos_bench/environments/remote/vm_env.py b/mlos_bench/mlos_bench/environments/remote/vm_env.py
index eae7bf982c..3be95ce2c2 100644
--- a/mlos_bench/mlos_bench/environments/remote/vm_env.py
+++ b/mlos_bench/mlos_bench/environments/remote/vm_env.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-"Remote" VM (Host) Environment.
-"""
+"""Remote VM (Host) Environment."""
import logging
diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py
index 05c5fdec86..fe31d6fb13 100644
--- a/mlos_bench/mlos_bench/environments/script_env.py
+++ b/mlos_bench/mlos_bench/environments/script_env.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base scriptable benchmark environment.
-"""
+"""Base scriptable benchmark environment."""
import abc
import logging
@@ -15,26 +13,25 @@ from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
-
from mlos_bench.util import try_parse_val
_LOG = logging.getLogger(__name__)
class ScriptEnv(Environment, metaclass=abc.ABCMeta):
- """
- Base Environment that runs scripts for setup/run/teardown.
- """
+ """Base Environment that runs scripts for setup/run/teardown."""
_RE_INVALID = re.compile(r"[^a-zA-Z0-9_]")
- def __init__(self,
- *,
- name: str,
- config: dict,
- global_config: Optional[dict] = None,
- tunables: Optional[TunableGroups] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: dict,
+ global_config: Optional[dict] = None,
+ tunables: Optional[TunableGroups] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new environment for script execution.
@@ -64,19 +61,29 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
- super().__init__(name=name, config=config, global_config=global_config,
- tunables=tunables, service=service)
+ super().__init__(
+ name=name,
+ config=config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
self._script_setup = self.config.get("setup")
self._script_run = self.config.get("run")
self._script_teardown = self.config.get("teardown")
self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", [])
- self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {})
+ self._shell_env_params_rename: Dict[str, str] = self.config.get(
+ "shell_env_params_rename", {}
+ )
results_stdout_pattern = self.config.get("results_stdout_pattern")
- self._results_stdout_pattern: Optional[re.Pattern[str]] = \
- re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None
+ self._results_stdout_pattern: Optional[re.Pattern[str]] = (
+ re.compile(results_stdout_pattern, flags=re.MULTILINE)
+ if results_stdout_pattern
+ else None
+ )
def _get_env_params(self, restrict: bool = True) -> Dict[str, str]:
"""
@@ -117,4 +124,6 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta):
if not self._results_stdout_pattern:
return {}
_LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout)
- return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)}
+ return {
+ key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)
+ }
diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py
index fbe3dcccf4..f3e0d0ea37 100644
--- a/mlos_bench/mlos_bench/environments/status.py
+++ b/mlos_bench/mlos_bench/environments/status.py
@@ -2,17 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Enum for the status of the benchmark/environment.
-"""
+"""Enum for the status of the benchmark/environment."""
import enum
class Status(enum.Enum):
- """
- Enum for the status of the benchmark/environment.
- """
+ """Enum for the status of the benchmark/environment."""
UNKNOWN = 0
PENDING = 1
@@ -24,9 +20,7 @@ class Status(enum.Enum):
TIMED_OUT = 7
def is_good(self) -> bool:
- """
- Check if the status of the benchmark/environment is good.
- """
+ """Check if the status of the benchmark/environment is good."""
return self in {
Status.PENDING,
Status.READY,
@@ -35,9 +29,8 @@ class Status(enum.Enum):
}
def is_completed(self) -> bool:
- """
- Check if the status of the benchmark/environment is
- one of {SUCCEEDED, CANCELED, FAILED, TIMED_OUT}.
+ """Check if the status of the benchmark/environment is one of {SUCCEEDED,
+ CANCELED, FAILED, TIMED_OUT}.
"""
return self in {
Status.SUCCEEDED,
@@ -47,37 +40,25 @@ class Status(enum.Enum):
}
def is_pending(self) -> bool:
- """
- Check if the status of the benchmark/environment is PENDING.
- """
+ """Check if the status of the benchmark/environment is PENDING."""
return self == Status.PENDING
def is_ready(self) -> bool:
- """
- Check if the status of the benchmark/environment is READY.
- """
+ """Check if the status of the benchmark/environment is READY."""
return self == Status.READY
def is_succeeded(self) -> bool:
- """
- Check if the status of the benchmark/environment is SUCCEEDED.
- """
+ """Check if the status of the benchmark/environment is SUCCEEDED."""
return self == Status.SUCCEEDED
def is_failed(self) -> bool:
- """
- Check if the status of the benchmark/environment is FAILED.
- """
+ """Check if the status of the benchmark/environment is FAILED."""
return self == Status.FAILED
def is_canceled(self) -> bool:
- """
- Check if the status of the benchmark/environment is CANCELED.
- """
+ """Check if the status of the benchmark/environment is CANCELED."""
return self == Status.CANCELED
def is_timed_out(self) -> bool:
- """
- Check if the status of the benchmark/environment is TIMED_OUT.
- """
+ """Check if the status of the benchmark/environment is TIMED_OUT."""
return self == Status.FAILED
diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py
index e618c18b16..65285e5d66 100644
--- a/mlos_bench/mlos_bench/event_loop_context.py
+++ b/mlos_bench/mlos_bench/event_loop_context.py
@@ -2,25 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-EventLoopContext class definition.
-"""
-
-from asyncio import AbstractEventLoop
-from concurrent.futures import Future
-from typing import Any, Coroutine, Optional, TypeVar
-from threading import Lock as ThreadLock, Thread
+"""EventLoopContext class definition."""
import asyncio
import logging
import sys
+from asyncio import AbstractEventLoop
+from concurrent.futures import Future
+from threading import Lock as ThreadLock
+from threading import Thread
+from typing import Any, Coroutine, Optional, TypeVar
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
-CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name
+CoroReturnType = TypeVar("CoroReturnType") # pylint: disable=invalid-name
if sys.version_info >= (3, 9):
FutureReturnType: TypeAlias = Future[CoroReturnType]
else:
@@ -31,15 +29,15 @@ _LOG = logging.getLogger(__name__)
class EventLoopContext:
"""
- EventLoopContext encapsulates a background thread for asyncio event
- loop processing as an aid for context managers.
+ EventLoopContext encapsulates a background thread for asyncio event loop processing
+ as an aid for context managers.
- There is generally only expected to be one of these, either as a base
- class instance if it's specific to that functionality or for the full
- mlos_bench process to support parallel trial runners, for instance.
+ There is generally only expected to be one of these, either as a base class instance
+ if it's specific to that functionality or for the full mlos_bench process to support
+ parallel trial runners, for instance.
- It's enter() and exit() routines are expected to be called from the
- caller's context manager routines (e.g., __enter__ and __exit__).
+ It's enter() and exit() routines are expected to be called from the caller's context
+ manager routines (e.g., __enter__ and __exit__).
"""
def __init__(self) -> None:
@@ -49,17 +47,13 @@ class EventLoopContext:
self._event_loop_thread_refcnt: int = 0
def _run_event_loop(self) -> None:
- """
- Runs the asyncio event loop in a background thread.
- """
+ """Runs the asyncio event loop in a background thread."""
assert self._event_loop is not None
asyncio.set_event_loop(self._event_loop)
self._event_loop.run_forever()
def enter(self) -> None:
- """
- Manages starting the background thread for event loop processing.
- """
+ """Manages starting the background thread for event loop processing."""
# Start the background thread if it's not already running.
with self._event_loop_thread_lock:
if not self._event_loop_thread:
@@ -74,9 +68,7 @@ class EventLoopContext:
self._event_loop_thread_refcnt += 1
def exit(self) -> None:
- """
- Manages cleaning up the background thread for event loop processing.
- """
+ """Manages cleaning up the background thread for event loop processing."""
with self._event_loop_thread_lock:
self._event_loop_thread_refcnt -= 1
assert self._event_loop_thread_refcnt >= 0
@@ -92,8 +84,8 @@ class EventLoopContext:
def run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType:
"""
- Runs the given coroutine in the background event loop thread and
- returns a Future that can be used to wait for the result.
+ Runs the given coroutine in the background event loop thread and returns a
+ Future that can be used to wait for the result.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py
index a9aa9e3f46..23421f195b 100644
--- a/mlos_bench/mlos_bench/launcher.py
+++ b/mlos_bench/mlos_bench/launcher.py
@@ -3,8 +3,8 @@
# Licensed under the MIT License.
#
"""
-A helper class to load the configuration files, parse the command line parameters,
-and instantiate the main components of mlos_bench system.
+A helper class to load the configuration files, parse the command line parameters, and
+instantiate the main components of mlos_bench system.
It is used in `mlos_bench.run` module to run the benchmark/optimizer from the
command line.
@@ -13,34 +13,26 @@ command line.
import argparse
import logging
import sys
-
from typing import Any, Dict, Iterable, List, Optional, Tuple
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.dict_templater import DictTemplater
-from mlos_bench.util import try_parse_val
-
-from mlos_bench.tunables.tunable import TunableValue
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.environments.base_environment import Environment
-
from mlos_bench.optimizers.base_optimizer import Optimizer
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
-
-from mlos_bench.storage.base_storage import Storage
-
-from mlos_bench.services.base_service import Service
-from mlos_bench.services.local.local_exec import LocalExecService
-from mlos_bench.services.config_persistence import ConfigPersistenceService
-
from mlos_bench.schedulers.base_scheduler import Scheduler
-
+from mlos_bench.services.base_service import Service
+from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.services.local.local_exec import LocalExecService
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
-
+from mlos_bench.storage.base_storage import Storage
+from mlos_bench.tunables.tunable import TunableValue
+from mlos_bench.tunables.tunable_groups import TunableGroups
+from mlos_bench.util import try_parse_val
_LOG_LEVEL = logging.INFO
-_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s'
+_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s"
logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT)
_LOG = logging.getLogger(__name__)
@@ -48,22 +40,22 @@ _LOG = logging.getLogger(__name__)
class Launcher:
# pylint: disable=too-few-public-methods,too-many-instance-attributes
- """
- Command line launcher for mlos_bench and mlos_core.
- """
+ """Command line launcher for mlos_bench and mlos_core."""
def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None):
# pylint: disable=too-many-statements
_LOG.info("Launch: %s", description)
epilog = """
- Additional --key=value pairs can be specified to augment or override values listed in --globals.
- Other required_args values can also be pulled from shell environment variables.
+ Additional --key=value pairs can be specified to augment or override
+ values listed in --globals.
+ Other required_args values can also be pulled from shell environment
+ variables.
- For additional details, please see the website or the README.md files in the source tree:
+ For additional details, please see the website or the README.md files in
+ the source tree:
"""
- parser = argparse.ArgumentParser(description=f"{description} : {long_text}",
- epilog=epilog)
+ parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog)
(args, args_rest) = self._parse_args(parser, argv)
# Bootstrap config loader: command line takes priority.
@@ -101,16 +93,18 @@ class Launcher:
args_rest,
{key: val for (key, val) in config.items() if key not in vars(args)},
)
- # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI.
+ # experiment_id is generally taken from --globals files, but we also allow
+ # overriding it on the CLI.
# It's useful to keep it there explicitly mostly for the --help output.
if args.experiment_id:
- self.global_config['experiment_id'] = args.experiment_id
- # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line
+ self.global_config["experiment_id"] = args.experiment_id
+ # trial_config_repeat_count is a scheduler property but it's convenient to
+ # set it via command line
if args.trial_config_repeat_count:
self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count
# Ensure that the trial_id is present since it gets used by some other
# configs but is typically controlled by the run optimize loop.
- self.global_config.setdefault('trial_id', 1)
+ self.global_config.setdefault("trial_id", 1)
self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True)
assert isinstance(self.global_config, dict)
@@ -118,24 +112,31 @@ class Launcher:
# --service cli args should override the config file values.
service_files: List[str] = config.get("services", []) + (args.service or [])
assert isinstance(self._parent_service, SupportsConfigLoading)
- self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service)
+ self._parent_service = self._parent_service.load_services(
+ service_files,
+ self.global_config,
+ self._parent_service,
+ )
env_path = args.environment or config.get("environment")
if not env_path:
_LOG.error("No environment config specified.")
- parser.error("At least the Environment config must be specified." +
- " Run `mlos_bench --help` and consult `README.md` for more info.")
+ parser.error(
+ "At least the Environment config must be specified."
+ + " Run `mlos_bench --help` and consult `README.md` for more info."
+ )
self.root_env_config = self._config_loader.resolve_path(env_path)
self.environment: Environment = self._config_loader.load_environment(
- self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service)
+ self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service
+ )
_LOG.info("Init environment: %s", self.environment)
# NOTE: Init tunable values *after* the Environment, but *before* the Optimizer
self.tunables = self._init_tunable_values(
args.random_init or config.get("random_init", False),
config.get("random_seed") if args.random_seed is None else args.random_seed,
- config.get("tunable_values", []) + (args.tunable_values or [])
+ config.get("tunable_values", []) + (args.tunable_values or []),
)
_LOG.info("Init tunables: %s", self.tunables)
@@ -145,106 +146,170 @@ class Launcher:
self.storage = self._load_storage(args.storage or config.get("storage"))
_LOG.info("Init storage: %s", self.storage)
- self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True))
+ self.teardown: bool = (
+ bool(args.teardown)
+ if args.teardown is not None
+ else bool(config.get("teardown", True))
+ )
self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler"))
_LOG.info("Init scheduler: %s", self.scheduler)
@property
def config_loader(self) -> ConfigPersistenceService:
- """
- Get the config loader service.
- """
+ """Get the config loader service."""
return self._config_loader
@property
def service(self) -> Service:
- """
- Get the parent service.
- """
+ """Get the parent service."""
return self._parent_service
@staticmethod
- def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]:
- """
- Parse the command line arguments.
- """
+ def _parse_args(
+ parser: argparse.ArgumentParser,
+ argv: Optional[List[str]],
+ ) -> Tuple[argparse.Namespace, List[str]]:
+ """Parse the command line arguments."""
parser.add_argument(
- '--config', required=False,
- help='Main JSON5 configuration file. Its keys are the same as the' +
- ' command line options and can be overridden by the latter.\n' +
- '\n' +
- ' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' +
- ' for additional config examples for this and other arguments.')
+ "--config",
+ required=False,
+ help="Main JSON5 configuration file. Its keys are the same as the"
+ + " command line options and can be overridden by the latter.\n"
+ + "\n"
+ + " See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ "
+ + " for additional config examples for this and other arguments.",
+ )
parser.add_argument(
- '--log_file', '--log-file', required=False,
- help='Path to the log file. Use stdout if omitted.')
+ "--log_file",
+ "--log-file",
+ required=False,
+ help="Path to the log file. Use stdout if omitted.",
+ )
parser.add_argument(
- '--log_level', '--log-level', required=False, type=str,
- help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' +
- ' Set to DEBUG for debug, WARNING for warnings only.')
+ "--log_level",
+ "--log-level",
+ required=False,
+ type=str,
+ help=f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}."
+ + " Set to DEBUG for debug, WARNING for warnings only.",
+ )
parser.add_argument(
- '--config_path', '--config-path', '--config-paths', '--config_paths',
- nargs="+", action='extend', required=False,
- help='One or more locations of JSON config files.')
+ "--config_path",
+ "--config-path",
+ "--config-paths",
+ "--config_paths",
+ nargs="+",
+ action="extend",
+ required=False,
+ help="One or more locations of JSON config files.",
+ )
parser.add_argument(
- '--service', '--services',
- nargs='+', action='extend', required=False,
- help='Path to JSON file with the configuration of the service(s) for environment(s) to use.')
+ "--service",
+ "--services",
+ nargs="+",
+ action="extend",
+ required=False,
+ help=(
+ "Path to JSON file with the configuration "
+ "of the service(s) for environment(s) to use."
+ ),
+ )
parser.add_argument(
- '--environment', required=False,
- help='Path to JSON file with the configuration of the benchmarking environment(s).')
+ "--environment",
+ required=False,
+ help="Path to JSON file with the configuration of the benchmarking environment(s).",
+ )
parser.add_argument(
- '--optimizer', required=False,
- help='Path to the optimizer configuration file. If omitted, run' +
- ' a single trial with default (or specified in --tunable_values).')
+ "--optimizer",
+ required=False,
+ help="Path to the optimizer configuration file. If omitted, run"
+ + " a single trial with default (or specified in --tunable_values).",
+ )
parser.add_argument(
- '--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int,
- help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.')
+ "--trial_config_repeat_count",
+ "--trial-config-repeat-count",
+ required=False,
+ type=int,
+ help=(
+ "Number of times to repeat each config. "
+ "Default is 1 trial per config, though more may be advised."
+ ),
+ )
parser.add_argument(
- '--scheduler', required=False,
- help='Path to the scheduler configuration file. By default, use' +
- ' a single worker synchronous scheduler.')
+ "--scheduler",
+ required=False,
+ help="Path to the scheduler configuration file. By default, use"
+ + " a single worker synchronous scheduler.",
+ )
parser.add_argument(
- '--storage', required=False,
- help='Path to the storage configuration file.' +
- ' If omitted, use the ephemeral in-memory SQL storage.')
+ "--storage",
+ required=False,
+ help="Path to the storage configuration file."
+ + " If omitted, use the ephemeral in-memory SQL storage.",
+ )
parser.add_argument(
- '--random_init', '--random-init', required=False, default=False,
- dest='random_init', action='store_true',
- help='Initialize tunables with random values. (Before applying --tunable_values).')
+ "--random_init",
+ "--random-init",
+ required=False,
+ default=False,
+ dest="random_init",
+ action="store_true",
+ help="Initialize tunables with random values. (Before applying --tunable_values).",
+ )
parser.add_argument(
- '--random_seed', '--random-seed', required=False, type=int,
- help='Seed to use with --random_init')
+ "--random_seed",
+ "--random-seed",
+ required=False,
+ type=int,
+ help="Seed to use with --random_init",
+ )
parser.add_argument(
- '--tunable_values', '--tunable-values', nargs="+", action='extend', required=False,
- help='Path to one or more JSON files that contain values of the tunable' +
- ' parameters. This can be used for a single trial (when no --optimizer' +
- ' is specified) or as default values for the first run in optimization.')
+ "--tunable_values",
+ "--tunable-values",
+ nargs="+",
+ action="extend",
+ required=False,
+ help="Path to one or more JSON files that contain values of the tunable"
+ + " parameters. This can be used for a single trial (when no --optimizer"
+ + " is specified) or as default values for the first run in optimization.",
+ )
parser.add_argument(
- '--globals', nargs="+", action='extend', required=False,
- help='Path to one or more JSON files that contain additional' +
- ' [private] parameters of the benchmarking environment.')
+ "--globals",
+ nargs="+",
+ action="extend",
+ required=False,
+ help="Path to one or more JSON files that contain additional"
+ + " [private] parameters of the benchmarking environment.",
+ )
parser.add_argument(
- '--no_teardown', '--no-teardown', required=False, default=None,
- dest='teardown', action='store_false',
- help='Disable teardown of the environment after the benchmark.')
+ "--no_teardown",
+ "--no-teardown",
+ required=False,
+ default=None,
+ dest="teardown",
+ action="store_false",
+ help="Disable teardown of the environment after the benchmark.",
+ )
parser.add_argument(
- '--experiment_id', '--experiment-id', required=False, default=None,
+ "--experiment_id",
+ "--experiment-id",
+ required=False,
+ default=None,
help="""
Experiment ID to use for the benchmark.
If omitted, the value from the --cli config or --globals is used.
@@ -254,7 +319,7 @@ class Launcher:
changes are made to config files, scripts, versions, etc.
This is left as a manual operation as detection of what is
"incompatible" is not easily automatable across systems.
- """
+ """,
)
# By default we use the command line arguments, but allow the caller to
@@ -267,9 +332,7 @@ class Launcher:
@staticmethod
def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]:
- """
- Helper function to parse global key/value pairs from the command line.
- """
+ """Helper function to parse global key/value pairs from the command line."""
_LOG.debug("Extra args: %s", cmdline)
config: Dict[str, TunableValue] = {}
@@ -296,16 +359,17 @@ class Launcher:
_LOG.debug("Parsed config: %s", config)
return config
- def _load_config(self,
- args_globals: Iterable[str],
- config_path: Iterable[str],
- args_rest: Iterable[str],
- global_config: Dict[str, Any]) -> Dict[str, Any]:
+ def _load_config(
+ self,
+ args_globals: Iterable[str],
+ config_path: Iterable[str],
+ args_rest: Iterable[str],
+ global_config: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """Get key/value pairs of the global configuration parameters from the specified
+ config files (if any) and command line arguments.
"""
- Get key/value pairs of the global configuration parameters
- from the specified config files (if any) and command line arguments.
- """
- for config_file in (args_globals or []):
+ for config_file in args_globals or []:
conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS)
assert isinstance(conf, dict)
global_config.update(conf)
@@ -314,19 +378,24 @@ class Launcher:
global_config["config_path"] = config_path
return global_config
- def _init_tunable_values(self, random_init: bool, seed: Optional[int],
- args_tunables: Optional[str]) -> TunableGroups:
- """
- Initialize the tunables and load key/value pairs of the tunable values
- from given JSON files, if specified.
+ def _init_tunable_values(
+ self,
+ random_init: bool,
+ seed: Optional[int],
+ args_tunables: Optional[str],
+ ) -> TunableGroups:
+ """Initialize the tunables and load key/value pairs of the tunable values from
+ given JSON files, if specified.
"""
tunables = self.environment.tunable_params
_LOG.debug("Init tunables: default = %s", tunables)
if random_init:
tunables = MockOptimizer(
- tunables=tunables, service=None,
- config={"start_with_defaults": False, "seed": seed}).suggest()
+ tunables=tunables,
+ service=None,
+ config={"start_with_defaults": False, "seed": seed},
+ ).suggest()
_LOG.debug("Init tunables: random = %s", tunables)
if args_tunables is not None:
@@ -340,50 +409,64 @@ class Launcher:
def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer:
"""
- Instantiate the Optimizer object from JSON config file, if specified
- in the --optimizer command line option. If config file not specified,
- create a one-shot optimizer to run a single benchmark trial.
+ Instantiate the Optimizer object from JSON config file, if specified in the
+ --optimizer command line option.
+
+ If config file not specified, create a one-shot optimizer to run a single
+ benchmark trial.
"""
if args_optimizer is None:
# global_config may contain additional properties, so we need to
# strip those out before instantiating the basic oneshot optimizer.
- config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS}
- return OneShotOptimizer(
- self.tunables, config=config, service=self._parent_service)
+ config = {
+ key: val
+ for key, val in self.global_config.items()
+ if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS
+ }
+ return OneShotOptimizer(self.tunables, config=config, service=self._parent_service)
class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER)
assert isinstance(class_config, Dict)
- optimizer = self._config_loader.build_optimizer(tunables=self.tunables,
- service=self._parent_service,
- config=class_config,
- global_config=self.global_config)
+ optimizer = self._config_loader.build_optimizer(
+ tunables=self.tunables,
+ service=self._parent_service,
+ config=class_config,
+ global_config=self.global_config,
+ )
return optimizer
def _load_storage(self, args_storage: Optional[str]) -> Storage:
"""
- Instantiate the Storage object from JSON file provided in the --storage
- command line parameter. If omitted, create an ephemeral in-memory SQL
- storage instead.
+ Instantiate the Storage object from JSON file provided in the --storage command
+ line parameter.
+
+ If omitted, create an ephemeral in-memory SQL storage instead.
"""
if args_storage is None:
# pylint: disable=import-outside-toplevel
from mlos_bench.storage.sql.storage import SqlStorage
- return SqlStorage(service=self._parent_service,
- config={
- "drivername": "sqlite",
- "database": ":memory:",
- "lazy_schema_create": True,
- })
+
+ return SqlStorage(
+ service=self._parent_service,
+ config={
+ "drivername": "sqlite",
+ "database": ":memory:",
+ "lazy_schema_create": True,
+ },
+ )
class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE)
assert isinstance(class_config, Dict)
- storage = self._config_loader.build_storage(service=self._parent_service,
- config=class_config,
- global_config=self.global_config)
+ storage = self._config_loader.build_storage(
+ service=self._parent_service,
+ config=class_config,
+ global_config=self.global_config,
+ )
return storage
def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler:
"""
Instantiate the Scheduler object from JSON file provided in the --scheduler
command line parameter.
+
Create a simple synchronous single-threaded scheduler if omitted.
"""
# Set `teardown` for scheduler only to prevent conflicts with other configs.
@@ -392,6 +475,7 @@ class Launcher:
if args_scheduler is None:
# pylint: disable=import-outside-toplevel
from mlos_bench.schedulers.sync_scheduler import SyncScheduler
+
return SyncScheduler(
# All config values can be overridden from global config
config={
diff --git a/mlos_bench/mlos_bench/optimizers/__init__.py b/mlos_bench/mlos_bench/optimizers/__init__.py
index f875917251..7cd6a8a25a 100644
--- a/mlos_bench/mlos_bench/optimizers/__init__.py
+++ b/mlos_bench/mlos_bench/optimizers/__init__.py
@@ -2,18 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Interfaces and wrapper classes for optimizers to be used in Autotune.
-"""
+"""Interfaces and wrapper classes for optimizers to be used in Autotune."""
from mlos_bench.optimizers.base_optimizer import Optimizer
+from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
-from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
__all__ = [
- 'Optimizer',
- 'MockOptimizer',
- 'OneShotOptimizer',
- 'MlosCoreOptimizer',
+ "Optimizer",
+ "MockOptimizer",
+ "OneShotOptimizer",
+ "MlosCoreOptimizer",
]
diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py
index 911c624315..6fa7ad87f4 100644
--- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py
+++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py
@@ -2,34 +2,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base class for an interface between the benchmarking framework
-and mlos_core optimizers.
+"""Base class for an interface between the benchmarking framework and mlos_core
+optimizers.
"""
import logging
from abc import ABCMeta, abstractmethod
-from distutils.util import strtobool # pylint: disable=deprecated-module
-
+from distutils.util import strtobool # pylint: disable=deprecated-module
from types import TracebackType
from typing import Dict, Optional, Sequence, Tuple, Type, Union
-from typing_extensions import Literal
from ConfigSpace import ConfigurationSpace
+from typing_extensions import Literal
from mlos_bench.config.schemas import ConfigSchema
-from mlos_bench.services.base_service import Service
from mlos_bench.environments.status import Status
+from mlos_bench.optimizers.convert_configspace import tunable_groups_to_configspace
+from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
-from mlos_bench.optimizers.convert_configspace import tunable_groups_to_configspace
_LOG = logging.getLogger(__name__)
-class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
- """
- An abstract interface between the benchmarking framework and mlos_core optimizers.
+class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
+ """An abstract interface between the benchmarking framework and mlos_core
+ optimizers.
"""
# See Also: mlos_bench/mlos_bench/config/schemas/optimizers/optimizer-schema.json
@@ -40,13 +38,16 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
"start_with_defaults",
}
- def __init__(self,
- tunables: TunableGroups,
- config: dict,
- global_config: Optional[dict] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ tunables: TunableGroups,
+ config: dict,
+ global_config: Optional[dict] = None,
+ service: Optional[Service] = None,
+ ):
"""
- Create a new optimizer for the given configuration space defined by the tunables.
+ Create a new optimizer for the given configuration space defined by the
+ tunables.
Parameters
----------
@@ -68,19 +69,20 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
self._seed = int(config.get("seed", 42))
self._in_context = False
- experiment_id = self._global_config.get('experiment_id')
+ experiment_id = self._global_config.get("experiment_id")
self.experiment_id = str(experiment_id).strip() if experiment_id else None
self._iter = 0
# If False, use the optimizer to suggest the initial configuration;
# if True (default), use the already initialized values for the first iteration.
self._start_with_defaults: bool = bool(
- strtobool(str(self._config.pop('start_with_defaults', True))))
- self._max_iter = int(self._config.pop('max_suggestions', 100))
+ strtobool(str(self._config.pop("start_with_defaults", True)))
+ )
+ self._max_iter = int(self._config.pop("max_suggestions", 100))
- opt_targets: Dict[str, str] = self._config.pop('optimization_targets', {'score': 'min'})
+ opt_targets: Dict[str, str] = self._config.pop("optimization_targets", {"score": "min"})
self._opt_targets: Dict[str, Literal[1, -1]] = {}
- for (opt_target, opt_dir) in opt_targets.items():
+ for opt_target, opt_dir in opt_targets.items():
if opt_dir == "min":
self._opt_targets[opt_target] = 1
elif opt_dir == "max":
@@ -89,10 +91,9 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
raise ValueError(f"Invalid optimization direction: {opt_dir} for {opt_target}")
def _validate_json_config(self, config: dict) -> None:
- """
- Reconstructs a basic json config that this class might have been
- instantiated from in order to validate configs provided outside the
- file loading mechanism.
+ """Reconstructs a basic json config that this class might have been instantiated
+ from in order to validate configs provided outside the file loading
+ mechanism.
"""
json_config: dict = {
"class": self.__class__.__module__ + "." + self.__class__.__name__,
@@ -108,21 +109,20 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
)
return f"{self.name}({opt_targets},config={self._config})"
- def __enter__(self) -> 'Optimizer':
- """
- Enter the optimizer's context.
- """
+ def __enter__(self) -> "Optimizer":
+ """Enter the optimizer's context."""
_LOG.debug("Optimizer START :: %s", self)
assert not self._in_context
self._in_context = True
return self
- def __exit__(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
- """
- Exit the context of the optimizer.
- """
+ def __exit__(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
+ """Exit the context of the optimizer."""
if ex_val is None:
_LOG.debug("Optimizer END :: %s", self)
else:
@@ -154,15 +154,14 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
@property
def seed(self) -> int:
- """
- The random seed for the optimizer.
- """
+ """The random seed for the optimizer."""
return self._seed
@property
def start_with_defaults(self) -> bool:
"""
Return True if the optimizer should start with the default values.
+
Note: This parameter is mutable and will be reset to False after the
defaults are first suggested.
"""
@@ -198,16 +197,16 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
@property
def name(self) -> str:
"""
- The name of the optimizer. We save this information in
- mlos_bench storage to track the source of each configuration.
+ The name of the optimizer.
+
+ We save this information in mlos_bench storage to track the source of each
+ configuration.
"""
return self.__class__.__name__
@property
- def targets(self) -> Dict[str, Literal['min', 'max']]:
- """
- A dictionary of {target: direction} of optimization targets.
- """
+ def targets(self) -> Dict[str, Literal["min", "max"]]:
+ """A dictionary of {target: direction} of optimization targets."""
return {
opt_target: "min" if opt_dir == 1 else "max"
for (opt_target, opt_dir) in self._opt_targets.items()
@@ -215,16 +214,18 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
@property
def supports_preload(self) -> bool:
- """
- Return True if the optimizer supports pre-loading the data from previous experiments.
+ """Return True if the optimizer supports pre-loading the data from previous
+ experiments.
"""
return True
@abstractmethod
- def bulk_register(self,
- configs: Sequence[dict],
- scores: Sequence[Optional[Dict[str, TunableValue]]],
- status: Optional[Sequence[Status]] = None) -> bool:
+ def bulk_register(
+ self,
+ configs: Sequence[dict],
+ scores: Sequence[Optional[Dict[str, TunableValue]]],
+ status: Optional[Sequence[Status]] = None,
+ ) -> bool:
"""
Pre-load the optimizer with the bulk data from previous experiments.
@@ -242,8 +243,12 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
is_not_empty : bool
True if there is data to register, false otherwise.
"""
- _LOG.info("Update the optimizer with: %d configs, %d scores, %d status values",
- len(configs or []), len(scores or []), len(status or []))
+ _LOG.info(
+ "Update the optimizer with: %d configs, %d scores, %d status values",
+ len(configs or []),
+ len(scores or []),
+ len(status or []),
+ )
if len(configs or []) != len(scores or []):
raise ValueError("Numbers of configs and scores do not match.")
if status is not None and len(configs or []) != len(status or []):
@@ -256,9 +261,8 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
def suggest(self) -> TunableGroups:
"""
- Generate the next suggestion.
- Base class' implementation increments the iteration count
- and returns the current values of the tunables.
+ Generate the next suggestion. Base class' implementation increments the
+ iteration count and returns the current values of the tunables.
Returns
-------
@@ -272,8 +276,12 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
return self._tunables.copy()
@abstractmethod
- def register(self, tunables: TunableGroups, status: Status,
- score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
+ def register(
+ self,
+ tunables: TunableGroups,
+ status: Status,
+ score: Optional[Dict[str, TunableValue]] = None,
+ ) -> Optional[Dict[str, float]]:
"""
Register the observation for the given configuration.
@@ -294,18 +302,25 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
Benchmark scores extracted (and possibly transformed)
from the dataframe that's being MINIMIZED.
"""
- _LOG.info("Iteration %d :: Register: %s = %s score: %s",
- self._iter, tunables, status, score)
+ _LOG.info(
+ "Iteration %d :: Register: %s = %s score: %s",
+ self._iter,
+ tunables,
+ status,
+ score,
+ )
if status.is_succeeded() == (score is None): # XOR
raise ValueError("Status and score must be consistent.")
return self._get_scores(status, score)
- def _get_scores(self, status: Status,
- scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]]
- ) -> Optional[Dict[str, float]]:
+ def _get_scores(
+ self,
+ status: Status,
+ scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]],
+ ) -> Optional[Dict[str, float]]:
"""
- Extract a scalar benchmark score from the dataframe.
- Change the sign if we are maximizing.
+ Extract a scalar benchmark score from the dataframe. Change the sign if we are
+ maximizing.
Parameters
----------
@@ -331,7 +346,7 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
assert scores is not None
target_metrics: Dict[str, float] = {}
- for (opt_target, opt_dir) in self._opt_targets.items():
+ for opt_target, opt_dir in self._opt_targets.items():
val = scores[opt_target]
assert val is not None
target_metrics[opt_target] = float(val) * opt_dir
@@ -341,12 +356,15 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
def not_converged(self) -> bool:
"""
Return True if not converged, False otherwise.
+
Base implementation just checks the iteration count.
"""
return self._iter < self._max_iter
@abstractmethod
- def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
+ def get_best_observation(
+ self,
+ ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
"""
Get the best observation so far.
diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py
index 6978a8d410..e4ff9897fa 100644
--- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py
+++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py
@@ -2,12 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Functions to convert TunableGroups to ConfigSpace for use with the mlos_core optimizers.
+"""Functions to convert TunableGroups to ConfigSpace for use with the mlos_core
+optimizers.
"""
import logging
-
from typing import Dict, List, Optional, Tuple, Union
from ConfigSpace import (
@@ -21,9 +20,10 @@ from ConfigSpace import (
Normal,
Uniform,
)
+
from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
-from mlos_bench.util import try_parse_val, nullable
+from mlos_bench.util import nullable, try_parse_val
_LOG = logging.getLogger(__name__)
@@ -31,6 +31,7 @@ _LOG = logging.getLogger(__name__)
class TunableValueKind:
"""
Enum for the kind of the tunable value (special or not).
+
It is not a true enum because ConfigSpace wants string values.
"""
@@ -40,15 +41,16 @@ class TunableValueKind:
def _normalize_weights(weights: List[float]) -> List[float]:
- """
- Helper function for normalizing weights to probabilities.
- """
+ """Helper function for normalizing weights to probabilities."""
total = sum(weights)
return [w / total for w in weights]
def _tunable_to_configspace(
- tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> ConfigurationSpace:
+ tunable: Tunable,
+ group_name: Optional[str] = None,
+ cost: int = 0,
+) -> ConfigurationSpace:
"""
Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects,
wrapped in a ConfigurationSpace for composability.
@@ -71,14 +73,17 @@ def _tunable_to_configspace(
meta = {"group": group_name, "cost": cost} # {"scaling": ""}
if tunable.type == "categorical":
- return ConfigurationSpace({
- tunable.name: CategoricalHyperparameter(
- name=tunable.name,
- choices=tunable.categories,
- weights=_normalize_weights(tunable.weights) if tunable.weights else None,
- default_value=tunable.default,
- meta=meta)
- })
+ return ConfigurationSpace(
+ {
+ tunable.name: CategoricalHyperparameter(
+ name=tunable.name,
+ choices=tunable.categories,
+ weights=_normalize_weights(tunable.weights) if tunable.weights else None,
+ default_value=tunable.default,
+ meta=meta,
+ )
+ }
+ )
distribution: Union[Uniform, Normal, Beta, None] = None
if tunable.distribution == "uniform":
@@ -86,12 +91,12 @@ def _tunable_to_configspace(
elif tunable.distribution == "normal":
distribution = Normal(
mu=tunable.distribution_params["mu"],
- sigma=tunable.distribution_params["sigma"]
+ sigma=tunable.distribution_params["sigma"],
)
elif tunable.distribution == "beta":
distribution = Beta(
alpha=tunable.distribution_params["alpha"],
- beta=tunable.distribution_params["beta"]
+ beta=tunable.distribution_params["beta"],
)
elif tunable.distribution is not None:
raise TypeError(f"Invalid Distribution Type: {tunable.distribution}")
@@ -103,22 +108,26 @@ def _tunable_to_configspace(
log=bool(tunable.is_log),
q=nullable(int, tunable.quantization),
distribution=distribution,
- default=(int(tunable.default)
- if tunable.in_range(tunable.default) and tunable.default is not None
- else None),
- meta=meta
+ default=(
+ int(tunable.default)
+ if tunable.in_range(tunable.default) and tunable.default is not None
+ else None
+ ),
+ meta=meta,
)
elif tunable.type == "float":
range_hp = Float(
name=tunable.name,
bounds=tunable.range,
log=bool(tunable.is_log),
- q=tunable.quantization, # type: ignore[arg-type]
+ q=tunable.quantization, # type: ignore[arg-type]
distribution=distribution, # type: ignore[arg-type]
- default=(float(tunable.default)
- if tunable.in_range(tunable.default) and tunable.default is not None
- else None),
- meta=meta
+ default=(
+ float(tunable.default)
+ if tunable.in_range(tunable.default) and tunable.default is not None
+ else None
+ ),
+ meta=meta,
)
else:
raise TypeError(f"Invalid Parameter Type: {tunable.type}")
@@ -136,31 +145,38 @@ def _tunable_to_configspace(
# Create three hyperparameters: one for regular values,
# one for special values, and one to choose between the two.
(special_name, type_name) = special_param_names(tunable.name)
- conf_space = ConfigurationSpace({
- tunable.name: range_hp,
- special_name: CategoricalHyperparameter(
- name=special_name,
- choices=tunable.special,
- weights=special_weights,
- default_value=tunable.default if tunable.default in tunable.special else None,
- meta=meta
- ),
- type_name: CategoricalHyperparameter(
- name=type_name,
- choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE],
- weights=switch_weights,
- default_value=TunableValueKind.SPECIAL,
- ),
- })
- conf_space.add_condition(EqualsCondition(
- conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL))
- conf_space.add_condition(EqualsCondition(
- conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE))
+ conf_space = ConfigurationSpace(
+ {
+ tunable.name: range_hp,
+ special_name: CategoricalHyperparameter(
+ name=special_name,
+ choices=tunable.special,
+ weights=special_weights,
+ default_value=tunable.default if tunable.default in tunable.special else None,
+ meta=meta,
+ ),
+ type_name: CategoricalHyperparameter(
+ name=type_name,
+ choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE],
+ weights=switch_weights,
+ default_value=TunableValueKind.SPECIAL,
+ ),
+ }
+ )
+ conf_space.add_condition(
+ EqualsCondition(conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL)
+ )
+ conf_space.add_condition(
+ EqualsCondition(conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE)
+ )
return conf_space
-def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = None) -> ConfigurationSpace:
+def tunable_groups_to_configspace(
+ tunables: TunableGroups,
+ seed: Optional[int] = None,
+) -> ConfigurationSpace:
"""
Convert TunableGroups to hyperparameters in ConfigurationSpace.
@@ -178,11 +194,16 @@ def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] =
A new ConfigurationSpace instance that corresponds to the input TunableGroups.
"""
space = ConfigurationSpace(seed=seed)
- for (tunable, group) in tunables:
+ for tunable, group in tunables:
space.add_configuration_space(
- prefix="", delimiter="",
+ prefix="",
+ delimiter="",
configuration_space=_tunable_to_configspace(
- tunable, group.name, group.get_current_cost()))
+ tunable,
+ group.name,
+ group.get_current_cost(),
+ ),
+ )
return space
@@ -201,7 +222,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration:
A ConfigSpace Configuration.
"""
values: Dict[str, TunableValue] = {}
- for (tunable, _group) in tunables:
+ for tunable, _group in tunables:
if tunable.special:
(special_name, type_name) = special_param_names(tunable.name)
if tunable.value in tunable.special:
@@ -219,13 +240,11 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration:
def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]:
"""
Remove the fields that correspond to special values in ConfigSpace.
+
In particular, remove and keys suffixes added by `special_param_names`.
"""
data = data.copy()
- specials = [
- special_param_name_strip(k)
- for k in data.keys() if special_param_name_is_temp(k)
- ]
+ specials = [special_param_name_strip(k) for k in data.keys() if special_param_name_is_temp(k)]
for k in specials:
(special_name, type_name) = special_param_names(k)
if data[type_name] == TunableValueKind.SPECIAL:
@@ -240,8 +259,8 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]:
def special_param_names(name: str) -> Tuple[str, str]:
"""
- Generate the names of the auxiliary hyperparameters that correspond
- to a tunable that can have special values.
+ Generate the names of the auxiliary hyperparameters that correspond to a tunable
+ that can have special values.
NOTE: `!` characters are currently disallowed in Tunable names in order handle this logic.
diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py
index 0e836212d7..8bcd090415 100644
--- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py
+++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py
@@ -2,38 +2,35 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Grid search optimizer for mlos_bench.
-"""
+"""Grid search optimizer for mlos_bench."""
import logging
+from typing import Dict, Iterable, Optional, Sequence, Set, Tuple
-from typing import Dict, Iterable, Set, Optional, Sequence, Tuple
-
-import numpy as np
import ConfigSpace
+import numpy as np
from ConfigSpace.util import generate_grid
from mlos_bench.environments.status import Status
+from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
+from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
+from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
-from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
-from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
-from mlos_bench.services.base_service import Service
_LOG = logging.getLogger(__name__)
class GridSearchOptimizer(TrackBestOptimizer):
- """
- Grid search optimizer.
- """
+ """Grid search optimizer."""
- def __init__(self,
- tunables: TunableGroups,
- config: dict,
- global_config: Optional[dict] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ tunables: TunableGroups,
+ config: dict,
+ global_config: Optional[dict] = None,
+ service: Optional[Service] = None,
+ ):
super().__init__(tunables, config, global_config, service)
# Track the grid as a set of tuples of tunable values and reconstruct the
@@ -52,11 +49,21 @@ class GridSearchOptimizer(TrackBestOptimizer):
def _sanity_check(self) -> None:
size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables])
if size == np.inf:
- raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}")
+ raise ValueError(
+ f"Unquantized tunables are not supported for grid search: {self._tunables}"
+ )
if size > 10000:
- _LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables)
+ _LOG.warning(
+ "Large number %d of config points requested for grid search: %s",
+ size,
+ self._tunables,
+ )
if size > self._max_iter:
- _LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter)
+ _LOG.warning(
+ "Grid search size %d, is greater than max iterations %d",
+ size,
+ self._max_iter,
+ )
def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]:
"""
@@ -69,12 +76,14 @@ class GridSearchOptimizer(TrackBestOptimizer):
# names instead of the order given by TunableGroups.
configs = [
configspace_data_to_tunable_values(dict(config))
- for config in
- generate_grid(self.config_space, {
- tunable.name: int(tunable.cardinality)
- for (tunable, _group) in self._tunables
- if tunable.quantization or tunable.type == "int"
- })
+ for config in generate_grid(
+ self.config_space,
+ {
+ tunable.name: int(tunable.cardinality)
+ for (tunable, _group) in self._tunables
+ if tunable.quantization or tunable.type == "int"
+ },
+ )
]
names = set(tuple(configs.keys()) for configs in configs)
assert len(names) == 1
@@ -104,15 +113,17 @@ class GridSearchOptimizer(TrackBestOptimizer):
# See NOTEs above.
return (dict(zip(self._config_keys, config)) for config in self._suggested_configs)
- def bulk_register(self,
- configs: Sequence[dict],
- scores: Sequence[Optional[Dict[str, TunableValue]]],
- status: Optional[Sequence[Status]] = None) -> bool:
+ def bulk_register(
+ self,
+ configs: Sequence[dict],
+ scores: Sequence[Optional[Dict[str, TunableValue]]],
+ status: Optional[Sequence[Status]] = None,
+ ) -> bool:
if not super().bulk_register(configs, scores, status):
return False
if status is None:
status = [Status.SUCCEEDED] * len(configs)
- for (params, score, trial_status) in zip(configs, scores, status):
+ for params, score, trial_status in zip(configs, scores, status):
tunables = self._tunables.copy().assign(params)
self.register(tunables, trial_status, score)
if _LOG.isEnabledFor(logging.DEBUG):
@@ -121,9 +132,7 @@ class GridSearchOptimizer(TrackBestOptimizer):
return True
def suggest(self) -> TunableGroups:
- """
- Generate the next grid search suggestion.
- """
+ """Generate the next grid search suggestion."""
tunables = super().suggest()
if self._start_with_defaults:
_LOG.info("Use default values for the first trial")
@@ -153,20 +162,35 @@ class GridSearchOptimizer(TrackBestOptimizer):
_LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
return tunables
- def register(self, tunables: TunableGroups, status: Status,
- score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
+ def register(
+ self,
+ tunables: TunableGroups,
+ status: Status,
+ score: Optional[Dict[str, TunableValue]] = None,
+ ) -> Optional[Dict[str, float]]:
registered_score = super().register(tunables, status, score)
try:
- config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values()))
+ config = dict(
+ ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())
+ )
self._suggested_configs.remove(tuple(config.values()))
except KeyError:
- _LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables)
+ _LOG.warning(
+ (
+ "Attempted to remove missing config "
+ "(previously registered?) from suggested set: %s"
+ ),
+ tunables,
+ )
return registered_score
def not_converged(self) -> bool:
if self._iter > self._max_iter:
if bool(self._pending_configs):
- _LOG.warning("Exceeded max iterations, but still have %d pending configs: %s",
- len(self._pending_configs), list(self._pending_configs.keys()))
+ _LOG.warning(
+ "Exceeded max iterations, but still have %d pending configs: %s",
+ len(self._pending_configs),
+ list(self._pending_configs.keys()),
+ )
return False
return bool(self._pending_configs)
diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py
index e0235f76b9..e8c1195421 100644
--- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py
+++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py
@@ -2,72 +2,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A wrapper for mlos_core optimizers for mlos_bench.
-"""
+"""A wrapper for mlos_core optimizers for mlos_bench."""
import logging
import os
-
from types import TracebackType
from typing import Dict, Optional, Sequence, Tuple, Type, Union
-from typing_extensions import Literal
import pandas as pd
-
-from mlos_core.optimizers import (
- BaseOptimizer, OptimizerType, OptimizerFactory, SpaceAdapterType, DEFAULT_OPTIMIZER_TYPE
-)
+from typing_extensions import Literal
from mlos_bench.environments.status import Status
-from mlos_bench.services.base_service import Service
-from mlos_bench.tunables.tunable import TunableValue
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.base_optimizer import Optimizer
-
from mlos_bench.optimizers.convert_configspace import (
TunableValueKind,
configspace_data_to_tunable_values,
special_param_names,
)
+from mlos_bench.services.base_service import Service
+from mlos_bench.tunables.tunable import TunableValue
+from mlos_bench.tunables.tunable_groups import TunableGroups
+from mlos_core.optimizers import (
+ DEFAULT_OPTIMIZER_TYPE,
+ BaseOptimizer,
+ OptimizerFactory,
+ OptimizerType,
+ SpaceAdapterType,
+)
_LOG = logging.getLogger(__name__)
class MlosCoreOptimizer(Optimizer):
- """
- A wrapper class for the mlos_core optimizers.
- """
+ """A wrapper class for the mlos_core optimizers."""
- def __init__(self,
- tunables: TunableGroups,
- config: dict,
- global_config: Optional[dict] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ tunables: TunableGroups,
+ config: dict,
+ global_config: Optional[dict] = None,
+ service: Optional[Service] = None,
+ ):
super().__init__(tunables, config, global_config, service)
- opt_type = getattr(OptimizerType, self._config.pop(
- 'optimizer_type', DEFAULT_OPTIMIZER_TYPE.name))
+ opt_type = getattr(
+ OptimizerType, self._config.pop("optimizer_type", DEFAULT_OPTIMIZER_TYPE.name)
+ )
if opt_type == OptimizerType.SMAC:
- output_directory = self._config.get('output_directory')
+ output_directory = self._config.get("output_directory")
if output_directory is not None:
# If output_directory is specified, turn it into an absolute path.
- self._config['output_directory'] = os.path.abspath(output_directory)
+ self._config["output_directory"] = os.path.abspath(output_directory)
else:
- _LOG.warning("SMAC optimizer output_directory was null. SMAC will use a temporary directory.")
+ _LOG.warning(
+ (
+ "SMAC optimizer output_directory was null. "
+ "SMAC will use a temporary directory."
+ )
+ )
# Make sure max_trials >= max_iterations.
- if 'max_trials' not in self._config:
- self._config['max_trials'] = self._max_iter
- assert int(self._config['max_trials']) >= self._max_iter, \
- f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}"
+ if "max_trials" not in self._config:
+ self._config["max_trials"] = self._max_iter
+ assert (
+ int(self._config["max_trials"]) >= self._max_iter
+ ), f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}"
- if 'run_name' not in self._config and self.experiment_id:
- self._config['run_name'] = self.experiment_id
+ if "run_name" not in self._config and self.experiment_id:
+ self._config["run_name"] = self.experiment_id
- space_adapter_type = self._config.pop('space_adapter_type', None)
- space_adapter_config = self._config.pop('space_adapter_config', {})
+ space_adapter_type = self._config.pop("space_adapter_type", None)
+ space_adapter_config = self._config.pop("space_adapter_config", {})
if space_adapter_type is not None:
space_adapter_type = getattr(SpaceAdapterType, space_adapter_type)
@@ -81,9 +87,12 @@ class MlosCoreOptimizer(Optimizer):
space_adapter_kwargs=space_adapter_config,
)
- def __exit__(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
+ def __exit__(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
self._opt.cleanup()
return super().__exit__(ex_type, ex_val, ex_tb)
@@ -91,10 +100,12 @@ class MlosCoreOptimizer(Optimizer):
def name(self) -> str:
return f"{self.__class__.__name__}:{self._opt.__class__.__name__}"
- def bulk_register(self,
- configs: Sequence[dict],
- scores: Sequence[Optional[Dict[str, TunableValue]]],
- status: Optional[Sequence[Status]] = None) -> bool:
+ def bulk_register(
+ self,
+ configs: Sequence[dict],
+ scores: Sequence[Optional[Dict[str, TunableValue]]],
+ status: Optional[Sequence[Status]] = None,
+ ) -> bool:
if not super().bulk_register(configs, scores, status):
return False
@@ -102,7 +113,8 @@ class MlosCoreOptimizer(Optimizer):
df_configs = self._to_df(configs) # Impute missing values, if necessary
df_scores = self._adjust_signs_df(
- pd.DataFrame([{} if score is None else score for score in scores]))
+ pd.DataFrame([{} if score is None else score for score in scores])
+ )
opt_targets = list(self._opt_targets)
if status is not None:
@@ -126,17 +138,15 @@ class MlosCoreOptimizer(Optimizer):
return True
def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame:
- """
- In-place adjust the signs of the scores for MINIMIZATION problem.
- """
- for (opt_target, opt_dir) in self._opt_targets.items():
+ """In-place adjust the signs of the scores for MINIMIZATION problem."""
+ for opt_target, opt_dir in self._opt_targets.items():
df_scores[opt_target] *= opt_dir
return df_scores
def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame:
"""
- Select from past trials only the columns required in this experiment and
- impute default values for the tunables that are missing in the dataframe.
+ Select from past trials only the columns required in this experiment and impute
+ default values for the tunables that are missing in the dataframe.
Parameters
----------
@@ -151,7 +161,7 @@ class MlosCoreOptimizer(Optimizer):
df_configs = pd.DataFrame(configs)
tunables_names = list(self._tunables.get_param_values().keys())
missing_cols = set(tunables_names).difference(df_configs.columns)
- for (tunable, _group) in self._tunables:
+ for tunable, _group in self._tunables:
if tunable.name in missing_cols:
df_configs[tunable.name] = tunable.default
else:
@@ -183,22 +193,34 @@ class MlosCoreOptimizer(Optimizer):
df_config, _metadata = self._opt.suggest(defaults=self._start_with_defaults)
self._start_with_defaults = False
_LOG.info("Iteration %d :: Suggest:\n%s", self._iter, df_config)
- return tunables.assign(
- configspace_data_to_tunable_values(df_config.loc[0].to_dict()))
+ return tunables.assign(configspace_data_to_tunable_values(df_config.loc[0].to_dict()))
- def register(self, tunables: TunableGroups, status: Status,
- score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
- registered_score = super().register(tunables, status, score) # Sign-adjusted for MINIMIZATION
+ def register(
+ self,
+ tunables: TunableGroups,
+ status: Status,
+ score: Optional[Dict[str, TunableValue]] = None,
+ ) -> Optional[Dict[str, float]]:
+ registered_score = super().register(
+ tunables,
+ status,
+ score,
+ ) # Sign-adjusted for MINIMIZATION
if status.is_completed():
assert registered_score is not None
df_config = self._to_df([tunables.get_param_values()])
_LOG.debug("Score: %s Dataframe:\n%s", registered_score, df_config)
# TODO: Specify (in the config) which metrics to pass to the optimizer.
# Issue: https://github.com/microsoft/MLOS/issues/745
- self._opt.register(configs=df_config, scores=pd.DataFrame([registered_score], dtype=float))
+ self._opt.register(
+ configs=df_config,
+ scores=pd.DataFrame([registered_score], dtype=float),
+ )
return registered_score
- def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
+ def get_best_observation(
+ self,
+ ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
(df_config, df_score, _df_context) = self._opt.get_best_observations()
if len(df_config) == 0:
return (None, None)
diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py
index 7d2caff8ff..fd157db81a 100644
--- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py
+++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py
@@ -2,35 +2,31 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Mock optimizer for mlos_bench.
-"""
+"""Mock optimizer for mlos_bench."""
-import random
import logging
-
+import random
from typing import Callable, Dict, Optional, Sequence
from mlos_bench.environments.status import Status
-from mlos_bench.tunables.tunable import Tunable, TunableValue
-from mlos_bench.tunables.tunable_groups import TunableGroups
-
from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
from mlos_bench.services.base_service import Service
+from mlos_bench.tunables.tunable import Tunable, TunableValue
+from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
class MockOptimizer(TrackBestOptimizer):
- """
- Mock optimizer to test the Environment API.
- """
+ """Mock optimizer to test the Environment API."""
- def __init__(self,
- tunables: TunableGroups,
- config: dict,
- global_config: Optional[dict] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ tunables: TunableGroups,
+ config: dict,
+ global_config: Optional[dict] = None,
+ service: Optional[Service] = None,
+ ):
super().__init__(tunables, config, global_config, service)
rnd = random.Random(self.seed)
self._random: Dict[str, Callable[[Tunable], TunableValue]] = {
@@ -39,15 +35,17 @@ class MockOptimizer(TrackBestOptimizer):
"int": lambda tunable: rnd.randint(*tunable.range),
}
- def bulk_register(self,
- configs: Sequence[dict],
- scores: Sequence[Optional[Dict[str, TunableValue]]],
- status: Optional[Sequence[Status]] = None) -> bool:
+ def bulk_register(
+ self,
+ configs: Sequence[dict],
+ scores: Sequence[Optional[Dict[str, TunableValue]]],
+ status: Optional[Sequence[Status]] = None,
+ ) -> bool:
if not super().bulk_register(configs, scores, status):
return False
if status is None:
status = [Status.SUCCEEDED] * len(configs)
- for (params, score, trial_status) in zip(configs, scores, status):
+ for params, score, trial_status in zip(configs, scores, status):
tunables = self._tunables.copy().assign(params)
self.register(tunables, trial_status, score)
if _LOG.isEnabledFor(logging.DEBUG):
@@ -56,15 +54,13 @@ class MockOptimizer(TrackBestOptimizer):
return True
def suggest(self) -> TunableGroups:
- """
- Generate the next (random) suggestion.
- """
+ """Generate the next (random) suggestion."""
tunables = super().suggest()
if self._start_with_defaults:
_LOG.info("Use default tunable values")
self._start_with_defaults = False
else:
- for (tunable, _group) in tunables:
+ for tunable, _group in tunables:
tunable.value = self._random[tunable.type](tunable)
_LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
return tunables
diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py
index 314d048298..f41114c185 100644
--- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py
+++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py
@@ -2,16 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-No-op optimizer for mlos_bench that proposes a single configuration.
-"""
+"""No-op optimizer for mlos_bench that proposes a single configuration."""
import logging
from typing import Optional
+from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable_groups import TunableGroups
-from mlos_bench.optimizers.mock_optimizer import MockOptimizer
_LOG = logging.getLogger(__name__)
@@ -19,24 +17,25 @@ _LOG = logging.getLogger(__name__)
class OneShotOptimizer(MockOptimizer):
"""
No-op optimizer that proposes a single configuration and returns.
+
Explicit configs (partial or full) are possible using configuration files.
"""
# TODO: Add support for multiple explicit configs (i.e., FewShot or Manual Optimizer) - #344
- def __init__(self,
- tunables: TunableGroups,
- config: dict,
- global_config: Optional[dict] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ tunables: TunableGroups,
+ config: dict,
+ global_config: Optional[dict] = None,
+ service: Optional[Service] = None,
+ ):
super().__init__(tunables, config, global_config, service)
_LOG.info("Run a single iteration for: %s", self._tunables)
self._max_iter = 1 # Always run for just one iteration.
def suggest(self) -> TunableGroups:
- """
- Always produce the same (initial) suggestion.
- """
+ """Always produce the same (initial) suggestion."""
tunables = super().suggest()
self._start_with_defaults = True
return tunables
diff --git a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py
index c5d07ab93d..6ad8ab48d2 100644
--- a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py
+++ b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py
@@ -2,40 +2,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Mock optimizer for mlos_bench.
-"""
+"""Mock optimizer for mlos_bench."""
import logging
from abc import ABCMeta
from typing import Dict, Optional, Tuple, Union
from mlos_bench.environments.status import Status
-from mlos_bench.tunables.tunable import TunableValue
-from mlos_bench.tunables.tunable_groups import TunableGroups
-
from mlos_bench.optimizers.base_optimizer import Optimizer
from mlos_bench.services.base_service import Service
+from mlos_bench.tunables.tunable import TunableValue
+from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
- """
- Base Optimizer class that keeps track of the best score and configuration.
- """
+ """Base Optimizer class that keeps track of the best score and configuration."""
- def __init__(self,
- tunables: TunableGroups,
- config: dict,
- global_config: Optional[dict] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ tunables: TunableGroups,
+ config: dict,
+ global_config: Optional[dict] = None,
+ service: Optional[Service] = None,
+ ):
super().__init__(tunables, config, global_config, service)
self._best_config: Optional[TunableGroups] = None
self._best_score: Optional[Dict[str, float]] = None
- def register(self, tunables: TunableGroups, status: Status,
- score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
+ def register(
+ self,
+ tunables: TunableGroups,
+ status: Status,
+ score: Optional[Dict[str, TunableValue]] = None,
+ ) -> Optional[Dict[str, float]]:
registered_score = super().register(tunables, status, score)
if status.is_succeeded() and self._is_better(registered_score):
self._best_score = registered_score
@@ -43,13 +44,11 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
return registered_score
def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool:
- """
- Compare the optimization scores to the best ones so far lexicographically.
- """
+ """Compare the optimization scores to the best ones so far lexicographically."""
if self._best_score is None:
return True
assert registered_score is not None
- for (opt_target, best_score) in self._best_score.items():
+ for opt_target, best_score in self._best_score.items():
score = registered_score[opt_target]
if score < best_score:
return True
@@ -57,7 +56,9 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
return False
return False
- def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
+ def get_best_observation(
+ self,
+ ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
if self._best_score is None:
return (None, None)
score = self._get_scores(Status.SUCCEEDED, self._best_score)
diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py
index f3c556c06a..f750f12038 100644
--- a/mlos_bench/mlos_bench/os_environ.py
+++ b/mlos_bench/mlos_bench/os_environ.py
@@ -3,8 +3,8 @@
# Licensed under the MIT License.
#
"""
-Simple platform agnostic abstraction for the OS environment variables.
-Meant as a replacement for os.environ vs nt.environ.
+Simple platform agnostic abstraction for the OS environment variables. Meant as a
+replacement for os.environ vs nt.environ.
Example
-------
@@ -22,16 +22,18 @@ else:
from typing_extensions import TypeAlias
if sys.version_info >= (3, 9):
- EnvironType: TypeAlias = os._Environ[str] # pylint: disable=protected-access,disable=unsubscriptable-object
+ # pylint: disable=protected-access,disable=unsubscriptable-object
+ EnvironType: TypeAlias = os._Environ[str]
else:
- EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access
+ EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access
# Handle case sensitivity differences between platforms.
# https://stackoverflow.com/a/19023293
-if sys.platform == 'win32':
- import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8)
+if sys.platform == "win32":
+ import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8)
+
environ: EnvironType = nt.environ
else:
environ: EnvironType = os.environ
-__all__ = ['environ']
+__all__ = ["environ"]
diff --git a/mlos_bench/mlos_bench/run.py b/mlos_bench/mlos_bench/run.py
index 85c8c2b0c5..57c48a87b9 100755
--- a/mlos_bench/mlos_bench/run.py
+++ b/mlos_bench/mlos_bench/run.py
@@ -20,8 +20,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
-def _main(argv: Optional[List[str]] = None
- ) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
+def _main(
+ argv: Optional[List[str]] = None,
+) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv)
diff --git a/mlos_bench/mlos_bench/schedulers/__init__.py b/mlos_bench/mlos_bench/schedulers/__init__.py
index c54e3c0efc..381261e53d 100644
--- a/mlos_bench/mlos_bench/schedulers/__init__.py
+++ b/mlos_bench/mlos_bench/schedulers/__init__.py
@@ -2,14 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Interfaces and implementations of the optimization loop scheduling policies.
-"""
+"""Interfaces and implementations of the optimization loop scheduling policies."""
from mlos_bench.schedulers.base_scheduler import Scheduler
from mlos_bench.schedulers.sync_scheduler import SyncScheduler
__all__ = [
- 'Scheduler',
- 'SyncScheduler',
+ "Scheduler",
+ "SyncScheduler",
]
diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py
index d996a6e00f..e9c051175a 100644
--- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py
+++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py
@@ -2,20 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base class for the optimization loop scheduling policies.
-"""
+"""Base class for the optimization loop scheduling policies."""
import json
import logging
-from datetime import datetime
-
from abc import ABCMeta, abstractmethod
+from datetime import datetime
from types import TracebackType
from typing import Any, Dict, Optional, Tuple, Type
-from typing_extensions import Literal
from pytz import UTC
+from typing_extensions import Literal
from mlos_bench.environments.base_environment import Environment
from mlos_bench.optimizers.base_optimizer import Optimizer
@@ -28,22 +25,23 @@ _LOG = logging.getLogger(__name__)
class Scheduler(metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
- """
- Base class for the optimization loop scheduling policies.
- """
+ """Base class for the optimization loop scheduling policies."""
- def __init__(self, *,
- config: Dict[str, Any],
- global_config: Dict[str, Any],
- environment: Environment,
- optimizer: Optimizer,
- storage: Storage,
- root_env_config: str):
+ def __init__(
+ self,
+ *,
+ config: Dict[str, Any],
+ global_config: Dict[str, Any],
+ environment: Environment,
+ optimizer: Optimizer,
+ storage: Storage,
+ root_env_config: str,
+ ):
"""
- Create a new instance of the scheduler. The constructor of this
- and the derived classes is called by the persistence service
- after reading the class JSON configuration. Other objects like
- the Environment and Optimizer are provided by the Launcher.
+ Create a new instance of the scheduler. The constructor of this and the derived
+ classes is called by the persistence service after reading the class JSON
+ configuration. Other objects like the Environment and Optimizer are provided by
+ the Launcher.
Parameters
----------
@@ -61,8 +59,11 @@ class Scheduler(metaclass=ABCMeta):
Path to the root environment configuration.
"""
self.global_config = global_config
- config = merge_parameters(dest=config.copy(), source=global_config,
- required_keys=["experiment_id", "trial_id"])
+ config = merge_parameters(
+ dest=config.copy(),
+ source=global_config,
+ required_keys=["experiment_id", "trial_id"],
+ )
self._experiment_id = config["experiment_id"].strip()
self._trial_id = int(config["trial_id"])
@@ -72,7 +73,9 @@ class Scheduler(metaclass=ABCMeta):
self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1))
if self._trial_config_repeat_count <= 0:
- raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}")
+ raise ValueError(
+ f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}"
+ )
self._do_teardown = bool(config.get("teardown", True))
@@ -96,10 +99,8 @@ class Scheduler(metaclass=ABCMeta):
"""
return self.__class__.__name__
- def __enter__(self) -> 'Scheduler':
- """
- Enter the scheduler's context.
- """
+ def __enter__(self) -> "Scheduler":
+ """Enter the scheduler's context."""
_LOG.debug("Scheduler START :: %s", self)
assert self.experiment is None
self.environment.__enter__()
@@ -118,13 +119,13 @@ class Scheduler(metaclass=ABCMeta):
).__enter__()
return self
- def __exit__(self,
- ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
- """
- Exit the context of the scheduler.
- """
+ def __exit__(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
+ """Exit the context of the scheduler."""
if ex_val is None:
_LOG.debug("Scheduler END :: %s", self)
else:
@@ -139,12 +140,14 @@ class Scheduler(metaclass=ABCMeta):
@abstractmethod
def start(self) -> None:
- """
- Start the optimization loop.
- """
+ """Start the optimization loop."""
assert self.experiment is not None
- _LOG.info("START: Experiment: %s Env: %s Optimizer: %s",
- self.experiment, self.environment, self.optimizer)
+ _LOG.info(
+ "START: Experiment: %s Env: %s Optimizer: %s",
+ self.experiment,
+ self.environment,
+ self.optimizer,
+ )
if _LOG.isEnabledFor(logging.INFO):
_LOG.info("Root Environment:\n%s", self.environment.pprint())
@@ -155,6 +158,7 @@ class Scheduler(metaclass=ABCMeta):
def teardown(self) -> None:
"""
Tear down the environment.
+
Call it after the completion of the `.start()` in the scheduler context.
"""
assert self.experiment is not None
@@ -162,17 +166,13 @@ class Scheduler(metaclass=ABCMeta):
self.environment.teardown()
def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
- """
- Get the best observation from the optimizer.
- """
+ """Get the best observation from the optimizer."""
(best_score, best_config) = self.optimizer.get_best_observation()
_LOG.info("Env: %s best score: %s", self.environment, best_score)
return (best_score, best_config)
def load_config(self, config_id: int) -> TunableGroups:
- """
- Load the existing tunable configuration from the storage.
- """
+ """Load the existing tunable configuration from the storage."""
assert self.experiment is not None
tunable_values = self.experiment.load_tunable_config(config_id)
tunables = self.environment.tunable_params.assign(tunable_values)
@@ -183,9 +183,11 @@ class Scheduler(metaclass=ABCMeta):
def _schedule_new_optimizer_suggestions(self) -> bool:
"""
- Optimizer part of the loop. Load the results of the executed trials
- into the optimizer, suggest new configurations, and add them to the queue.
- Return True if optimization is not over, False otherwise.
+ Optimizer part of the loop.
+
+ Load the results of the executed trials into the optimizer, suggest new
+ configurations, and add them to the queue. Return True if optimization is not
+ over, False otherwise.
"""
assert self.experiment is not None
(trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id)
@@ -201,33 +203,38 @@ class Scheduler(metaclass=ABCMeta):
return not_done
def schedule_trial(self, tunables: TunableGroups) -> None:
- """
- Add a configuration to the queue of trials.
- """
+ """Add a configuration to the queue of trials."""
for repeat_i in range(1, self._trial_config_repeat_count + 1):
- self._add_trial_to_queue(tunables, config={
- # Add some additional metadata to track for the trial such as the
- # optimizer config used.
- # Note: these values are unfortunately mutable at the moment.
- # Consider them as hints of what the config was the trial *started*.
- # It is possible that the experiment configs were changed
- # between resuming the experiment (since that is not currently
- # prevented).
- "optimizer": self.optimizer.name,
- "repeat_i": repeat_i,
- "is_defaults": tunables.is_defaults(),
- **{
- f"opt_{key}_{i}": val
- for (i, opt_target) in enumerate(self.optimizer.targets.items())
- for (key, val) in zip(["target", "direction"], opt_target)
- }
- })
+ self._add_trial_to_queue(
+ tunables,
+ config={
+ # Add some additional metadata to track for the trial such as the
+ # optimizer config used.
+ # Note: these values are unfortunately mutable at the moment.
+ # Consider them as hints of what the config was the trial *started*.
+ # It is possible that the experiment configs were changed
+ # between resuming the experiment (since that is not currently
+ # prevented).
+ "optimizer": self.optimizer.name,
+ "repeat_i": repeat_i,
+ "is_defaults": tunables.is_defaults(),
+ **{
+ f"opt_{key}_{i}": val
+ for (i, opt_target) in enumerate(self.optimizer.targets.items())
+ for (key, val) in zip(["target", "direction"], opt_target)
+ },
+ },
+ )
- def _add_trial_to_queue(self, tunables: TunableGroups,
- ts_start: Optional[datetime] = None,
- config: Optional[Dict[str, Any]] = None) -> None:
+ def _add_trial_to_queue(
+ self,
+ tunables: TunableGroups,
+ ts_start: Optional[datetime] = None,
+ config: Optional[Dict[str, Any]] = None,
+ ) -> None:
"""
Add a configuration to the queue of trials.
+
A wrapper for the `Experiment.new_trial` method.
"""
assert self.experiment is not None
@@ -236,7 +243,9 @@ class Scheduler(metaclass=ABCMeta):
def _run_schedule(self, running: bool = False) -> None:
"""
- Scheduler part of the loop. Check for pending trials in the queue and run them.
+ Scheduler part of the loop.
+
+ Check for pending trials in the queue and run them.
"""
assert self.experiment is not None
for trial in self.experiment.pending_trials(datetime.now(UTC), running=running):
@@ -245,6 +254,7 @@ class Scheduler(metaclass=ABCMeta):
def not_done(self) -> bool:
"""
Check the stopping conditions.
+
By default, stop when the optimizer converges or max limit of trials reached.
"""
return self.optimizer.not_converged() and (
@@ -254,7 +264,9 @@ class Scheduler(metaclass=ABCMeta):
@abstractmethod
def run_trial(self, trial: Storage.Trial) -> None:
"""
- Set up and run a single trial. Save the results in the storage.
+ Set up and run a single trial.
+
+ Save the results in the storage.
"""
assert self.experiment is not None
self._trial_count += 1
diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py
index a73a493533..e56d15ca17 100644
--- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py
+++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A simple single-threaded synchronous optimization loop implementation.
-"""
+"""A simple single-threaded synchronous optimization loop implementation."""
import logging
from datetime import datetime
@@ -19,14 +17,10 @@ _LOG = logging.getLogger(__name__)
class SyncScheduler(Scheduler):
- """
- A simple single-threaded synchronous optimization loop implementation.
- """
+ """A simple single-threaded synchronous optimization loop implementation."""
def start(self) -> None:
- """
- Start the optimization loop.
- """
+ """Start the optimization loop."""
super().start()
is_warm_up = self.optimizer.supports_preload
@@ -42,7 +36,9 @@ class SyncScheduler(Scheduler):
def run_trial(self, trial: Storage.Trial) -> None:
"""
- Set up and run a single trial. Save the results in the storage.
+ Set up and run a single trial.
+
+ Save the results in the storage.
"""
super().run_trial(trial)
@@ -53,7 +49,8 @@ class SyncScheduler(Scheduler):
trial.update(Status.FAILED, datetime.now(UTC))
return
- (status, timestamp, results) = self.environment.run() # Block and wait for the final result.
+ # Block and wait for the final result.
+ (status, timestamp, results) = self.environment.run()
_LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results)
# In async mode (TODO), poll the environment for status and telemetry
diff --git a/mlos_bench/mlos_bench/services/__init__.py b/mlos_bench/mlos_bench/services/__init__.py
index 89e71be815..b768afb09c 100644
--- a/mlos_bench/mlos_bench/services/__init__.py
+++ b/mlos_bench/mlos_bench/services/__init__.py
@@ -2,17 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Services for implementing Environments for mlos_bench.
-"""
+"""Services for implementing Environments for mlos_bench."""
-from mlos_bench.services.base_service import Service
from mlos_bench.services.base_fileshare import FileShareService
+from mlos_bench.services.base_service import Service
from mlos_bench.services.local.local_exec import LocalExecService
-
__all__ = [
- 'Service',
- 'FileShareService',
- 'LocalExecService',
+ "Service",
+ "FileShareService",
+ "LocalExecService",
]
diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py
index b7282089e4..75ff3d2408 100644
--- a/mlos_bench/mlos_bench/services/base_fileshare.py
+++ b/mlos_bench/mlos_bench/services/base_fileshare.py
@@ -2,12 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base class for remote file shares.
-"""
+"""Base class for remote file shares."""
import logging
-
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
@@ -18,14 +15,15 @@ _LOG = logging.getLogger(__name__)
class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
- """
- An abstract base of all file shares.
- """
+ """An abstract base of all file shares."""
- def __init__(self, config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new file share with a given config.
@@ -43,12 +41,20 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [self.upload, self.download])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(methods, [self.upload, self.download]),
)
@abstractmethod
- def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
+ def download(
+ self,
+ params: dict,
+ remote_path: str,
+ local_path: str,
+ recursive: bool = True,
+ ) -> None:
"""
Downloads contents from a remote share path to a local path.
@@ -66,11 +72,22 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
if True (the default), download the entire directory tree.
"""
params = params or {}
- _LOG.info("Download from File Share %s recursively: %s -> %s (%s)",
- "" if recursive else "non", remote_path, local_path, params)
+ _LOG.info(
+ "Download from File Share %s recursively: %s -> %s (%s)",
+ "" if recursive else "non",
+ remote_path,
+ local_path,
+ params,
+ )
@abstractmethod
- def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
+ def upload(
+ self,
+ params: dict,
+ local_path: str,
+ remote_path: str,
+ recursive: bool = True,
+ ) -> None:
"""
Uploads contents from a local path to remote share path.
@@ -87,5 +104,10 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
if True (the default), upload the entire directory tree.
"""
params = params or {}
- _LOG.info("Upload to File Share %s recursively: %s -> %s (%s)",
- "" if recursive else "non", local_path, remote_path, params)
+ _LOG.info(
+ "Upload to File Share %s recursively: %s -> %s (%s)",
+ "" if recursive else "non",
+ local_path,
+ remote_path,
+ params,
+ )
diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py
index b171568172..c5d9b78c87 100644
--- a/mlos_bench/mlos_bench/services/base_service.py
+++ b/mlos_bench/mlos_bench/services/base_service.py
@@ -2,15 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base class for the service mix-ins.
-"""
+"""Base class for the service mix-ins."""
import json
import logging
-
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
+
from typing_extensions import Literal
from mlos_bench.config.schemas import ConfigSchema
@@ -21,16 +19,16 @@ _LOG = logging.getLogger(__name__)
class Service:
- """
- An abstract base of all Environment Services and used to build up mix-ins.
- """
+ """An abstract base of all Environment Services and used to build up mix-ins."""
@classmethod
- def new(cls,
- class_name: str,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional["Service"] = None) -> "Service":
+ def new(
+ cls,
+ class_name: str,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional["Service"] = None,
+ ) -> "Service":
"""
Factory method for a new service with a given config.
@@ -57,11 +55,13 @@ class Service:
assert issubclass(cls, Service)
return instantiate_from_config(cls, class_name, config, global_config, parent)
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional["Service"] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional["Service"] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new service with a given config.
@@ -101,12 +101,15 @@ class Service:
_LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
@staticmethod
- def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None],
- local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]:
+ def merge_methods(
+ ext_methods: Union[Dict[str, Callable], List[Callable], None],
+ local_methods: Union[Dict[str, Callable], List[Callable]],
+ ) -> Dict[str, Callable]:
"""
Merge methods from the external caller with the local ones.
- This function is usually called by the derived class constructor
- just before invoking the constructor of the base class.
+
+ This function is usually called by the derived class constructor just before
+ invoking the constructor of the base class.
"""
if isinstance(local_methods, dict):
local_methods = local_methods.copy()
@@ -138,9 +141,12 @@ class Service:
self._in_context = True
return self
- def __exit__(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
+ def __exit__(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
"""
Exit the Service mix-in context.
@@ -170,21 +176,24 @@ class Service:
"""
Enters the context for this particular Service instance.
- Called by the base __enter__ method of the Service class so it can be
- used with mix-ins and overridden by subclasses.
+ Called by the base __enter__ method of the Service class so it can be used with
+ mix-ins and overridden by subclasses.
"""
assert not self._in_context
self._in_context = True
return self
- def _exit_context(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
+ def _exit_context(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
"""
Exits the context for this particular Service instance.
- Called by the base __enter__ method of the Service class so it can be
- used with mix-ins and overridden by subclasses.
+ Called by the base __enter__ method of the Service class so it can be used with
+ mix-ins and overridden by subclasses.
"""
# pylint: disable=unused-argument
assert self._in_context
@@ -192,13 +201,13 @@ class Service:
return False
def _validate_json_config(self, config: dict) -> None:
- """
- Reconstructs a basic json config that this class might have been
- instantiated from in order to validate configs provided outside the
- file loading mechanism.
+ """Reconstructs a basic json config that this class might have been instantiated
+ from in order to validate configs provided outside the file loading
+ mechanism.
"""
if self.__class__ == Service:
- # Skip over the case where instantiate a bare base Service class in order to build up a mix-in.
+ # Skip over the case where instantiate a bare base Service class in
+ # order to build up a mix-in.
assert config == {}
return
json_config: dict = {
@@ -212,9 +221,7 @@ class Service:
return f"{self.__class__.__name__}@{hex(id(self))}"
def pprint(self) -> str:
- """
- Produce a human-readable string listing all public methods of the service.
- """
+ """Produce a human-readable string listing all public methods of the service."""
return f"{self} ::\n" + "\n".join(
f' "{key}": {getattr(val, "__self__", "stand-alone")}'
for (key, val) in self._service_methods.items()
@@ -265,10 +272,11 @@ class Service:
# Unfortunately, by creating a set, we may destroy the ability to
# preserve the context enter/exit order, but hopefully it doesn't
# matter.
- svc_method.__self__ for _, svc_method in self._service_methods.items()
+ svc_method.__self__
+ for _, svc_method in self._service_methods.items()
# Note: some methods are actually stand alone functions, so we need
# to filter them out.
- if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service)
+ if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service)
}
def export(self) -> Dict[str, Callable]:
diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py
index 4329d8f7e3..8e8a05b0e8 100644
--- a/mlos_bench/mlos_bench/services/config_persistence.py
+++ b/mlos_bench/mlos_bench/services/config_persistence.py
@@ -2,22 +2,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Helper functions to load, instantiate, and serialize Python objects
-that encapsulate benchmark environments, tunable parameters, and
-service functions.
+"""Helper functions to load, instantiate, and serialize Python objects that encapsulate
+benchmark environments, tunable parameters, and service functions.
"""
+import json # For logging only
+import logging
import os
import sys
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
-import json # For logging only
-import logging
-
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, TYPE_CHECKING
-
-import json5 # To read configs with comments and other JSON5 syntax features
-from jsonschema import ValidationError, SchemaError
+import json5 # To read configs with comments and other JSON5 syntax features
+from jsonschema import SchemaError, ValidationError
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.environments.base_environment import Environment
@@ -26,7 +32,12 @@ from mlos_bench.services.base_service import Service
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
-from mlos_bench.util import instantiate_from_config, merge_parameters, path_join, preprocess_dynamic_configs
+from mlos_bench.util import (
+ instantiate_from_config,
+ merge_parameters,
+ path_join,
+ preprocess_dynamic_configs,
+)
if sys.version_info < (3, 10):
from importlib_resources import files
@@ -34,25 +45,27 @@ else:
from importlib.resources import files
if TYPE_CHECKING:
- from mlos_bench.storage.base_storage import Storage
from mlos_bench.schedulers.base_scheduler import Scheduler
+ from mlos_bench.storage.base_storage import Storage
_LOG = logging.getLogger(__name__)
class ConfigPersistenceService(Service, SupportsConfigLoading):
- """
- Collection of methods to deserialize the Environment, Service, and TunableGroups objects.
+ """Collection of methods to deserialize the Environment, Service, and TunableGroups
+ objects.
"""
BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/")
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of config persistence service.
@@ -69,17 +82,22 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [
- self.resolve_path,
- self.load_config,
- self.prepare_class_load,
- self.build_service,
- self.build_environment,
- self.load_services,
- self.load_environment,
- self.load_environment_list,
- ])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ [
+ self.resolve_path,
+ self.load_config,
+ self.prepare_class_load,
+ self.build_service,
+ self.build_environment,
+ self.load_services,
+ self.load_environment,
+ self.load_environment_list,
+ ],
+ ),
)
self._config_loader_service = self
@@ -107,11 +125,10 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
"""
return list(self._config_path) # make a copy to avoid modifications
- def resolve_path(self, file_path: str,
- extra_paths: Optional[Iterable[str]] = None) -> str:
+ def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str:
"""
- Prepend the suitable `_config_path` to `path` if the latter is not absolute.
- If `_config_path` is `None` or `path` is absolute, return `path` as is.
+ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If
+ `_config_path` is `None` or `path` is absolute, return `path` as is.
Parameters
----------
@@ -138,14 +155,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
_LOG.debug("Path not resolved: %s", file_path)
return file_path
- def load_config(self,
- json_file_name: str,
- schema_type: Optional[ConfigSchema],
- ) -> Dict[str, Any]:
+ def load_config(
+ self,
+ json_file_name: str,
+ schema_type: Optional[ConfigSchema],
+ ) -> Dict[str, Any]:
"""
- Load JSON config file. Search for a file relative to `_config_path`
- if the input path is not absolute.
- This method is exported to be used as a service.
+ Load JSON config file. Search for a file relative to `_config_path` if the input
+ path is not absolute. This method is exported to be used as a service.
Parameters
----------
@@ -161,16 +178,22 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
"""
json_file_name = self.resolve_path(json_file_name)
_LOG.info("Load config: %s", json_file_name)
- with open(json_file_name, mode='r', encoding='utf-8') as fh_json:
+ with open(json_file_name, mode="r", encoding="utf-8") as fh_json:
config = json5.load(fh_json)
if schema_type is not None:
try:
schema_type.validate(config)
except (ValidationError, SchemaError) as ex:
- _LOG.error("Failed to validate config %s against schema type %s at %s",
- json_file_name, schema_type.name, schema_type.value)
- raise ValueError(f"Failed to validate config {json_file_name} against " +
- f"schema type {schema_type.name} at {schema_type.value}") from ex
+ _LOG.error(
+ "Failed to validate config %s against schema type %s at %s",
+ json_file_name,
+ schema_type.name,
+ schema_type.value,
+ )
+ raise ValueError(
+ f"Failed to validate config {json_file_name} against "
+ f"schema type {schema_type.name} at {schema_type.value}"
+ ) from ex
if isinstance(config, dict) and config.get("$schema"):
# Remove $schema attributes from the config after we've validated
# them to avoid passing them on to other objects
@@ -181,15 +204,17 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
del config["$schema"]
else:
_LOG.warning("Config %s is not validated against a schema.", json_file_name)
- return config # type: ignore[no-any-return]
+ return config # type: ignore[no-any-return]
- def prepare_class_load(self, config: Dict[str, Any],
- global_config: Optional[Dict[str, Any]] = None,
- parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]:
+ def prepare_class_load(
+ self,
+ config: Dict[str, Any],
+ global_config: Optional[Dict[str, Any]] = None,
+ parent_args: Optional[Dict[str, TunableValue]] = None,
+ ) -> Tuple[str, Dict[str, Any]]:
"""
- Extract the class instantiation parameters from the configuration.
- Mix-in the global parameters and resolve the local file system paths,
- where it is required.
+ Extract the class instantiation parameters from the configuration. Mix-in the
+ global parameters and resolve the local file system paths, where it is required.
Parameters
----------
@@ -228,19 +253,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
raise ValueError(f"Parameter {key} must be a string or a list")
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Instantiating: %s with config:\n%s",
- class_name, json.dumps(class_config, indent=2))
+ _LOG.debug(
+ "Instantiating: %s with config:\n%s",
+ class_name,
+ json.dumps(class_config, indent=2),
+ )
return (class_name, class_config)
- def build_optimizer(self, *,
- tunables: TunableGroups,
- service: Service,
- config: Dict[str, Any],
- global_config: Optional[Dict[str, Any]] = None) -> Optimizer:
+ def build_optimizer(
+ self,
+ *,
+ tunables: TunableGroups,
+ service: Service,
+ config: Dict[str, Any],
+ global_config: Optional[Dict[str, Any]] = None,
+ ) -> Optimizer:
"""
- Instantiation of mlos_bench Optimizer
- that depend on Service and TunableGroups.
+ Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups.
A class *MUST* have a constructor that takes four named arguments:
(tunables, config, global_config, service)
@@ -266,18 +296,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
if tunables_path is not None:
tunables = self._load_tunables(tunables_path, tunables)
(class_name, class_config) = self.prepare_class_load(config, global_config)
- inst = instantiate_from_config(Optimizer, class_name, # type: ignore[type-abstract]
- tunables=tunables,
- config=class_config,
- global_config=global_config,
- service=service)
+ inst = instantiate_from_config(
+ Optimizer, # type: ignore[type-abstract]
+ class_name,
+ tunables=tunables,
+ config=class_config,
+ global_config=global_config,
+ service=service,
+ )
_LOG.info("Created: Optimizer %s", inst)
return inst
- def build_storage(self, *,
- service: Service,
- config: Dict[str, Any],
- global_config: Optional[Dict[str, Any]] = None) -> "Storage":
+ def build_storage(
+ self,
+ *,
+ service: Service,
+ config: Dict[str, Any],
+ global_config: Optional[Dict[str, Any]] = None,
+ ) -> "Storage":
"""
Instantiation of mlos_bench Storage objects.
@@ -296,21 +332,29 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
A new instance of the Storage class.
"""
(class_name, class_config) = self.prepare_class_load(config, global_config)
- from mlos_bench.storage.base_storage import Storage # pylint: disable=import-outside-toplevel
- inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract]
- config=class_config,
- global_config=global_config,
- service=service)
+ # pylint: disable=import-outside-toplevel
+ from mlos_bench.storage.base_storage import Storage
+
+ inst = instantiate_from_config(
+ Storage, # type: ignore[type-abstract]
+ class_name,
+ config=class_config,
+ global_config=global_config,
+ service=service,
+ )
_LOG.info("Created: Storage %s", inst)
return inst
- def build_scheduler(self, *,
- config: Dict[str, Any],
- global_config: Dict[str, Any],
- environment: Environment,
- optimizer: Optimizer,
- storage: "Storage",
- root_env_config: str) -> "Scheduler":
+ def build_scheduler(
+ self,
+ *,
+ config: Dict[str, Any],
+ global_config: Dict[str, Any],
+ environment: Environment,
+ optimizer: Optimizer,
+ storage: "Storage",
+ root_env_config: str,
+ ) -> "Scheduler":
"""
Instantiation of mlos_bench Scheduler.
@@ -335,23 +379,30 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
A new instance of the Scheduler.
"""
(class_name, class_config) = self.prepare_class_load(config, global_config)
- from mlos_bench.schedulers.base_scheduler import Scheduler # pylint: disable=import-outside-toplevel
- inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract]
- config=class_config,
- global_config=global_config,
- environment=environment,
- optimizer=optimizer,
- storage=storage,
- root_env_config=root_env_config)
+ # pylint: disable=import-outside-toplevel
+ from mlos_bench.schedulers.base_scheduler import Scheduler
+
+ inst = instantiate_from_config(
+ Scheduler, # type: ignore[type-abstract]
+ class_name,
+ config=class_config,
+ global_config=global_config,
+ environment=environment,
+ optimizer=optimizer,
+ storage=storage,
+ root_env_config=root_env_config,
+ )
_LOG.info("Created: Scheduler %s", inst)
return inst
- def build_environment(self, # pylint: disable=too-many-arguments
- config: Dict[str, Any],
- tunables: TunableGroups,
- global_config: Optional[Dict[str, Any]] = None,
- parent_args: Optional[Dict[str, TunableValue]] = None,
- service: Optional[Service] = None) -> Environment:
+ def build_environment(
+ self, # pylint: disable=too-many-arguments
+ config: Dict[str, Any],
+ tunables: TunableGroups,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent_args: Optional[Dict[str, TunableValue]] = None,
+ service: Optional[Service] = None,
+ ) -> Environment:
"""
Factory method for a new environment with a given config.
@@ -391,16 +442,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
tunables = self._load_tunables(env_tunables_path, tunables)
_LOG.debug("Creating env: %s :: %s", env_name, env_class)
- env = Environment.new(env_name=env_name, class_name=env_class,
- config=env_config, global_config=global_config,
- tunables=tunables, service=service)
+ env = Environment.new(
+ env_name=env_name,
+ class_name=env_class,
+ config=env_config,
+ global_config=global_config,
+ tunables=tunables,
+ service=service,
+ )
_LOG.info("Created env: %s :: %s", env_name, env)
return env
- def _build_standalone_service(self, config: Dict[str, Any],
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None) -> Service:
+ def _build_standalone_service(
+ self,
+ config: Dict[str, Any],
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ ) -> Service:
"""
Factory method for a new service with a given config.
@@ -425,9 +484,12 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
_LOG.info("Created service: %s", service)
return service
- def _build_composite_service(self, config_list: Iterable[Dict[str, Any]],
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None) -> Service:
+ def _build_composite_service(
+ self,
+ config_list: Iterable[Dict[str, Any]],
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ ) -> Service:
"""
Factory method for a new service with a given config.
@@ -453,18 +515,21 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
service.register(parent.export())
for config in config_list:
- service.register(self._build_standalone_service(
- config, global_config, service).export())
+ service.register(
+ self._build_standalone_service(config, global_config, service).export()
+ )
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Created mix-in service: %s", service)
return service
- def build_service(self,
- config: Dict[str, Any],
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None) -> Service:
+ def build_service(
+ self,
+ config: Dict[str, Any],
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ ) -> Service:
"""
Factory method for a new service with a given config.
@@ -486,8 +551,7 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
services from the list plus the parent mix-in.
"""
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Build service from config:\n%s",
- json.dumps(config, indent=2))
+ _LOG.debug("Build service from config:\n%s", json.dumps(config, indent=2))
assert isinstance(config, dict)
config_list: List[Dict[str, Any]]
@@ -502,12 +566,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
return self._build_composite_service(config_list, global_config, parent)
- def load_environment(self, # pylint: disable=too-many-arguments
- json_file_name: str,
- tunables: TunableGroups,
- global_config: Optional[Dict[str, Any]] = None,
- parent_args: Optional[Dict[str, TunableValue]] = None,
- service: Optional[Service] = None) -> Environment:
+ def load_environment(
+ self, # pylint: disable=too-many-arguments
+ json_file_name: str,
+ tunables: TunableGroups,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent_args: Optional[Dict[str, TunableValue]] = None,
+ service: Optional[Service] = None,
+ ) -> Environment:
"""
Load and build new environment from the config file.
@@ -534,12 +600,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
assert isinstance(config, dict)
return self.build_environment(config, tunables, global_config, parent_args, service)
- def load_environment_list(self, # pylint: disable=too-many-arguments
- json_file_name: str,
- tunables: TunableGroups,
- global_config: Optional[Dict[str, Any]] = None,
- parent_args: Optional[Dict[str, TunableValue]] = None,
- service: Optional[Service] = None) -> List[Environment]:
+ def load_environment_list(
+ self, # pylint: disable=too-many-arguments
+ json_file_name: str,
+ tunables: TunableGroups,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent_args: Optional[Dict[str, TunableValue]] = None,
+ service: Optional[Service] = None,
+ ) -> List[Environment]:
"""
Load and build a list of environments from the config file.
@@ -564,16 +632,17 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
A list of new benchmarking environments.
"""
config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT)
- return [
- self.build_environment(config, tunables, global_config, parent_args, service)
- ]
+ return [self.build_environment(config, tunables, global_config, parent_args, service)]
- def load_services(self, json_file_names: Iterable[str],
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None) -> Service:
+ def load_services(
+ self,
+ json_file_names: Iterable[str],
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ ) -> Service:
"""
- Read the configuration files and bundle all service methods
- from those configs into a single Service object.
+ Read the configuration files and bundle all service methods from those configs
+ into a single Service object.
Parameters
----------
@@ -589,16 +658,18 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
service : Service
A collection of service methods.
"""
- _LOG.info("Load services: %s parent: %s",
- json_file_names, parent.__class__.__name__)
+ _LOG.info("Load services: %s parent: %s", json_file_names, parent.__class__.__name__)
service = Service({}, global_config, parent)
for fname in json_file_names:
config = self.load_config(fname, ConfigSchema.SERVICE)
service.register(self.build_service(config, global_config, service).export())
return service
- def _load_tunables(self, json_file_names: Iterable[str],
- parent: TunableGroups) -> TunableGroups:
+ def _load_tunables(
+ self,
+ json_file_names: Iterable[str],
+ parent: TunableGroups,
+ ) -> TunableGroups:
"""
Load a collection of tunable parameters from JSON files into the parent
TunableGroup.
diff --git a/mlos_bench/mlos_bench/services/local/__init__.py b/mlos_bench/mlos_bench/services/local/__init__.py
index f35ea4c7e8..afe9f05d20 100644
--- a/mlos_bench/mlos_bench/services/local/__init__.py
+++ b/mlos_bench/mlos_bench/services/local/__init__.py
@@ -2,13 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Local scheduler side Services for mlos_bench.
-"""
+"""Local scheduler side Services for mlos_bench."""
from mlos_bench.services.local.local_exec import LocalExecService
-
__all__ = [
- 'LocalExecService',
+ "LocalExecService",
]
diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py
index 2ca567dfd4..a1339312a6 100644
--- a/mlos_bench/mlos_bench/services/local/local_exec.py
+++ b/mlos_bench/mlos_bench/services/local/local_exec.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Helper functions to run scripts and commands locally on the scheduler side.
-"""
+"""Helper functions to run scripts and commands locally on the scheduler side."""
import errno
import logging
@@ -12,10 +10,18 @@ import os
import shlex
import subprocess
import sys
-
from string import Template
from typing import (
- Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
)
from mlos_bench.os_environ import environ
@@ -31,9 +37,9 @@ _LOG = logging.getLogger(__name__)
def split_cmdline(cmdline: str) -> Iterable[List[str]]:
"""
- A single command line may contain multiple commands separated by
- special characters (e.g., &&, ||, etc.) so further split the
- commandline into an array of subcommand arrays.
+ A single command line may contain multiple commands separated by special characters
+ (e.g., &&, ||, etc.) so further split the commandline into an array of subcommand
+ arrays.
Parameters
----------
@@ -66,16 +72,20 @@ def split_cmdline(cmdline: str) -> Iterable[List[str]]:
class LocalExecService(TempDirContextService, SupportsLocalExec):
"""
- Collection of methods to run scripts and commands in an external process
- on the node acting as the scheduler. Can be useful for data processing
- due to reduced dependency management complications vs the target environment.
+ Collection of methods to run scripts and commands in an external process on the node
+ acting as the scheduler.
+
+ Can be useful for data processing due to reduced dependency management complications
+ vs the target environment.
"""
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of a service to run scripts locally.
@@ -92,14 +102,19 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [self.local_exec])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(methods, [self.local_exec]),
)
self.abort_on_error = self.config.get("abort_on_error", True)
- def local_exec(self, script_lines: Iterable[str],
- env: Optional[Mapping[str, "TunableValue"]] = None,
- cwd: Optional[str] = None) -> Tuple[int, str, str]:
+ def local_exec(
+ self,
+ script_lines: Iterable[str],
+ env: Optional[Mapping[str, "TunableValue"]] = None,
+ cwd: Optional[str] = None,
+ ) -> Tuple[int, str, str]:
"""
Execute the script lines from `script_lines` in a local process.
@@ -141,8 +156,8 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]:
"""
- Resolves local script path (first token) in the (sub)command line
- tokens to its full path.
+ Resolves local script path (first token) in the (sub)command line tokens to its
+ full path.
Parameters
----------
@@ -167,9 +182,12 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
subcmd_tokens.insert(0, sys.executable)
return subcmd_tokens
- def _local_exec_script(self, script_line: str,
- env_params: Optional[Mapping[str, "TunableValue"]],
- cwd: str) -> Tuple[int, str, str]:
+ def _local_exec_script(
+ self,
+ script_line: str,
+ env_params: Optional[Mapping[str, "TunableValue"]],
+ cwd: str,
+ ) -> Tuple[int, str, str]:
"""
Execute the script from `script_path` in a local process.
@@ -198,7 +216,7 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
if env_params:
env = {key: str(val) for (key, val) in env_params.items()}
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# A hack to run Python on Windows with env variables set:
env_copy = environ.copy()
env_copy["PYTHONPATH"] = ""
@@ -206,7 +224,7 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
env = env_copy
try:
- if sys.platform != 'win32':
+ if sys.platform != "win32":
cmd = [" ".join(cmd)]
_LOG.info("Run: %s", cmd)
@@ -214,8 +232,15 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
_LOG.debug("Expands to: %s", Template(" ".join(cmd)).safe_substitute(env))
_LOG.debug("Current working dir: %s", cwd)
- proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True,
- text=True, check=False, capture_output=True)
+ proc = subprocess.run(
+ cmd,
+ env=env or None,
+ cwd=cwd,
+ shell=True,
+ text=True,
+ check=False,
+ capture_output=True,
+ )
_LOG.debug("Run: return code = %d", proc.returncode)
return (proc.returncode, proc.stdout, proc.stderr)
diff --git a/mlos_bench/mlos_bench/services/local/temp_dir_context.py b/mlos_bench/mlos_bench/services/local/temp_dir_context.py
index a0cf3e0e57..e65a45934b 100644
--- a/mlos_bench/mlos_bench/services/local/temp_dir_context.py
+++ b/mlos_bench/mlos_bench/services/local/temp_dir_context.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Helper functions to work with temp files locally on the scheduler side.
-"""
+"""Helper functions to work with temp files locally on the scheduler side."""
import abc
import logging
@@ -21,21 +19,23 @@ _LOG = logging.getLogger(__name__)
class TempDirContextService(Service, metaclass=abc.ABCMeta):
"""
- A *base* service class that provides a method to create a temporary
- directory context for local scripts.
+ A *base* service class that provides a method to create a temporary directory
+ context for local scripts.
- It is inherited by LocalExecService and MockLocalExecService.
- This class is not supposed to be used as a standalone service.
+ It is inherited by LocalExecService and MockLocalExecService. This class is not
+ supposed to be used as a standalone service.
"""
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
- Create a new instance of a service that provides temporary directory context
- for local exec service.
+ Create a new instance of a service that provides temporary directory context for
+ local exec service.
Parameters
----------
@@ -50,8 +50,10 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta):
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [self.temp_dir_context])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(methods, [self.temp_dir_context]),
)
self._temp_dir = self.config.get("temp_dir")
if self._temp_dir:
@@ -61,7 +63,10 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta):
self._temp_dir = self._config_loader_service.resolve_path(self._temp_dir)
_LOG.info("%s: temp dir: %s", self, self._temp_dir)
- def temp_dir_context(self, path: Optional[str] = None) -> Union[TemporaryDirectory, nullcontext]:
+ def temp_dir_context(
+ self,
+ path: Optional[str] = None,
+ ) -> Union[TemporaryDirectory, nullcontext]:
"""
Create a temp directory or use the provided path.
diff --git a/mlos_bench/mlos_bench/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/services/remote/azure/__init__.py
index 741593d035..cfe12e3c46 100644
--- a/mlos_bench/mlos_bench/services/remote/azure/__init__.py
+++ b/mlos_bench/mlos_bench/services/remote/azure/__init__.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Azure-specific benchmark environments for mlos_bench.
-"""
+"""Azure-specific benchmark environments for mlos_bench."""
from mlos_bench.services.remote.azure.azure_auth import AzureAuthService
from mlos_bench.services.remote.azure.azure_fileshare import AzureFileShareService
@@ -12,11 +10,10 @@ from mlos_bench.services.remote.azure.azure_network_services import AzureNetwork
from mlos_bench.services.remote.azure.azure_saas import AzureSaaSConfigService
from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService
-
__all__ = [
- 'AzureAuthService',
- 'AzureFileShareService',
- 'AzureNetworkService',
- 'AzureSaaSConfigService',
- 'AzureVMService',
+ "AzureAuthService",
+ "AzureFileShareService",
+ "AzureNetworkService",
+ "AzureSaaSConfigService",
+ "AzureVMService",
]
diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py
index b1e484c009..619e8eed90 100644
--- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py
+++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py
@@ -2,19 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for managing VMs on Azure.
-"""
+"""A collection Service functions for managing VMs on Azure."""
import logging
from base64 import b64decode
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union
-from pytz import UTC
-
import azure.identity as azure_id
from azure.keyvault.secrets import SecretClient
+from pytz import UTC
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.authenticator_type import SupportsAuth
@@ -24,17 +21,17 @@ _LOG = logging.getLogger(__name__)
class AzureAuthService(Service, SupportsAuth):
- """
- Helper methods to get access to Azure services.
- """
+ """Helper methods to get access to Azure services."""
- _REQ_INTERVAL = 300 # = 5 min
+ _REQ_INTERVAL = 300 # = 5 min
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of Azure authentication services proxy.
@@ -51,11 +48,16 @@ class AzureAuthService(Service, SupportsAuth):
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [
- self.get_access_token,
- self.get_auth_headers,
- ])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ [
+ self.get_access_token,
+ self.get_auth_headers,
+ ],
+ ),
)
# This parameter can come from command line as strings, so conversion is needed.
@@ -71,12 +73,13 @@ class AzureAuthService(Service, SupportsAuth):
# Verify info required for SP auth early
if "spClientId" in self.config:
check_required_params(
- self.config, {
+ self.config,
+ {
"spClientId",
"keyVaultName",
"certName",
"tenant",
- }
+ },
)
def _init_sp(self) -> None:
@@ -105,12 +108,14 @@ class AzureAuthService(Service, SupportsAuth):
cert_bytes = b64decode(secret.value)
# Reauthenticate as the service principal.
- self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes)
+ self._cred = azure_id.CertificateCredential(
+ tenant_id=tenant_id,
+ client_id=sp_client_id,
+ certificate_data=cert_bytes,
+ )
def get_access_token(self) -> str:
- """
- Get the access token from Azure CLI, if expired.
- """
+ """Get the access token from Azure CLI, if expired."""
# Ensure we are logged as the Service Principal, if provided
if "spClientId" in self.config:
self._init_sp()
@@ -126,7 +131,5 @@ class AzureAuthService(Service, SupportsAuth):
return self._access_token
def get_auth_headers(self) -> dict:
- """
- Get the authorization part of HTTP headers for REST API calls.
- """
+ """Get the authorization part of HTTP headers for REST API calls."""
return {"Authorization": "Bearer " + self.get_access_token()}
diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py
index 187b7c055b..9503d11409 100644
--- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py
+++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py
@@ -2,15 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base class for certain Azure Services classes that do deployments.
-"""
+"""Base class for certain Azure Services classes that do deployments."""
import abc
import json
-import time
import logging
-
+import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import requests
@@ -26,33 +23,34 @@ _LOG = logging.getLogger(__name__)
class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
- """
- Helper methods to manage and deploy Azure resources via REST APIs.
- """
+ """Helper methods to manage and deploy Azure resources via REST APIs."""
- _POLL_INTERVAL = 4 # seconds
- _POLL_TIMEOUT = 300 # seconds
- _REQUEST_TIMEOUT = 5 # seconds
+ _POLL_INTERVAL = 4 # seconds
+ _POLL_TIMEOUT = 300 # seconds
+ _REQUEST_TIMEOUT = 5 # seconds
_REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request
- _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
+ # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
+ _REQUEST_RETRY_BACKOFF_FACTOR = 0.3
# Azure Resources Deployment REST API as described in
# https://docs.microsoft.com/en-us/rest/api/resources/deployments
_URL_DEPLOY = (
- "https://management.azure.com" +
- "/subscriptions/{subscription}" +
- "/resourceGroups/{resource_group}" +
- "/providers/Microsoft.Resources" +
- "/deployments/{deployment_name}" +
+ "https://management.azure.com"
+ "/subscriptions/{subscription}"
+ "/resourceGroups/{resource_group}"
+ "/providers/Microsoft.Resources"
+ "/deployments/{deployment_name}"
"?api-version=2022-05-01"
)
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of an Azure Services proxy.
@@ -70,38 +68,49 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"""
super().__init__(config, global_config, parent, methods)
- check_required_params(self.config, [
- "subscription",
- "resourceGroup",
- ])
+ check_required_params(
+ self.config,
+ [
+ "subscription",
+ "resourceGroup",
+ ],
+ )
# These parameters can come from command line as strings, so conversion is needed.
self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL))
self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT))
self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
- self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES))
- self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR))
+ self._total_retries = int(
+ self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)
+ )
+ self._backoff_factor = float(
+ self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)
+ )
self._deploy_template = {}
self._deploy_params = {}
if self.config.get("deploymentTemplatePath") is not None:
# TODO: Provide external schema validation?
template = self.config_loader_service.load_config(
- self.config['deploymentTemplatePath'], schema_type=None)
+ self.config["deploymentTemplatePath"],
+ schema_type=None,
+ )
assert template is not None and isinstance(template, dict)
self._deploy_template = template
# Allow for recursive variable expansion as we do with global params and const_args.
- deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config)
+ deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars(
+ extra_source_dict=global_config
+ )
self._deploy_params = merge_parameters(dest=deploy_params, source=global_config)
else:
- _LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.")
+ _LOG.info(
+ "No deploymentTemplatePath provided. Deployment services will be unavailable.",
+ )
@property
def deploy_params(self) -> dict:
- """
- Get the deployment parameters.
- """
+ """Get the deployment parameters."""
return self._deploy_params
@abc.abstractmethod
@@ -122,24 +131,24 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
raise NotImplementedError("Should be overridden by subclass.")
def _get_session(self, params: dict) -> requests.Session:
- """
- Get a session object that includes automatic retries and headers for REST API calls.
+ """Get a session object that includes automatic retries and headers for REST API
+ calls.
"""
total_retries = params.get("requestTotalRetries", self._total_retries)
backoff_factor = params.get("requestBackoffFactor", self._backoff_factor)
session = requests.Session()
session.mount(
"https://",
- HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)))
+ HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)),
+ )
session.headers.update(self._get_headers())
return session
def _get_headers(self) -> dict:
- """
- Get the headers for the REST API calls.
- """
- assert self._parent is not None and isinstance(self._parent, SupportsAuth), \
- "Authorization service not provided. Include service-auth.jsonc?"
+ """Get the headers for the REST API calls."""
+ assert self._parent is not None and isinstance(
+ self._parent, SupportsAuth
+ ), "Authorization service not provided. Include service-auth.jsonc?"
return self._parent.get_auth_headers()
@staticmethod
@@ -235,9 +244,11 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
return (Status.FAILED, {})
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Response: %s\n%s", response,
- json.dumps(response.json(), indent=2)
- if response.content else "")
+ _LOG.debug(
+ "Response: %s\n%s",
+ response,
+ json.dumps(response.json(), indent=2) if response.content else "",
+ )
if response.status_code == 200:
output = response.json()
@@ -252,15 +263,16 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
"""
- Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or FAILED.
- Return TIMED_OUT when timing out.
+ Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or
+ FAILED. Return TIMED_OUT when timing out.
Parameters
----------
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
is_setup : bool
- If True, wait for resource being deployed; otherwise, wait for successful deprovisioning.
+ If True, wait for resource being deployed; otherwise, wait for
+ successful deprovisioning.
Returns
-------
@@ -270,15 +282,22 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
params = self._set_default_params(params)
- _LOG.info("Wait for %s to %s", params.get("deploymentName"),
- "provision" if is_setup else "deprovision")
+ _LOG.info(
+ "Wait for %s to %s",
+ params.get("deploymentName"),
+ "provision" if is_setup else "deprovision",
+ )
return self._wait_while(self._check_deployment, Status.PENDING, params)
- def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]],
- loop_status: Status, params: dict) -> Tuple[Status, dict]:
+ def _wait_while(
+ self,
+ func: Callable[[dict], Tuple[Status, dict]],
+ loop_status: Status,
+ params: dict,
+ ) -> Tuple[Status, dict]:
"""
- Invoke `func` periodically while the status is equal to `loop_status`.
- Return TIMED_OUT when timing out.
+ Invoke `func` periodically while the status is equal to `loop_status`. Return
+ TIMED_OUT when timing out.
Parameters
----------
@@ -297,12 +316,20 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"""
params = self._set_default_params(params)
config = merge_parameters(
- dest=self.config.copy(), source=params, required_keys=["deploymentName"])
+ dest=self.config.copy(),
+ source=params,
+ required_keys=["deploymentName"],
+ )
poll_period = params.get("pollInterval", self._poll_interval)
- _LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s",
- config["deploymentName"], loop_status, poll_period, self._poll_timeout)
+ _LOG.debug(
+ "Wait for %s status %s :: poll %.2f timeout %d s",
+ config["deploymentName"],
+ loop_status,
+ poll_period,
+ self._poll_timeout,
+ )
ts_timeout = time.time() + self._poll_timeout
poll_delay = poll_period
@@ -326,10 +353,10 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
_LOG.warning("Request timed out: %s", params)
return (Status.TIMED_OUT, {})
- def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements
+ def _check_deployment(self, params: dict) -> Tuple[Status, dict]:
+ # pylint: disable=too-many-return-statements
"""
- Check if Azure deployment exists.
- Return SUCCEEDED if true, PENDING otherwise.
+ Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise.
Parameters
----------
@@ -352,7 +379,7 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"subscription",
"resourceGroup",
"deploymentName",
- ]
+ ],
)
_LOG.info("Check deployment: %s", config["deploymentName"])
@@ -413,13 +440,20 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
if not self._deploy_template:
raise ValueError(f"Missing deployment template: {self}")
params = self._set_default_params(params)
- config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"])
+ config = merge_parameters(
+ dest=self.config.copy(),
+ source=params,
+ required_keys=["deploymentName"],
+ )
_LOG.info("Deploy: %s :: %s", config["deploymentName"], params)
params = merge_parameters(dest=self._deploy_params.copy(), source=params)
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Deploy: %s merged params ::\n%s",
- config["deploymentName"], json.dumps(params, indent=2))
+ _LOG.debug(
+ "Deploy: %s merged params ::\n%s",
+ config["deploymentName"],
+ json.dumps(params, indent=2),
+ )
url = self._URL_DEPLOY.format(
subscription=config["subscription"],
@@ -432,22 +466,29 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"mode": "Incremental",
"template": self._deploy_template,
"parameters": {
- key: {"value": val} for (key, val) in params.items()
+ key: {"value": val}
+ for (key, val) in params.items()
if key in self._deploy_template.get("parameters", {})
- }
+ },
}
}
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2))
- response = requests.put(url, json=json_req,
- headers=self._get_headers(), timeout=self._request_timeout)
+ response = requests.put(
+ url,
+ json=json_req,
+ headers=self._get_headers(),
+ timeout=self._request_timeout,
+ )
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Response: %s\n%s", response,
- json.dumps(response.json(), indent=2)
- if response.content else "")
+ _LOG.debug(
+ "Response: %s\n%s",
+ response,
+ json.dumps(response.json(), indent=2) if response.content else "",
+ )
else:
_LOG.info("Response: %s", response)
diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
index af09f4c723..d80ea862c9 100644
--- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
+++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
@@ -2,37 +2,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection FileShare functions for interacting with Azure File Shares.
-"""
+"""A collection FileShare functions for interacting with Azure File Shares."""
-import os
import logging
-
+import os
from typing import Any, Callable, Dict, List, Optional, Set, Union
-from azure.storage.fileshare import ShareClient
from azure.core.exceptions import ResourceNotFoundError
+from azure.storage.fileshare import ShareClient
-from mlos_bench.services.base_service import Service
from mlos_bench.services.base_fileshare import FileShareService
+from mlos_bench.services.base_service import Service
from mlos_bench.util import check_required_params
_LOG = logging.getLogger(__name__)
class AzureFileShareService(FileShareService):
- """
- Helper methods for interacting with Azure File Share
- """
+ """Helper methods for interacting with Azure File Share."""
_SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}"
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new file share Service for Azure environments with a given config.
@@ -50,16 +47,19 @@ class AzureFileShareService(FileShareService):
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [self.upload, self.download])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(methods, [self.upload, self.download]),
)
check_required_params(
- self.config, {
+ self.config,
+ {
"storageAccountName",
"storageFileShareName",
"storageAccountKey",
- }
+ },
)
self._share_client = ShareClient.from_share_url(
@@ -70,7 +70,13 @@ class AzureFileShareService(FileShareService):
credential=self.config["storageAccountKey"],
)
- def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
+ def download(
+ self,
+ params: dict,
+ remote_path: str,
+ local_path: str,
+ recursive: bool = True,
+ ) -> None:
super().download(params, remote_path, local_path, recursive)
dir_client = self._share_client.get_directory_client(remote_path)
if dir_client.exists():
@@ -95,16 +101,21 @@ class AzureFileShareService(FileShareService):
# Translate into non-Azure exception:
raise FileNotFoundError(f"Cannot download: {remote_path}") from ex
- def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
+ def upload(
+ self,
+ params: dict,
+ local_path: str,
+ remote_path: str,
+ recursive: bool = True,
+ ) -> None:
super().upload(params, local_path, remote_path, recursive)
self._upload(local_path, remote_path, recursive, set())
def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None:
"""
- Upload contents from a local path to an Azure file share.
- This method is called from `.upload()` above. We need it to avoid exposing
- the `seen` parameter and to make `.upload()` match the base class' virtual
- method.
+ Upload contents from a local path to an Azure file share. This method is called
+ from `.upload()` above. We need it to avoid exposing the `seen` parameter and to
+ make `.upload()` match the base class' virtual method.
Parameters
----------
@@ -143,8 +154,8 @@ class AzureFileShareService(FileShareService):
def _remote_makedirs(self, remote_path: str) -> None:
"""
- Create remote directories for the entire path.
- Succeeds even some or all directories along the path already exist.
+ Create remote directories for the entire path. Succeeds even some or all
+ directories along the path already exist.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py
index 081d5d842e..29552de4f0 100644
--- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py
+++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py
@@ -2,47 +2,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for managing virtual networks on Azure.
-"""
+"""A collection Service functions for managing virtual networks on Azure."""
import logging
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
-from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService
-from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
+from mlos_bench.services.remote.azure.azure_deployment_services import (
+ AzureDeploymentService,
+)
+from mlos_bench.services.types.network_provisioner_type import (
+ SupportsNetworkProvisioning,
+)
from mlos_bench.util import merge_parameters
_LOG = logging.getLogger(__name__)
class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
- """
- Helper methods to manage Virtual Networks on Azure.
- """
+ """Helper methods to manage Virtual Networks on Azure."""
# Azure Compute REST API calls as described in
# https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01
- # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01
+ # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 # pylint: disable=line-too-long # noqa
_URL_DEPROVISION = (
- "https://management.azure.com" +
- "/subscriptions/{subscription}" +
- "/resourceGroups/{resource_group}" +
- "/providers/Microsoft.Network" +
- "/virtualNetwork/{vnet_name}" +
- "/delete" +
+ "https://management.azure.com"
+ "/subscriptions/{subscription}"
+ "/resourceGroups/{resource_group}"
+ "/providers/Microsoft.Network"
+ "/virtualNetwork/{vnet_name}"
+ "/delete"
"?api-version=2023-05-01"
)
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of Azure Network services proxy.
@@ -59,25 +60,35 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [
- # SupportsNetworkProvisioning
- self.provision_network,
- self.deprovision_network,
- self.wait_network_deployment,
- ])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ [
+ # SupportsNetworkProvisioning
+ self.provision_network,
+ self.deprovision_network,
+ self.wait_network_deployment,
+ ],
+ ),
)
if not self._deploy_template:
- raise ValueError("AzureNetworkService requires a deployment template:\n"
- + f"config={config}\nglobal_config={global_config}")
+ raise ValueError(
+ "AzureNetworkService requires a deployment template:\n"
+ + f"config={config}\nglobal_config={global_config}"
+ )
- def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
+ def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
# Try and provide a semi sane default for the deploymentName if not provided
# since this is a common way to set the deploymentName and can same some
# config work for the caller.
if "vnetName" in params and "deploymentName" not in params:
params["deploymentName"] = f"{params['vnetName']}-deployment"
- _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"])
+ _LOG.info(
+ "deploymentName missing from params. Defaulting to '%s'.",
+ params["deploymentName"],
+ )
return params
def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
@@ -148,15 +159,18 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
"resourceGroup",
"deploymentName",
"vnetName",
- ]
+ ],
)
_LOG.info("Deprovision Network: %s", config["vnetName"])
_LOG.info("Deprovision deployment: %s", config["deploymentName"])
- (status, results) = self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format(
- subscription=config["subscription"],
- resource_group=config["resourceGroup"],
- vnet_name=config["vnetName"],
- ))
+ (status, results) = self._azure_rest_api_post_helper(
+ config,
+ self._URL_DEPROVISION.format(
+ subscription=config["subscription"],
+ resource_group=config["resourceGroup"],
+ vnet_name=config["vnetName"],
+ ),
+ )
if ignore_errors and status == Status.FAILED:
_LOG.warning("Ignoring error: %s", results)
status = Status.SUCCEEDED
diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py
index 34bec7d25e..042e599f0b 100644
--- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py
+++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py
@@ -2,11 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for configuring SaaS instances on Azure.
-"""
+"""A collection Service functions for configuring SaaS instances on Azure."""
import logging
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import requests
@@ -21,9 +18,7 @@ _LOG = logging.getLogger(__name__)
class AzureSaaSConfigService(Service, SupportsRemoteConfig):
- """
- Helper methods to configure Azure Flex services.
- """
+ """Helper methods to configure Azure Flex services."""
_REQUEST_TIMEOUT = 5 # seconds
@@ -33,20 +28,22 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
# https://learn.microsoft.com/en-us/rest/api/mariadb/configurations
_URL_CONFIGURE = (
- "https://management.azure.com" +
- "/subscriptions/{subscription}" +
- "/resourceGroups/{resource_group}" +
- "/providers/{provider}" +
- "/{server_type}/{vm_name}" +
- "/{update}" +
+ "https://management.azure.com"
+ "/subscriptions/{subscription}"
+ "/resourceGroups/{resource_group}"
+ "/providers/{provider}"
+ "/{server_type}/{vm_name}"
+ "/{update}"
"?api-version={api_version}"
)
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of Azure services proxy.
@@ -63,18 +60,20 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [
- self.configure,
- self.is_config_pending
- ])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(methods, [self.configure, self.is_config_pending]),
)
- check_required_params(self.config, {
- "subscription",
- "resourceGroup",
- "provider",
- })
+ check_required_params(
+ self.config,
+ {
+ "subscription",
+ "resourceGroup",
+ "provider",
+ },
+ )
# Provide sane defaults for known DB providers.
provider = self.config.get("provider")
@@ -118,8 +117,7 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
# These parameters can come from command line as strings, so conversion is needed.
self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
- def configure(self, config: Dict[str, Any],
- params: Dict[str, Any]) -> Tuple[Status, dict]:
+ def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
"""
Update the parameters of an Azure DB service.
@@ -157,33 +155,39 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
If "isConfigPendingReboot" is set to True, rebooting a VM is necessary.
Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED}
"""
- config = merge_parameters(
- dest=self.config.copy(), source=config, required_keys=["vmName"])
+ config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
url = self._url_config_get.format(vm_name=config["vmName"])
_LOG.debug("Request: GET %s", url)
- response = requests.put(
- url, headers=self._get_headers(), timeout=self._request_timeout)
+ response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout)
_LOG.debug("Response: %s :: %s", response, response.text)
if response.status_code == 504:
return (Status.TIMED_OUT, {})
if response.status_code != 200:
return (Status.FAILED, {})
# Currently, Azure Flex servers require a VM reboot.
- return (Status.SUCCEEDED, {"isConfigPendingReboot": any(
- {'False': False, 'True': True}[val['properties']['isConfigPendingRestart']]
- for val in response.json()['value']
- )})
+ return (
+ Status.SUCCEEDED,
+ {
+ "isConfigPendingReboot": any(
+ {"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]]
+ for val in response.json()["value"]
+ )
+ },
+ )
def _get_headers(self) -> dict:
- """
- Get the headers for the REST API calls.
- """
- assert self._parent is not None and isinstance(self._parent, SupportsAuth), \
- "Authorization service not provided. Include service-auth.jsonc?"
+ """Get the headers for the REST API calls."""
+ assert self._parent is not None and isinstance(
+ self._parent, SupportsAuth
+ ), "Authorization service not provided. Include service-auth.jsonc?"
return self._parent.get_auth_headers()
- def _config_one(self, config: Dict[str, Any],
- param_name: str, param_value: Any) -> Tuple[Status, dict]:
+ def _config_one(
+ self,
+ config: Dict[str, Any],
+ param_name: str,
+ param_value: Any,
+ ) -> Tuple[Status, dict]:
"""
Update a single parameter of the Azure DB service.
@@ -202,13 +206,15 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
- config = merge_parameters(
- dest=self.config.copy(), source=config, required_keys=["vmName"])
+ config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name)
_LOG.debug("Request: PUT %s", url)
- response = requests.put(url, headers=self._get_headers(),
- json={"properties": {"value": str(param_value)}},
- timeout=self._request_timeout)
+ response = requests.put(
+ url,
+ headers=self._get_headers(),
+ json={"properties": {"value": str(param_value)}},
+ timeout=self._request_timeout,
+ )
_LOG.debug("Response: %s :: %s", response, response.text)
if response.status_code == 504:
return (Status.TIMED_OUT, {})
@@ -216,11 +222,10 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
return (Status.SUCCEEDED, {})
return (Status.FAILED, {})
- def _config_many(self, config: Dict[str, Any],
- params: Dict[str, Any]) -> Tuple[Status, dict]:
+ def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
"""
- Update the parameters of an Azure DB service one-by-one.
- (If batch API is not available for it).
+ Update the parameters of an Azure DB service one-by-one. (If batch API is not
+ available for it).
Parameters
----------
@@ -235,14 +240,13 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
- for (param_name, param_value) in params.items():
+ for param_name, param_value in params.items():
(status, result) = self._config_one(config, param_name, param_value)
if not status.is_succeeded():
return (status, result)
return (Status.SUCCEEDED, {})
- def _config_batch(self, config: Dict[str, Any],
- params: Dict[str, Any]) -> Tuple[Status, dict]:
+ def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
"""
Batch update the parameters of an Azure DB service.
@@ -259,19 +263,21 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
- config = merge_parameters(
- dest=self.config.copy(), source=config, required_keys=["vmName"])
+ config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
url = self._url_config_set.format(vm_name=config["vmName"])
json_req = {
"value": [
- {"name": key, "properties": {"value": str(val)}}
- for (key, val) in params.items()
+ {"name": key, "properties": {"value": str(val)}} for (key, val) in params.items()
],
# "resetAllToDefault": "True"
}
_LOG.debug("Request: POST %s", url)
- response = requests.post(url, headers=self._get_headers(),
- json=json_req, timeout=self._request_timeout)
+ response = requests.post(
+ url,
+ headers=self._get_headers(),
+ json=json_req,
+ timeout=self._request_timeout,
+ )
_LOG.debug("Response: %s :: %s", response, response.text)
if response.status_code == 504:
return (Status.TIMED_OUT, {})
diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py
index 7fdbdc18df..3d390645f5 100644
--- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py
+++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py
@@ -2,33 +2,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for managing VMs on Azure.
-"""
+"""A collection Service functions for managing VMs on Azure."""
import json
import logging
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import requests
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
-from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService
-from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
-from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
+from mlos_bench.services.remote.azure.azure_deployment_services import (
+ AzureDeploymentService,
+)
from mlos_bench.services.types.host_ops_type import SupportsHostOps
+from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
from mlos_bench.services.types.os_ops_type import SupportsOSOps
+from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.util import merge_parameters
_LOG = logging.getLogger(__name__)
-class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec):
- """
- Helper methods to manage VMs on Azure.
- """
+class AzureVMService(
+ AzureDeploymentService,
+ SupportsHostProvisioning,
+ SupportsHostOps,
+ SupportsOSOps,
+ SupportsRemoteExec,
+):
+ """Helper methods to manage VMs on Azure."""
# pylint: disable=too-many-ancestors
@@ -37,34 +40,34 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start
_URL_START = (
- "https://management.azure.com" +
- "/subscriptions/{subscription}" +
- "/resourceGroups/{resource_group}" +
- "/providers/Microsoft.Compute" +
- "/virtualMachines/{vm_name}" +
- "/start" +
+ "https://management.azure.com"
+ "/subscriptions/{subscription}"
+ "/resourceGroups/{resource_group}"
+ "/providers/Microsoft.Compute"
+ "/virtualMachines/{vm_name}"
+ "/start"
"?api-version=2022-03-01"
)
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off
_URL_STOP = (
- "https://management.azure.com" +
- "/subscriptions/{subscription}" +
- "/resourceGroups/{resource_group}" +
- "/providers/Microsoft.Compute" +
- "/virtualMachines/{vm_name}" +
- "/powerOff" +
+ "https://management.azure.com"
+ "/subscriptions/{subscription}"
+ "/resourceGroups/{resource_group}"
+ "/providers/Microsoft.Compute"
+ "/virtualMachines/{vm_name}"
+ "/powerOff"
"?api-version=2022-03-01"
)
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate
_URL_DEALLOCATE = (
- "https://management.azure.com" +
- "/subscriptions/{subscription}" +
- "/resourceGroups/{resource_group}" +
- "/providers/Microsoft.Compute" +
- "/virtualMachines/{vm_name}" +
- "/deallocate" +
+ "https://management.azure.com"
+ "/subscriptions/{subscription}"
+ "/resourceGroups/{resource_group}"
+ "/providers/Microsoft.Compute"
+ "/virtualMachines/{vm_name}"
+ "/deallocate"
"?api-version=2022-03-01"
)
@@ -76,42 +79,44 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/delete
# _URL_DEPROVISION = (
- # "https://management.azure.com" +
- # "/subscriptions/{subscription}" +
- # "/resourceGroups/{resource_group}" +
- # "/providers/Microsoft.Compute" +
- # "/virtualMachines/{vm_name}" +
- # "/delete" +
+ # "https://management.azure.com"
+ # "/subscriptions/{subscription}"
+ # "/resourceGroups/{resource_group}"
+ # "/providers/Microsoft.Compute"
+ # "/virtualMachines/{vm_name}"
+ # "/delete"
# "?api-version=2022-03-01"
# )
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart
_URL_REBOOT = (
- "https://management.azure.com" +
- "/subscriptions/{subscription}" +
- "/resourceGroups/{resource_group}" +
- "/providers/Microsoft.Compute" +
- "/virtualMachines/{vm_name}" +
- "/restart" +
+ "https://management.azure.com"
+ "/subscriptions/{subscription}"
+ "/resourceGroups/{resource_group}"
+ "/providers/Microsoft.Compute"
+ "/virtualMachines/{vm_name}"
+ "/restart"
"?api-version=2022-03-01"
)
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command
_URL_REXEC_RUN = (
- "https://management.azure.com" +
- "/subscriptions/{subscription}" +
- "/resourceGroups/{resource_group}" +
- "/providers/Microsoft.Compute" +
- "/virtualMachines/{vm_name}" +
- "/runCommand" +
+ "https://management.azure.com"
+ "/subscriptions/{subscription}"
+ "/resourceGroups/{resource_group}"
+ "/providers/Microsoft.Compute"
+ "/virtualMachines/{vm_name}"
+ "/runCommand"
"?api-version=2022-03-01"
)
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of Azure VM services proxy.
@@ -128,26 +133,31 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
New methods to register with the service.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [
- # SupportsHostProvisioning
- self.provision_host,
- self.deprovision_host,
- self.deallocate_host,
- self.wait_host_deployment,
- # SupportsHostOps
- self.start_host,
- self.stop_host,
- self.restart_host,
- self.wait_host_operation,
- # SupportsOSOps
- self.shutdown,
- self.reboot,
- self.wait_os_operation,
- # SupportsRemoteExec
- self.remote_exec,
- self.get_remote_exec_results,
- ])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ [
+ # SupportsHostProvisioning
+ self.provision_host,
+ self.deprovision_host,
+ self.deallocate_host,
+ self.wait_host_deployment,
+ # SupportsHostOps
+ self.start_host,
+ self.stop_host,
+ self.restart_host,
+ self.wait_host_operation,
+ # SupportsOSOps
+ self.shutdown,
+ self.reboot,
+ self.wait_os_operation,
+ # SupportsRemoteExec
+ self.remote_exec,
+ self.get_remote_exec_results,
+ ],
+ ),
)
# As a convenience, allow reading customData out of a file, rather than
@@ -156,19 +166,24 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
# can be done using the `base64()` string function inside the ARM template.
self._custom_data_file = self.config.get("customDataFile", None)
if self._custom_data_file:
- if self._deploy_params.get('customData', None):
+ if self._deploy_params.get("customData", None):
raise ValueError("Both customDataFile and customData are specified.")
- self._custom_data_file = self.config_loader_service.resolve_path(self._custom_data_file)
- with open(self._custom_data_file, 'r', encoding='utf-8') as custom_data_fh:
+ self._custom_data_file = self.config_loader_service.resolve_path(
+ self._custom_data_file
+ )
+ with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh:
self._deploy_params["customData"] = custom_data_fh.read()
- def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
+ def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
# Try and provide a semi sane default for the deploymentName if not provided
# since this is a common way to set the deploymentName and can same some
# config work for the caller.
if "vmName" in params and "deploymentName" not in params:
params["deploymentName"] = f"{params['vmName']}-deployment"
- _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"])
+ _LOG.info(
+ "deploymentName missing from params. Defaulting to '%s'.",
+ params["deploymentName"],
+ )
return params
def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
@@ -263,20 +278,24 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"resourceGroup",
"deploymentName",
"vmName",
- ]
+ ],
)
_LOG.info("Deprovision VM: %s", config["vmName"])
_LOG.info("Deprovision deployment: %s", config["deploymentName"])
# TODO: Properly deprovision *all* resources specified in the ARM template.
- return self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format(
- subscription=config["subscription"],
- resource_group=config["resourceGroup"],
- vm_name=config["vmName"],
- ))
+ return self._azure_rest_api_post_helper(
+ config,
+ self._URL_DEPROVISION.format(
+ subscription=config["subscription"],
+ resource_group=config["resourceGroup"],
+ vm_name=config["vmName"],
+ ),
+ )
def deallocate_host(self, params: dict) -> Tuple[Status, dict]:
"""
- Deallocates the VM on Azure by shutting it down then releasing the compute resources.
+ Deallocates the VM on Azure by shutting it down then releasing the compute
+ resources.
Note: This can cause the VM to arrive on a new host node when its
restarted, which may have different performance characteristics.
@@ -300,14 +319,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
- ]
+ ],
)
_LOG.info("Deallocate VM: %s", config["vmName"])
- return self._azure_rest_api_post_helper(config, self._URL_DEALLOCATE.format(
- subscription=config["subscription"],
- resource_group=config["resourceGroup"],
- vm_name=config["vmName"],
- ))
+ return self._azure_rest_api_post_helper(
+ config,
+ self._URL_DEALLOCATE.format(
+ subscription=config["subscription"],
+ resource_group=config["resourceGroup"],
+ vm_name=config["vmName"],
+ ),
+ )
def start_host(self, params: dict) -> Tuple[Status, dict]:
"""
@@ -332,14 +354,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
- ]
+ ],
)
_LOG.info("Start VM: %s :: %s", config["vmName"], params)
- return self._azure_rest_api_post_helper(config, self._URL_START.format(
- subscription=config["subscription"],
- resource_group=config["resourceGroup"],
- vm_name=config["vmName"],
- ))
+ return self._azure_rest_api_post_helper(
+ config,
+ self._URL_START.format(
+ subscription=config["subscription"],
+ resource_group=config["resourceGroup"],
+ vm_name=config["vmName"],
+ ),
+ )
def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]:
"""
@@ -366,14 +391,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
- ]
+ ],
)
_LOG.info("Stop VM: %s", config["vmName"])
- return self._azure_rest_api_post_helper(config, self._URL_STOP.format(
- subscription=config["subscription"],
- resource_group=config["resourceGroup"],
- vm_name=config["vmName"],
- ))
+ return self._azure_rest_api_post_helper(
+ config,
+ self._URL_STOP.format(
+ subscription=config["subscription"],
+ resource_group=config["resourceGroup"],
+ vm_name=config["vmName"],
+ ),
+ )
def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
return self.stop_host(params, force)
@@ -403,20 +431,27 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
- ]
+ ],
)
_LOG.info("Reboot VM: %s", config["vmName"])
- return self._azure_rest_api_post_helper(config, self._URL_REBOOT.format(
- subscription=config["subscription"],
- resource_group=config["resourceGroup"],
- vm_name=config["vmName"],
- ))
+ return self._azure_rest_api_post_helper(
+ config,
+ self._URL_REBOOT.format(
+ subscription=config["subscription"],
+ resource_group=config["resourceGroup"],
+ vm_name=config["vmName"],
+ ),
+ )
def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
return self.restart_host(params, force)
- def remote_exec(self, script: Iterable[str], config: dict,
- env_params: dict) -> Tuple[Status, dict]:
+ def remote_exec(
+ self,
+ script: Iterable[str],
+ config: dict,
+ env_params: dict,
+ ) -> Tuple[Status, dict]:
"""
Run a command on Azure VM.
@@ -446,7 +481,7 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
- ]
+ ],
)
if _LOG.isEnabledFor(logging.INFO):
@@ -455,7 +490,7 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
json_req = {
"commandId": "RunShellScript",
"script": list(script),
- "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()]
+ "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()],
}
url = self._URL_REXEC_RUN.format(
@@ -468,12 +503,18 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
_LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2))
response = requests.post(
- url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout)
+ url,
+ json=json_req,
+ headers=self._get_headers(),
+ timeout=self._request_timeout,
+ )
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Response: %s\n%s", response,
- json.dumps(response.json(), indent=2)
- if response.content else "")
+ _LOG.debug(
+ "Response: %s\n%s",
+ response,
+ json.dumps(response.json(), indent=2) if response.content else "",
+ )
else:
_LOG.info("Response: %s", response)
@@ -481,10 +522,10 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
# TODO: extract the results from JSON response
return (Status.SUCCEEDED, config)
elif response.status_code == 202:
- return (Status.PENDING, {
- **config,
- "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")
- })
+ return (
+ Status.PENDING,
+ {**config, "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")},
+ )
else:
_LOG.error("Response: %s :: %s", response, response.text)
# _LOG.error("Bad Request:\n%s", response.request.body)
diff --git a/mlos_bench/mlos_bench/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/services/remote/ssh/__init__.py
index 2ab1705a74..cd897649ec 100644
--- a/mlos_bench/mlos_bench/services/remote/ssh/__init__.py
+++ b/mlos_bench/mlos_bench/services/remote/ssh/__init__.py
@@ -4,8 +4,8 @@
#
"""SSH remote service."""
-from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService
+from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
__all__ = [
"SshHostService",
diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py
index f753947aa7..383fcfbd20 100644
--- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py
+++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py
@@ -2,16 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection functions for interacting with SSH servers as file shares.
-"""
+"""A collection functions for interacting with SSH servers as file shares."""
+import logging
from enum import Enum
from typing import Tuple, Union
-import logging
-
-from asyncssh import scp, SFTPError, SFTPNoSuchFile, SFTPFailure, SSHClientConnection
+from asyncssh import SFTPError, SFTPFailure, SFTPNoSuchFile, SSHClientConnection, scp
from mlos_bench.services.base_fileshare import FileShareService
from mlos_bench.services.remote.ssh.ssh_service import SshService
@@ -21,9 +18,7 @@ _LOG = logging.getLogger(__name__)
class CopyMode(Enum):
- """
- Copy mode enum.
- """
+ """Copy mode enum."""
DOWNLOAD = 1
UPLOAD = 2
@@ -32,17 +27,23 @@ class CopyMode(Enum):
class SshFileShareService(FileShareService, SshService):
"""A collection of functions for interacting with SSH servers as file shares."""
- async def _start_file_copy(self, params: dict, mode: CopyMode,
- local_path: str, remote_path: str,
- recursive: bool = True) -> None:
+ async def _start_file_copy(
+ self,
+ params: dict,
+ mode: CopyMode,
+ local_path: str,
+ remote_path: str,
+ recursive: bool = True,
+ ) -> None:
# pylint: disable=too-many-arguments
"""
- Starts a file copy operation
+ Starts a file copy operation.
Parameters
----------
params : dict
- Flat dictionary of (key, value) pairs of parameters (used for establishing the connection).
+ Flat dictionary of (key, value) pairs of parameters (used for
+ establishing the connection).
mode : CopyMode
Whether to download or upload the file.
local_path : str
@@ -74,40 +75,70 @@ class SshFileShareService(FileShareService, SshService):
raise ValueError(f"Unknown copy mode: {mode}")
return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True)
- def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
+ def download(
+ self,
+ params: dict,
+ remote_path: str,
+ local_path: str,
+ recursive: bool = True,
+ ) -> None:
params = merge_parameters(
dest=self.config.copy(),
source=params,
required_keys=[
"ssh_hostname",
- ]
+ ],
)
super().download(params, remote_path, local_path, recursive)
file_copy_future = self._run_coroutine(
- self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive))
+ self._start_file_copy(
+ params,
+ CopyMode.DOWNLOAD,
+ local_path,
+ remote_path,
+ recursive,
+ )
+ )
try:
file_copy_future.result()
except (OSError, SFTPError) as ex:
- _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex)
+ _LOG.error(
+ "Failed to download %s to %s from %s: %s",
+ remote_path,
+ local_path,
+ params,
+ ex,
+ )
if isinstance(ex, SFTPNoSuchFile) or (
- isinstance(ex, SFTPFailure) and ex.code == 4
- and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory"))
+ isinstance(ex, SFTPFailure)
+ and ex.code == 4
+ and any(
+ msg.lower() in ex.reason.lower()
+ for msg in ("File not found", "No such file or directory")
+ )
):
_LOG.warning("File %s does not exist on %s", remote_path, params)
raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex
raise ex
- def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
+ def upload(
+ self,
+ params: dict,
+ local_path: str,
+ remote_path: str,
+ recursive: bool = True,
+ ) -> None:
params = merge_parameters(
dest=self.config.copy(),
source=params,
required_keys=[
"ssh_hostname",
- ]
+ ],
)
super().upload(params, local_path, remote_path, recursive)
file_copy_future = self._run_coroutine(
- self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive))
+ self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)
+ )
try:
file_copy_future.result()
except (OSError, SFTPError) as ex:
diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py
index 40c84e6300..36f1f7866b 100644
--- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py
+++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py
@@ -2,39 +2,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for managing hosts via SSH.
-"""
+"""A collection Service functions for managing hosts via SSH."""
+import logging
from concurrent.futures import Future
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
-import logging
-
-from asyncssh import SSHCompletedProcess, ConnectionLost, DisconnectError, ProcessError
+from asyncssh import ConnectionLost, DisconnectError, ProcessError, SSHCompletedProcess
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.services.remote.ssh.ssh_service import SshService
-from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.services.types.os_ops_type import SupportsOSOps
+from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.util import merge_parameters
_LOG = logging.getLogger(__name__)
class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
- """
- Helper methods to manage machines via SSH.
- """
+ """Helper methods to manage machines via SSH."""
# pylint: disable=too-many-instance-attributes
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of an SSH Service.
@@ -53,24 +50,36 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
# Same methods are also provided by the AzureVMService class
# pylint: disable=duplicate-code
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [
- self.shutdown,
- self.reboot,
- self.wait_os_operation,
- self.remote_exec,
- self.get_remote_exec_results,
- ]))
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ [
+ self.shutdown,
+ self.reboot,
+ self.wait_os_operation,
+ self.remote_exec,
+ self.get_remote_exec_results,
+ ],
+ ),
+ )
self._shell = self.config.get("ssh_shell", "/bin/bash")
- async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess:
+ async def _run_cmd(
+ self,
+ params: dict,
+ script: Iterable[str],
+ env_params: dict,
+ ) -> SSHCompletedProcess:
"""
Runs a command asynchronously on a host via SSH.
Parameters
----------
params : dict
- Flat dictionary of (key, value) pairs of parameters (used for establishing the connection).
+ Flat dictionary of (key, value) pairs of parameters (used for
+ establishing the connection).
cmd : str
Command(s) to run via shell.
@@ -83,19 +92,29 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
# Script should be an iterable of lines, not an iterable string.
script = [script]
connection, _ = await self._get_client_connection(params)
- # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values.
+ # Note: passing environment variables to SSH servers is typically restricted
+ # to just some LC_* values.
# Handle transferring environment variables by making a script to set them.
env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()]
- script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()]
+ script_lines = env_script_lines + [
+ line_split for line in script for line_split in line.splitlines()
+ ]
# Note: connection.run() uses "exec" with a shell by default.
- script_str = '\n'.join(script_lines)
+ script_str = "\n".join(script_lines)
_LOG.debug("Running script on %s:\n%s", connection, script_str)
- return await connection.run(script_str,
- check=False,
- timeout=self._request_timeout,
- env=env_params)
+ return await connection.run(
+ script_str,
+ check=False,
+ timeout=self._request_timeout,
+ env=env_params,
+ )
- def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]:
+ def remote_exec(
+ self,
+ script: Iterable[str],
+ config: dict,
+ env_params: dict,
+ ) -> Tuple["Status", dict]:
"""
Start running a command on remote host OS.
@@ -122,9 +141,15 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
source=config,
required_keys=[
"ssh_hostname",
- ]
+ ],
+ )
+ config["asyncRemoteExecResultsFuture"] = self._run_coroutine(
+ self._run_cmd(
+ config,
+ script,
+ env_params,
+ )
)
- config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params))
return (Status.PENDING, config)
def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]:
@@ -155,7 +180,11 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout
stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr
return (
- Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED,
+ (
+ Status.SUCCEEDED
+ if result.exit_status == 0 and result.returncode == 0
+ else Status.FAILED
+ ),
{
"stdout": stdout,
"stderr": stderr,
@@ -167,7 +196,8 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
return (Status.FAILED, {"result": result})
def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]:
- """_summary_
+ """
+ _summary_
Parameters
----------
@@ -187,9 +217,9 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
source=params,
required_keys=[
"ssh_hostname",
- ]
+ ],
)
- cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list])
+ cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list])
script = rf"""
if [[ $EUID -ne 0 ]]; then
sudo=$(command -v sudo)
@@ -224,10 +254,10 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
cmd_opts_list = [
- 'shutdown -h now',
- 'poweroff',
- 'halt -p',
- 'systemctl poweroff',
+ "shutdown -h now",
+ "poweroff",
+ "halt -p",
+ "systemctl poweroff",
]
return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params)
@@ -249,18 +279,18 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
cmd_opts_list = [
- 'shutdown -r now',
- 'reboot',
- 'halt --reboot',
- 'systemctl reboot',
- 'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1',
+ "shutdown -r now",
+ "reboot",
+ "halt --reboot",
+ "systemctl reboot",
+ "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1",
]
return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params)
def wait_os_operation(self, params: dict) -> Tuple[Status, dict]:
"""
- Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED.
- Return TIMED_OUT when timing out.
+ Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return
+ TIMED_OUT when timing out.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py
index 50ab07d4d2..706764a1f1 100644
--- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py
+++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py
@@ -2,25 +2,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection functions for interacting with SSH servers as file shares.
-"""
-
-from abc import ABCMeta
-from asyncio import Event as CoroEvent, Lock as CoroLock
-from warnings import warn
-from types import TracebackType
-from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, Union
-from threading import current_thread
+"""A collection functions for interacting with SSH servers as file shares."""
import logging
import os
+from abc import ABCMeta
+from asyncio import Event as CoroEvent
+from asyncio import Lock as CoroLock
+from threading import current_thread
+from types import TracebackType
+from typing import (
+ Any,
+ Callable,
+ Coroutine,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+from warnings import warn
import asyncssh
from asyncssh.connection import SSHClientConnection
+from mlos_bench.event_loop_context import (
+ CoroReturnType,
+ EventLoopContext,
+ FutureReturnType,
+)
from mlos_bench.services.base_service import Service
-from mlos_bench.event_loop_context import EventLoopContext, CoroReturnType, FutureReturnType
from mlos_bench.util import nullable
_LOG = logging.getLogger(__name__)
@@ -30,13 +43,13 @@ class SshClient(asyncssh.SSHClient):
"""
Wrapper around SSHClient to help provide connection caching and reconnect logic.
- Used by the SshService to try and maintain a single connection to hosts,
- handle reconnects if possible, and use that to run commands rather than
- reconnect for each command.
+ Used by the SshService to try and maintain a single connection to hosts, handle
+ reconnects if possible, and use that to run commands rather than reconnect for each
+ command.
"""
- _CONNECTION_PENDING = 'INIT'
- _CONNECTION_LOST = 'LOST'
+ _CONNECTION_PENDING = "INIT"
+ _CONNECTION_LOST = "LOST"
def __init__(self, *args: tuple, **kwargs: dict):
self._connection_id: str = SshClient._CONNECTION_PENDING
@@ -50,12 +63,16 @@ class SshClient(asyncssh.SSHClient):
@staticmethod
def id_from_connection(connection: SSHClientConnection) -> str:
"""Gets a unique id repr for the connection."""
- return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ return f"{connection._username}@{connection._host}:{connection._port}"
@staticmethod
def id_from_params(connect_params: dict) -> str:
"""Gets a unique id repr for the connection."""
- return f"{connect_params.get('username')}@{connect_params['host']}:{connect_params.get('port')}"
+ return (
+ f"{connect_params.get('username')}@{connect_params['host']}"
+ f":{connect_params.get('port')}"
+ )
def connection_made(self, conn: SSHClientConnection) -> None:
"""
@@ -64,8 +81,12 @@ class SshClient(asyncssh.SSHClient):
Changes the connection_id from _CONNECTION_PENDING to a unique id repr.
"""
self._conn_event.clear()
- _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \
- # pylint: disable=protected-access
+ _LOG.debug(
+ "%s: Connection made by %s: %s",
+ current_thread().name,
+ conn._options.env, # pylint: disable=protected-access
+ conn,
+ )
self._connection_id = SshClient.id_from_connection(conn)
self._connection = conn
self._conn_event.set()
@@ -75,18 +96,26 @@ class SshClient(asyncssh.SSHClient):
self._conn_event.clear()
_LOG.debug("%s: %s", current_thread().name, "connection_lost")
if exc is None:
- _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc)
+ _LOG.debug(
+ "%s: gracefully disconnected ssh from %s: %s",
+ current_thread().name,
+ self._connection_id,
+ exc,
+ )
else:
- _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc)
+ _LOG.debug(
+ "%s: ssh connection lost on %s: %s",
+ current_thread().name,
+ self._connection_id,
+ exc,
+ )
self._connection_id = SshClient._CONNECTION_LOST
self._connection = None
self._conn_event.set()
return super().connection_lost(exc)
async def connection(self) -> Optional[SSHClientConnection]:
- """
- Waits for and returns the SSHClientConnection to be established or lost.
- """
+ """Waits for and returns the SSHClientConnection to be established or lost."""
_LOG.debug("%s: Waiting for connection to be available.", current_thread().name)
await self._conn_event.wait()
_LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id)
@@ -96,6 +125,7 @@ class SshClient(asyncssh.SSHClient):
class SshClientCache:
"""
Manages a cache of SshClient connections.
+
Note: Only one per event loop thread supported.
See additional details in SshService comments.
"""
@@ -114,6 +144,7 @@ class SshClientCache:
def enter(self) -> None:
"""
Manages the cache lifecycle with reference counting.
+
To be used in the __enter__ method of a caller's context manager.
"""
self._refcnt += 1
@@ -121,6 +152,7 @@ class SshClientCache:
def exit(self) -> None:
"""
Manages the cache lifecycle with reference counting.
+
To be used in the __exit__ method of a caller's context manager.
"""
self._refcnt -= 1
@@ -130,7 +162,10 @@ class SshClientCache:
warn(RuntimeWarning("SshClientCache lock was still held on exit."))
self._cache_lock.release()
- async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]:
+ async def get_client_connection(
+ self,
+ connect_params: dict,
+ ) -> Tuple[SSHClientConnection, SshClient]:
"""
Gets a (possibly cached) client connection.
@@ -153,13 +188,21 @@ class SshClientCache:
_LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id)
connection = await client.connection()
if not connection:
- _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id)
+ _LOG.debug(
+ "%s: Removing stale client connection %s from cache.",
+ current_thread().name,
+ connection_id,
+ )
self._cache.pop(connection_id)
# Try to reconnect next.
else:
_LOG.debug("%s: Using cached client %s", current_thread().name, connection_id)
if connection_id not in self._cache:
- _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id)
+ _LOG.debug(
+ "%s: Establishing client connection to %s",
+ current_thread().name,
+ connection_id,
+ )
connection, client = await asyncssh.create_connection(SshClient, **connect_params)
assert isinstance(client, SshClient)
self._cache[connection_id] = (connection, client)
@@ -167,18 +210,14 @@ class SshClientCache:
return self._cache[connection_id]
def cleanup(self) -> None:
- """
- Closes all cached connections.
- """
- for (connection, _) in self._cache.values():
+ """Closes all cached connections."""
+ for connection, _ in self._cache.values():
connection.close()
self._cache = {}
class SshService(Service, metaclass=ABCMeta):
- """
- Base class for SSH services.
- """
+ """Base class for SSH services."""
# AsyncSSH requires an asyncio event loop to be running to work.
# However, running that event loop blocks the main thread.
@@ -210,21 +249,23 @@ class SshService(Service, metaclass=ABCMeta):
_REQUEST_TIMEOUT: Optional[float] = None # seconds
- def __init__(self,
- config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
super().__init__(config, global_config, parent, methods)
# Make sure that the value we allow overriding on a per-connection
# basis are present in the config so merge_parameters can do its thing.
- self.config.setdefault('ssh_port', None)
- assert isinstance(self.config['ssh_port'], (int, type(None)))
- self.config.setdefault('ssh_username', None)
- assert isinstance(self.config['ssh_username'], (str, type(None)))
- self.config.setdefault('ssh_priv_key_path', None)
- assert isinstance(self.config['ssh_priv_key_path'], (str, type(None)))
+ self.config.setdefault("ssh_port", None)
+ assert isinstance(self.config["ssh_port"], (int, type(None)))
+ self.config.setdefault("ssh_username", None)
+ assert isinstance(self.config["ssh_username"], (str, type(None)))
+ self.config.setdefault("ssh_priv_key_path", None)
+ assert isinstance(self.config["ssh_priv_key_path"], (str, type(None)))
# None can be used to disable the request timeout.
self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT)
@@ -235,24 +276,25 @@ class SshService(Service, metaclass=ABCMeta):
# In general scripted commands shouldn't need a pty and having one
# available can confuse some commands, though we may need to make
# this configurable in the future.
- 'request_pty': False,
- # By default disable known_hosts checking (since most VMs expected to be dynamically created).
- 'known_hosts': None,
+ "request_pty": False,
+ # By default disable known_hosts checking (since most VMs expected to be
+ # dynamically created).
+ "known_hosts": None,
}
- if 'ssh_known_hosts_file' in self.config:
- self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None)
- if isinstance(self._connect_params['known_hosts'], str):
- known_hosts_file = os.path.expanduser(self._connect_params['known_hosts'])
+ if "ssh_known_hosts_file" in self.config:
+ self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None)
+ if isinstance(self._connect_params["known_hosts"], str):
+ known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"])
if not os.path.exists(known_hosts_file):
raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist")
- self._connect_params['known_hosts'] = known_hosts_file
- if self._connect_params['known_hosts'] is None:
+ self._connect_params["known_hosts"] = known_hosts_file
+ if self._connect_params["known_hosts"] is None:
_LOG.info("%s known_hosts checking is disabled per config.", self)
- if 'ssh_keepalive_interval' in self.config:
- keepalive_internal = self.config.get('ssh_keepalive_interval')
- self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal)
+ if "ssh_keepalive_interval" in self.config:
+ keepalive_internal = self.config.get("ssh_keepalive_interval")
+ self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal)
def _enter_context(self) -> "SshService":
# Start the background thread if it's not already running.
@@ -262,9 +304,12 @@ class SshService(Service, metaclass=ABCMeta):
super()._enter_context()
return self
- def _exit_context(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
+ def _exit_context(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
# Stop the background thread if it's not needed anymore and potentially
# cleanup the cache as well.
assert self._in_context
@@ -276,6 +321,7 @@ class SshService(Service, metaclass=ABCMeta):
def clear_client_cache(cls) -> None:
"""
Clears the cache of client connections.
+
Note: This may cause in flight operations to fail.
"""
cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup()
@@ -319,24 +365,27 @@ class SshService(Service, metaclass=ABCMeta):
# Start with the base config params.
connect_params = self._connect_params.copy()
- connect_params['host'] = params['ssh_hostname'] # required
+ connect_params["host"] = params["ssh_hostname"] # required
- if params.get('ssh_port'):
- connect_params['port'] = int(params.pop('ssh_port'))
- elif self.config['ssh_port']:
- connect_params['port'] = int(self.config['ssh_port'])
+ if params.get("ssh_port"):
+ connect_params["port"] = int(params.pop("ssh_port"))
+ elif self.config["ssh_port"]:
+ connect_params["port"] = int(self.config["ssh_port"])
- if 'ssh_username' in params:
- connect_params['username'] = str(params.pop('ssh_username'))
- elif self.config['ssh_username']:
- connect_params['username'] = str(self.config['ssh_username'])
+ if "ssh_username" in params:
+ connect_params["username"] = str(params.pop("ssh_username"))
+ elif self.config["ssh_username"]:
+ connect_params["username"] = str(self.config["ssh_username"])
- priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path'])
+ priv_key_file: Optional[str] = params.get(
+ "ssh_priv_key_path",
+ self.config["ssh_priv_key_path"],
+ )
if priv_key_file:
priv_key_file = os.path.expanduser(priv_key_file)
if not os.path.exists(priv_key_file):
raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist")
- connect_params['client_keys'] = [priv_key_file]
+ connect_params["client_keys"] = [priv_key_file]
return connect_params
@@ -355,4 +404,6 @@ class SshService(Service, metaclass=ABCMeta):
The connection and client objects.
"""
assert self._in_context
- return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params))
+ return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(
+ self._get_connect_params(params)
+ )
diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py
index 2a9cbe3248..e2d0cb55b5 100644
--- a/mlos_bench/mlos_bench/services/types/__init__.py
+++ b/mlos_bench/mlos_bench/services/types/__init__.py
@@ -2,8 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Service types for implementing declaring Service behavior for Environments to use in mlos_bench.
+"""Service types for implementing declaring Service behavior for Environments to use in
+mlos_bench.
"""
from mlos_bench.services.types.authenticator_type import SupportsAuth
@@ -11,18 +11,19 @@ from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
-from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
+from mlos_bench.services.types.network_provisioner_type import (
+ SupportsNetworkProvisioning,
+)
from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
-
__all__ = [
- 'SupportsAuth',
- 'SupportsConfigLoading',
- 'SupportsFileShareOps',
- 'SupportsHostProvisioning',
- 'SupportsLocalExec',
- 'SupportsNetworkProvisioning',
- 'SupportsRemoteConfig',
- 'SupportsRemoteExec',
+ "SupportsAuth",
+ "SupportsConfigLoading",
+ "SupportsFileShareOps",
+ "SupportsHostProvisioning",
+ "SupportsLocalExec",
+ "SupportsNetworkProvisioning",
+ "SupportsRemoteConfig",
+ "SupportsRemoteExec",
]
diff --git a/mlos_bench/mlos_bench/services/types/authenticator_type.py b/mlos_bench/mlos_bench/services/types/authenticator_type.py
index fcec792d7d..6f99dd6bce 100644
--- a/mlos_bench/mlos_bench/services/types/authenticator_type.py
+++ b/mlos_bench/mlos_bench/services/types/authenticator_type.py
@@ -2,18 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for authentication for the cloud services.
-"""
+"""Protocol interface for authentication for the cloud services."""
from typing import Protocol, runtime_checkable
@runtime_checkable
class SupportsAuth(Protocol):
- """
- Protocol interface for authentication for the cloud services.
- """
+ """Protocol interface for authentication for the cloud services."""
def get_access_token(self) -> str:
"""
diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py
index 401e4c6720..4eb473edff 100644
--- a/mlos_bench/mlos_bench/services/types/config_loader_type.py
+++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py
@@ -2,34 +2,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for helper functions to lookup and load configs.
-"""
+"""Protocol interface for helper functions to lookup and load configs."""
-from typing import Any, Dict, List, Iterable, Optional, Union, Protocol, runtime_checkable, TYPE_CHECKING
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Protocol,
+ Union,
+ runtime_checkable,
+)
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.tunables.tunable import TunableValue
-
# Avoid's circular import issues.
if TYPE_CHECKING:
- from mlos_bench.tunables.tunable_groups import TunableGroups
- from mlos_bench.services.base_service import Service
from mlos_bench.environments.base_environment import Environment
+ from mlos_bench.services.base_service import Service
+ from mlos_bench.tunables.tunable_groups import TunableGroups
@runtime_checkable
class SupportsConfigLoading(Protocol):
- """
- Protocol interface for helper functions to lookup and load configs.
- """
+ """Protocol interface for helper functions to lookup and load configs."""
- def resolve_path(self, file_path: str,
- extra_paths: Optional[Iterable[str]] = None) -> str:
+ def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str:
"""
- Prepend the suitable `_config_path` to `path` if the latter is not absolute.
- If `_config_path` is `None` or `path` is absolute, return `path` as is.
+ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If
+ `_config_path` is `None` or `path` is absolute, return `path` as is.
Parameters
----------
@@ -44,11 +48,14 @@ class SupportsConfigLoading(Protocol):
An actual path to the config or script.
"""
- def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]:
+ def load_config(
+ self,
+ json_file_name: str,
+ schema_type: Optional[ConfigSchema],
+ ) -> Union[dict, List[dict]]:
"""
- Load JSON config file. Search for a file relative to `_config_path`
- if the input path is not absolute.
- This method is exported to be used as a service.
+ Load JSON config file. Search for a file relative to `_config_path` if the input
+ path is not absolute. This method is exported to be used as a service.
Parameters
----------
@@ -63,12 +70,14 @@ class SupportsConfigLoading(Protocol):
Free-format dictionary that contains the configuration.
"""
- def build_environment(self, # pylint: disable=too-many-arguments
- config: dict,
- tunables: "TunableGroups",
- global_config: Optional[dict] = None,
- parent_args: Optional[Dict[str, TunableValue]] = None,
- service: Optional["Service"] = None) -> "Environment":
+ def build_environment(
+ self, # pylint: disable=too-many-arguments
+ config: dict,
+ tunables: "TunableGroups",
+ global_config: Optional[dict] = None,
+ parent_args: Optional[Dict[str, TunableValue]] = None,
+ service: Optional["Service"] = None,
+ ) -> "Environment":
"""
Factory method for a new environment with a given config.
@@ -98,12 +107,13 @@ class SupportsConfigLoading(Protocol):
"""
def load_environment_list( # pylint: disable=too-many-arguments
- self,
- json_file_name: str,
- tunables: "TunableGroups",
- global_config: Optional[dict] = None,
- parent_args: Optional[Dict[str, TunableValue]] = None,
- service: Optional["Service"] = None) -> List["Environment"]:
+ self,
+ json_file_name: str,
+ tunables: "TunableGroups",
+ global_config: Optional[dict] = None,
+ parent_args: Optional[Dict[str, TunableValue]] = None,
+ service: Optional["Service"] = None,
+ ) -> List["Environment"]:
"""
Load and build a list of environments from the config file.
@@ -128,12 +138,15 @@ class SupportsConfigLoading(Protocol):
A list of new benchmarking environments.
"""
- def load_services(self, json_file_names: Iterable[str],
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional["Service"] = None) -> "Service":
+ def load_services(
+ self,
+ json_file_names: Iterable[str],
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional["Service"] = None,
+ ) -> "Service":
"""
- Read the configuration files and bundle all service methods
- from those configs into a single Service object.
+ Read the configuration files and bundle all service methods from those configs
+ into a single Service object.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/services/types/fileshare_type.py b/mlos_bench/mlos_bench/services/types/fileshare_type.py
index 87ec9e49da..c69516992b 100644
--- a/mlos_bench/mlos_bench/services/types/fileshare_type.py
+++ b/mlos_bench/mlos_bench/services/types/fileshare_type.py
@@ -2,20 +2,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for file share operations.
-"""
+"""Protocol interface for file share operations."""
from typing import Protocol, runtime_checkable
@runtime_checkable
class SupportsFileShareOps(Protocol):
- """
- Protocol interface for file share operations.
- """
+ """Protocol interface for file share operations."""
- def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
+ def download(
+ self,
+ params: dict,
+ remote_path: str,
+ local_path: str,
+ recursive: bool = True,
+ ) -> None:
"""
Downloads contents from a remote share path to a local path.
@@ -33,7 +35,13 @@ class SupportsFileShareOps(Protocol):
if True (the default), download the entire directory tree.
"""
- def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
+ def upload(
+ self,
+ params: dict,
+ local_path: str,
+ remote_path: str,
+ recursive: bool = True,
+ ) -> None:
"""
Uploads contents from a local path to remote share path.
diff --git a/mlos_bench/mlos_bench/services/types/host_ops_type.py b/mlos_bench/mlos_bench/services/types/host_ops_type.py
index a5d0b5b036..166406714d 100644
--- a/mlos_bench/mlos_bench/services/types/host_ops_type.py
+++ b/mlos_bench/mlos_bench/services/types/host_ops_type.py
@@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for Host/VM boot operations.
-"""
+"""Protocol interface for Host/VM boot operations."""
-from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
+from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsHostOps(Protocol):
- """
- Protocol interface for Host/VM boot operations.
- """
+ """Protocol interface for Host/VM boot operations."""
def start_host(self, params: dict) -> Tuple["Status", dict]:
"""
diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py
index b3560783fc..1df0716fa1 100644
--- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py
+++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py
@@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for Host/VM provisioning operations.
-"""
+"""Protocol interface for Host/VM provisioning operations."""
-from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
+from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsHostProvisioning(Protocol):
- """
- Protocol interface for Host/VM provisioning operations.
- """
+ """Protocol interface for Host/VM provisioning operations."""
def provision_host(self, params: dict) -> Tuple["Status", dict]:
"""
@@ -46,7 +42,8 @@ class SupportsHostProvisioning(Protocol):
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
is_setup : bool
- If True, wait for Host/VM being deployed; otherwise, wait for successful deprovisioning.
+ If True, wait for Host/VM being deployed; otherwise, wait for successful
+ deprovisioning.
Returns
-------
@@ -74,7 +71,8 @@ class SupportsHostProvisioning(Protocol):
def deallocate_host(self, params: dict) -> Tuple["Status", dict]:
"""
- Deallocates the Host/VM by shutting it down then releasing the compute resources.
+ Deallocates the Host/VM by shutting it down then releasing the compute
+ resources.
Note: This can cause the VM to arrive on a new host node when its
restarted, which may have different performance characteristics.
diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py
index 1a3808bb61..d0c8c357f0 100644
--- a/mlos_bench/mlos_bench/services/types/local_exec_type.py
+++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py
@@ -2,15 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for Service types that provide helper functions to run
-scripts and commands locally on the scheduler side.
+"""Protocol interface for Service types that provide helper functions to run scripts and
+commands locally on the scheduler side.
"""
-from typing import Iterable, Mapping, Optional, Tuple, Union, Protocol, runtime_checkable
-
-import tempfile
import contextlib
+import tempfile
+from typing import (
+ Iterable,
+ Mapping,
+ Optional,
+ Protocol,
+ Tuple,
+ Union,
+ runtime_checkable,
+)
from mlos_bench.tunables.tunable import TunableValue
@@ -18,16 +24,19 @@ from mlos_bench.tunables.tunable import TunableValue
@runtime_checkable
class SupportsLocalExec(Protocol):
"""
- Protocol interface for a collection of methods to run scripts and commands
- in an external process on the node acting as the scheduler. Can be useful
- for data processing due to reduced dependency management complications vs
- the target environment.
- Used in LocalEnv and provided by LocalExecService.
+ Protocol interface for a collection of methods to run scripts and commands in an
+ external process on the node acting as the scheduler.
+
+ Can be useful for data processing due to reduced dependency management complications
+ vs the target environment. Used in LocalEnv and provided by LocalExecService.
"""
- def local_exec(self, script_lines: Iterable[str],
- env: Optional[Mapping[str, TunableValue]] = None,
- cwd: Optional[str] = None) -> Tuple[int, str, str]:
+ def local_exec(
+ self,
+ script_lines: Iterable[str],
+ env: Optional[Mapping[str, TunableValue]] = None,
+ cwd: Optional[str] = None,
+ ) -> Tuple[int, str, str]:
"""
Execute the script lines from `script_lines` in a local process.
@@ -48,7 +57,10 @@ class SupportsLocalExec(Protocol):
A 3-tuple of return code, stdout, and stderr of the script process.
"""
- def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
+ def temp_dir_context(
+ self,
+ path: Optional[str] = None,
+ ) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
"""
Create a temp directory or use the provided path.
diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py
index 5b6a9a6936..3525fbdee1 100644
--- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py
+++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py
@@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for Network provisioning operations.
-"""
+"""Protocol interface for Network provisioning operations."""
-from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
+from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsNetworkProvisioning(Protocol):
- """
- Protocol interface for Network provisioning operations.
- """
+ """Protocol interface for Network provisioning operations."""
def provision_network(self, params: dict) -> Tuple["Status", dict]:
"""
@@ -46,7 +42,8 @@ class SupportsNetworkProvisioning(Protocol):
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
is_setup : bool
- If True, wait for Network being deployed; otherwise, wait for successful deprovisioning.
+ If True, wait for Network being deployed; otherwise, wait for successful
+ deprovisioning.
Returns
-------
@@ -56,7 +53,11 @@ class SupportsNetworkProvisioning(Protocol):
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
- def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple["Status", dict]:
+ def deprovision_network(
+ self,
+ params: dict,
+ ignore_errors: bool = True,
+ ) -> Tuple["Status", dict]:
"""
Deprovisions the Network by deleting it.
diff --git a/mlos_bench/mlos_bench/services/types/os_ops_type.py b/mlos_bench/mlos_bench/services/types/os_ops_type.py
index ba36c6914a..8b727f87a6 100644
--- a/mlos_bench/mlos_bench/services/types/os_ops_type.py
+++ b/mlos_bench/mlos_bench/services/types/os_ops_type.py
@@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for Host/OS operations.
-"""
+"""Protocol interface for Host/OS operations."""
-from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
+from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsOSOps(Protocol):
- """
- Protocol interface for Host/OS operations.
- """
+ """Protocol interface for Host/OS operations."""
def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
"""
@@ -56,8 +52,8 @@ class SupportsOSOps(Protocol):
def wait_os_operation(self, params: dict) -> Tuple["Status", dict]:
"""
- Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED.
- Return TIMED_OUT when timing out.
+ Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return
+ TIMED_OUT when timing out.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py
index 4008aff576..7e8d0a6e77 100644
--- a/mlos_bench/mlos_bench/services/types/remote_config_type.py
+++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py
@@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for configuring cloud services.
-"""
+"""Protocol interface for configuring cloud services."""
-from typing import Any, Dict, Protocol, Tuple, TYPE_CHECKING, runtime_checkable
+from typing import TYPE_CHECKING, Any, Dict, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@@ -14,12 +12,9 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsRemoteConfig(Protocol):
- """
- Protocol interface for configuring cloud services.
- """
+ """Protocol interface for configuring cloud services."""
- def configure(self, config: Dict[str, Any],
- params: Dict[str, Any]) -> Tuple["Status", dict]:
+ def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple["Status", dict]:
"""
Update the parameters of a SaaS service in the cloud.
diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py
index 8dd41e51a8..b6285a8f96 100644
--- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py
+++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py
@@ -2,12 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for Service types that provide helper functions to run
-scripts on a remote host OS.
+"""Protocol interface for Service types that provide helper functions to run scripts on
+a remote host OS.
"""
-from typing import Iterable, Tuple, Protocol, runtime_checkable, TYPE_CHECKING
+from typing import TYPE_CHECKING, Iterable, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@@ -15,13 +14,16 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsRemoteExec(Protocol):
- """
- Protocol interface for Service types that provide helper functions to run
- scripts on a remote host OS.
+ """Protocol interface for Service types that provide helper functions to run scripts
+ on a remote host OS.
"""
- def remote_exec(self, script: Iterable[str], config: dict,
- env_params: dict) -> Tuple["Status", dict]:
+ def remote_exec(
+ self,
+ script: Iterable[str],
+ config: dict,
+ env_params: dict,
+ ) -> Tuple["Status", dict]:
"""
Run a command on remote host OS.
diff --git a/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py b/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py
index 0574d25c61..69d24f3fd3 100644
--- a/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py
+++ b/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py
@@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Protocol interface for VM provisioning operations.
-"""
+"""Protocol interface for VM provisioning operations."""
-from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
+from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsVMOps(Protocol):
- """
- Protocol interface for VM provisioning operations.
- """
+ """Protocol interface for VM provisioning operations."""
def vm_provision(self, params: dict) -> Tuple["Status", dict]:
"""
@@ -122,8 +118,8 @@ class SupportsVMOps(Protocol):
def wait_vm_operation(self, params: dict) -> Tuple["Status", dict]:
"""
- Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED.
- Return TIMED_OUT when timing out.
+ Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED. Return
+ TIMED_OUT when timing out.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/storage/__init__.py b/mlos_bench/mlos_bench/storage/__init__.py
index 9ae5c80f36..64e70c20f7 100644
--- a/mlos_bench/mlos_bench/storage/__init__.py
+++ b/mlos_bench/mlos_bench/storage/__init__.py
@@ -2,14 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Interfaces to the storage backends for OS Autotune.
-"""
+"""Interfaces to the storage backends for OS Autotune."""
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.storage_factory import from_config
__all__ = [
- 'Storage',
- 'from_config',
+ "Storage",
+ "from_config",
]
diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py
index b112a7b575..60c27bc522 100644
--- a/mlos_bench/mlos_bench/storage/base_experiment_data.py
+++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py
@@ -2,13 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base interface for accessing the stored benchmark experiment data.
-"""
+"""Base interface for accessing the stored benchmark experiment data."""
from abc import ABCMeta, abstractmethod
-from distutils.util import strtobool # pylint: disable=deprecated-module
-from typing import Dict, Literal, Optional, Tuple, TYPE_CHECKING
+from distutils.util import strtobool # pylint: disable=deprecated-module
+from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple
import pandas
@@ -16,7 +14,9 @@ from mlos_bench.storage.base_tunable_config_data import TunableConfigData
if TYPE_CHECKING:
from mlos_bench.storage.base_trial_data import TrialData
- from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
+ from mlos_bench.storage.base_tunable_config_trial_group_data import (
+ TunableConfigTrialGroupData,
+ )
class ExperimentData(metaclass=ABCMeta):
@@ -30,12 +30,15 @@ class ExperimentData(metaclass=ABCMeta):
RESULT_COLUMN_PREFIX = "result."
CONFIG_COLUMN_PREFIX = "config."
- def __init__(self, *,
- experiment_id: str,
- description: str,
- root_env_config: str,
- git_repo: str,
- git_commit: str):
+ def __init__(
+ self,
+ *,
+ experiment_id: str,
+ description: str,
+ root_env_config: str,
+ git_repo: str,
+ git_commit: str,
+ ):
self._experiment_id = experiment_id
self._description = description
self._root_env_config = root_env_config
@@ -44,16 +47,12 @@ class ExperimentData(metaclass=ABCMeta):
@property
def experiment_id(self) -> str:
- """
- ID of the experiment.
- """
+ """ID of the experiment."""
return self._experiment_id
@property
def description(self) -> str:
- """
- Description of the experiment.
- """
+ """Description of the experiment."""
return self._description
@property
@@ -123,7 +122,8 @@ class ExperimentData(metaclass=ABCMeta):
@property
def default_tunable_config_id(self) -> Optional[int]:
"""
- Retrieves the (tunable) config id for the default tunable values for this experiment.
+ Retrieves the (tunable) config id for the default tunable values for this
+ experiment.
Note: this is by *default* the first trial executed for this experiment.
However, it is currently possible that the user changed the tunables config
@@ -140,9 +140,9 @@ class ExperimentData(metaclass=ABCMeta):
trials_items = sorted(self.trials.items())
if not trials_items:
return None
- for (_trial_id, trial) in trials_items:
+ for _trial_id, trial in trials_items:
# Take the first config id marked as "defaults" when it was instantiated.
- if strtobool(str(trial.metadata_dict.get('is_defaults', False))):
+ if strtobool(str(trial.metadata_dict.get("is_defaults", False))):
return trial.tunable_config_id
# Fallback (min trial_id)
return trials_items[0][1].tunable_config_id
@@ -157,7 +157,8 @@ class ExperimentData(metaclass=ABCMeta):
-------
results : pandas.DataFrame
A DataFrame with configurations and results from all trials of the experiment.
- Has columns [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status]
+ Has columns
+ [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status]
followed by tunable config parameters (prefixed with "config.") and
trial results (prefixed with "result."). The latter can be NULLs if the
trial was not successful.
diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py
index 350b1b8ec8..9c0e88e3d5 100644
--- a/mlos_bench/mlos_bench/storage/base_storage.py
+++ b/mlos_bench/mlos_bench/storage/base_storage.py
@@ -2,15 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base interface for saving and restoring the benchmark data.
-"""
+"""Base interface for saving and restoring the benchmark data."""
import logging
from abc import ABCMeta, abstractmethod
from datetime import datetime
from types import TracebackType
-from typing import Optional, List, Tuple, Dict, Iterator, Type, Any
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Type
+
from typing_extensions import Literal
from mlos_bench.config.schemas import ConfigSchema
@@ -24,15 +23,16 @@ _LOG = logging.getLogger(__name__)
class Storage(metaclass=ABCMeta):
- """
- An abstract interface between the benchmarking framework
- and storage systems (e.g., SQLite or MLFLow).
+ """An abstract interface between the benchmarking framework and storage systems
+ (e.g., SQLite or MLFLow).
"""
- def __init__(self,
- config: Dict[str, Any],
- global_config: Optional[dict] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ config: Dict[str, Any],
+ global_config: Optional[dict] = None,
+ service: Optional[Service] = None,
+ ):
"""
Create a new storage object.
@@ -48,10 +48,9 @@ class Storage(metaclass=ABCMeta):
self._global_config = global_config or {}
def _validate_json_config(self, config: dict) -> None:
- """
- Reconstructs a basic json config that this class might have been
- instantiated from in order to validate configs provided outside the
- file loading mechanism.
+ """Reconstructs a basic json config that this class might have been instantiated
+ from in order to validate configs provided outside the file loading
+ mechanism.
"""
json_config: dict = {
"class": self.__class__.__module__ + "." + self.__class__.__name__,
@@ -73,13 +72,16 @@ class Storage(metaclass=ABCMeta):
"""
@abstractmethod
- def experiment(self, *,
- experiment_id: str,
- trial_id: int,
- root_env_config: str,
- description: str,
- tunables: TunableGroups,
- opt_targets: Dict[str, Literal['min', 'max']]) -> 'Storage.Experiment':
+ def experiment(
+ self,
+ *,
+ experiment_id: str,
+ trial_id: int,
+ root_env_config: str,
+ description: str,
+ tunables: TunableGroups,
+ opt_targets: Dict[str, Literal["min", "max"]],
+ ) -> "Storage.Experiment":
"""
Create a new experiment in the storage.
@@ -112,26 +114,31 @@ class Storage(metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
Base interface for storing the results of the experiment.
+
This class is instantiated in the `Storage.experiment()` method.
"""
- def __init__(self,
- *,
- tunables: TunableGroups,
- experiment_id: str,
- trial_id: int,
- root_env_config: str,
- description: str,
- opt_targets: Dict[str, Literal['min', 'max']]):
+ def __init__(
+ self,
+ *,
+ tunables: TunableGroups,
+ experiment_id: str,
+ trial_id: int,
+ root_env_config: str,
+ description: str,
+ opt_targets: Dict[str, Literal["min", "max"]],
+ ):
self._tunables = tunables.copy()
self._trial_id = trial_id
self._experiment_id = experiment_id
- (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config)
+ (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(
+ root_env_config
+ )
self._description = description
self._opt_targets = opt_targets
self._in_context = False
- def __enter__(self) -> 'Storage.Experiment':
+ def __enter__(self) -> "Storage.Experiment":
"""
Enter the context of the experiment.
@@ -143,9 +150,12 @@ class Storage(metaclass=ABCMeta):
self._in_context = True
return self
- def __exit__(self, exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType]) -> Literal[False]:
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> Literal[False]:
"""
End the context of the experiment.
@@ -156,8 +166,11 @@ class Storage(metaclass=ABCMeta):
_LOG.debug("Finishing experiment: %s", self)
else:
assert exc_type and exc_val
- _LOG.warning("Finishing experiment: %s", self,
- exc_info=(exc_type, exc_val, exc_tb))
+ _LOG.warning(
+ "Finishing experiment: %s",
+ self,
+ exc_info=(exc_type, exc_val, exc_tb),
+ )
assert self._in_context
self._teardown(is_ok)
self._in_context = False
@@ -168,7 +181,8 @@ class Storage(metaclass=ABCMeta):
def _setup(self) -> None:
"""
- Create a record of the new experiment or find an existing one in the storage.
+ Create a record of the new experiment or find an existing one in the
+ storage.
This method is called by `Storage.Experiment.__enter__()`.
"""
@@ -187,36 +201,34 @@ class Storage(metaclass=ABCMeta):
@property
def experiment_id(self) -> str:
- """Get the Experiment's ID"""
+ """Get the Experiment's ID."""
return self._experiment_id
@property
def trial_id(self) -> int:
- """Get the current Trial ID"""
+ """Get the current Trial ID."""
return self._trial_id
@property
def description(self) -> str:
- """Get the Experiment's description"""
+ """Get the Experiment's description."""
return self._description
@property
def tunables(self) -> TunableGroups:
- """Get the Experiment's tunables"""
+ """Get the Experiment's tunables."""
return self._tunables
@property
def opt_targets(self) -> Dict[str, Literal["min", "max"]]:
- """
- Get the Experiment's optimization targets and directions
- """
+ """Get the Experiment's optimization targets and directions."""
return self._opt_targets
@abstractmethod
def merge(self, experiment_ids: List[str]) -> None:
"""
- Merge in the results of other (compatible) experiments trials.
- Used to help warm up the optimizer for this experiment.
+ Merge in the results of other (compatible) experiments trials. Used to help
+ warm up the optimizer for this experiment.
Parameters
----------
@@ -226,9 +238,7 @@ class Storage(metaclass=ABCMeta):
@abstractmethod
def load_tunable_config(self, config_id: int) -> Dict[str, Any]:
- """
- Load tunable values for a given config ID.
- """
+ """Load tunable values for a given config ID."""
@abstractmethod
def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
@@ -247,8 +257,10 @@ class Storage(metaclass=ABCMeta):
"""
@abstractmethod
- def load(self, last_trial_id: int = -1,
- ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
+ def load(
+ self,
+ last_trial_id: int = -1,
+ ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
"""
Load (tunable values, benchmark scores, status) to warm-up the optimizer.
@@ -268,10 +280,15 @@ class Storage(metaclass=ABCMeta):
"""
@abstractmethod
- def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']:
+ def pending_trials(
+ self,
+ timestamp: datetime,
+ *,
+ running: bool,
+ ) -> Iterator["Storage.Trial"]:
"""
- Return an iterator over the pending trials that are scheduled to run
- on or before the specified timestamp.
+ Return an iterator over the pending trials that are scheduled to run on or
+ before the specified timestamp.
Parameters
----------
@@ -288,8 +305,12 @@ class Storage(metaclass=ABCMeta):
"""
@abstractmethod
- def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
- config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial':
+ def new_trial(
+ self,
+ tunables: TunableGroups,
+ ts_start: Optional[datetime] = None,
+ config: Optional[Dict[str, Any]] = None,
+ ) -> "Storage.Trial":
"""
Create a new experiment run in the storage.
@@ -313,13 +334,20 @@ class Storage(metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
Base interface for storing the results of a single run of the experiment.
+
This class is instantiated in the `Storage.Experiment.trial()` method.
"""
- def __init__(self, *,
- tunables: TunableGroups, experiment_id: str, trial_id: int,
- tunable_config_id: int, opt_targets: Dict[str, Literal['min', 'max']],
- config: Optional[Dict[str, Any]] = None):
+ def __init__(
+ self,
+ *,
+ tunables: TunableGroups,
+ experiment_id: str,
+ trial_id: int,
+ tunable_config_id: int,
+ opt_targets: Dict[str, Literal["min", "max"]],
+ config: Optional[Dict[str, Any]] = None,
+ ):
self._tunables = tunables
self._experiment_id = experiment_id
self._trial_id = trial_id
@@ -332,29 +360,23 @@ class Storage(metaclass=ABCMeta):
@property
def trial_id(self) -> int:
- """
- ID of the current trial.
- """
+ """ID of the current trial."""
return self._trial_id
@property
def tunable_config_id(self) -> int:
- """
- ID of the current trial (tunable) configuration.
- """
+ """ID of the current trial (tunable) configuration."""
return self._tunable_config_id
@property
def opt_targets(self) -> Dict[str, Literal["min", "max"]]:
- """
- Get the Trial's optimization targets and directions.
- """
+ """Get the Trial's optimization targets and directions."""
return self._opt_targets
@property
def tunables(self) -> TunableGroups:
"""
- Tunable parameters of the current trial
+ Tunable parameters of the current trial.
(e.g., application Environment's "config")
"""
@@ -362,8 +384,8 @@ class Storage(metaclass=ABCMeta):
def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
- Produce a copy of the global configuration updated
- with the parameters of the current trial.
+ Produce a copy of the global configuration updated with the parameters of
+ the current trial.
Note: this is not the target Environment's "config" (i.e., tunable
params), but rather the internal "config" which consists of a
@@ -377,9 +399,12 @@ class Storage(metaclass=ABCMeta):
return config
@abstractmethod
- def update(self, status: Status, timestamp: datetime,
- metrics: Optional[Dict[str, Any]] = None
- ) -> Optional[Dict[str, Any]]:
+ def update(
+ self,
+ status: Status,
+ timestamp: datetime,
+ metrics: Optional[Dict[str, Any]] = None,
+ ) -> Optional[Dict[str, Any]]:
"""
Update the storage with the results of the experiment.
@@ -403,14 +428,21 @@ class Storage(metaclass=ABCMeta):
assert metrics is not None
opt_targets = set(self._opt_targets.keys())
if not opt_targets.issubset(metrics.keys()):
- _LOG.warning("Trial %s :: opt.targets missing: %s",
- self, opt_targets.difference(metrics.keys()))
+ _LOG.warning(
+ "Trial %s :: opt.targets missing: %s",
+ self,
+ opt_targets.difference(metrics.keys()),
+ )
# raise ValueError()
return metrics
@abstractmethod
- def update_telemetry(self, status: Status, timestamp: datetime,
- metrics: List[Tuple[datetime, str, Any]]) -> None:
+ def update_telemetry(
+ self,
+ status: Status,
+ timestamp: datetime,
+ metrics: List[Tuple[datetime, str, Any]],
+ ) -> None:
"""
Save the experiment's telemetry data and intermediate status.
diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py
index 93d1b62c9b..2a74d77e5e 100644
--- a/mlos_bench/mlos_bench/storage/base_trial_data.py
+++ b/mlos_bench/mlos_bench/storage/base_trial_data.py
@@ -2,40 +2,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base interface for accessing the stored benchmark trial data.
-"""
+"""Base interface for accessing the stored benchmark trial data."""
from abc import ABCMeta, abstractmethod
from datetime import datetime
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, Optional
import pandas
from pytz import UTC
from mlos_bench.environments.status import Status
-from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.storage.base_tunable_config_data import TunableConfigData
from mlos_bench.storage.util import kv_df_to_dict
+from mlos_bench.tunables.tunable import TunableValue
if TYPE_CHECKING:
- from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
+ from mlos_bench.storage.base_tunable_config_trial_group_data import (
+ TunableConfigTrialGroupData,
+ )
class TrialData(metaclass=ABCMeta):
"""
Base interface for accessing the stored experiment benchmark trial data.
- A trial is a single run of an experiment with a given configuration (e.g., set
- of tunable parameters).
+ A trial is a single run of an experiment with a given configuration (e.g., set of
+ tunable parameters).
"""
- def __init__(self, *,
- experiment_id: str,
- trial_id: int,
- tunable_config_id: int,
- ts_start: datetime,
- ts_end: Optional[datetime],
- status: Status):
+ def __init__(
+ self,
+ *,
+ experiment_id: str,
+ trial_id: int,
+ tunable_config_id: int,
+ ts_start: datetime,
+ ts_end: Optional[datetime],
+ status: Status,
+ ):
self._experiment_id = experiment_id
self._trial_id = trial_id
self._tunable_config_id = tunable_config_id
@@ -46,7 +49,10 @@ class TrialData(metaclass=ABCMeta):
self._status = status
def __repr__(self) -> str:
- return f"Trial :: {self._experiment_id}:{self._trial_id} cid:{self._tunable_config_id} {self._status.name}"
+ return (
+ f"Trial :: {self._experiment_id}:{self._trial_id} "
+ f"cid:{self._tunable_config_id} {self._status.name}"
+ )
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
@@ -55,44 +61,32 @@ class TrialData(metaclass=ABCMeta):
@property
def experiment_id(self) -> str:
- """
- ID of the experiment this trial belongs to.
- """
+ """ID of the experiment this trial belongs to."""
return self._experiment_id
@property
def trial_id(self) -> int:
- """
- ID of the trial.
- """
+ """ID of the trial."""
return self._trial_id
@property
def ts_start(self) -> datetime:
- """
- Start timestamp of the trial (UTC).
- """
+ """Start timestamp of the trial (UTC)."""
return self._ts_start
@property
def ts_end(self) -> Optional[datetime]:
- """
- End timestamp of the trial (UTC).
- """
+ """End timestamp of the trial (UTC)."""
return self._ts_end
@property
def status(self) -> Status:
- """
- Status of the trial.
- """
+ """Status of the trial."""
return self._status
@property
def tunable_config_id(self) -> int:
- """
- ID of the (tunable) configuration of the trial.
- """
+ """ID of the (tunable) configuration of the trial."""
return self._tunable_config_id
@property
@@ -112,9 +106,7 @@ class TrialData(metaclass=ABCMeta):
@property
@abstractmethod
def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData":
- """
- Retrieve the trial's (tunable) config trial group data from the storage.
- """
+ """Retrieve the trial's (tunable) config trial group data from the storage."""
@property
@abstractmethod
diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py
index 0dce110b1b..62751deb8e 100644
--- a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py
+++ b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base interface for accessing the stored benchmark (tunable) config data.
-"""
+"""Base interface for accessing the stored benchmark (tunable) config data."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional
@@ -21,8 +19,7 @@ class TunableConfigData(metaclass=ABCMeta):
A configuration in this context is the set of tunable parameter values.
"""
- def __init__(self, *,
- tunable_config_id: int):
+ def __init__(self, *, tunable_config_id: int):
self._tunable_config_id = tunable_config_id
def __repr__(self) -> str:
@@ -35,9 +32,7 @@ class TunableConfigData(metaclass=ABCMeta):
@property
def tunable_config_id(self) -> int:
- """
- Unique ID of the (tunable) configuration.
- """
+ """Unique ID of the (tunable) configuration."""
return self._tunable_config_id
@property
diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py
index b64524fb85..c01c7544b3 100644
--- a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py
+++ b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py
@@ -2,12 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base interface for accessing the stored benchmark config trial group data.
-"""
+"""Base interface for accessing the stored benchmark config trial group data."""
from abc import ABCMeta, abstractmethod
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, Optional
import pandas
@@ -19,18 +17,21 @@ if TYPE_CHECKING:
class TunableConfigTrialGroupData(metaclass=ABCMeta):
"""
- Base interface for accessing the stored experiment benchmark tunable config
- trial group data.
+ Base interface for accessing the stored experiment benchmark tunable config trial
+ group data.
A (tunable) config is used to define an instance of values for a set of tunable
parameters for a given experiment and can be used by one or more trial instances
(e.g., for repeats), which we call a (tunable) config trial group.
"""
- def __init__(self, *,
- experiment_id: str,
- tunable_config_id: int,
- tunable_config_trial_group_id: Optional[int] = None):
+ def __init__(
+ self,
+ *,
+ experiment_id: str,
+ tunable_config_id: int,
+ tunable_config_trial_group_id: Optional[int] = None,
+ ):
self._experiment_id = experiment_id
self._tunable_config_id = tunable_config_id
# can be lazily initialized as necessary:
@@ -38,23 +39,17 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
@property
def experiment_id(self) -> str:
- """
- ID of the experiment.
- """
+ """ID of the experiment."""
return self._experiment_id
@property
def tunable_config_id(self) -> int:
- """
- ID of the config.
- """
+ """ID of the config."""
return self._tunable_config_id
@abstractmethod
def _get_tunable_config_trial_group_id(self) -> int:
- """
- Retrieve the trial's config_trial_group_id from the storage.
- """
+ """Retrieve the trial's config_trial_group_id from the storage."""
raise NotImplementedError("subclass must implement")
@property
@@ -77,13 +72,17 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
- return self._tunable_config_id == other._tunable_config_id and self._experiment_id == other._experiment_id
+ return (
+ self._tunable_config_id == other._tunable_config_id
+ and self._experiment_id == other._experiment_id
+ )
@property
@abstractmethod
def tunable_config(self) -> TunableConfigData:
"""
- Retrieve the (tunable) config data for this (tunable) config trial group from the storage.
+ Retrieve the (tunable) config data for this (tunable) config trial group from
+ the storage.
Returns
-------
@@ -94,7 +93,8 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
@abstractmethod
def trials(self) -> Dict[int, "TrialData"]:
"""
- Retrieve the trials' data for this (tunable) config trial group from the storage.
+ Retrieve the trials' data for this (tunable) config trial group from the
+ storage.
Returns
-------
@@ -106,7 +106,8 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
@abstractmethod
def results_df(self) -> pandas.DataFrame:
"""
- Retrieve all results for this (tunable) config trial group as a single DataFrame.
+ Retrieve all results for this (tunable) config trial group as a single
+ DataFrame.
Returns
-------
diff --git a/mlos_bench/mlos_bench/storage/sql/__init__.py b/mlos_bench/mlos_bench/storage/sql/__init__.py
index 735e21bcaf..9d749ed35d 100644
--- a/mlos_bench/mlos_bench/storage/sql/__init__.py
+++ b/mlos_bench/mlos_bench/storage/sql/__init__.py
@@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Interfaces to the SQL-based storage backends for OS Autotune.
-"""
+"""Interfaces to the SQL-based storage backends for OS Autotune."""
from mlos_bench.storage.sql.storage import SqlStorage
__all__ = [
- 'SqlStorage',
+ "SqlStorage",
]
diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py
index a3895f065a..3b0c6c31fb 100644
--- a/mlos_bench/mlos_bench/storage/sql/common.py
+++ b/mlos_bench/mlos_bench/storage/sql/common.py
@@ -2,39 +2,45 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Common SQL methods for accessing the stored benchmark data.
-"""
+"""Common SQL methods for accessing the stored benchmark data."""
from typing import Dict, Optional
import pandas
-from sqlalchemy import Engine, Integer, func, and_, select
+from sqlalchemy import Engine, Integer, and_, func, select
from mlos_bench.environments.status import Status
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.storage.base_trial_data import TrialData
from mlos_bench.storage.sql.schema import DbSchema
-from mlos_bench.util import utcify_timestamp, utcify_nullable_timestamp
+from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp
def get_trials(
- engine: Engine,
- schema: DbSchema,
- experiment_id: str,
- tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]:
+ engine: Engine,
+ schema: DbSchema,
+ experiment_id: str,
+ tunable_config_id: Optional[int] = None,
+) -> Dict[int, TrialData]:
"""
- Gets TrialData for the given experiment_data and optionally additionally
- restricted by tunable_config_id.
+ Gets TrialData for the given experiment_data and optionally additionally restricted
+ by tunable_config_id.
+
Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
"""
- from mlos_bench.storage.sql.trial_data import TrialSqlData # pylint: disable=import-outside-toplevel,cyclic-import
+ # pylint: disable=import-outside-toplevel,cyclic-import
+ from mlos_bench.storage.sql.trial_data import TrialSqlData
+
with engine.connect() as conn:
# Build up sql a statement for fetching trials.
- stmt = schema.trial.select().where(
- schema.trial.c.exp_id == experiment_id,
- ).order_by(
- schema.trial.c.exp_id.asc(),
- schema.trial.c.trial_id.asc(),
+ stmt = (
+ schema.trial.select()
+ .where(
+ schema.trial.c.exp_id == experiment_id,
+ )
+ .order_by(
+ schema.trial.c.exp_id.asc(),
+ schema.trial.c.trial_id.asc(),
+ )
)
# Optionally restrict to those using a particular tunable config.
if tunable_config_id is not None:
@@ -58,27 +64,36 @@ def get_trials(
def get_results_df(
- engine: Engine,
- schema: DbSchema,
- experiment_id: str,
- tunable_config_id: Optional[int] = None) -> pandas.DataFrame:
+ engine: Engine,
+ schema: DbSchema,
+ experiment_id: str,
+ tunable_config_id: Optional[int] = None,
+) -> pandas.DataFrame:
"""
- Gets TrialData for the given experiment_data and optionally additionally
- restricted by tunable_config_id.
+ Gets TrialData for the given experiment_data and optionally additionally restricted
+ by tunable_config_id.
+
Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
"""
# pylint: disable=too-many-locals
with engine.connect() as conn:
# Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config.
- tunable_config_group_id_stmt = schema.trial.select().with_only_columns(
- schema.trial.c.exp_id,
- schema.trial.c.config_id,
- func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'),
- ).where(
- schema.trial.c.exp_id == experiment_id,
- ).group_by(
- schema.trial.c.exp_id,
- schema.trial.c.config_id,
+ tunable_config_group_id_stmt = (
+ schema.trial.select()
+ .with_only_columns(
+ schema.trial.c.exp_id,
+ schema.trial.c.config_id,
+ func.min(schema.trial.c.trial_id)
+ .cast(Integer)
+ .label("tunable_config_trial_group_id"),
+ )
+ .where(
+ schema.trial.c.exp_id == experiment_id,
+ )
+ .group_by(
+ schema.trial.c.exp_id,
+ schema.trial.c.config_id,
+ )
)
# Optionally restrict to those using a particular tunable config.
if tunable_config_id is not None:
@@ -88,18 +103,22 @@ def get_results_df(
tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery()
# Get each trial's metadata.
- cur_trials_stmt = select(
- schema.trial,
- tunable_config_trial_group_id_subquery,
- ).where(
- schema.trial.c.exp_id == experiment_id,
- and_(
- tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
- tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
- ),
- ).order_by(
- schema.trial.c.exp_id.asc(),
- schema.trial.c.trial_id.asc(),
+ cur_trials_stmt = (
+ select(
+ schema.trial,
+ tunable_config_trial_group_id_subquery,
+ )
+ .where(
+ schema.trial.c.exp_id == experiment_id,
+ and_(
+ tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
+ tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
+ ),
+ )
+ .order_by(
+ schema.trial.c.exp_id.asc(),
+ schema.trial.c.trial_id.asc(),
+ )
)
# Optionally restrict to those using a particular tunable config.
if tunable_config_id is not None:
@@ -108,39 +127,48 @@ def get_results_df(
)
cur_trials = conn.execute(cur_trials_stmt)
trials_df = pandas.DataFrame(
- [(
- row.trial_id,
- utcify_timestamp(row.ts_start, origin="utc"),
- utcify_nullable_timestamp(row.ts_end, origin="utc"),
- row.config_id,
- row.tunable_config_trial_group_id,
- row.status,
- ) for row in cur_trials.fetchall()],
+ [
+ (
+ row.trial_id,
+ utcify_timestamp(row.ts_start, origin="utc"),
+ utcify_nullable_timestamp(row.ts_end, origin="utc"),
+ row.config_id,
+ row.tunable_config_trial_group_id,
+ row.status,
+ )
+ for row in cur_trials.fetchall()
+ ],
columns=[
- 'trial_id',
- 'ts_start',
- 'ts_end',
- 'tunable_config_id',
- 'tunable_config_trial_group_id',
- 'status',
- ]
+ "trial_id",
+ "ts_start",
+ "ts_end",
+ "tunable_config_id",
+ "tunable_config_trial_group_id",
+ "status",
+ ],
)
# Get each trial's config in wide format.
- configs_stmt = schema.trial.select().with_only_columns(
- schema.trial.c.trial_id,
- schema.trial.c.config_id,
- schema.config_param.c.param_id,
- schema.config_param.c.param_value,
- ).where(
- schema.trial.c.exp_id == experiment_id,
- ).join(
- schema.config_param,
- schema.config_param.c.config_id == schema.trial.c.config_id,
- isouter=True
- ).order_by(
- schema.trial.c.trial_id,
- schema.config_param.c.param_id,
+ configs_stmt = (
+ schema.trial.select()
+ .with_only_columns(
+ schema.trial.c.trial_id,
+ schema.trial.c.config_id,
+ schema.config_param.c.param_id,
+ schema.config_param.c.param_value,
+ )
+ .where(
+ schema.trial.c.exp_id == experiment_id,
+ )
+ .join(
+ schema.config_param,
+ schema.config_param.c.config_id == schema.trial.c.config_id,
+ isouter=True,
+ )
+ .order_by(
+ schema.trial.c.trial_id,
+ schema.config_param.c.param_id,
+ )
)
if tunable_config_id is not None:
configs_stmt = configs_stmt.where(
@@ -148,41 +176,75 @@ def get_results_df(
)
configs = conn.execute(configs_stmt)
configs_df = pandas.DataFrame(
- [(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value)
- for row in configs.fetchall()],
- columns=['trial_id', 'tunable_config_id', 'param', 'value']
+ [
+ (
+ row.trial_id,
+ row.config_id,
+ ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id,
+ row.param_value,
+ )
+ for row in configs.fetchall()
+ ],
+ columns=["trial_id", "tunable_config_id", "param", "value"],
).pivot(
- index=["trial_id", "tunable_config_id"], columns="param", values="value",
+ index=["trial_id", "tunable_config_id"],
+ columns="param",
+ values="value",
)
- configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp)
+ configs_df = configs_df.apply( # type: ignore[assignment] # (fp)
+ pandas.to_numeric,
+ errors="coerce",
+ ).fillna(configs_df)
# Get each trial's results in wide format.
- results_stmt = schema.trial_result.select().with_only_columns(
- schema.trial_result.c.trial_id,
- schema.trial_result.c.metric_id,
- schema.trial_result.c.metric_value,
- ).where(
- schema.trial_result.c.exp_id == experiment_id,
- ).order_by(
- schema.trial_result.c.trial_id,
- schema.trial_result.c.metric_id,
+ results_stmt = (
+ schema.trial_result.select()
+ .with_only_columns(
+ schema.trial_result.c.trial_id,
+ schema.trial_result.c.metric_id,
+ schema.trial_result.c.metric_value,
+ )
+ .where(
+ schema.trial_result.c.exp_id == experiment_id,
+ )
+ .order_by(
+ schema.trial_result.c.trial_id,
+ schema.trial_result.c.metric_id,
+ )
)
if tunable_config_id is not None:
- results_stmt = results_stmt.join(schema.trial, and_(
- schema.trial.c.exp_id == schema.trial_result.c.exp_id,
- schema.trial.c.trial_id == schema.trial_result.c.trial_id,
- schema.trial.c.config_id == tunable_config_id,
- ))
+ results_stmt = results_stmt.join(
+ schema.trial,
+ and_(
+ schema.trial.c.exp_id == schema.trial_result.c.exp_id,
+ schema.trial.c.trial_id == schema.trial_result.c.trial_id,
+ schema.trial.c.config_id == tunable_config_id,
+ ),
+ )
results = conn.execute(results_stmt)
results_df = pandas.DataFrame(
- [(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value)
- for row in results.fetchall()],
- columns=['trial_id', 'metric', 'value']
+ [
+ (
+ row.trial_id,
+ ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id,
+ row.metric_value,
+ )
+ for row in results.fetchall()
+ ],
+ columns=["trial_id", "metric", "value"],
).pivot(
- index="trial_id", columns="metric", values="value",
+ index="trial_id",
+ columns="metric",
+ values="value",
)
- results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp)
+ results_df = results_df.apply( # type: ignore[assignment] # (fp)
+ pandas.to_numeric,
+ errors="coerce",
+ ).fillna(results_df)
# Concat the trials, configs, and results.
- return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \
- .merge(results_df, on="trial_id", how="left")
+ return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge(
+ results_df,
+ on="trial_id",
+ how="left",
+ )
diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py
index 108d3e1d2d..8f6f938945 100644
--- a/mlos_bench/mlos_bench/storage/sql/experiment.py
+++ b/mlos_bench/mlos_bench/storage/sql/experiment.py
@@ -2,43 +2,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Saving and restoring the benchmark data using SQLAlchemy.
-"""
+"""Saving and restoring the benchmark data using SQLAlchemy."""
-import logging
import hashlib
+import logging
from datetime import datetime
-from typing import Optional, Tuple, List, Literal, Dict, Iterator, Any
+from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple
from pytz import UTC
-
-from sqlalchemy import Engine, Connection, CursorResult, Table, column, func, select
+from sqlalchemy import Connection, CursorResult, Engine, Table, column, func, select
from mlos_bench.environments.status import Status
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.trial import Trial
+from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import nullable, utcify_timestamp
_LOG = logging.getLogger(__name__)
class Experiment(Storage.Experiment):
- """
- Logic for retrieving and storing the results of a single experiment.
- """
+ """Logic for retrieving and storing the results of a single experiment."""
- def __init__(self, *,
- engine: Engine,
- schema: DbSchema,
- tunables: TunableGroups,
- experiment_id: str,
- trial_id: int,
- root_env_config: str,
- description: str,
- opt_targets: Dict[str, Literal['min', 'max']]):
+ def __init__(
+ self,
+ *,
+ engine: Engine,
+ schema: DbSchema,
+ tunables: TunableGroups,
+ experiment_id: str,
+ trial_id: int,
+ root_env_config: str,
+ description: str,
+ opt_targets: Dict[str, Literal["min", "max"]],
+ ):
super().__init__(
tunables=tunables,
experiment_id=experiment_id,
@@ -56,18 +54,22 @@ class Experiment(Storage.Experiment):
# Get git info and the last trial ID for the experiment.
# pylint: disable=not-callable
exp_info = conn.execute(
- self._schema.experiment.select().with_only_columns(
+ self._schema.experiment.select()
+ .with_only_columns(
self._schema.experiment.c.git_repo,
self._schema.experiment.c.git_commit,
self._schema.experiment.c.root_env_config,
func.max(self._schema.trial.c.trial_id).label("trial_id"),
- ).join(
+ )
+ .join(
self._schema.trial,
self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id,
- isouter=True
- ).where(
+ isouter=True,
+ )
+ .where(
self._schema.experiment.c.exp_id == self._experiment_id,
- ).group_by(
+ )
+ .group_by(
self._schema.experiment.c.git_repo,
self._schema.experiment.c.git_commit,
self._schema.experiment.c.root_env_config,
@@ -76,33 +78,47 @@ class Experiment(Storage.Experiment):
if exp_info is None:
_LOG.info("Start new experiment: %s", self._experiment_id)
# It's a new experiment: create a record for it in the database.
- conn.execute(self._schema.experiment.insert().values(
- exp_id=self._experiment_id,
- description=self._description,
- git_repo=self._git_repo,
- git_commit=self._git_commit,
- root_env_config=self._root_env_config,
- ))
- conn.execute(self._schema.objectives.insert().values([
- {
- "exp_id": self._experiment_id,
- "optimization_target": opt_target,
- "optimization_direction": opt_dir,
- }
- for (opt_target, opt_dir) in self.opt_targets.items()
- ]))
+ conn.execute(
+ self._schema.experiment.insert().values(
+ exp_id=self._experiment_id,
+ description=self._description,
+ git_repo=self._git_repo,
+ git_commit=self._git_commit,
+ root_env_config=self._root_env_config,
+ )
+ )
+ conn.execute(
+ self._schema.objectives.insert().values(
+ [
+ {
+ "exp_id": self._experiment_id,
+ "optimization_target": opt_target,
+ "optimization_direction": opt_dir,
+ }
+ for (opt_target, opt_dir) in self.opt_targets.items()
+ ]
+ )
+ )
else:
if exp_info.trial_id is not None:
self._trial_id = exp_info.trial_id + 1
- _LOG.info("Continue experiment: %s last trial: %s resume from: %d",
- self._experiment_id, exp_info.trial_id, self._trial_id)
+ _LOG.info(
+ "Continue experiment: %s last trial: %s resume from: %d",
+ self._experiment_id,
+ exp_info.trial_id,
+ self._trial_id,
+ )
# TODO: Sanity check that certain critical configs (e.g.,
# objectives) haven't changed to be incompatible such that a new
# experiment should be started (possibly by prewarming with the
# previous one).
if exp_info.git_commit != self._git_commit:
- _LOG.warning("Experiment %s git expected: %s %s",
- self, exp_info.git_repo, exp_info.git_commit)
+ _LOG.warning(
+ "Experiment %s git expected: %s %s",
+ self,
+ exp_info.git_repo,
+ exp_info.git_commit,
+ )
def merge(self, experiment_ids: List[str]) -> None:
_LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
@@ -115,33 +131,42 @@ class Experiment(Storage.Experiment):
def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
with self._engine.connect() as conn:
cur_telemetry = conn.execute(
- self._schema.trial_telemetry.select().where(
+ self._schema.trial_telemetry.select()
+ .where(
self._schema.trial_telemetry.c.exp_id == self._experiment_id,
- self._schema.trial_telemetry.c.trial_id == trial_id
- ).order_by(
+ self._schema.trial_telemetry.c.trial_id == trial_id,
+ )
+ .order_by(
self._schema.trial_telemetry.c.ts,
self._schema.trial_telemetry.c.metric_id,
)
)
# Not all storage backends store the original zone info.
# We try to ensure data is entered in UTC and augment it on return again here.
- return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
- for row in cur_telemetry.fetchall()]
+ return [
+ (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
+ for row in cur_telemetry.fetchall()
+ ]
- def load(self, last_trial_id: int = -1,
- ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
+ def load(
+ self,
+ last_trial_id: int = -1,
+ ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
with self._engine.connect() as conn:
cur_trials = conn.execute(
- self._schema.trial.select().with_only_columns(
+ self._schema.trial.select()
+ .with_only_columns(
self._schema.trial.c.trial_id,
self._schema.trial.c.config_id,
self._schema.trial.c.status,
- ).where(
+ )
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id > last_trial_id,
- self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']),
- ).order_by(
+ self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]),
+ )
+ .order_by(
self._schema.trial.c.trial_id.asc(),
)
)
@@ -155,12 +180,24 @@ class Experiment(Storage.Experiment):
stat = Status[trial.status]
status.append(stat)
trial_ids.append(trial.trial_id)
- configs.append(self._get_key_val(
- conn, self._schema.config_param, "param", config_id=trial.config_id))
+ configs.append(
+ self._get_key_val(
+ conn,
+ self._schema.config_param,
+ "param",
+ config_id=trial.config_id,
+ )
+ )
if stat.is_succeeded():
- scores.append(self._get_key_val(
- conn, self._schema.trial_result, "metric",
- exp_id=self._experiment_id, trial_id=trial.trial_id))
+ scores.append(
+ self._get_key_val(
+ conn,
+ self._schema.trial_result,
+ "metric",
+ exp_id=self._experiment_id,
+ trial_id=trial.trial_id,
+ )
+ )
else:
scores.append(None)
@@ -170,55 +207,73 @@ class Experiment(Storage.Experiment):
def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]:
"""
Helper method to retrieve key-value pairs from the database.
+
(E.g., configurations, results, and telemetry).
"""
cur_result: CursorResult[Tuple[str, Any]] = conn.execute(
select(
column(f"{field}_id"),
column(f"{field}_value"),
- ).select_from(table).where(
- *[column(key) == val for (key, val) in kwargs.items()]
)
+ .select_from(table)
+ .where(*[column(key) == val for (key, val) in kwargs.items()])
+ )
+ # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to
+ # avoid naming conflicts.
+ return dict(
+ row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access
)
- # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts.
- return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access
@staticmethod
- def _save_params(conn: Connection, table: Table,
- params: Dict[str, Any], **kwargs: Any) -> None:
+ def _save_params(
+ conn: Connection,
+ table: Table,
+ params: Dict[str, Any],
+ **kwargs: Any,
+ ) -> None:
if not params:
return
- conn.execute(table.insert(), [
- {
- **kwargs,
- "param_id": key,
- "param_value": nullable(str, val)
- }
- for (key, val) in params.items()
- ])
+ conn.execute(
+ table.insert(),
+ [
+ {**kwargs, "param_id": key, "param_value": nullable(str, val)}
+ for (key, val) in params.items()
+ ],
+ )
def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
timestamp = utcify_timestamp(timestamp, origin="local")
_LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
if running:
- pending_status = ['PENDING', 'READY', 'RUNNING']
+ pending_status = ["PENDING", "READY", "RUNNING"]
else:
- pending_status = ['PENDING']
+ pending_status = ["PENDING"]
with self._engine.connect() as conn:
- cur_trials = conn.execute(self._schema.trial.select().where(
- self._schema.trial.c.exp_id == self._experiment_id,
- (self._schema.trial.c.ts_start.is_(None) |
- (self._schema.trial.c.ts_start <= timestamp)),
- self._schema.trial.c.ts_end.is_(None),
- self._schema.trial.c.status.in_(pending_status),
- ))
+ cur_trials = conn.execute(
+ self._schema.trial.select().where(
+ self._schema.trial.c.exp_id == self._experiment_id,
+ (
+ self._schema.trial.c.ts_start.is_(None)
+ | (self._schema.trial.c.ts_start <= timestamp)
+ ),
+ self._schema.trial.c.ts_end.is_(None),
+ self._schema.trial.c.status.in_(pending_status),
+ )
+ )
for trial in cur_trials.fetchall():
tunables = self._get_key_val(
- conn, self._schema.config_param, "param",
- config_id=trial.config_id)
+ conn,
+ self._schema.config_param,
+ "param",
+ config_id=trial.config_id,
+ )
config = self._get_key_val(
- conn, self._schema.trial_param, "param",
- exp_id=self._experiment_id, trial_id=trial.trial_id)
+ conn,
+ self._schema.trial_param,
+ "param",
+ exp_id=self._experiment_id,
+ trial_id=trial.trial_id,
+ )
yield Trial(
engine=self._engine,
schema=self._schema,
@@ -233,47 +288,63 @@ class Experiment(Storage.Experiment):
def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:
"""
- Get the config ID for the given tunables. If the config does not exist,
- create a new record for it.
+ Get the config ID for the given tunables.
+
+ If the config does not exist, create a new record for it.
"""
- config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest()
- cur_config = conn.execute(self._schema.config.select().where(
- self._schema.config.c.config_hash == config_hash
- )).fetchone()
+ config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest()
+ cur_config = conn.execute(
+ self._schema.config.select().where(self._schema.config.c.config_hash == config_hash)
+ ).fetchone()
if cur_config is not None:
return int(cur_config.config_id) # mypy doesn't know it's always int
# Config not found, create a new one:
- config_id: int = conn.execute(self._schema.config.insert().values(
- config_hash=config_hash)).inserted_primary_key[0]
+ config_id: int = conn.execute(
+ self._schema.config.insert().values(config_hash=config_hash)
+ ).inserted_primary_key[0]
self._save_params(
- conn, self._schema.config_param,
+ conn,
+ self._schema.config_param,
{tunable.name: tunable.value for (tunable, _group) in tunables},
- config_id=config_id)
+ config_id=config_id,
+ )
return config_id
- def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
- config: Optional[Dict[str, Any]] = None) -> Storage.Trial:
+ def new_trial(
+ self,
+ tunables: TunableGroups,
+ ts_start: Optional[datetime] = None,
+ config: Optional[Dict[str, Any]] = None,
+ ) -> Storage.Trial:
# MySQL can round microseconds into the future causing scheduler to skip trials.
# Truncate microseconds to avoid this issue.
- ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(microsecond=0)
+ ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(
+ microsecond=0
+ )
_LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
with self._engine.begin() as conn:
try:
config_id = self._get_config_id(conn, tunables)
- conn.execute(self._schema.trial.insert().values(
- exp_id=self._experiment_id,
- trial_id=self._trial_id,
- config_id=config_id,
- ts_start=ts_start,
- status='PENDING',
- ))
+ conn.execute(
+ self._schema.trial.insert().values(
+ exp_id=self._experiment_id,
+ trial_id=self._trial_id,
+ config_id=config_id,
+ ts_start=ts_start,
+ status="PENDING",
+ )
+ )
# Note: config here is the framework config, not the target
# environment config (i.e., tunables).
if config is not None:
self._save_params(
- conn, self._schema.trial_param, config,
- exp_id=self._experiment_id, trial_id=self._trial_id)
+ conn,
+ self._schema.trial_param,
+ config,
+ exp_id=self._experiment_id,
+ trial_id=self._trial_id,
+ )
trial = Trial(
engine=self._engine,
diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py
index 31b1d64af0..f29b9fedda 100644
--- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py
+++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py
@@ -2,12 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-An interface to access the experiment benchmark data stored in SQL DB.
-"""
-from typing import Dict, Literal, Optional
-
+"""An interface to access the experiment benchmark data stored in SQL DB."""
import logging
+from typing import Dict, Literal, Optional
import pandas
from sqlalchemy import Engine, Integer, String, func
@@ -15,11 +12,15 @@ from sqlalchemy import Engine, Integer, String, func
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.storage.base_trial_data import TrialData
from mlos_bench.storage.base_tunable_config_data import TunableConfigData
-from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
+from mlos_bench.storage.base_tunable_config_trial_group_data import (
+ TunableConfigTrialGroupData,
+)
from mlos_bench.storage.sql import common
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
-from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData
+from mlos_bench.storage.sql.tunable_config_trial_group_data import (
+ TunableConfigTrialGroupSqlData,
+)
_LOG = logging.getLogger(__name__)
@@ -32,14 +33,17 @@ class ExperimentSqlData(ExperimentData):
scripts and mlos_bench configuration files.
"""
- def __init__(self, *,
- engine: Engine,
- schema: DbSchema,
- experiment_id: str,
- description: str,
- root_env_config: str,
- git_repo: str,
- git_commit: str):
+ def __init__(
+ self,
+ *,
+ engine: Engine,
+ schema: DbSchema,
+ experiment_id: str,
+ description: str,
+ root_env_config: str,
+ git_repo: str,
+ git_commit: str,
+ ):
super().__init__(
experiment_id=experiment_id,
description=description,
@@ -54,9 +58,11 @@ class ExperimentSqlData(ExperimentData):
def objectives(self) -> Dict[str, Literal["min", "max"]]:
with self._engine.connect() as conn:
objectives_db_data = conn.execute(
- self._schema.objectives.select().where(
+ self._schema.objectives.select()
+ .where(
self._schema.objectives.c.exp_id == self._experiment_id,
- ).order_by(
+ )
+ .order_by(
self._schema.objectives.c.weight.desc(),
self._schema.objectives.c.optimization_target.asc(),
)
@@ -66,7 +72,8 @@ class ExperimentSqlData(ExperimentData):
for objective in objectives_db_data.fetchall()
}
- # TODO: provide a way to get individual data to avoid repeated bulk fetches where only small amounts of data is accessed.
+ # TODO: provide a way to get individual data to avoid repeated bulk fetches
+ # where only small amounts of data is accessed.
# Or else make the TrialData object lazily populate.
@property
@@ -77,13 +84,17 @@ class ExperimentSqlData(ExperimentData):
def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]:
with self._engine.connect() as conn:
tunable_config_trial_groups = conn.execute(
- self._schema.trial.select().with_only_columns(
+ self._schema.trial.select()
+ .with_only_columns(
self._schema.trial.c.config_id,
- func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
- 'tunable_config_trial_group_id'),
- ).where(
+ func.min(self._schema.trial.c.trial_id)
+ .cast(Integer)
+ .label("tunable_config_trial_group_id"), # pylint: disable=not-callable
+ )
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
- ).group_by(
+ )
+ .group_by(
self._schema.trial.c.exp_id,
self._schema.trial.c.config_id,
)
@@ -94,7 +105,7 @@ class ExperimentSqlData(ExperimentData):
schema=self._schema,
experiment_id=self._experiment_id,
tunable_config_id=tunable_config_trial_group.config_id,
- tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id,
+ tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, # pylint:disable=line-too-long # noqa
)
for tunable_config_trial_group in tunable_config_trial_groups.fetchall()
}
@@ -103,11 +114,14 @@ class ExperimentSqlData(ExperimentData):
def tunable_configs(self) -> Dict[int, TunableConfigData]:
with self._engine.connect() as conn:
tunable_configs = conn.execute(
- self._schema.trial.select().with_only_columns(
- self._schema.trial.c.config_id.cast(Integer).label('config_id'),
- ).where(
+ self._schema.trial.select()
+ .with_only_columns(
+ self._schema.trial.c.config_id.cast(Integer).label("config_id"),
+ )
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
- ).group_by(
+ )
+ .group_by(
self._schema.trial.c.exp_id,
self._schema.trial.c.config_id,
)
@@ -124,7 +138,8 @@ class ExperimentSqlData(ExperimentData):
@property
def default_tunable_config_id(self) -> Optional[int]:
"""
- Retrieves the (tunable) config id for the default tunable values for this experiment.
+ Retrieves the (tunable) config id for the default tunable values for this
+ experiment.
Note: this is by *default* the first trial executed for this experiment.
However, it is currently possible that the user changed the tunables config
@@ -136,20 +151,28 @@ class ExperimentSqlData(ExperimentData):
"""
with self._engine.connect() as conn:
query_results = conn.execute(
- self._schema.trial.select().with_only_columns(
- self._schema.trial.c.config_id.cast(Integer).label('config_id'),
- ).where(
+ self._schema.trial.select()
+ .with_only_columns(
+ self._schema.trial.c.config_id.cast(Integer).label("config_id"),
+ )
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id.in_(
- self._schema.trial_param.select().with_only_columns(
- func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
- "first_trial_id_with_defaults"),
- ).where(
+ self._schema.trial_param.select()
+ .with_only_columns(
+ func.min(self._schema.trial_param.c.trial_id)
+ .cast(Integer)
+ .label("first_trial_id_with_defaults"), # pylint: disable=not-callable
+ )
+ .where(
self._schema.trial_param.c.exp_id == self._experiment_id,
self._schema.trial_param.c.param_id == "is_defaults",
- func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]),
- ).scalar_subquery()
- )
+ func.lower(self._schema.trial_param.c.param_value, type_=String).in_(
+ ["1", "true"]
+ ),
+ )
+ .scalar_subquery()
+ ),
)
)
min_default_trial_row = query_results.fetchone()
@@ -158,17 +181,24 @@ class ExperimentSqlData(ExperimentData):
return min_default_trial_row._tuple()[0]
# fallback logic - assume minimum trial_id for experiment
query_results = conn.execute(
- self._schema.trial.select().with_only_columns(
- self._schema.trial.c.config_id.cast(Integer).label('config_id'),
- ).where(
+ self._schema.trial.select()
+ .with_only_columns(
+ self._schema.trial.c.config_id.cast(Integer).label("config_id"),
+ )
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id.in_(
- self._schema.trial.select().with_only_columns(
- func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"),
- ).where(
+ self._schema.trial.select()
+ .with_only_columns(
+ func.min(self._schema.trial.c.trial_id)
+ .cast(Integer)
+ .label("first_trial_id"),
+ )
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
- ).scalar_subquery()
- )
+ )
+ .scalar_subquery()
+ ),
)
)
min_trial_row = query_results.fetchone()
diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py
index c59adc1c67..3900568b75 100644
--- a/mlos_bench/mlos_bench/storage/sql/schema.py
+++ b/mlos_bench/mlos_bench/storage/sql/schema.py
@@ -2,17 +2,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-DB schema definition.
-"""
+"""DB schema definition."""
import logging
-from typing import List, Any
+from typing import Any, List
from sqlalchemy import (
- Engine, MetaData, Dialect, create_mock_engine,
- Table, Column, Sequence, Integer, Float, String, DateTime,
- PrimaryKeyConstraint, ForeignKeyConstraint, UniqueConstraint,
+ Column,
+ DateTime,
+ Dialect,
+ Engine,
+ Float,
+ ForeignKeyConstraint,
+ Integer,
+ MetaData,
+ PrimaryKeyConstraint,
+ Sequence,
+ String,
+ Table,
+ UniqueConstraint,
+ create_mock_engine,
)
_LOG = logging.getLogger(__name__)
@@ -38,9 +47,7 @@ class _DDL:
class DbSchema:
- """
- A class to define and create the DB schema.
- """
+ """A class to define and create the DB schema."""
# This class is internal to SqlStorage and is mostly a struct
# for all DB tables, so it's ok to disable the warnings.
@@ -53,9 +60,7 @@ class DbSchema:
_STATUS_LEN = 16
def __init__(self, engine: Engine):
- """
- Declare the SQLAlchemy schema for the database.
- """
+ """Declare the SQLAlchemy schema for the database."""
_LOG.info("Create the DB schema for: %s", engine)
self._engine = engine
# TODO: bind for automatic schema updates? (#649)
@@ -69,7 +74,6 @@ class DbSchema:
Column("root_env_config", String(1024), nullable=False),
Column("git_repo", String(1024), nullable=False),
Column("git_commit", String(40), nullable=False),
-
PrimaryKeyConstraint("exp_id"),
)
@@ -84,20 +88,29 @@ class DbSchema:
# Will need to adjust the insert and return values to support this
# eventually.
Column("weight", Float, nullable=True),
-
PrimaryKeyConstraint("exp_id", "optimization_target"),
ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
)
# A workaround for SQLAlchemy issue with autoincrement in DuckDB:
if engine.dialect.name == "duckdb":
- seq_config_id = Sequence('seq_config_id')
- col_config_id = Column("config_id", Integer, seq_config_id,
- server_default=seq_config_id.next_value(),
- nullable=False, primary_key=True)
+ seq_config_id = Sequence("seq_config_id")
+ col_config_id = Column(
+ "config_id",
+ Integer,
+ seq_config_id,
+ server_default=seq_config_id.next_value(),
+ nullable=False,
+ primary_key=True,
+ )
else:
- col_config_id = Column("config_id", Integer, nullable=False,
- primary_key=True, autoincrement=True)
+ col_config_id = Column(
+ "config_id",
+ Integer,
+ nullable=False,
+ primary_key=True,
+ autoincrement=True,
+ )
self.config = Table(
"config",
@@ -116,7 +129,6 @@ class DbSchema:
Column("ts_end", DateTime),
# Should match the text IDs of `mlos_bench.environments.Status` enum:
Column("status", String(self._STATUS_LEN), nullable=False),
-
PrimaryKeyConstraint("exp_id", "trial_id"),
ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
@@ -130,7 +142,6 @@ class DbSchema:
Column("config_id", Integer, nullable=False),
Column("param_id", String(self._ID_LEN), nullable=False),
Column("param_value", String(self._PARAM_VALUE_LEN)),
-
PrimaryKeyConstraint("config_id", "param_id"),
ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
)
@@ -144,10 +155,11 @@ class DbSchema:
Column("trial_id", Integer, nullable=False),
Column("param_id", String(self._ID_LEN), nullable=False),
Column("param_value", String(self._PARAM_VALUE_LEN)),
-
PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
- ForeignKeyConstraint(["exp_id", "trial_id"],
- [self.trial.c.exp_id, self.trial.c.trial_id]),
+ ForeignKeyConstraint(
+ ["exp_id", "trial_id"],
+ [self.trial.c.exp_id, self.trial.c.trial_id],
+ ),
)
self.trial_status = Table(
@@ -157,10 +169,11 @@ class DbSchema:
Column("trial_id", Integer, nullable=False),
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
Column("status", String(self._STATUS_LEN), nullable=False),
-
UniqueConstraint("exp_id", "trial_id", "ts"),
- ForeignKeyConstraint(["exp_id", "trial_id"],
- [self.trial.c.exp_id, self.trial.c.trial_id]),
+ ForeignKeyConstraint(
+ ["exp_id", "trial_id"],
+ [self.trial.c.exp_id, self.trial.c.trial_id],
+ ),
)
self.trial_result = Table(
@@ -170,10 +183,11 @@ class DbSchema:
Column("trial_id", Integer, nullable=False),
Column("metric_id", String(self._ID_LEN), nullable=False),
Column("metric_value", String(self._METRIC_VALUE_LEN)),
-
PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
- ForeignKeyConstraint(["exp_id", "trial_id"],
- [self.trial.c.exp_id, self.trial.c.trial_id]),
+ ForeignKeyConstraint(
+ ["exp_id", "trial_id"],
+ [self.trial.c.exp_id, self.trial.c.trial_id],
+ ),
)
self.trial_telemetry = Table(
@@ -184,26 +198,25 @@ class DbSchema:
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
Column("metric_id", String(self._ID_LEN), nullable=False),
Column("metric_value", String(self._METRIC_VALUE_LEN)),
-
UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
- ForeignKeyConstraint(["exp_id", "trial_id"],
- [self.trial.c.exp_id, self.trial.c.trial_id]),
+ ForeignKeyConstraint(
+ ["exp_id", "trial_id"],
+ [self.trial.c.exp_id, self.trial.c.trial_id],
+ ),
)
_LOG.debug("Schema: %s", self._meta)
- def create(self) -> 'DbSchema':
- """
- Create the DB schema.
- """
+ def create(self) -> "DbSchema":
+ """Create the DB schema."""
_LOG.info("Create the DB schema")
self._meta.create_all(self._engine)
return self
def __repr__(self) -> str:
"""
- Produce a string with all SQL statements required to create the schema
- from scratch in current SQL dialect.
+ Produce a string with all SQL statements required to create the schema from
+ scratch in current SQL dialect.
That is, return a collection of CREATE TABLE statements and such.
NOTE: this method is quite heavy! We use it only once at startup
diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py
index 1bfe695300..6b6d11e699 100644
--- a/mlos_bench/mlos_bench/storage/sql/storage.py
+++ b/mlos_bench/mlos_bench/storage/sql/storage.py
@@ -2,35 +2,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Saving and restoring the benchmark data in SQL database.
-"""
+"""Saving and restoring the benchmark data in SQL database."""
import logging
from typing import Dict, Literal, Optional
from sqlalchemy import URL, create_engine
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.services.base_service import Service
-from mlos_bench.storage.base_storage import Storage
-from mlos_bench.storage.sql.schema import DbSchema
-from mlos_bench.storage.sql.experiment import Experiment
from mlos_bench.storage.base_experiment_data import ExperimentData
+from mlos_bench.storage.base_storage import Storage
+from mlos_bench.storage.sql.experiment import Experiment
from mlos_bench.storage.sql.experiment_data import ExperimentSqlData
+from mlos_bench.storage.sql.schema import DbSchema
+from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
class SqlStorage(Storage):
- """
- An implementation of the Storage interface using SQLAlchemy backend.
- """
+ """An implementation of the Storage interface using SQLAlchemy backend."""
- def __init__(self,
- config: dict,
- global_config: Optional[dict] = None,
- service: Optional[Service] = None):
+ def __init__(
+ self,
+ config: dict,
+ global_config: Optional[dict] = None,
+ service: Optional[Service] = None,
+ ):
super().__init__(config, global_config, service)
lazy_schema_create = self._config.pop("lazy_schema_create", False)
self._log_sql = self._config.pop("log_sql", False)
@@ -47,7 +45,7 @@ class SqlStorage(Storage):
@property
def _schema(self) -> DbSchema:
"""Lazily create schema upon first access."""
- if not hasattr(self, '_db_schema'):
+ if not hasattr(self, "_db_schema"):
self._db_schema = DbSchema(self._engine).create()
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("DDL statements:\n%s", self._schema)
@@ -56,13 +54,16 @@ class SqlStorage(Storage):
def __repr__(self) -> str:
return self._repr
- def experiment(self, *,
- experiment_id: str,
- trial_id: int,
- root_env_config: str,
- description: str,
- tunables: TunableGroups,
- opt_targets: Dict[str, Literal['min', 'max']]) -> Storage.Experiment:
+ def experiment(
+ self,
+ *,
+ experiment_id: str,
+ trial_id: int,
+ root_env_config: str,
+ description: str,
+ tunables: TunableGroups,
+ opt_targets: Dict[str, Literal["min", "max"]],
+ ) -> Storage.Experiment:
return Experiment(
engine=self._engine,
schema=self._schema,
diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py
index 4806056e05..6c2cf26cc7 100644
--- a/mlos_bench/mlos_bench/storage/sql/trial.py
+++ b/mlos_bench/mlos_bench/storage/sql/trial.py
@@ -2,40 +2,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Saving and updating benchmark data using SQLAlchemy backend.
-"""
+"""Saving and updating benchmark data using SQLAlchemy backend."""
import logging
from datetime import datetime
-from typing import List, Literal, Optional, Tuple, Dict, Any
+from typing import Any, Dict, List, Literal, Optional, Tuple
-from sqlalchemy import Engine, Connection
+from sqlalchemy import Connection, Engine
from sqlalchemy.exc import IntegrityError
from mlos_bench.environments.status import Status
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
+from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import nullable, utcify_timestamp
_LOG = logging.getLogger(__name__)
class Trial(Storage.Trial):
- """
- Store the results of a single run of the experiment in SQL database.
- """
+ """Store the results of a single run of the experiment in SQL database."""
- def __init__(self, *,
- engine: Engine,
- schema: DbSchema,
- tunables: TunableGroups,
- experiment_id: str,
- trial_id: int,
- config_id: int,
- opt_targets: Dict[str, Literal['min', 'max']],
- config: Optional[Dict[str, Any]] = None):
+ def __init__(
+ self,
+ *,
+ engine: Engine,
+ schema: DbSchema,
+ tunables: TunableGroups,
+ experiment_id: str,
+ trial_id: int,
+ config_id: int,
+ opt_targets: Dict[str, Literal["min", "max"]],
+ config: Optional[Dict[str, Any]] = None,
+ ):
super().__init__(
tunables=tunables,
experiment_id=experiment_id,
@@ -47,9 +46,12 @@ class Trial(Storage.Trial):
self._engine = engine
self._schema = schema
- def update(self, status: Status, timestamp: datetime,
- metrics: Optional[Dict[str, Any]] = None
- ) -> Optional[Dict[str, Any]]:
+ def update(
+ self,
+ status: Status,
+ timestamp: datetime,
+ metrics: Optional[Dict[str, Any]] = None,
+ ) -> Optional[Dict[str, Any]]:
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
metrics = super().update(status, timestamp, metrics)
@@ -59,13 +61,16 @@ class Trial(Storage.Trial):
if status.is_completed():
# Final update of the status and ts_end:
cur_status = conn.execute(
- self._schema.trial.update().where(
+ self._schema.trial.update()
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id == self._trial_id,
self._schema.trial.c.ts_end.is_(None),
self._schema.trial.c.status.notin_(
- ['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']),
- ).values(
+ ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
+ ),
+ )
+ .values(
status=status.name,
ts_end=timestamp,
)
@@ -73,29 +78,37 @@ class Trial(Storage.Trial):
if cur_status.rowcount not in {1, -1}:
_LOG.warning("Trial %s :: update failed: %s", self, status)
raise RuntimeError(
- f"Failed to update the status of the trial {self} to {status}." +
- f" ({cur_status.rowcount} rows)")
+ f"Failed to update the status of the trial {self} to {status}. "
+ f"({cur_status.rowcount} rows)"
+ )
if metrics:
- conn.execute(self._schema.trial_result.insert().values([
- {
- "exp_id": self._experiment_id,
- "trial_id": self._trial_id,
- "metric_id": key,
- "metric_value": nullable(str, val),
- }
- for (key, val) in metrics.items()
- ]))
+ conn.execute(
+ self._schema.trial_result.insert().values(
+ [
+ {
+ "exp_id": self._experiment_id,
+ "trial_id": self._trial_id,
+ "metric_id": key,
+ "metric_value": nullable(str, val),
+ }
+ for (key, val) in metrics.items()
+ ]
+ )
+ )
else:
# Update of the status and ts_start when starting the trial:
assert metrics is None, f"Unexpected metrics for status: {status}"
cur_status = conn.execute(
- self._schema.trial.update().where(
+ self._schema.trial.update()
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id == self._trial_id,
self._schema.trial.c.ts_end.is_(None),
self._schema.trial.c.status.notin_(
- ['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']),
- ).values(
+ ["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
+ ),
+ )
+ .values(
status=status.name,
ts_start=timestamp,
)
@@ -108,8 +121,12 @@ class Trial(Storage.Trial):
raise
return metrics
- def update_telemetry(self, status: Status, timestamp: datetime,
- metrics: List[Tuple[datetime, str, Any]]) -> None:
+ def update_telemetry(
+ self,
+ status: Status,
+ timestamp: datetime,
+ metrics: List[Tuple[datetime, str, Any]],
+ ) -> None:
super().update_telemetry(status, timestamp, metrics)
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
@@ -120,33 +137,42 @@ class Trial(Storage.Trial):
# See Also: comments in
with self._engine.begin() as conn:
self._update_status(conn, status, timestamp)
- for (metric_ts, key, val) in metrics:
+ for metric_ts, key, val in metrics:
with self._engine.begin() as conn:
try:
- conn.execute(self._schema.trial_telemetry.insert().values(
- exp_id=self._experiment_id,
- trial_id=self._trial_id,
- ts=metric_ts,
- metric_id=key,
- metric_value=nullable(str, val),
- ))
+ conn.execute(
+ self._schema.trial_telemetry.insert().values(
+ exp_id=self._experiment_id,
+ trial_id=self._trial_id,
+ ts=metric_ts,
+ metric_id=key,
+ metric_value=nullable(str, val),
+ )
+ )
except IntegrityError as ex:
_LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
"""
Insert a new status record into the database.
+
This call is idempotent.
"""
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
try:
- conn.execute(self._schema.trial_status.insert().values(
- exp_id=self._experiment_id,
- trial_id=self._trial_id,
- ts=timestamp,
- status=status.name,
- ))
+ conn.execute(
+ self._schema.trial_status.insert().values(
+ exp_id=self._experiment_id,
+ trial_id=self._trial_id,
+ ts=timestamp,
+ status=status.name,
+ )
+ )
except IntegrityError as ex:
- _LOG.warning("Status with that timestamp already exists: %s %s :: %s",
- self, timestamp, ex)
+ _LOG.warning(
+ "Status with that timestamp already exists: %s %s :: %s",
+ self,
+ timestamp,
+ ex,
+ )
diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py
index 7353e96e79..40362b25fd 100644
--- a/mlos_bench/mlos_bench/storage/sql/trial_data.py
+++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py
@@ -2,40 +2,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-An interface to access the benchmark trial data stored in SQL DB.
-"""
+"""An interface to access the benchmark trial data stored in SQL DB."""
from datetime import datetime
-from typing import Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
import pandas
from sqlalchemy import Engine
+from mlos_bench.environments.status import Status
from mlos_bench.storage.base_trial_data import TrialData
from mlos_bench.storage.base_tunable_config_data import TunableConfigData
-from mlos_bench.environments.status import Status
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
from mlos_bench.util import utcify_timestamp
if TYPE_CHECKING:
- from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
+ from mlos_bench.storage.base_tunable_config_trial_group_data import (
+ TunableConfigTrialGroupData,
+ )
class TrialSqlData(TrialData):
- """
- An interface to access the trial data stored in the SQL DB.
- """
+ """An interface to access the trial data stored in the SQL DB."""
- def __init__(self, *,
- engine: Engine,
- schema: DbSchema,
- experiment_id: str,
- trial_id: int,
- config_id: int,
- ts_start: datetime,
- ts_end: Optional[datetime],
- status: Status):
+ def __init__(
+ self,
+ *,
+ engine: Engine,
+ schema: DbSchema,
+ experiment_id: str,
+ trial_id: int,
+ config_id: int,
+ ts_start: datetime,
+ ts_end: Optional[datetime],
+ status: Status,
+ ):
super().__init__(
experiment_id=experiment_id,
trial_id=trial_id,
@@ -54,49 +55,59 @@ class TrialSqlData(TrialData):
Note: this corresponds to the Trial object's "tunables" property.
"""
- return TunableConfigSqlData(engine=self._engine, schema=self._schema,
- tunable_config_id=self._tunable_config_id)
+ return TunableConfigSqlData(
+ engine=self._engine,
+ schema=self._schema,
+ tunable_config_id=self._tunable_config_id,
+ )
@property
def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData":
- """
- Retrieve the trial's tunable config group configuration data from the storage.
+ """Retrieve the trial's tunable config group configuration data from the
+ storage.
"""
# pylint: disable=import-outside-toplevel
- from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData
- return TunableConfigTrialGroupSqlData(engine=self._engine, schema=self._schema,
- experiment_id=self._experiment_id,
- tunable_config_id=self._tunable_config_id)
+ from mlos_bench.storage.sql.tunable_config_trial_group_data import (
+ TunableConfigTrialGroupSqlData,
+ )
+
+ return TunableConfigTrialGroupSqlData(
+ engine=self._engine,
+ schema=self._schema,
+ experiment_id=self._experiment_id,
+ tunable_config_id=self._tunable_config_id,
+ )
@property
def results_df(self) -> pandas.DataFrame:
- """
- Retrieve the trials' results from the storage.
- """
+ """Retrieve the trials' results from the storage."""
with self._engine.connect() as conn:
cur_results = conn.execute(
- self._schema.trial_result.select().where(
+ self._schema.trial_result.select()
+ .where(
self._schema.trial_result.c.exp_id == self._experiment_id,
- self._schema.trial_result.c.trial_id == self._trial_id
- ).order_by(
+ self._schema.trial_result.c.trial_id == self._trial_id,
+ )
+ .order_by(
self._schema.trial_result.c.metric_id,
)
)
return pandas.DataFrame(
[(row.metric_id, row.metric_value) for row in cur_results.fetchall()],
- columns=['metric', 'value'])
+ columns=["metric", "value"],
+ )
@property
def telemetry_df(self) -> pandas.DataFrame:
- """
- Retrieve the trials' telemetry from the storage.
- """
+ """Retrieve the trials' telemetry from the storage."""
with self._engine.connect() as conn:
cur_telemetry = conn.execute(
- self._schema.trial_telemetry.select().where(
+ self._schema.trial_telemetry.select()
+ .where(
self._schema.trial_telemetry.c.exp_id == self._experiment_id,
- self._schema.trial_telemetry.c.trial_id == self._trial_id
- ).order_by(
+ self._schema.trial_telemetry.c.trial_id == self._trial_id,
+ )
+ .order_by(
self._schema.trial_telemetry.c.ts,
self._schema.trial_telemetry.c.metric_id,
)
@@ -104,8 +115,12 @@ class TrialSqlData(TrialData):
# Not all storage backends store the original zone info.
# We try to ensure data is entered in UTC and augment it on return again here.
return pandas.DataFrame(
- [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()],
- columns=['ts', 'metric', 'value'])
+ [
+ (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
+ for row in cur_telemetry.fetchall()
+ ],
+ columns=["ts", "metric", "value"],
+ )
@property
def metadata_df(self) -> pandas.DataFrame:
@@ -116,13 +131,16 @@ class TrialSqlData(TrialData):
"""
with self._engine.connect() as conn:
cur_params = conn.execute(
- self._schema.trial_param.select().where(
+ self._schema.trial_param.select()
+ .where(
self._schema.trial_param.c.exp_id == self._experiment_id,
- self._schema.trial_param.c.trial_id == self._trial_id
- ).order_by(
+ self._schema.trial_param.c.trial_id == self._trial_id,
+ )
+ .order_by(
self._schema.trial_param.c.param_id,
)
)
return pandas.DataFrame(
[(row.param_id, row.param_value) for row in cur_params.fetchall()],
- columns=['parameter', 'value'])
+ columns=["parameter", "value"],
+ )
diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py
index e484979790..40225039be 100644
--- a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py
+++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-An interface to access the tunable config data stored in SQL DB.
-"""
+"""An interface to access the tunable config data stored in SQL DB."""
import pandas
from sqlalchemy import Engine
@@ -20,10 +18,7 @@ class TunableConfigSqlData(TunableConfigData):
A configuration in this context is the set of tunable parameter values.
"""
- def __init__(self, *,
- engine: Engine,
- schema: DbSchema,
- tunable_config_id: int):
+ def __init__(self, *, engine: Engine, schema: DbSchema, tunable_config_id: int):
super().__init__(tunable_config_id=tunable_config_id)
self._engine = engine
self._schema = schema
@@ -32,12 +27,13 @@ class TunableConfigSqlData(TunableConfigData):
def config_df(self) -> pandas.DataFrame:
with self._engine.connect() as conn:
cur_config = conn.execute(
- self._schema.config_param.select().where(
- self._schema.config_param.c.config_id == self._tunable_config_id
- ).order_by(
+ self._schema.config_param.select()
+ .where(self._schema.config_param.c.config_id == self._tunable_config_id)
+ .order_by(
self._schema.config_param.c.param_id,
)
)
return pandas.DataFrame(
[(row.param_id, row.param_value) for row in cur_config.fetchall()],
- columns=['parameter', 'value'])
+ columns=["parameter", "value"],
+ )
diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py
index 775683133d..5069e435b2 100644
--- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py
+++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py
@@ -2,17 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-An interface to access the tunable config trial group data stored in SQL DB.
-"""
+"""An interface to access the tunable config trial group data stored in SQL DB."""
-from typing import Dict, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Dict, Optional
import pandas
from sqlalchemy import Engine, Integer, func
from mlos_bench.storage.base_tunable_config_data import TunableConfigData
-from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
+from mlos_bench.storage.base_tunable_config_trial_group_data import (
+ TunableConfigTrialGroupData,
+)
from mlos_bench.storage.sql import common
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
@@ -23,20 +23,23 @@ if TYPE_CHECKING:
class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData):
"""
- SQL interface for accessing the stored experiment benchmark tunable config
- trial group data.
+ SQL interface for accessing the stored experiment benchmark tunable config trial
+ group data.
A (tunable) config is used to define an instance of values for a set of tunable
parameters for a given experiment and can be used by one or more trial instances
(e.g., for repeats), which we call a (tunable) config trial group.
"""
- def __init__(self, *,
- engine: Engine,
- schema: DbSchema,
- experiment_id: str,
- tunable_config_id: int,
- tunable_config_trial_group_id: Optional[int] = None):
+ def __init__(
+ self,
+ *,
+ engine: Engine,
+ schema: DbSchema,
+ experiment_id: str,
+ tunable_config_id: int,
+ tunable_config_trial_group_id: Optional[int] = None,
+ ):
super().__init__(
experiment_id=experiment_id,
tunable_config_id=tunable_config_id,
@@ -46,25 +49,28 @@ class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData):
self._schema = schema
def _get_tunable_config_trial_group_id(self) -> int:
- """
- Retrieve the trial's tunable_config_trial_group_id from the storage.
- """
+ """Retrieve the trial's tunable_config_trial_group_id from the storage."""
with self._engine.connect() as conn:
tunable_config_trial_group = conn.execute(
- self._schema.trial.select().with_only_columns(
- func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
- 'tunable_config_trial_group_id'),
- ).where(
+ self._schema.trial.select()
+ .with_only_columns(
+ func.min(self._schema.trial.c.trial_id)
+ .cast(Integer)
+ .label("tunable_config_trial_group_id"), # pylint: disable=not-callable
+ )
+ .where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.config_id == self._tunable_config_id,
- ).group_by(
+ )
+ .group_by(
self._schema.trial.c.exp_id,
self._schema.trial.c.config_id,
)
)
row = tunable_config_trial_group.fetchone()
assert row is not None
- return row._tuple()[0] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
+ # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
+ return row._tuple()[0]
@property
def tunable_config(self) -> TunableConfigData:
@@ -77,15 +83,26 @@ class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData):
@property
def trials(self) -> Dict[int, "TrialData"]:
"""
- Retrieve the trials' data for this (tunable) config trial group from the storage.
+ Retrieve the trials' data for this (tunable) config trial group from the
+ storage.
Returns
-------
trials : Dict[int, TrialData]
A dictionary of the trials' data, keyed by trial id.
"""
- return common.get_trials(self._engine, self._schema, self._experiment_id, self._tunable_config_id)
+ return common.get_trials(
+ self._engine,
+ self._schema,
+ self._experiment_id,
+ self._tunable_config_id,
+ )
@property
def results_df(self) -> pandas.DataFrame:
- return common.get_results_df(self._engine, self._schema, self._experiment_id, self._tunable_config_id)
+ return common.get_results_df(
+ self._engine,
+ self._schema,
+ self._experiment_id,
+ self._tunable_config_id,
+ )
diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py
index faa934b28f..ea0201717d 100644
--- a/mlos_bench/mlos_bench/storage/storage_factory.py
+++ b/mlos_bench/mlos_bench/storage/storage_factory.py
@@ -2,20 +2,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Factory method to create a new Storage instance from configs.
-"""
+"""Factory method to create a new Storage instance from configs."""
-from typing import Any, Optional, List, Dict
+from typing import Any, Dict, List, Optional
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.storage.base_storage import Storage
-def from_config(config_file: str,
- global_configs: Optional[List[str]] = None,
- **kwargs: Any) -> Storage:
+def from_config(
+ config_file: str,
+ global_configs: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> Storage:
"""
Create a new storage object from JSON5 config file.
@@ -36,7 +36,7 @@ def from_config(config_file: str,
config_path: List[str] = kwargs.get("config_path", [])
config_loader = ConfigPersistenceService({"config_path": config_path})
global_config = {}
- for fname in (global_configs or []):
+ for fname in global_configs or []:
config = config_loader.load_config(fname, ConfigSchema.GLOBALS)
global_config.update(config)
config_path += config.get("config_path", [])
diff --git a/mlos_bench/mlos_bench/storage/util.py b/mlos_bench/mlos_bench/storage/util.py
index a4610da8de..173f7d95d6 100644
--- a/mlos_bench/mlos_bench/storage/util.py
+++ b/mlos_bench/mlos_bench/storage/util.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Utility functions for the storage subsystem.
-"""
+"""Utility functions for the storage subsystem."""
from typing import Dict, Optional
@@ -25,16 +23,18 @@ def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValu
A dataframe with exactly two columns, 'parameter' (or 'metric') and 'value', where
'parameter' is a string and 'value' is some TunableValue or None.
"""
- if dataframe.columns.tolist() == ['metric', 'value']:
+ if dataframe.columns.tolist() == ["metric", "value"]:
dataframe = dataframe.copy()
- dataframe.rename(columns={'metric': 'parameter'}, inplace=True)
- assert dataframe.columns.tolist() == ['parameter', 'value']
+ dataframe.rename(columns={"metric": "parameter"}, inplace=True)
+ assert dataframe.columns.tolist() == ["parameter", "value"]
data = {}
- for _, row in dataframe.astype('O').iterrows():
- if not isinstance(row['value'], TunableValueTypeTuple):
+ for _, row in dataframe.astype("O").iterrows():
+ if not isinstance(row["value"], TunableValueTypeTuple):
raise TypeError(f"Invalid column type: {type(row['value'])} value: {row['value']}")
- assert isinstance(row['parameter'], str)
- if row['parameter'] in data:
+ assert isinstance(row["parameter"], str)
+ if row["parameter"] in data:
raise ValueError(f"Duplicate parameter '{row['parameter']}' in dataframe")
- data[row['parameter']] = try_parse_val(row['value']) if isinstance(row['value'], str) else row['value']
+ data[row["parameter"]] = (
+ try_parse_val(row["value"]) if isinstance(row["value"], str) else row["value"]
+ )
return data
diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py
index d1c1781ada..8737057665 100644
--- a/mlos_bench/mlos_bench/tests/__init__.py
+++ b/mlos_bench/mlos_bench/tests/__init__.py
@@ -4,24 +4,23 @@
#
"""
Tests for mlos_bench.
+
Used to make mypy happy about multiple conftest.py modules.
"""
+import filecmp
+import os
+import shutil
+import socket
from datetime import tzinfo
from logging import debug, warning
from subprocess import run
from typing import List, Optional
-import filecmp
-import os
-import socket
-import shutil
-
-import pytz
import pytest
+import pytz
from mlos_bench.util import get_class_from_name, nullable
-
ZONE_NAMES = [
# Explicit time zones.
"UTC",
@@ -31,26 +30,35 @@ ZONE_NAMES = [
None,
]
ZONE_INFO: List[Optional[tzinfo]] = [
- nullable(pytz.timezone, zone_name)
- for zone_name in ZONE_NAMES
+ nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES
]
# A decorator for tests that require docker.
# Use with @requires_docker above a test_...() function.
-DOCKER = shutil.which('docker')
+DOCKER = shutil.which("docker")
if DOCKER:
- cmd = run("docker builder inspect default || docker buildx inspect default", shell=True, check=False, capture_output=True)
+ cmd = run(
+ "docker builder inspect default || docker buildx inspect default",
+ shell=True,
+ check=False,
+ capture_output=True,
+ )
stdout = cmd.stdout.decode()
- if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line):
+ if cmd.returncode != 0 or not any(
+ line for line in stdout.splitlines() if "Platform" in line and "linux" in line
+ ):
debug("Docker is available but missing support for targeting linux platform.")
DOCKER = None
-requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.')
+requires_docker = pytest.mark.skipif(
+ not DOCKER,
+ reason="Docker with Linux support is not available on this system.",
+)
# A decorator for tests that require ssh.
# Use with @requires_ssh above a test_...() function.
-SSH = shutil.which('ssh')
-requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.')
+SSH = shutil.which("ssh")
+requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.")
# A common seed to use to avoid tracking down race conditions and intermingling
# issues of seeds across tests that run in non-deterministic parallel orders.
@@ -61,9 +69,7 @@ SEED = 42
def try_resolve_class_name(class_name: Optional[str]) -> Optional[str]:
- """
- Gets the full class name from the given name or None on error.
- """
+ """Gets the full class name from the given name or None on error."""
if class_name is None:
return None
try:
@@ -74,9 +80,7 @@ def try_resolve_class_name(class_name: Optional[str]) -> Optional[str]:
def check_class_name(obj: object, expected_class_name: str) -> bool:
- """
- Compares the class name of the given object with the given name.
- """
+ """Compares the class name of the given object with the given name."""
full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__
return full_class_name == try_resolve_class_name(expected_class_name)
@@ -121,20 +125,27 @@ def resolve_host_name(host: str) -> Optional[str]:
def are_dir_trees_equal(dir1: str, dir2: str) -> bool:
"""
- Compare two directories recursively. Files in each directory are
- assumed to be equal if their names and contents are equal.
+ Compare two directories recursively. Files in each directory are assumed to be equal
+ if their names and contents are equal.
- @param dir1: First directory path
- @param dir2: Second directory path
+ @param dir1: First directory path @param dir2: Second directory path
- @return: True if the directory trees are the same and
- there were no errors while accessing the directories or files,
- False otherwise.
+ @return: True if the directory trees are the same and there were no errors while
+ accessing the directories or files, False otherwise.
"""
# See Also: https://stackoverflow.com/a/6681395
dirs_cmp = filecmp.dircmp(dir1, dir2)
- if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0:
- warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}")
+ if (
+ len(dirs_cmp.left_only) > 0
+ or len(dirs_cmp.right_only) > 0
+ or len(dirs_cmp.funny_files) > 0
+ ):
+ warning(
+ (
+ f"Found differences in dir trees {dir1}, {dir2}:\n"
+ f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}"
+ )
+ )
return False
(_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False)
if len(mismatch) > 0 or len(errors) > 0:
diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py
index 75b2ae0cbe..b2b6146a56 100644
--- a/mlos_bench/mlos_bench/tests/config/__init__.py
+++ b/mlos_bench/mlos_bench/tests/config/__init__.py
@@ -2,14 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Helper functions for config example loading tests.
-"""
-
-from typing import Callable, List, Optional
+"""Helper functions for config example loading tests."""
import os
import sys
+from typing import Callable, List, Optional
from mlos_bench.util import path_join
@@ -22,10 +19,13 @@ else:
BUILTIN_TEST_CONFIG_PATH = str(files("mlos_bench.tests.config").joinpath("")).replace("\\", "/")
-def locate_config_examples(root_dir: str,
- config_examples_dir: str,
- examples_filter: Optional[Callable[[List[str]], List[str]]] = None) -> List[str]:
- """Locates all config examples in the given directory.
+def locate_config_examples(
+ root_dir: str,
+ config_examples_dir: str,
+ examples_filter: Optional[Callable[[List[str]], List[str]]] = None,
+) -> List[str]:
+ """
+ Locates all config examples in the given directory.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py
index 6fb341ff44..3db11e6cb2 100644
--- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py
+++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py
@@ -2,27 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for loading storage config examples.
-"""
-
-from typing import List
+"""Tests for loading storage config examples."""
import logging
import sys
+from typing import List
import pytest
-from mlos_bench.tests import check_class_name
-from mlos_bench.tests.config import locate_config_examples, BUILTIN_TEST_CONFIG_PATH
-
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.environments import Environment
+from mlos_bench.launcher import Launcher
from mlos_bench.optimizers import Optimizer
-from mlos_bench.storage import Storage
from mlos_bench.schedulers import Scheduler
from mlos_bench.services.config_persistence import ConfigPersistenceService
-from mlos_bench.launcher import Launcher
+from mlos_bench.storage import Storage
+from mlos_bench.tests import check_class_name
+from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples
from mlos_bench.util import path_join
if sys.version_info < (3, 10):
@@ -45,15 +41,26 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]:
configs = [
- *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs),
- *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs),
+ *locate_config_examples(
+ ConfigPersistenceService.BUILTIN_CONFIG_PATH,
+ CONFIG_TYPE,
+ filter_configs,
+ ),
+ *locate_config_examples(
+ BUILTIN_TEST_CONFIG_PATH,
+ CONFIG_TYPE,
+ filter_configs,
+ ),
]
assert configs
@pytest.mark.skip(reason="Use full Launcher test (below) instead now.")
@pytest.mark.parametrize("config_path", configs)
-def test_load_cli_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: # pragma: no cover
+def test_load_cli_config_examples(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> None: # pragma: no cover
"""Tests loading a config example."""
# pylint: disable=too-complex
config = config_loader_service.load_config(config_path, ConfigSchema.CLI)
@@ -63,7 +70,7 @@ def test_load_cli_config_examples(config_loader_service: ConfigPersistenceServic
assert isinstance(config_paths, list)
config_paths.reverse()
for path in config_paths:
- config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access
+ config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access
# Foreach arg that references another file, see if we can at least load that too.
args_to_skip = {
@@ -100,7 +107,10 @@ def test_load_cli_config_examples(config_loader_service: ConfigPersistenceServic
@pytest.mark.parametrize("config_path", configs)
-def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
+def test_load_cli_config_examples_via_launcher(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> None:
"""Tests loading a config example via the Launcher."""
config = config_loader_service.load_config(config_path, ConfigSchema.CLI)
assert isinstance(config, dict)
@@ -108,10 +118,13 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers
# Try to load the CLI config by instantiating a launcher.
# To do this we need to make sure to give it a few extra paths and globals
# to look for for our examples.
- cli_args = f"--config {config_path}" + \
- f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" + \
- f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" + \
+ cli_args = (
+ f"--config {config_path}"
+ f" --config-path {files('mlos_bench.config')} "
+ f" --config-path {files('mlos_bench.tests.config')}"
+ f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}"
f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc"
+ )
launcher = Launcher(description=__name__, long_text=config_path, argv=cli_args.split())
assert launcher
@@ -122,15 +135,16 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers
assert isinstance(config_paths, list)
for path in config_paths:
# Note: Checks that the order is maintained are handled in launcher_parse_args.py
- assert any(config_path.endswith(path) for config_path in launcher.config_loader.config_paths), \
- f"Expected {path} to be in {launcher.config_loader.config_paths}"
+ assert any(
+ config_path.endswith(path) for config_path in launcher.config_loader.config_paths
+ ), f"Expected {path} to be in {launcher.config_loader.config_paths}"
- if 'experiment_id' in config:
- assert launcher.global_config['experiment_id'] == config['experiment_id']
- if 'trial_id' in config:
- assert launcher.global_config['trial_id'] == config['trial_id']
+ if "experiment_id" in config:
+ assert launcher.global_config["experiment_id"] == config["experiment_id"]
+ if "trial_id" in config:
+ assert launcher.global_config["trial_id"] == config["trial_id"]
- expected_log_level = logging.getLevelName(config.get('log_level', "INFO"))
+ expected_log_level = logging.getLevelName(config.get("log_level", "INFO"))
if isinstance(expected_log_level, int):
expected_log_level = logging.getLevelName(expected_log_level)
current_log_level = logging.getLevelName(logging.root.getEffectiveLevel())
@@ -138,7 +152,7 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers
# TODO: Check that the log_file handler is set correctly.
- expected_teardown = config.get('teardown', True)
+ expected_teardown = config.get("teardown", True)
assert launcher.teardown == expected_teardown
# Note: Testing of "globals" processing handled in launcher_parse_args_test.py
@@ -147,22 +161,34 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers
# Launcher loaded the expected types as well.
assert isinstance(launcher.environment, Environment)
- env_config = launcher.config_loader.load_config(config["environment"], ConfigSchema.ENVIRONMENT)
+ env_config = launcher.config_loader.load_config(
+ config["environment"],
+ ConfigSchema.ENVIRONMENT,
+ )
assert check_class_name(launcher.environment, env_config["class"])
assert isinstance(launcher.optimizer, Optimizer)
if "optimizer" in config:
- opt_config = launcher.config_loader.load_config(config["optimizer"], ConfigSchema.OPTIMIZER)
+ opt_config = launcher.config_loader.load_config(
+ config["optimizer"],
+ ConfigSchema.OPTIMIZER,
+ )
assert check_class_name(launcher.optimizer, opt_config["class"])
assert isinstance(launcher.storage, Storage)
if "storage" in config:
- storage_config = launcher.config_loader.load_config(config["storage"], ConfigSchema.STORAGE)
+ storage_config = launcher.config_loader.load_config(
+ config["storage"],
+ ConfigSchema.STORAGE,
+ )
assert check_class_name(launcher.storage, storage_config["class"])
assert isinstance(launcher.scheduler, Scheduler)
if "scheduler" in config:
- scheduler_config = launcher.config_loader.load_config(config["scheduler"], ConfigSchema.SCHEDULER)
+ scheduler_config = launcher.config_loader.load_config(
+ config["scheduler"],
+ ConfigSchema.SCHEDULER,
+ )
assert check_class_name(launcher.scheduler, scheduler_config["class"])
# TODO: Check that the launcher assigns the tunables values as expected.
diff --git a/mlos_bench/mlos_bench/tests/config/conftest.py b/mlos_bench/mlos_bench/tests/config/conftest.py
index fdcb3370cf..5f9167dc85 100644
--- a/mlos_bench/mlos_bench/tests/config/conftest.py
+++ b/mlos_bench/mlos_bench/tests/config/conftest.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Test fixtures for mlos_bench config loader tests.
-"""
+"""Test fixtures for mlos_bench config loader tests."""
import sys
@@ -22,9 +20,11 @@ else:
@pytest.fixture
def config_loader_service() -> ConfigPersistenceService:
"""Config loader service fixture."""
- return ConfigPersistenceService(config={
- "config_path": [
- str(files("mlos_bench.tests.config")),
- path_join(str(files("mlos_bench.tests.config")), "globals"),
- ]
- })
+ return ConfigPersistenceService(
+ config={
+ "config_path": [
+ str(files("mlos_bench.tests.config")),
+ path_join(str(files("mlos_bench.tests.config")), "globals"),
+ ]
+ }
+ )
diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py
index d2f975fd85..fe5e651d95 100644
--- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py
+++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py
@@ -2,23 +2,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for loading environment config examples.
-"""
+"""Tests for loading environment config examples."""
import logging
from typing import List
import pytest
-from mlos_bench.tests.config import locate_config_examples
-
from mlos_bench.config.schemas.config_schemas import ConfigSchema
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.composite_env import CompositeEnv
from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.tests.config import locate_config_examples
from mlos_bench.tunables.tunable_groups import TunableGroups
-
_LOG = logging.getLogger(__name__)
_LOG.setLevel(logging.DEBUG)
@@ -29,16 +25,27 @@ CONFIG_TYPE = "environments"
def filter_configs(configs_to_filter: List[str]) -> List[str]:
"""If necessary, filter out json files that aren't for the module we're testing."""
- configs_to_filter = [config_path for config_path in configs_to_filter if not config_path.endswith("-tunables.jsonc")]
+ configs_to_filter = [
+ config_path
+ for config_path in configs_to_filter
+ if not config_path.endswith("-tunables.jsonc")
+ ]
return configs_to_filter
-configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs)
+configs = locate_config_examples(
+ ConfigPersistenceService.BUILTIN_CONFIG_PATH,
+ CONFIG_TYPE,
+ filter_configs,
+)
assert configs
@pytest.mark.parametrize("config_path", configs)
-def test_load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
+def test_load_environment_config_examples(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> None:
"""Tests loading an environment config example."""
envs = load_environment_config_examples(config_loader_service, config_path)
for env in envs:
@@ -46,11 +53,17 @@ def test_load_environment_config_examples(config_loader_service: ConfigPersisten
assert isinstance(env, Environment)
-def load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> List[Environment]:
+def load_environment_config_examples(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> List[Environment]:
"""Loads an environment config example."""
# Make sure that any "required_args" are provided.
- global_config = config_loader_service.load_config("experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS)
- global_config.setdefault('trial_id', 1) # normally populated by Launcher
+ global_config = config_loader_service.load_config(
+ "experiments/experiment_test_config.jsonc",
+ ConfigSchema.GLOBALS,
+ )
+ global_config.setdefault("trial_id", 1) # normally populated by Launcher
# Make sure we have the required services for the envs being used.
mock_service_configs = [
@@ -62,24 +75,41 @@ def load_environment_config_examples(config_loader_service: ConfigPersistenceSer
"services/remote/mock/mock_auth_service.jsonc",
]
- tunable_groups = TunableGroups() # base tunable groups that all others get built on
+ tunable_groups = TunableGroups() # base tunable groups that all others get built on
for mock_service_config_path in mock_service_configs:
- mock_service_config = config_loader_service.load_config(mock_service_config_path, ConfigSchema.SERVICE)
- config_loader_service.register(config_loader_service.build_service(
- config=mock_service_config, parent=config_loader_service).export())
+ mock_service_config = config_loader_service.load_config(
+ mock_service_config_path,
+ ConfigSchema.SERVICE,
+ )
+ config_loader_service.register(
+ config_loader_service.build_service(
+ config=mock_service_config,
+ parent=config_loader_service,
+ ).export()
+ )
envs = config_loader_service.load_environment_list(
- config_path, tunable_groups, global_config, service=config_loader_service)
+ config_path,
+ tunable_groups,
+ global_config,
+ service=config_loader_service,
+ )
return envs
-composite_configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/")
+composite_configs = locate_config_examples(
+ ConfigPersistenceService.BUILTIN_CONFIG_PATH,
+ "environments/root/",
+)
assert composite_configs
@pytest.mark.parametrize("config_path", composite_configs)
-def test_load_composite_env_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
+def test_load_composite_env_config_examples(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> None:
"""Tests loading a composite env config example."""
envs = load_environment_config_examples(config_loader_service, config_path)
assert len(envs) == 1
@@ -92,17 +122,22 @@ def test_load_composite_env_config_examples(config_loader_service: ConfigPersist
assert child_env.tunable_params is not None
checked_child_env_groups = set()
- for (child_tunable, child_group) in child_env.tunable_params:
+ for child_tunable, child_group in child_env.tunable_params:
# Lookup that tunable in the composite env.
assert child_tunable in composite_env.tunable_params
- (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(child_tunable)
- assert child_tunable is composite_tunable # Check that the tunables are the same object.
+ (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(
+ child_tunable
+ )
+ # Check that the tunables are the same object.
+ assert child_tunable is composite_tunable
if child_group.name not in checked_child_env_groups:
assert child_group is composite_group
checked_child_env_groups.add(child_group.name)
- # Check that when we change a child env, it's value is reflected in the composite env as well.
- # That is to say, they refer to the same objects, despite having potentially been loaded from separate configs.
+ # Check that when we change a child env, it's value is reflected in the
+ # composite env as well.
+ # That is to say, they refer to the same objects, despite having
+ # potentially been loaded from separate configs.
if child_tunable.is_categorical:
old_cat_value = child_tunable.category
assert child_tunable.value == old_cat_value
diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py
index 4d3a2602b5..5940962478 100644
--- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py
+++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py
@@ -2,19 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for loading globals config examples.
-"""
+"""Tests for loading globals config examples."""
import logging
from typing import List
import pytest
-from mlos_bench.tests.config import locate_config_examples, BUILTIN_TEST_CONFIG_PATH
-
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.config_persistence import ConfigPersistenceService
-
+from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples
_LOG = logging.getLogger(__name__)
_LOG.setLevel(logging.DEBUG)
@@ -30,16 +26,35 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]:
configs = [
- # *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs),
- *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs),
- *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs),
- *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, "experiments", filter_configs),
+ # *locate_config_examples(
+ # ConfigPersistenceService.BUILTIN_CONFIG_PATH,
+ # CONFIG_TYPE,
+ # filter_configs,
+ # ),
+ *locate_config_examples(
+ ConfigPersistenceService.BUILTIN_CONFIG_PATH,
+ "experiments",
+ filter_configs,
+ ),
+ *locate_config_examples(
+ BUILTIN_TEST_CONFIG_PATH,
+ CONFIG_TYPE,
+ filter_configs,
+ ),
+ *locate_config_examples(
+ BUILTIN_TEST_CONFIG_PATH,
+ "experiments",
+ filter_configs,
+ ),
]
assert configs
@pytest.mark.parametrize("config_path", configs)
-def test_load_globals_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
+def test_load_globals_config_examples(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> None:
"""Tests loading a config example."""
config = config_loader_service.load_config(config_path, ConfigSchema.GLOBALS)
assert isinstance(config, dict)
diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py
index bd9099b608..4feefb8440 100644
--- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py
+++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py
@@ -2,23 +2,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for loading optimizer config examples.
-"""
+"""Tests for loading optimizer config examples."""
import logging
from typing import List
import pytest
-from mlos_bench.tests.config import locate_config_examples
-
from mlos_bench.config.schemas import ConfigSchema
-from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.optimizers.base_optimizer import Optimizer
+from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.tests.config import locate_config_examples
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import get_class_from_name
-
_LOG = logging.getLogger(__name__)
_LOG.setLevel(logging.DEBUG)
@@ -32,12 +28,19 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]:
return configs_to_filter
-configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs)
+configs = locate_config_examples(
+ ConfigPersistenceService.BUILTIN_CONFIG_PATH,
+ CONFIG_TYPE,
+ filter_configs,
+)
assert configs
@pytest.mark.parametrize("config_path", configs)
-def test_load_optimizer_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
+def test_load_optimizer_config_examples(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> None:
"""Tests loading a config example."""
config = config_loader_service.load_config(config_path, ConfigSchema.OPTIMIZER)
assert isinstance(config, dict)
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py
index 7e6edacbb3..02cdc4fdee 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py
@@ -2,16 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Common tests for config schemas and their validation and test cases.
-"""
+"""Common tests for config schemas and their validation and test cases."""
+import os
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, Set
-import os
-
import json5
import jsonschema
import pytest
@@ -23,9 +20,7 @@ from mlos_bench.tests.config import locate_config_examples
# A dataclass to make pylint happy.
@dataclass
class SchemaTestType:
- """
- The different type of schema test cases we expect to have.
- """
+ """The different type of schema test cases we expect to have."""
test_case_type: str
test_case_subtypes: Set[str]
@@ -35,17 +30,18 @@ class SchemaTestType:
# The different type of schema test cases we expect to have.
-_SCHEMA_TEST_TYPES = {x.test_case_type: x for x in (
- SchemaTestType(test_case_type='good', test_case_subtypes={'full', 'partial'}),
- SchemaTestType(test_case_type='bad', test_case_subtypes={'invalid', 'unhandled'}),
-)}
+_SCHEMA_TEST_TYPES = {
+ x.test_case_type: x
+ for x in (
+ SchemaTestType(test_case_type="good", test_case_subtypes={"full", "partial"}),
+ SchemaTestType(test_case_type="bad", test_case_subtypes={"invalid", "unhandled"}),
+ )
+}
@dataclass
-class SchemaTestCaseInfo():
- """
- Some basic info about a schema test case.
- """
+class SchemaTestCaseInfo:
+ """Some basic info about a schema test case."""
config: Dict[str, Any]
test_case_file: str
@@ -57,27 +53,27 @@ class SchemaTestCaseInfo():
def check_schema_dir_layout(test_cases_root: str) -> None:
- """
- Makes sure the directory layout matches what we expect so we aren't missing
- any extra configs or test cases.
+ """Makes sure the directory layout matches what we expect so we aren't missing any
+ extra configs or test cases.
"""
for test_case_dir in os.listdir(test_cases_root):
- if test_case_dir == 'README.md':
+ if test_case_dir == "README.md":
continue
if test_case_dir not in _SCHEMA_TEST_TYPES:
raise NotImplementedError(f"Unhandled test case type: {test_case_dir}")
for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)):
- if test_case_subdir == 'README.md':
+ if test_case_subdir == "README.md":
continue
if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes:
- raise NotImplementedError(f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}")
+ raise NotImplementedError(
+ f"Unhandled test case subtype {test_case_subdir} "
+ f"for test case type {test_case_dir}"
+ )
@dataclass
class TestCases:
- """
- A container for test cases by type.
- """
+ """A container for test cases by type."""
by_path: Dict[str, SchemaTestCaseInfo]
by_type: Dict[str, Dict[str, SchemaTestCaseInfo]]
@@ -85,18 +81,22 @@ class TestCases:
def get_schema_test_cases(test_cases_root: str) -> TestCases:
- """
- Gets a dict of schema test cases from the given root.
- """
- test_cases = TestCases(by_path={},
- by_type={x: {} for x in _SCHEMA_TEST_TYPES},
- by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes})
+ """Gets a dict of schema test cases from the given root."""
+ test_cases = TestCases(
+ by_path={},
+ by_type={x: {} for x in _SCHEMA_TEST_TYPES},
+ by_subtype={
+ y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes
+ },
+ )
check_schema_dir_layout(test_cases_root)
# Note: we sort the test cases so that we can deterministically test them in parallel.
- for (test_case_type, schema_test_type) in _SCHEMA_TEST_TYPES.items():
+ for test_case_type, schema_test_type in _SCHEMA_TEST_TYPES.items():
for test_case_subtype in schema_test_type.test_case_subtypes:
- for test_case_file in locate_config_examples(test_cases_root, os.path.join(test_case_type, test_case_subtype)):
- with open(test_case_file, mode='r', encoding='utf-8') as test_case_fh:
+ for test_case_file in locate_config_examples(
+ test_cases_root, os.path.join(test_case_type, test_case_subtype)
+ ):
+ with open(test_case_file, mode="r", encoding="utf-8") as test_case_fh:
try:
test_case_info = SchemaTestCaseInfo(
config=json5.load(test_case_fh),
@@ -105,8 +105,12 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases:
test_case_subtype=test_case_subtype,
)
test_cases.by_path[test_case_info.test_case_file] = test_case_info
- test_cases.by_type[test_case_info.test_case_type][test_case_info.test_case_file] = test_case_info
- test_cases.by_subtype[test_case_info.test_case_subtype][test_case_info.test_case_file] = test_case_info
+ test_cases.by_type[test_case_info.test_case_type][
+ test_case_info.test_case_file
+ ] = test_case_info
+ test_cases.by_subtype[test_case_info.test_case_subtype][
+ test_case_info.test_case_file
+ ] = test_case_info
except Exception as ex:
raise RuntimeError("Failed to load test case: " + test_case_file) from ex
assert test_cases
@@ -118,7 +122,10 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases:
return test_cases
-def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None:
+def check_test_case_against_schema(
+ test_case: SchemaTestCaseInfo,
+ schema_type: ConfigSchema,
+) -> None:
"""
Checks the given test case against the given schema.
@@ -143,9 +150,12 @@ def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: C
raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}")
-def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None:
- """
- Checks that the config fails to validate if extra params are present in certain places.
+def check_test_case_config_with_extra_param(
+ test_case: SchemaTestCaseInfo,
+ schema_type: ConfigSchema,
+) -> None:
+ """Checks that the config fails to validate if extra params are present in certain
+ places.
"""
config = deepcopy(test_case.config)
schema_type.validate(config)
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py
index ffc0add973..a47395e2d2 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py
@@ -2,20 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for CLI schema validation.
-"""
+"""Tests for CLI schema validation."""
from os import path
import pytest
from mlos_bench.config.schemas import ConfigSchema
-
-from mlos_bench.tests.config.schemas import (get_schema_test_cases,
- check_test_case_against_schema,
- check_test_case_config_with_extra_param)
-
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ check_test_case_config_with_extra_param,
+ get_schema_test_cases,
+)
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
@@ -26,27 +24,31 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Now we actually perform all of those validation tests.
+
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_cli_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the CLI config validates against the schema.
- """
+ """Checks that the CLI config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.CLI)
if TEST_CASES.by_path[test_case_name].test_case_type != "bad":
# Unified schema has a hard time validating bad configs, so we skip it.
- # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them,
- # so adding/removing params doesn't invalidate it against all of the config types.
+ # The trouble is that tunable-values, cli, globals all look like flat dicts
+ # with minor constraints on them, so adding/removing params doesn't
+ # invalidate it against all of the config types.
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"]))
def test_cli_configs_with_extra_param(test_case_name: str) -> None:
+ """Checks that the cli config fails to validate if extra params are present in
+ certain places.
"""
- Checks that the cli config fails to validate if extra params are present in certain places.
- """
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI)
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.CLI,
+ )
if TEST_CASES.by_path[test_case_name].test_case_type != "bad":
# Unified schema has a hard time validating bad configs, so we skip it.
- # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them,
- # so adding/removing params doesn't invalidate it against all of the config types.
+ # The trouble is that tunable-values, cli, globals all look like flat dicts
+ # with minor constraints on them, so adding/removing params doesn't
+ # invalidate it against all of the config types.
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py
index 8d1c5135d0..3819f1848e 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py
@@ -2,26 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for environment schema validation.
-"""
+"""Tests for environment schema validation."""
from os import path
import pytest
-from mlos_core.tests import get_all_concrete_subclasses
-
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.composite_env import CompositeEnv
from mlos_bench.environments.script_env import ScriptEnv
-
from mlos_bench.tests import try_resolve_class_name
-from mlos_bench.tests.config.schemas import (get_schema_test_cases,
- check_test_case_against_schema,
- check_test_case_config_with_extra_param)
-
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ check_test_case_config_with_extra_param,
+ get_schema_test_cases,
+)
+from mlos_core.tests import get_all_concrete_subclasses
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
@@ -34,48 +31,60 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Dynamically enumerate some of the cases we want to make sure we cover.
NON_CONFIG_ENV_CLASSES = {
- ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python.
+ # ScriptEnv is ABCMeta abstract, but there's no good way to test that
+ # dynamically in Python.
+ ScriptEnv,
}
-expected_environment_class_names = [subclass.__module__ + "." + subclass.__name__
- for subclass
- in get_all_concrete_subclasses(Environment, pkg_name='mlos_bench')
- if subclass not in NON_CONFIG_ENV_CLASSES]
+expected_environment_class_names = [
+ subclass.__module__ + "." + subclass.__name__
+ for subclass in get_all_concrete_subclasses(Environment, pkg_name="mlos_bench")
+ if subclass not in NON_CONFIG_ENV_CLASSES
+]
assert expected_environment_class_names
COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__
-expected_leaf_environment_class_names = [subclass_name for subclass_name in expected_environment_class_names
- if subclass_name != COMPOSITE_ENV_CLASS_NAME]
+expected_leaf_environment_class_names = [
+ subclass_name
+ for subclass_name in expected_environment_class_names
+ if subclass_name != COMPOSITE_ENV_CLASS_NAME
+]
# Do the full cross product of all the test cases and all the Environment types.
@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
@pytest.mark.parametrize("env_class", expected_environment_class_names)
def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_class: str) -> None:
- """
- Checks to see if there is a given type of test case for the given mlos_bench Environment type.
+ """Checks to see if there is a given type of test case for the given mlos_bench
+ Environment type.
"""
for test_case in TEST_CASES.by_subtype[test_case_subtype].values():
if try_resolve_class_name(test_case.config.get("class")) == env_class:
return
raise NotImplementedError(
- f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}")
+ f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}"
+ )
# Now we actually perform all of those validation tests.
+
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_environment_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the environment config validates against the schema.
- """
+ """Checks that the environment config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT)
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"]))
def test_environment_configs_with_extra_param(test_case_name: str) -> None:
+ """Checks that the environment config fails to validate if extra params are present
+ in certain places.
"""
- Checks that the environment config fails to validate if extra params are present in certain places.
- """
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT)
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.ENVIRONMENT,
+ )
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.UNIFIED,
+ )
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py
index 59d3ddd866..bcfc0aeb79 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py
@@ -2,18 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for CLI schema validation.
-"""
+"""Tests for CLI schema validation."""
from os import path
import pytest
from mlos_bench.config.schemas import ConfigSchema
-
-from mlos_bench.tests.config.schemas import get_schema_test_cases, check_test_case_against_schema
-
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ get_schema_test_cases,
+)
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
@@ -24,14 +23,14 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Now we actually perform all of those validation tests.
+
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_globals_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the CLI config validates against the schema.
- """
+ """Checks that the CLI config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS)
if TEST_CASES.by_path[test_case_name].test_case_type != "bad":
# Unified schema has a hard time validating bad configs, so we skip it.
- # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them,
- # so adding/removing params doesn't invalidate it against all of the config types.
+ # The trouble is that tunable-values, cli, globals all look like flat dicts
+ # with minor constraints on them, so adding/removing params doesn't
+ # invalidate it against all of the config types.
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py
index e69c50c4bd..87c7dd7a27 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py
@@ -2,28 +2,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for optimizer schema validation.
-"""
+"""Tests for optimizer schema validation."""
from os import path
from typing import Optional
import pytest
+from mlos_bench.config.schemas import ConfigSchema
+from mlos_bench.optimizers.base_optimizer import Optimizer
+from mlos_bench.tests import try_resolve_class_name
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ check_test_case_config_with_extra_param,
+ get_schema_test_cases,
+)
from mlos_core.optimizers import OptimizerType
from mlos_core.spaces.adapters import SpaceAdapterType
from mlos_core.tests import get_all_concrete_subclasses
-from mlos_bench.config.schemas import ConfigSchema
-from mlos_bench.optimizers.base_optimizer import Optimizer
-
-from mlos_bench.tests import try_resolve_class_name
-from mlos_bench.tests.config.schemas import (get_schema_test_cases,
- check_test_case_against_schema,
- check_test_case_config_with_extra_param)
-
-
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
# - enumerate and try to check that we've covered all the cases
@@ -34,12 +31,17 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Dynamically enumerate some of the cases we want to make sure we cover.
-expected_mlos_bench_optimizer_class_names = [subclass.__module__ + "." + subclass.__name__
- for subclass in get_all_concrete_subclasses(Optimizer, # type: ignore[type-abstract]
- pkg_name='mlos_bench')]
+expected_mlos_bench_optimizer_class_names = [
+ subclass.__module__ + "." + subclass.__name__
+ for subclass in get_all_concrete_subclasses(
+ Optimizer, # type: ignore[type-abstract]
+ pkg_name="mlos_bench",
+ )
+]
assert expected_mlos_bench_optimizer_class_names
-# Also make sure that we check for configs where the optimizer_type or space_adapter_type are left unspecified (None).
+# Also make sure that we check for configs where the optimizer_type or
+# space_adapter_type are left unspecified (None).
expected_mlos_core_optimizer_types = list(OptimizerType) + [None]
assert expected_mlos_core_optimizer_types
@@ -51,15 +53,21 @@ assert expected_mlos_core_space_adapter_types
# Do the full cross product of all the test cases and all the optimizer types.
@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
@pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names)
-def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_bench_optimizer_type: str) -> None:
- """
- Checks to see if there is a given type of test case for the given mlos_bench optimizer type.
+def test_case_coverage_mlos_bench_optimizer_type(
+ test_case_subtype: str,
+ mlos_bench_optimizer_type: str,
+) -> None:
+ """Checks to see if there is a given type of test case for the given mlos_bench
+ optimizer type.
"""
for test_case in TEST_CASES.by_subtype[test_case_subtype].values():
if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type:
return
raise NotImplementedError(
- f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}")
+ f"Missing test case for subtype {test_case_subtype} "
+ f"for Optimizer class {mlos_bench_optimizer_type}"
+ )
+
# Being a little lazy for the moment and relaxing the requirement that we have
# a subtype test case for each optimizer and space adapter combo.
@@ -68,60 +76,79 @@ def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_be
@pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type))
# @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
@pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types)
-def test_case_coverage_mlos_core_optimizer_type(test_case_type: str,
- mlos_core_optimizer_type: Optional[OptimizerType]) -> None:
- """
- Checks to see if there is a given type of test case for the given mlos_core optimizer type.
+def test_case_coverage_mlos_core_optimizer_type(
+ test_case_type: str,
+ mlos_core_optimizer_type: Optional[OptimizerType],
+) -> None:
+ """Checks to see if there is a given type of test case for the given mlos_core
+ optimizer type.
"""
optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name
for test_case in TEST_CASES.by_type[test_case_type].values():
- if try_resolve_class_name(test_case.config.get("class")) \
- == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer":
+ if (
+ try_resolve_class_name(test_case.config.get("class"))
+ == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer"
+ ):
optimizer_type = None
if test_case.config.get("config"):
optimizer_type = test_case.config["config"].get("optimizer_type", None)
if optimizer_type == optimizer_name:
return
raise NotImplementedError(
- f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}")
+ f"Missing test case for type {test_case_type} "
+ f"for MlosCore Optimizer type {mlos_core_optimizer_type}"
+ )
@pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type))
# @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
@pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types)
-def test_case_coverage_mlos_core_space_adapter_type(test_case_type: str,
- mlos_core_space_adapter_type: Optional[SpaceAdapterType]) -> None:
+def test_case_coverage_mlos_core_space_adapter_type(
+ test_case_type: str,
+ mlos_core_space_adapter_type: Optional[SpaceAdapterType],
+) -> None:
+ """Checks to see if there is a given type of test case for the given mlos_core space
+ adapter type.
"""
- Checks to see if there is a given type of test case for the given mlos_core space adapter type.
- """
- space_adapter_name = None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name
+ space_adapter_name = (
+ None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name
+ )
for test_case in TEST_CASES.by_type[test_case_type].values():
- if try_resolve_class_name(test_case.config.get("class")) \
- == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer":
+ if (
+ try_resolve_class_name(test_case.config.get("class"))
+ == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer"
+ ):
space_adapter_type = None
if test_case.config.get("config"):
space_adapter_type = test_case.config["config"].get("space_adapter_type", None)
if space_adapter_type == space_adapter_name:
return
raise NotImplementedError(
- f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}")
+ f"Missing test case for type {test_case_type} "
+ f"for SpaceAdapter type {mlos_core_space_adapter_type}"
+ )
# Now we actually perform all of those validation tests.
+
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_optimizer_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the optimizer config validates against the schema.
- """
+ """Checks that the optimizer config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.OPTIMIZER)
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"]))
def test_optimizer_configs_with_extra_param(test_case_name: str) -> None:
+ """Checks that the optimizer config fails to validate if extra params are present in
+ certain places.
"""
- Checks that the optimizer config fails to validate if extra params are present in certain places.
- """
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER)
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.OPTIMIZER,
+ )
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.UNIFIED,
+ )
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py
index 6e625b8ef2..56945739d7 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py
@@ -2,24 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for schedulers schema validation.
-"""
+"""Tests for schedulers schema validation."""
from os import path
import pytest
-from mlos_core.tests import get_all_concrete_subclasses
-
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.schedulers.base_scheduler import Scheduler
-
from mlos_bench.tests import try_resolve_class_name
-from mlos_bench.tests.config.schemas import (get_schema_test_cases,
- check_test_case_against_schema,
- check_test_case_config_with_extra_param)
-
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ check_test_case_config_with_extra_param,
+ get_schema_test_cases,
+)
+from mlos_core.tests import get_all_concrete_subclasses
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
@@ -31,9 +28,13 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Dynamically enumerate some of the cases we want to make sure we cover.
-expected_mlos_bench_scheduler_class_names = [subclass.__module__ + "." + subclass.__name__
- for subclass in get_all_concrete_subclasses(Scheduler, # type: ignore[type-abstract]
- pkg_name='mlos_bench')]
+expected_mlos_bench_scheduler_class_names = [
+ subclass.__module__ + "." + subclass.__name__
+ for subclass in get_all_concrete_subclasses(
+ Scheduler, # type: ignore[type-abstract]
+ pkg_name="mlos_bench",
+ )
+]
assert expected_mlos_bench_scheduler_class_names
# Do the full cross product of all the test cases and all the scheduler types.
@@ -41,35 +42,45 @@ assert expected_mlos_bench_scheduler_class_names
@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
@pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names)
-def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_bench_scheduler_type: str) -> None:
- """
- Checks to see if there is a given type of test case for the given mlos_bench scheduler type.
+def test_case_coverage_mlos_bench_scheduler_type(
+ test_case_subtype: str,
+ mlos_bench_scheduler_type: str,
+) -> None:
+ """Checks to see if there is a given type of test case for the given mlos_bench
+ scheduler type.
"""
for test_case in TEST_CASES.by_subtype[test_case_subtype].values():
if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_scheduler_type:
return
raise NotImplementedError(
- f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}")
+ f"Missing test case for subtype {test_case_subtype} "
+ f"for Scheduler class {mlos_bench_scheduler_type}"
+ )
+
# Now we actually perform all of those validation tests.
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_scheduler_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the scheduler config validates against the schema.
- """
+ """Checks that the scheduler config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SCHEDULER)
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"]))
def test_scheduler_configs_with_extra_param(test_case_name: str) -> None:
+ """Checks that the scheduler config fails to validate if extra params are present in
+ certain places.
"""
- Checks that the scheduler config fails to validate if extra params are present in certain places.
- """
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER)
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.SCHEDULER,
+ )
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.UNIFIED,
+ )
if __name__ == "__main__":
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py
index c96346ad7b..e8b95ad85b 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py
@@ -2,29 +2,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for service schema validation.
-"""
+"""Tests for service schema validation."""
from os import path
from typing import Any, Dict, List
import pytest
-from mlos_core.tests import get_all_concrete_subclasses
-
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.base_service import Service
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.services.local.temp_dir_context import TempDirContextService
-from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService
+from mlos_bench.services.remote.azure.azure_deployment_services import (
+ AzureDeploymentService,
+)
from mlos_bench.services.remote.ssh.ssh_service import SshService
-
from mlos_bench.tests import try_resolve_class_name
-from mlos_bench.tests.config.schemas import (get_schema_test_cases,
- check_test_case_against_schema,
- check_test_case_config_with_extra_param)
-
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ check_test_case_config_with_extra_param,
+ get_schema_test_cases,
+)
+from mlos_core.tests import get_all_concrete_subclasses
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
@@ -37,16 +36,21 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Dynamically enumerate some of the cases we want to make sure we cover.
NON_CONFIG_SERVICE_CLASSES = {
- ConfigPersistenceService, # configured thru the launcher cli args
- TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python.
- AzureDeploymentService, # ABCMeta abstract base class
- SshService, # ABCMeta abstract base class
+ # configured thru the launcher cli args
+ ConfigPersistenceService,
+ # ABCMeta abstract class, but no good way to test that dynamically in Python.
+ TempDirContextService,
+ # ABCMeta abstract base class
+ AzureDeploymentService,
+ # ABCMeta abstract base class
+ SshService,
}
-expected_service_class_names = [subclass.__module__ + "." + subclass.__name__
- for subclass
- in get_all_concrete_subclasses(Service, pkg_name='mlos_bench')
- if subclass not in NON_CONFIG_SERVICE_CLASSES]
+expected_service_class_names = [
+ subclass.__module__ + "." + subclass.__name__
+ for subclass in get_all_concrete_subclasses(Service, pkg_name="mlos_bench")
+ if subclass not in NON_CONFIG_SERVICE_CLASSES
+]
assert expected_service_class_names
@@ -54,13 +58,13 @@ assert expected_service_class_names
@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
@pytest.mark.parametrize("service_class", expected_service_class_names)
def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_class: str) -> None:
- """
- Checks to see if there is a given type of test case for the given mlos_bench Service type.
+ """Checks to see if there is a given type of test case for the given mlos_bench
+ Service type.
"""
for test_case in TEST_CASES.by_subtype[test_case_subtype].values():
config_list: List[Dict[str, Any]]
if not isinstance(test_case.config, dict):
- continue # type: ignore[unreachable]
+ continue # type: ignore[unreachable]
if "class" not in test_case.config:
config_list = test_case.config["services"]
else:
@@ -69,24 +73,30 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c
if try_resolve_class_name(config.get("class")) == service_class:
return
raise NotImplementedError(
- f"Missing test case for subtype {test_case_subtype} for service class {service_class}")
+ f"Missing test case for subtype {test_case_subtype} for service class {service_class}"
+ )
# Now we actually perform all of those validation tests.
+
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_service_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the service config validates against the schema.
- """
+ """Checks that the service config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE)
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"]))
def test_service_configs_with_extra_param(test_case_name: str) -> None:
+ """Checks that the service config fails to validate if extra params are present in
+ certain places.
"""
- Checks that the service config fails to validate if extra params are present in certain places.
- """
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE)
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.SERVICE,
+ )
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.UNIFIED,
+ )
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py
index 7c42b85c4b..c3dd4ced81 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py
@@ -2,24 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for storage schema validation.
-"""
+"""Tests for storage schema validation."""
from os import path
import pytest
-from mlos_core.tests import get_all_concrete_subclasses
-
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.storage.base_storage import Storage
-
from mlos_bench.tests import try_resolve_class_name
-from mlos_bench.tests.config.schemas import (get_schema_test_cases,
- check_test_case_against_schema,
- check_test_case_config_with_extra_param)
-
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ check_test_case_config_with_extra_param,
+ get_schema_test_cases,
+)
+from mlos_core.tests import get_all_concrete_subclasses
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
@@ -29,9 +26,13 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Dynamically enumerate some of the cases we want to make sure we cover.
-expected_mlos_bench_storage_class_names = [subclass.__module__ + "." + subclass.__name__
- for subclass in get_all_concrete_subclasses(Storage, # type: ignore[type-abstract]
- pkg_name='mlos_bench')]
+expected_mlos_bench_storage_class_names = [
+ subclass.__module__ + "." + subclass.__name__
+ for subclass in get_all_concrete_subclasses(
+ Storage, # type: ignore[type-abstract]
+ pkg_name="mlos_bench",
+ )
+]
assert expected_mlos_bench_storage_class_names
# Do the full cross product of all the test cases and all the storage types.
@@ -39,36 +40,48 @@ assert expected_mlos_bench_storage_class_names
@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype))
@pytest.mark.parametrize("mlos_bench_storage_type", expected_mlos_bench_storage_class_names)
-def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_bench_storage_type: str) -> None:
- """
- Checks to see if there is a given type of test case for the given mlos_bench storage type.
+def test_case_coverage_mlos_bench_storage_type(
+ test_case_subtype: str,
+ mlos_bench_storage_type: str,
+) -> None:
+ """Checks to see if there is a given type of test case for the given mlos_bench
+ storage type.
"""
for test_case in TEST_CASES.by_subtype[test_case_subtype].values():
if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_storage_type:
return
raise NotImplementedError(
- f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}")
+ f"Missing test case for subtype {test_case_subtype} "
+ f"for Storage class {mlos_bench_storage_type}"
+ )
# Now we actually perform all of those validation tests.
+
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_storage_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the storage config validates against the schema.
- """
+ """Checks that the storage config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.STORAGE)
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"]))
def test_storage_configs_with_extra_param(test_case_name: str) -> None:
+ """Checks that the storage config fails to validate if extra params are present in
+ certain places.
"""
- Checks that the storage config fails to validate if extra params are present in certain places.
- """
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE)
- check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED)
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.STORAGE,
+ )
+ check_test_case_config_with_extra_param(
+ TEST_CASES.by_type["good"][test_case_name],
+ ConfigSchema.UNIFIED,
+ )
-if __name__ == '__main__':
- pytest.main([__file__, '-n0'],)
+if __name__ == "__main__":
+ pytest.main(
+ [__file__, "-n0"],
+ )
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py
index fda78a19f9..f0694fc50f 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py
@@ -2,18 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for tunable params schema validation.
-"""
+"""Tests for tunable params schema validation."""
from os import path
import pytest
from mlos_bench.config.schemas import ConfigSchema
-
-from mlos_bench.tests.config.schemas import get_schema_test_cases, check_test_case_against_schema
-
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ get_schema_test_cases,
+)
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
@@ -24,10 +23,9 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Now we actually perform all of those validation tests.
+
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_tunable_params_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the tunable params config validates against the schema.
- """
+ """Checks that the tunable params config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_PARAMS)
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py
index 043bb725bc..9b24a39d75 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py
+++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py
@@ -2,18 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for tunable values schema validation.
-"""
+"""Tests for tunable values schema validation."""
from os import path
import pytest
from mlos_bench.config.schemas import ConfigSchema
-
-from mlos_bench.tests.config.schemas import get_schema_test_cases, check_test_case_against_schema
-
+from mlos_bench.tests.config.schemas import (
+ check_test_case_against_schema,
+ get_schema_test_cases,
+)
# General testing strategy:
# - hand code a set of good/bad configs (useful to test editor schema checking)
@@ -24,14 +23,14 @@ TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases
# Now we actually perform all of those validation tests.
+
@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path))
def test_tunable_values_configs_against_schema(test_case_name: str) -> None:
- """
- Checks that the tunable values config validates against the schema.
- """
+ """Checks that the tunable values config validates against the schema."""
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES)
if TEST_CASES.by_path[test_case_name].test_case_type != "bad":
# Unified schema has a hard time validating bad configs, so we skip it.
- # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them,
- # so adding/removing params doesn't invalidate it against all of the config types.
+ # The trouble is that tunable-values, cli, globals all look like flat dicts
+ # with minor constraints on them, so adding/removing params doesn't
+ # invalidate it against all of the config types.
check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED)
diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py
index f3da324dee..5545327080 100644
--- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py
+++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py
@@ -2,20 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for loading service config examples.
-"""
+"""Tests for loading service config examples."""
import logging
from typing import List
import pytest
-from mlos_bench.tests.config import locate_config_examples
-
from mlos_bench.config.schemas.config_schemas import ConfigSchema
from mlos_bench.services.base_service import Service
from mlos_bench.services.config_persistence import ConfigPersistenceService
-
+from mlos_bench.tests.config import locate_config_examples
_LOG = logging.getLogger(__name__)
_LOG.setLevel(logging.DEBUG)
@@ -27,19 +23,30 @@ CONFIG_TYPE = "services"
def filter_configs(configs_to_filter: List[str]) -> List[str]:
"""If necessary, filter out json files that aren't for the module we're testing."""
+
def predicate(config_path: str) -> bool:
- arm_template = config_path.find("services/remote/azure/arm-templates/") >= 0 and config_path.endswith(".jsonc")
+ arm_template = config_path.find(
+ "services/remote/azure/arm-templates/"
+ ) >= 0 and config_path.endswith(".jsonc")
setup_rg_scripts = config_path.find("azure/scripts/setup-rg") >= 0
return not (arm_template or setup_rg_scripts)
+
return [config_path for config_path in configs_to_filter if predicate(config_path)]
-configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs)
+configs = locate_config_examples(
+ ConfigPersistenceService.BUILTIN_CONFIG_PATH,
+ CONFIG_TYPE,
+ filter_configs,
+)
assert configs
@pytest.mark.parametrize("config_path", configs)
-def test_load_service_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
+def test_load_service_config_examples(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> None:
"""Tests loading a config example."""
config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE)
# Make an instance of the class based on the config.
diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py
index 039d49948f..bb9161144a 100644
--- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py
+++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py
@@ -2,22 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for loading storage config examples.
-"""
+"""Tests for loading storage config examples."""
import logging
from typing import List
import pytest
-from mlos_bench.tests.config import locate_config_examples
-
from mlos_bench.config.schemas.config_schemas import ConfigSchema
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.storage.base_storage import Storage
+from mlos_bench.tests.config import locate_config_examples
from mlos_bench.util import get_class_from_name
-
_LOG = logging.getLogger(__name__)
_LOG.setLevel(logging.DEBUG)
@@ -31,12 +27,19 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]:
return configs_to_filter
-configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs)
+configs = locate_config_examples(
+ ConfigPersistenceService.BUILTIN_CONFIG_PATH,
+ CONFIG_TYPE,
+ filter_configs,
+)
assert configs
@pytest.mark.parametrize("config_path", configs)
-def test_load_storage_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None:
+def test_load_storage_config_examples(
+ config_loader_service: ConfigPersistenceService,
+ config_path: str,
+) -> None:
"""Tests loading a config example."""
config = config_loader_service.load_config(config_path, ConfigSchema.STORAGE)
assert isinstance(config, dict)
diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py
index f30fe43585..a13c57a2cd 100644
--- a/mlos_bench/mlos_bench/tests/conftest.py
+++ b/mlos_bench/mlos_bench/tests/conftest.py
@@ -2,23 +2,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Common fixtures for mock TunableGroups and Environment objects.
-"""
-
-from typing import Any, Generator, List
+"""Common fixtures for mock TunableGroups and Environment objects."""
import os
-
-from fasteners import InterProcessLock, InterProcessReaderWriterLock
-from pytest_docker.plugin import get_docker_services, Services as DockerServices
+from typing import Any, Generator, List
import pytest
+from fasteners import InterProcessLock, InterProcessReaderWriterLock
+from pytest_docker.plugin import Services as DockerServices
+from pytest_docker.plugin import get_docker_services
from mlos_bench.environments.mock_env import MockEnv
-from mlos_bench.tunables.tunable_groups import TunableGroups
-
from mlos_bench.tests import SEED, tunable_groups_fixtures
+from mlos_bench.tunables.tunable_groups import TunableGroups
# pylint: disable=redefined-outer-name
# -- Ignore pylint complaints about pytest references to
@@ -33,9 +29,7 @@ covariant_group = tunable_groups_fixtures.covariant_group
@pytest.fixture
def mock_env(tunable_groups: TunableGroups) -> MockEnv:
- """
- Test fixture for MockEnv.
- """
+ """Test fixture for MockEnv."""
return MockEnv(
name="Test Env",
config={
@@ -44,15 +38,13 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv:
"mock_env_range": [60, 120],
"mock_env_metrics": ["score"],
},
- tunables=tunable_groups
+ tunables=tunable_groups,
)
@pytest.fixture
def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv:
- """
- Test fixture for MockEnv.
- """
+ """Test fixture for MockEnv."""
return MockEnv(
name="Test Env No Noise",
config={
@@ -61,7 +53,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv:
"mock_env_range": [60, 120],
"mock_env_metrics": ["score", "other_score"],
},
- tunables=tunable_groups
+ tunables=tunable_groups,
)
@@ -105,30 +97,37 @@ def docker_compose_project_name(short_testrun_uid: str) -> str:
@pytest.fixture(scope="session")
-def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessReaderWriterLock:
+def docker_services_lock(
+ shared_temp_dir: str,
+ short_testrun_uid: str,
+) -> InterProcessReaderWriterLock:
"""
- Gets a pytest session lock for xdist workers to mark when they're using the
- docker services.
+ Gets a pytest session lock for xdist workers to mark when they're using the docker
+ services.
Yields
------
A lock to ensure that setup/teardown operations don't happen while a
worker is using the docker services.
"""
- return InterProcessReaderWriterLock(f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock")
+ return InterProcessReaderWriterLock(
+ f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock"
+ )
@pytest.fixture(scope="session")
def docker_setup_teardown_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessLock:
"""
- Gets a pytest session lock between xdist workers for the docker
- setup/teardown operations.
+ Gets a pytest session lock between xdist workers for the docker setup/teardown
+ operations.
Yields
------
A lock to ensure that only one worker is doing setup/teardown at a time.
"""
- return InterProcessLock(f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock")
+ return InterProcessLock(
+ f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock"
+ )
@pytest.fixture(scope="session")
@@ -141,8 +140,8 @@ def locked_docker_services(
docker_setup_teardown_lock: InterProcessLock,
docker_services_lock: InterProcessReaderWriterLock,
) -> Generator[DockerServices, Any, None]:
- """
- A locked version of the docker_services fixture to implement xdist single instance locking.
+ """A locked version of the docker_services fixture to implement xdist single
+ instance locking.
"""
# pylint: disable=too-many-arguments
# Mark the services as in use with the reader lock.
diff --git a/mlos_bench/mlos_bench/tests/dict_templater_test.py b/mlos_bench/mlos_bench/tests/dict_templater_test.py
index 63219d9246..4b64f50fd4 100644
--- a/mlos_bench/mlos_bench/tests/dict_templater_test.py
+++ b/mlos_bench/mlos_bench/tests/dict_templater_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for DictTemplater class.
-"""
+"""Unit tests for DictTemplater class."""
from copy import deepcopy
from typing import Any, Dict
@@ -46,9 +44,7 @@ def source_template_dict() -> Dict[str, Any]:
def test_no_side_effects(source_template_dict: Dict[str, Any]) -> None:
- """
- Test that the templater does not modify the source dictionary.
- """
+ """Test that the templater does not modify the source dictionary."""
source_template_dict_copy = deepcopy(source_template_dict)
results = DictTemplater(source_template_dict_copy).expand_vars()
assert results
@@ -56,9 +52,7 @@ def test_no_side_effects(source_template_dict: Dict[str, Any]) -> None:
def test_secondary_expansion(source_template_dict: Dict[str, Any]) -> None:
- """
- Test that internal expansions work as expected.
- """
+ """Test that internal expansions work as expected."""
results = DictTemplater(source_template_dict).expand_vars()
assert results == {
"extra_str-ref": "$extra_str-ref",
@@ -85,9 +79,7 @@ def test_secondary_expansion(source_template_dict: Dict[str, Any]) -> None:
def test_os_env_expansion(source_template_dict: Dict[str, Any]) -> None:
- """
- Test that expansions from OS env work as expected.
- """
+ """Test that expansions from OS env work as expected."""
environ["extra_str"] = "os-env-extra_str"
environ["string"] = "shouldn't be used"
@@ -117,9 +109,7 @@ def test_os_env_expansion(source_template_dict: Dict[str, Any]) -> None:
def test_from_extras_expansion(source_template_dict: Dict[str, Any]) -> None:
- """
- Test that
- """
+ """Test that."""
extra_source_dict = {
"extra_str": "str-from-extras",
"string": "shouldn't be used",
diff --git a/mlos_bench/mlos_bench/tests/environments/__init__.py b/mlos_bench/mlos_bench/tests/environments/__init__.py
index e33188a9e3..01155de0b2 100644
--- a/mlos_bench/mlos_bench/tests/environments/__init__.py
+++ b/mlos_bench/mlos_bench/tests/environments/__init__.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests helpers for mlos_bench.environments.
-"""
+"""Tests helpers for mlos_bench.environments."""
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
@@ -12,16 +10,17 @@ from typing import Any, Dict, List, Optional, Tuple
import pytest
from mlos_bench.environments.base_environment import Environment
-
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
-def check_env_success(env: Environment,
- tunable_groups: TunableGroups,
- expected_results: Dict[str, TunableValue],
- expected_telemetry: List[Tuple[datetime, str, Any]],
- global_config: Optional[dict] = None) -> None:
+def check_env_success(
+ env: Environment,
+ tunable_groups: TunableGroups,
+ expected_results: Dict[str, TunableValue],
+ expected_telemetry: List[Tuple[datetime, str, Any]],
+ global_config: Optional[dict] = None,
+) -> None:
"""
Set up an environment and run a test experiment there.
@@ -51,13 +50,13 @@ def check_env_success(env: Environment,
assert telemetry == pytest.approx(expected_telemetry, nan_ok=True)
env_context.teardown()
- assert not env_context._is_ready # pylint: disable=protected-access
+ assert not env_context._is_ready # pylint: disable=protected-access
def check_env_fail_telemetry(env: Environment, tunable_groups: TunableGroups) -> None:
"""
- Set up a local environment and run a test experiment there;
- Make sure the environment `.status()` call fails.
+ Set up a local environment and run a test experiment there; Make sure the
+ environment `.status()` call fails.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py
index 69253e31c1..04f9e8c54c 100644
--- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py
@@ -2,16 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for base environment class functionality.
-"""
+"""Unit tests for base environment class functionality."""
from typing import Dict
import pytest
-from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.environments.base_environment import Environment
+from mlos_bench.tunables.tunable import TunableValue
_GROUPS = {
"group": ["a", "b"],
@@ -25,40 +23,43 @@ _GROUPS = {
def test_expand_groups() -> None:
- """
- Check the dollar variable expansion for tunable groups.
- """
+ """Check the dollar variable expansion for tunable groups."""
assert Environment._expand_groups(
- ["begin", "$list", "$empty", "$str", "end"],
- _GROUPS) == ["begin", "c", "d", "efg", "end"]
+ [
+ "begin",
+ "$list",
+ "$empty",
+ "$str",
+ "end",
+ ],
+ _GROUPS,
+ ) == [
+ "begin",
+ "c",
+ "d",
+ "efg",
+ "end",
+ ]
def test_expand_groups_empty_input() -> None:
- """
- Make sure an empty group stays empty.
- """
+ """Make sure an empty group stays empty."""
assert Environment._expand_groups([], _GROUPS) == []
def test_expand_groups_empty_list() -> None:
- """
- Make sure an empty group expansion works properly.
- """
+ """Make sure an empty group expansion works properly."""
assert not Environment._expand_groups(["$empty"], _GROUPS)
def test_expand_groups_unknown() -> None:
- """
- Make sure we fail on unknown $GROUP names expansion.
- """
+ """Make sure we fail on unknown $GROUP names expansion."""
with pytest.raises(KeyError):
Environment._expand_groups(["$list", "$UNKNOWN", "$str", "end"], _GROUPS)
def test_expand_const_args() -> None:
- """
- Test expansion of const args via expand_vars.
- """
+ """Test expansion of const args via expand_vars."""
const_args: Dict[str, TunableValue] = {
"a": "b",
"foo": "$bar/baz",
diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py
index e135a868ef..0d81ec7847 100644
--- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py
@@ -2,17 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Check how the services get inherited and overridden in child environments.
-"""
+"""Check how the services get inherited and overridden in child environments."""
import os
import pytest
from mlos_bench.environments.composite_env import CompositeEnv
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.services.local.local_exec import LocalExecService
+from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import path_join
# pylint: disable=redefined-outer-name
@@ -20,9 +18,7 @@ from mlos_bench.util import path_join
@pytest.fixture
def composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
- """
- Test fixture for CompositeEnv with services included on multiple levels.
- """
+ """Test fixture for CompositeEnv with services included on multiple levels."""
return CompositeEnv(
name="Root",
config={
@@ -40,28 +36,26 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
"name": "Env 3 :: tmp_other_3",
"class": "mlos_bench.environments.mock_env.MockEnv",
"include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"],
- }
+ },
]
},
tunables=tunable_groups,
service=LocalExecService(
- config={
- "temp_dir": "_test_tmp_global"
- },
- parent=ConfigPersistenceService({
- "config_path": [
- path_join(os.path.dirname(__file__), "../config", abs_path=True),
- ]
- })
- )
+ config={"temp_dir": "_test_tmp_global"},
+ parent=ConfigPersistenceService(
+ {
+ "config_path": [
+ path_join(os.path.dirname(__file__), "../config", abs_path=True),
+ ]
+ }
+ ),
+ ),
)
def test_composite_services(composite_env: CompositeEnv) -> None:
- """
- Check that each environment gets its own instance of the services.
- """
- for (i, path) in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")):
+ """Check that each environment gets its own instance of the services."""
+ for i, path in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")):
service = composite_env.children[i]._service # pylint: disable=protected-access
assert service is not None and hasattr(service, "temp_dir_context")
with service.temp_dir_context() as temp_dir:
diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py
index fd7c022939..80463ea3d9 100644
--- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py
@@ -2,24 +2,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for composite environment.
-"""
+"""Unit tests for composite environment."""
import pytest
from mlos_bench.environments.composite_env import CompositeEnv
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.tunables.tunable_groups import TunableGroups
# pylint: disable=redefined-outer-name
@pytest.fixture
def composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
- """
- Test fixture for CompositeEnv.
- """
+ """Test fixture for CompositeEnv."""
return CompositeEnv(
name="Composite Test Environment",
config={
@@ -28,7 +24,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
"vm_server_name": "Mock Server VM",
"vm_client_name": "Mock Client VM",
"someConst": "root",
- "global_param": "default"
+ "global_param": "default",
},
"children": [
{
@@ -43,7 +39,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
"required_args": ["vmName", "someConst", "global_param"],
"mock_env_range": [60, 120],
"mock_env_metrics": ["score"],
- }
+ },
},
{
"name": "Mock Server Environment 2",
@@ -53,12 +49,12 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
"const_args": {
"vmName": "$vm_server_name",
"EnvId": 2,
- "global_param": "local"
+ "global_param": "local",
},
"required_args": ["vmName"],
"mock_env_range": [60, 120],
"mock_env_metrics": ["score"],
- }
+ },
},
{
"name": "Mock Control Environment 3",
@@ -72,79 +68,79 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
"required_args": ["vmName", "vm_server_name", "vm_client_name"],
"mock_env_range": [60, 120],
"mock_env_metrics": ["score"],
- }
- }
- ]
+ },
+ },
+ ],
},
tunables=tunable_groups,
service=ConfigPersistenceService({}),
- global_config={
- "global_param": "global_value"
- }
+ global_config={"global_param": "global_value"},
)
def test_composite_env_params(composite_env: CompositeEnv) -> None:
"""
- Check that the const_args from the parent environment get propagated to the children.
+ Check that the const_args from the parent environment get propagated to the
+ children.
+
NOTE: The current logic is that variables flow down via required_args and const_args, parent
"""
assert composite_env.children[0].parameters == {
- "vmName": "Mock Client VM", # const_args from the parent thru variable substitution
- "EnvId": 1, # const_args from the child
- "vmSize": "Standard_B4ms", # tunable_params from the parent
- "someConst": "root", # pulled in from parent via required_args
- "global_param": "global_value" # pulled in from the global_config
+ "vmName": "Mock Client VM", # const_args from the parent thru variable substitution
+ "EnvId": 1, # const_args from the child
+ "vmSize": "Standard_B4ms", # tunable_params from the parent
+ "someConst": "root", # pulled in from parent via required_args
+ "global_param": "global_value", # pulled in from the global_config
}
assert composite_env.children[1].parameters == {
- "vmName": "Mock Server VM", # const_args from the parent
- "EnvId": 2, # const_args from the child
- "idle": "halt", # tunable_params from the parent
+ "vmName": "Mock Server VM", # const_args from the parent
+ "EnvId": 2, # const_args from the child
+ "idle": "halt", # tunable_params from the parent
# "someConst": "root" # not required, so not passed from the parent
- "global_param": "global_value" # pulled in from the global_config
+ "global_param": "global_value", # pulled in from the global_config
}
assert composite_env.children[2].parameters == {
- "vmName": "Mock Control VM", # const_args from the parent
- "EnvId": 3, # const_args from the child
- "idle": "halt", # tunable_params from the parent
+ "vmName": "Mock Control VM", # const_args from the parent
+ "EnvId": 3, # const_args from the child
+ "idle": "halt", # tunable_params from the parent
# "someConst": "root" # not required, so not passed from the parent
"vm_client_name": "Mock Client VM",
- "vm_server_name": "Mock Server VM"
+ "vm_server_name": "Mock Server VM",
# "global_param": "global_value" # not required, so not picked from the global_config
}
def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None:
- """
- Check that the child environments update their tunable parameters.
- """
- tunable_groups.assign({
- "vmSize": "Standard_B2s",
- "idle": "mwait",
- "kernel_sched_migration_cost_ns": 100000,
- })
+ """Check that the child environments update their tunable parameters."""
+ tunable_groups.assign(
+ {
+ "vmSize": "Standard_B2s",
+ "idle": "mwait",
+ "kernel_sched_migration_cost_ns": 100000,
+ }
+ )
with composite_env as env_context:
assert env_context.setup(tunable_groups)
assert composite_env.children[0].parameters == {
- "vmName": "Mock Client VM", # const_args from the parent
- "EnvId": 1, # const_args from the child
- "vmSize": "Standard_B2s", # tunable_params from the parent
- "someConst": "root", # pulled in from parent via required_args
- "global_param": "global_value" # pulled in from the global_config
+ "vmName": "Mock Client VM", # const_args from the parent
+ "EnvId": 1, # const_args from the child
+ "vmSize": "Standard_B2s", # tunable_params from the parent
+ "someConst": "root", # pulled in from parent via required_args
+ "global_param": "global_value", # pulled in from the global_config
}
assert composite_env.children[1].parameters == {
- "vmName": "Mock Server VM", # const_args from the parent
- "EnvId": 2, # const_args from the child
- "idle": "mwait", # tunable_params from the parent
+ "vmName": "Mock Server VM", # const_args from the parent
+ "EnvId": 2, # const_args from the child
+ "idle": "mwait", # tunable_params from the parent
# "someConst": "root" # not required, so not passed from the parent
- "global_param": "global_value" # pulled in from the global_config
+ "global_param": "global_value", # pulled in from the global_config
}
assert composite_env.children[2].parameters == {
- "vmName": "Mock Control VM", # const_args from the parent
- "EnvId": 3, # const_args from the child
- "idle": "mwait", # tunable_params from the parent
+ "vmName": "Mock Control VM", # const_args from the parent
+ "EnvId": 3, # const_args from the child
+ "idle": "mwait", # tunable_params from the parent
"vm_client_name": "Mock Client VM",
"vm_server_name": "Mock Server VM",
# "global_param": "global_value" # not required, so not picked from the global_config
@@ -153,9 +149,7 @@ def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: Tunabl
@pytest.fixture
def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
- """
- Test fixture for CompositeEnv.
- """
+ """Test fixture for CompositeEnv."""
return CompositeEnv(
name="Composite Test Environment",
config={
@@ -163,7 +157,7 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
"const_args": {
"vm_server_name": "Mock Server VM",
"vm_client_name": "Mock Client VM",
- "someConst": "root"
+ "someConst": "root",
},
"children": [
{
@@ -191,11 +185,11 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
"EnvId",
"someConst",
"vm_server_name",
- "global_param"
+ "global_param",
],
"mock_env_range": [60, 120],
"mock_env_metrics": ["score"],
- }
+ },
},
# ...
],
@@ -220,76 +214,78 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
"required_args": ["vmName", "EnvId", "vm_client_name"],
"mock_env_range": [60, 120],
"mock_env_metrics": ["score"],
- }
+ },
},
# ...
],
},
},
-
- ]
+ ],
},
tunables=tunable_groups,
service=ConfigPersistenceService({}),
- global_config={
- "global_param": "global_value"
- }
+ global_config={"global_param": "global_value"},
)
def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None:
"""
- Check that the const_args from the parent environment get propagated to the children.
+ Check that the const_args from the parent environment get propagated to the
+ children.
+
NOTE: The current logic is that variables flow down via required_args and const_args, parent
"""
assert isinstance(nested_composite_env.children[0], CompositeEnv)
assert nested_composite_env.children[0].children[0].parameters == {
- "vmName": "Mock Client VM", # const_args from the parent thru variable substitution
- "EnvId": 1, # const_args from the child
- "vmSize": "Standard_B4ms", # tunable_params from the parent
- "someConst": "root", # pulled in from parent via required_args
+ "vmName": "Mock Client VM", # const_args from the parent thru variable substitution
+ "EnvId": 1, # const_args from the child
+ "vmSize": "Standard_B4ms", # tunable_params from the parent
+ "someConst": "root", # pulled in from parent via required_args
"vm_server_name": "Mock Server VM",
- "global_param": "global_value" # pulled in from the global_config
+ "global_param": "global_value", # pulled in from the global_config
}
assert isinstance(nested_composite_env.children[1], CompositeEnv)
assert nested_composite_env.children[1].children[0].parameters == {
- "vmName": "Mock Server VM", # const_args from the parent
- "EnvId": 2, # const_args from the child
- "idle": "halt", # tunable_params from the parent
+ "vmName": "Mock Server VM", # const_args from the parent
+ "EnvId": 2, # const_args from the child
+ "idle": "halt", # tunable_params from the parent
# "someConst": "root" # not required, so not passed from the parent
"vm_client_name": "Mock Client VM",
# "global_param": "global_value" # not required, so not picked from the global_config
}
-def test_nested_composite_env_setup(nested_composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None:
- """
- Check that the child environments update their tunable parameters.
- """
- tunable_groups.assign({
- "vmSize": "Standard_B2s",
- "idle": "mwait",
- "kernel_sched_migration_cost_ns": 100000,
- })
+def test_nested_composite_env_setup(
+ nested_composite_env: CompositeEnv,
+ tunable_groups: TunableGroups,
+) -> None:
+ """Check that the child environments update their tunable parameters."""
+ tunable_groups.assign(
+ {
+ "vmSize": "Standard_B2s",
+ "idle": "mwait",
+ "kernel_sched_migration_cost_ns": 100000,
+ }
+ )
with nested_composite_env as env_context:
assert env_context.setup(tunable_groups)
assert isinstance(nested_composite_env.children[0], CompositeEnv)
assert nested_composite_env.children[0].children[0].parameters == {
- "vmName": "Mock Client VM", # const_args from the parent
- "EnvId": 1, # const_args from the child
- "vmSize": "Standard_B2s", # tunable_params from the parent
- "someConst": "root", # pulled in from parent via required_args
+ "vmName": "Mock Client VM", # const_args from the parent
+ "EnvId": 1, # const_args from the child
+ "vmSize": "Standard_B2s", # tunable_params from the parent
+ "someConst": "root", # pulled in from parent via required_args
"vm_server_name": "Mock Server VM",
- "global_param": "global_value" # pulled in from the global_config
+ "global_param": "global_value", # pulled in from the global_config
}
assert isinstance(nested_composite_env.children[1], CompositeEnv)
assert nested_composite_env.children[1].children[0].parameters == {
- "vmName": "Mock Server VM", # const_args from the parent
- "EnvId": 2, # const_args from the child
- "idle": "mwait", # tunable_params from the parent
+ "vmName": "Mock Server VM", # const_args from the parent
+ "EnvId": 2, # const_args from the child
+ "idle": "mwait", # tunable_params from the parent
# "someConst": "root" # not required, so not passed from the parent
"vm_client_name": "Mock Client VM",
}
diff --git a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py
index 7395aa3e15..4c4fcd5dae 100644
--- a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Test the selection of tunables / tunable groups for the environment.
-"""
+"""Test the selection of tunables / tunable groups for the environment."""
from mlos_bench.environments.mock_env import MockEnv
from mlos_bench.services.config_persistence import ConfigPersistenceService
@@ -12,13 +10,11 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
def test_one_group(tunable_groups: TunableGroups) -> None:
- """
- Make sure only one tunable group is available to the environment.
- """
+ """Make sure only one tunable group is available to the environment."""
env = MockEnv(
name="Test Env",
config={"tunable_params": ["provision"]},
- tunables=tunable_groups
+ tunables=tunable_groups,
)
assert env.tunable_params.get_param_values() == {
"vmSize": "Standard_B4ms",
@@ -26,13 +22,11 @@ def test_one_group(tunable_groups: TunableGroups) -> None:
def test_two_groups(tunable_groups: TunableGroups) -> None:
- """
- Make sure only the selected tunable groups are available to the environment.
- """
+ """Make sure only the selected tunable groups are available to the environment."""
env = MockEnv(
name="Test Env",
config={"tunable_params": ["provision", "kernel"]},
- tunables=tunable_groups
+ tunables=tunable_groups,
)
assert env.tunable_params.get_param_values() == {
"vmSize": "Standard_B4ms",
@@ -42,9 +36,8 @@ def test_two_groups(tunable_groups: TunableGroups) -> None:
def test_two_groups_setup(tunable_groups: TunableGroups) -> None:
- """
- Make sure only the selected tunable groups are available to the environment,
- the set is not changed after calling the `.setup()` method.
+ """Make sure only the selected tunable groups are available to the environment, the
+ set is not changed after calling the `.setup()` method.
"""
env = MockEnv(
name="Test Env",
@@ -55,7 +48,7 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None:
"const_param2": "foo",
},
},
- tunables=tunable_groups
+ tunables=tunable_groups,
)
expected_params = {
"vmSize": "Standard_B4ms",
@@ -77,34 +70,22 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None:
def test_zero_groups_implicit(tunable_groups: TunableGroups) -> None:
- """
- Make sure that no tunable groups are available to the environment by default.
- """
- env = MockEnv(
- name="Test Env",
- config={},
- tunables=tunable_groups
- )
+ """Make sure that no tunable groups are available to the environment by default."""
+ env = MockEnv(name="Test Env", config={}, tunables=tunable_groups)
assert env.tunable_params.get_param_values() == {}
def test_zero_groups_explicit(tunable_groups: TunableGroups) -> None:
+ """Make sure that no tunable groups are available to the environment when explicitly
+ specifying an empty list of tunable_params.
"""
- Make sure that no tunable groups are available to the environment
- when explicitly specifying an empty list of tunable_params.
- """
- env = MockEnv(
- name="Test Env",
- config={"tunable_params": []},
- tunables=tunable_groups
- )
+ env = MockEnv(name="Test Env", config={"tunable_params": []}, tunables=tunable_groups)
assert env.tunable_params.get_param_values() == {}
def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None:
- """
- Make sure that no tunable groups are available to the environment by default
- and it does not change after the setup.
+ """Make sure that no tunable groups are available to the environment by default and
+ it does not change after the setup.
"""
env = MockEnv(
name="Test Env",
@@ -114,7 +95,7 @@ def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None:
"const_param2": "foo",
},
},
- tunables=tunable_groups
+ tunables=tunable_groups,
)
assert env.tunable_params.get_param_values() == {}
@@ -130,16 +111,13 @@ def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None:
def test_loader_level_include() -> None:
- """
- Make sure only the selected tunable groups are available to the environment,
- the set is not changed after calling the `.setup()` method.
+ """Make sure only the selected tunable groups are available to the environment, the
+ set is not changed after calling the `.setup()` method.
"""
env_json = {
"class": "mlos_bench.environments.mock_env.MockEnv",
"name": "Test Env",
- "include_tunables": [
- "environments/os/linux/boot/linux-boot-tunables.jsonc"
- ],
+ "include_tunables": ["environments/os/linux/boot/linux-boot-tunables.jsonc"],
"config": {
"tunable_params": ["linux-kernel-boot"],
"const_args": {
@@ -148,12 +126,14 @@ def test_loader_level_include() -> None:
},
},
}
- loader = ConfigPersistenceService({
- "config_path": [
- "mlos_bench/config",
- "mlos_bench/examples",
- ]
- })
+ loader = ConfigPersistenceService(
+ {
+ "config_path": [
+ "mlos_bench/config",
+ "mlos_bench/examples",
+ ]
+ }
+ )
env = loader.build_environment(config=env_json, tunables=TunableGroups())
expected_params = {
"align_va_addr": "on",
diff --git a/mlos_bench/mlos_bench/tests/environments/local/__init__.py b/mlos_bench/mlos_bench/tests/environments/local/__init__.py
index 5d8fc32c6b..d0a954f6fc 100644
--- a/mlos_bench/mlos_bench/tests/environments/local/__init__.py
+++ b/mlos_bench/mlos_bench/tests/environments/local/__init__.py
@@ -4,6 +4,7 @@
#
"""
Tests for mlos_bench.environments.local.
+
Used to make mypy happy about multiple conftest.py modules.
"""
@@ -32,14 +33,20 @@ def create_local_env(tunable_groups: TunableGroups, config: Dict[str, Any]) -> L
env : LocalEnv
A new instance of the local environment.
"""
- return LocalEnv(name="TestLocalEnv", config=config, tunables=tunable_groups,
- service=LocalExecService(parent=ConfigPersistenceService()))
+ return LocalEnv(
+ name="TestLocalEnv",
+ config=config,
+ tunables=tunable_groups,
+ service=LocalExecService(parent=ConfigPersistenceService()),
+ )
-def create_composite_local_env(tunable_groups: TunableGroups,
- global_config: Dict[str, Any],
- params: Dict[str, Any],
- local_configs: List[Dict[str, Any]]) -> CompositeEnv:
+def create_composite_local_env(
+ tunable_groups: TunableGroups,
+ global_config: Dict[str, Any],
+ params: Dict[str, Any],
+ local_configs: List[Dict[str, Any]],
+) -> CompositeEnv:
"""
Create a CompositeEnv with several LocalEnv instances.
@@ -70,7 +77,7 @@ def create_composite_local_env(tunable_groups: TunableGroups,
"config": config,
}
for (i, config) in enumerate(local_configs)
- ]
+ ],
},
tunables=tunable_groups,
global_config=global_config,
diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py
index 4815e1c50d..4d15a6fcee 100644
--- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py
@@ -2,20 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for the composition of several LocalEnv benchmark environments.
-"""
+"""Unit tests for the composition of several LocalEnv benchmark environments."""
import sys
from datetime import datetime, timedelta, tzinfo
from typing import Optional
-from pytz import UTC
import pytest
+from pytz import UTC
-from mlos_bench.tunables.tunable_groups import TunableGroups
+from mlos_bench.tests import ZONE_INFO
from mlos_bench.tests.environments import check_env_success
from mlos_bench.tests.environments.local import create_composite_local_env
-from mlos_bench.tests import ZONE_INFO
+from mlos_bench.tunables.tunable_groups import TunableGroups
def _format_str(zone_info: Optional[tzinfo]) -> str:
@@ -28,8 +26,9 @@ def _format_str(zone_info: Optional[tzinfo]) -> str:
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None:
"""
- Produce benchmark and telemetry data in TWO local environments
- and combine the results.
+ Produce benchmark and telemetry data in TWO local environments and combine the
+ results.
+
Also checks that global configs flow down at least one level of CompositeEnv
to its children without being explicitly specified in the CompositeEnv so they
can be used in the shell_envs by its children.
@@ -43,7 +42,7 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo
time_str1 = ts1.strftime(format_str)
time_str2 = ts2.strftime(format_str)
- (var_prefix, var_suffix) = ("%", "%") if sys.platform == 'win32' else ("$", "")
+ (var_prefix, var_suffix) = ("%", "%") if sys.platform == "win32" else ("$", "")
env = create_composite_local_env(
tunable_groups=tunable_groups,
@@ -66,9 +65,12 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo
},
"required_args": ["errors", "reads"],
"shell_env_params": [
- "latency", # const_args overridden by the composite env
- "errors", # Comes from the parent const_args
- "reads" # const_args overridden by the global config
+ # const_args overridden by the composite env
+ "latency",
+ # Comes from the parent const_args
+ "errors",
+ # const_args overridden by the global config
+ "reads",
],
"run": [
"echo 'metric,value' > output.csv",
@@ -90,9 +92,12 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo
},
"required_args": ["writes"],
"shell_env_params": [
- "throughput", # const_args overridden by the composite env
- "score", # Comes from the local const_args
- "writes" # Comes straight from the global config
+ # const_args overridden by the composite env
+ "throughput",
+ # Comes from the local const_args
+ "score",
+ # Comes straight from the global config
+ "writes",
],
"run": [
"echo 'metric,value' > output.csv",
@@ -106,12 +111,13 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo
],
"read_results_file": "output.csv",
"read_telemetry_file": "telemetry.csv",
- }
- ]
+ },
+ ],
)
check_env_success(
- env, tunable_groups,
+ env,
+ tunable_groups,
expected_results={
"latency": 4.2,
"throughput": 768.0,
diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py
index 7b4de8c237..684e7e13f6 100644
--- a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py
@@ -2,34 +2,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for extracting data from LocalEnv stdout.
-"""
+"""Unit tests for extracting data from LocalEnv stdout."""
import sys
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tests.environments import check_env_success
from mlos_bench.tests.environments.local import create_local_env
+from mlos_bench.tunables.tunable_groups import TunableGroups
def test_local_env_stdout(tunable_groups: TunableGroups) -> None:
- """
- Print benchmark results to stdout and capture them in the LocalEnv.
- """
- local_env = create_local_env(tunable_groups, {
- "run": [
- "echo 'Benchmark results:'", # This line should be ignored
- "echo 'latency,111'",
- "echo 'throughput,222'",
- "echo 'score,0.999'",
- "echo 'a,0,b,1'",
- ],
- "results_stdout_pattern": r"(\w+),([0-9.]+)",
- })
+ """Print benchmark results to stdout and capture them in the LocalEnv."""
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ "echo 'Benchmark results:'", # This line should be ignored
+ "echo 'latency,111'",
+ "echo 'throughput,222'",
+ "echo 'score,0.999'",
+ "echo 'a,0,b,1'",
+ ],
+ "results_stdout_pattern": r"(\w+),([0-9.]+)",
+ },
+ )
check_env_success(
- local_env, tunable_groups,
+ local_env,
+ tunable_groups,
expected_results={
"latency": 111.0,
"throughput": 222.0,
@@ -42,22 +42,24 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None:
def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None:
- """
- Print benchmark results to stdout and capture them in the LocalEnv.
- """
- local_env = create_local_env(tunable_groups, {
- "run": [
- "echo 'Benchmark results:'", # This line should be ignored
- "echo 'latency,111'",
- "echo 'throughput,222'",
- "echo 'score,0.999'",
- "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern
- ],
- "results_stdout_pattern": r"^(\w+),([0-9.]+)$",
- })
+ """Print benchmark results to stdout and capture them in the LocalEnv."""
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ "echo 'Benchmark results:'", # This line should be ignored
+ "echo 'latency,111'",
+ "echo 'throughput,222'",
+ "echo 'score,0.999'",
+ "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern
+ ],
+ "results_stdout_pattern": r"^(\w+),([0-9.]+)$",
+ },
+ )
check_env_success(
- local_env, tunable_groups,
+ local_env,
+ tunable_groups,
expected_results={
"latency": 111.0,
"throughput": 222.0,
@@ -69,27 +71,31 @@ def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None:
def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None:
+ """Print benchmark results to *BOTH* stdout and a file and extract the results from
+ both.
"""
- Print benchmark results to *BOTH* stdout and a file and extract the results from both.
- """
- local_env = create_local_env(tunable_groups, {
- "run": [
- "echo 'latency,111'",
- "echo 'throughput,222'",
- "echo 'score,0.999'",
- "echo 'stdout-msg,string'",
- "echo '-------------------'", # Should be ignored
- "echo 'metric,value' > output.csv",
- "echo 'extra1,333' >> output.csv",
- "echo 'extra2,444' >> output.csv",
- "echo 'file-msg,string' >> output.csv",
- ],
- "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)",
- "read_results_file": "output.csv",
- })
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ "echo 'latency,111'",
+ "echo 'throughput,222'",
+ "echo 'score,0.999'",
+ "echo 'stdout-msg,string'",
+ "echo '-------------------'", # Should be ignored
+ "echo 'metric,value' > output.csv",
+ "echo 'extra1,333' >> output.csv",
+ "echo 'extra2,444' >> output.csv",
+ "echo 'file-msg,string' >> output.csv",
+ ],
+ "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)",
+ "read_results_file": "output.csv",
+ },
+ )
check_env_success(
- local_env, tunable_groups,
+ local_env,
+ tunable_groups,
expected_results={
"latency": 111.0,
"throughput": 222.0,
diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py
index ba104da542..9cda41f14d 100644
--- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py
@@ -2,20 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for telemetry and status of LocalEnv benchmark environment.
-"""
+"""Unit tests for telemetry and status of LocalEnv benchmark environment."""
from datetime import datetime, timedelta, tzinfo
from typing import Optional
+import pytest
from pytz import UTC
-import pytest
-
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tests import ZONE_INFO
-from mlos_bench.tests.environments import check_env_success, check_env_fail_telemetry
+from mlos_bench.tests.environments import check_env_fail_telemetry, check_env_success
from mlos_bench.tests.environments.local import create_local_env
+from mlos_bench.tunables.tunable_groups import TunableGroups
def _format_str(zone_info: Optional[tzinfo]) -> str:
@@ -27,9 +24,7 @@ def _format_str(zone_info: Optional[tzinfo]) -> str:
# FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...`
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None:
- """
- Produce benchmark and telemetry data in a local script and read it.
- """
+ """Produce benchmark and telemetry data in a local script and read it."""
ts1 = datetime.now(zone_info)
ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second
ts2 = ts1 + timedelta(minutes=1)
@@ -38,25 +33,29 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[
time_str1 = ts1.strftime(format_str)
time_str2 = ts2.strftime(format_str)
- local_env = create_local_env(tunable_groups, {
- "run": [
- "echo 'metric,value' > output.csv",
- "echo 'latency,4.1' >> output.csv",
- "echo 'throughput,512' >> output.csv",
- "echo 'score,0.95' >> output.csv",
- "echo '-------------------'", # This output does not go anywhere
- "echo 'timestamp,metric,value' > telemetry.csv",
- f"echo {time_str1},cpu_load,0.65 >> telemetry.csv",
- f"echo {time_str1},mem_usage,10240 >> telemetry.csv",
- f"echo {time_str2},cpu_load,0.8 >> telemetry.csv",
- f"echo {time_str2},mem_usage,20480 >> telemetry.csv",
- ],
- "read_results_file": "output.csv",
- "read_telemetry_file": "telemetry.csv",
- })
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ "echo 'metric,value' > output.csv",
+ "echo 'latency,4.1' >> output.csv",
+ "echo 'throughput,512' >> output.csv",
+ "echo 'score,0.95' >> output.csv",
+ "echo '-------------------'", # This output does not go anywhere
+ "echo 'timestamp,metric,value' > telemetry.csv",
+ f"echo {time_str1},cpu_load,0.65 >> telemetry.csv",
+ f"echo {time_str1},mem_usage,10240 >> telemetry.csv",
+ f"echo {time_str2},cpu_load,0.8 >> telemetry.csv",
+ f"echo {time_str2},mem_usage,20480 >> telemetry.csv",
+ ],
+ "read_results_file": "output.csv",
+ "read_telemetry_file": "telemetry.csv",
+ },
+ )
check_env_success(
- local_env, tunable_groups,
+ local_env,
+ tunable_groups,
expected_results={
"latency": 4.1,
"throughput": 512.0,
@@ -73,10 +72,11 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[
# FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...`
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None:
- """
- Read the telemetry data with no header.
- """
+def test_local_env_telemetry_no_header(
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
+ """Read the telemetry data with no header."""
ts1 = datetime.now(zone_info)
ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second
ts2 = ts1 + timedelta(minutes=1)
@@ -85,18 +85,22 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info:
time_str1 = ts1.strftime(format_str)
time_str2 = ts2.strftime(format_str)
- local_env = create_local_env(tunable_groups, {
- "run": [
- f"echo {time_str1},cpu_load,0.65 > telemetry.csv",
- f"echo {time_str1},mem_usage,10240 >> telemetry.csv",
- f"echo {time_str2},cpu_load,0.8 >> telemetry.csv",
- f"echo {time_str2},mem_usage,20480 >> telemetry.csv",
- ],
- "read_telemetry_file": "telemetry.csv",
- })
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ f"echo {time_str1},cpu_load,0.65 > telemetry.csv",
+ f"echo {time_str1},mem_usage,10240 >> telemetry.csv",
+ f"echo {time_str2},cpu_load,0.8 >> telemetry.csv",
+ f"echo {time_str2},mem_usage,20480 >> telemetry.csv",
+ ],
+ "read_telemetry_file": "telemetry.csv",
+ },
+ )
check_env_success(
- local_env, tunable_groups,
+ local_env,
+ tunable_groups,
expected_results={},
expected_telemetry=[
(ts1.astimezone(UTC), "cpu_load", 0.65),
@@ -107,12 +111,18 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info:
)
-@pytest.mark.filterwarnings("ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0") # pylint: disable=line-too-long # noqa
+@pytest.mark.filterwarnings(
+ (
+ "ignore:.*(Could not infer format, so each element will be parsed individually, "
+ "falling back to `dateutil`).*:UserWarning::0"
+ )
+) # pylint: disable=line-too-long # noqa
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None:
- """
- Read the telemetry data with incorrect header.
- """
+def test_local_env_telemetry_wrong_header(
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
+ """Read the telemetry data with incorrect header."""
ts1 = datetime.now(zone_info)
ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second
ts2 = ts1 + timedelta(minutes=1)
@@ -121,25 +131,26 @@ def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_in
time_str1 = ts1.strftime(format_str)
time_str2 = ts2.strftime(format_str)
- local_env = create_local_env(tunable_groups, {
- "run": [
- # Error: the data is correct, but the header has unexpected column names
- "echo 'ts,metric_name,metric_value' > telemetry.csv",
- f"echo {time_str1},cpu_load,0.65 >> telemetry.csv",
- f"echo {time_str1},mem_usage,10240 >> telemetry.csv",
- f"echo {time_str2},cpu_load,0.8 >> telemetry.csv",
- f"echo {time_str2},mem_usage,20480 >> telemetry.csv",
- ],
- "read_telemetry_file": "telemetry.csv",
- })
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ # Error: the data is correct, but the header has unexpected column names
+ "echo 'ts,metric_name,metric_value' > telemetry.csv",
+ f"echo {time_str1},cpu_load,0.65 >> telemetry.csv",
+ f"echo {time_str1},mem_usage,10240 >> telemetry.csv",
+ f"echo {time_str2},cpu_load,0.8 >> telemetry.csv",
+ f"echo {time_str2},mem_usage,20480 >> telemetry.csv",
+ ],
+ "read_telemetry_file": "telemetry.csv",
+ },
+ )
check_env_fail_telemetry(local_env, tunable_groups)
def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None:
- """
- Fail when the telemetry data has wrong format.
- """
+ """Fail when the telemetry data has wrong format."""
zone_info = UTC
ts1 = datetime.now(zone_info)
ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second
@@ -149,33 +160,37 @@ def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None:
time_str1 = ts1.strftime(format_str)
time_str2 = ts2.strftime(format_str)
- local_env = create_local_env(tunable_groups, {
- "run": [
- # Error: too many columns
- f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv",
- f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv",
- f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv",
- f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv",
- ],
- "read_telemetry_file": "telemetry.csv",
- })
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ # Error: too many columns
+ f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv",
+ f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv",
+ f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv",
+ f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv",
+ ],
+ "read_telemetry_file": "telemetry.csv",
+ },
+ )
check_env_fail_telemetry(local_env, tunable_groups)
def test_local_env_telemetry_invalid_ts(tunable_groups: TunableGroups) -> None:
- """
- Fail when the telemetry data has wrong format.
- """
- local_env = create_local_env(tunable_groups, {
- "run": [
- # Error: field 1 must be a timestamp
- "echo 1,cpu_load,0.65 > telemetry.csv",
- "echo 2,mem_usage,10240 >> telemetry.csv",
- "echo 3,cpu_load,0.8 >> telemetry.csv",
- "echo 4,mem_usage,20480 >> telemetry.csv",
- ],
- "read_telemetry_file": "telemetry.csv",
- })
+ """Fail when the telemetry data has wrong format."""
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ # Error: field 1 must be a timestamp
+ "echo 1,cpu_load,0.65 > telemetry.csv",
+ "echo 2,mem_usage,10240 >> telemetry.csv",
+ "echo 3,cpu_load,0.8 >> telemetry.csv",
+ "echo 4,mem_usage,20480 >> telemetry.csv",
+ ],
+ "read_telemetry_file": "telemetry.csv",
+ },
+ )
check_env_fail_telemetry(local_env, tunable_groups)
diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py
index 9fcd26ead2..25eea76b17 100644
--- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py
@@ -2,32 +2,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for LocalEnv benchmark environment.
-"""
+"""Unit tests for LocalEnv benchmark environment."""
import pytest
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tests.environments import check_env_success
from mlos_bench.tests.environments.local import create_local_env
+from mlos_bench.tunables.tunable_groups import TunableGroups
def test_local_env(tunable_groups: TunableGroups) -> None:
- """
- Produce benchmark and telemetry data in a local script and read it.
- """
- local_env = create_local_env(tunable_groups, {
- "run": [
- "echo 'metric,value' > output.csv",
- "echo 'latency,10' >> output.csv",
- "echo 'throughput,66' >> output.csv",
- "echo 'score,0.9' >> output.csv",
- ],
- "read_results_file": "output.csv",
- })
+ """Produce benchmark and telemetry data in a local script and read it."""
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ "echo 'metric,value' > output.csv",
+ "echo 'latency,10' >> output.csv",
+ "echo 'throughput,66' >> output.csv",
+ "echo 'score,0.9' >> output.csv",
+ ],
+ "read_results_file": "output.csv",
+ },
+ )
check_env_success(
- local_env, tunable_groups,
+ local_env,
+ tunable_groups,
expected_results={
"latency": 10.0,
"throughput": 66.0,
@@ -38,12 +38,10 @@ def test_local_env(tunable_groups: TunableGroups) -> None:
def test_local_env_service_context(tunable_groups: TunableGroups) -> None:
+ """Basic check that context support for Service mixins are handled when environment
+ contexts are entered.
"""
- Basic check that context support for Service mixins are handled when environment contexts are entered.
- """
- local_env = create_local_env(tunable_groups, {
- "run": ["echo NA"]
- })
+ local_env = create_local_env(tunable_groups, {"run": ["echo NA"]})
# pylint: disable=protected-access
assert local_env._service
assert not local_env._service._in_context
@@ -51,27 +49,28 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None:
with local_env as env_context:
assert env_context._in_context
assert local_env._service._in_context
- assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive)
+ assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive)
assert all(svc._in_context for svc in local_env._service._service_contexts)
assert all(svc._in_context for svc in local_env._service._services)
- assert not local_env._service._in_context # type: ignore[unreachable] # (false positive)
+ assert not local_env._service._in_context # type: ignore[unreachable] # (false positive)
assert not local_env._service._service_contexts
assert not any(svc._in_context for svc in local_env._service._services)
def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None:
- """
- Fail if the results are not in the expected format.
- """
- local_env = create_local_env(tunable_groups, {
- "run": [
- # No header
- "echo 'latency,10' > output.csv",
- "echo 'throughput,66' >> output.csv",
- "echo 'score,0.9' >> output.csv",
- ],
- "read_results_file": "output.csv",
- })
+ """Fail if the results are not in the expected format."""
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ # No header
+ "echo 'latency,10' > output.csv",
+ "echo 'throughput,66' >> output.csv",
+ "echo 'score,0.9' >> output.csv",
+ ],
+ "read_results_file": "output.csv",
+ },
+ )
with local_env as env_context:
assert env_context.setup(tunable_groups)
@@ -80,19 +79,21 @@ def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None:
def test_local_env_wide(tunable_groups: TunableGroups) -> None:
- """
- Produce benchmark data in wide format and read it.
- """
- local_env = create_local_env(tunable_groups, {
- "run": [
- "echo 'latency,throughput,score' > output.csv",
- "echo '10,66,0.9' >> output.csv",
- ],
- "read_results_file": "output.csv",
- })
+ """Produce benchmark data in wide format and read it."""
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "run": [
+ "echo 'latency,throughput,score' > output.csv",
+ "echo '10,66,0.9' >> output.csv",
+ ],
+ "read_results_file": "output.csv",
+ },
+ )
check_env_success(
- local_env, tunable_groups,
+ local_env,
+ tunable_groups,
expected_results={
"latency": 10,
"throughput": 66,
diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py
index ac7ff257e1..ef90155f0f 100644
--- a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py
@@ -2,71 +2,68 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for passing shell environment variables into LocalEnv scripts.
-"""
+"""Unit tests for passing shell environment variables into LocalEnv scripts."""
import sys
import pytest
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tests.environments import check_env_success
from mlos_bench.tests.environments.local import create_local_env
+from mlos_bench.tunables.tunable_groups import TunableGroups
def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: dict) -> None:
- """
- Check that LocalEnv can set shell environment variables.
- """
- local_env = create_local_env(tunable_groups, {
- "const_args": {
- "const_arg": 111, # Passed into "shell_env_params"
- "other_arg": 222, # NOT passed into "shell_env_params"
+ """Check that LocalEnv can set shell environment variables."""
+ local_env = create_local_env(
+ tunable_groups,
+ {
+ "const_args": {
+ "const_arg": 111, # Passed into "shell_env_params"
+ "other_arg": 222, # NOT passed into "shell_env_params"
+ },
+ "tunable_params": ["kernel"],
+ "shell_env_params": [
+ "const_arg", # From "const_arg"
+ "kernel_sched_latency_ns", # From "tunable_params"
+ ],
+ "run": [
+ "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv",
+ f"echo {shell_subcmd} >> output.csv",
+ ],
+ "read_results_file": "output.csv",
},
- "tunable_params": ["kernel"],
- "shell_env_params": [
- "const_arg", # From "const_arg"
- "kernel_sched_latency_ns", # From "tunable_params"
- ],
- "run": [
- "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv",
- f"echo {shell_subcmd} >> output.csv",
- ],
- "read_results_file": "output.csv",
- })
+ )
check_env_success(local_env, tunable_groups, expected, [])
-@pytest.mark.skipif(sys.platform == 'win32', reason="sh-like shell only")
+@pytest.mark.skipif(sys.platform == "win32", reason="sh-like shell only")
def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None:
- """
- Check that LocalEnv can set shell environment variables in sh-like shell.
- """
+ """Check that LocalEnv can set shell environment variables in sh-like shell."""
_run_local_env(
tunable_groups,
shell_subcmd="$const_arg,$other_arg,$unknown_arg,$kernel_sched_latency_ns",
expected={
- "const_arg": 111, # From "const_args"
- "other_arg": float("NaN"), # Not included in "shell_env_params"
- "unknown_arg": float("NaN"), # Unknown/undefined variable
- "kernel_sched_latency_ns": 2000000, # From "tunable_params"
- }
+ "const_arg": 111, # From "const_args"
+ "other_arg": float("NaN"), # Not included in "shell_env_params"
+ "unknown_arg": float("NaN"), # Unknown/undefined variable
+ "kernel_sched_latency_ns": 2000000, # From "tunable_params"
+ },
)
-@pytest.mark.skipif(sys.platform != 'win32', reason="Windows only")
+@pytest.mark.skipif(sys.platform != "win32", reason="Windows only")
def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None:
- """
- Check that LocalEnv can set shell environment variables on Windows / cmd shell.
+ """Check that LocalEnv can set shell environment variables on Windows / cmd
+ shell.
"""
_run_local_env(
tunable_groups,
shell_subcmd=r"%const_arg%,%other_arg%,%unknown_arg%,%kernel_sched_latency_ns%",
expected={
- "const_arg": 111, # From "const_args"
- "other_arg": r"%other_arg%", # Not included in "shell_env_params"
- "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable
- "kernel_sched_latency_ns": 2000000, # From "tunable_params"
- }
+ "const_arg": 111, # From "const_args"
+ "other_arg": r"%other_arg%", # Not included in "shell_env_params"
+ "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable
+ "kernel_sched_latency_ns": 2000000, # From "tunable_params"
+ },
)
diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py
index bb455b8b76..9c40d422e7 100644
--- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py
@@ -2,49 +2,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for passing shell environment variables into LocalEnv scripts.
-"""
+"""Unit tests for passing shell environment variables into LocalEnv scripts."""
import pytest
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.services.local.local_exec import LocalExecService
-
-from mlos_bench.tests.services.remote.mock.mock_fileshare_service import MockFileShareService
+from mlos_bench.tests.services.remote.mock.mock_fileshare_service import (
+ MockFileShareService,
+)
+from mlos_bench.tunables.tunable_groups import TunableGroups
# pylint: disable=redefined-outer-name
@pytest.fixture(scope="module")
def mock_fileshare_service() -> MockFileShareService:
- """
- Create a new mock FileShareService instance.
- """
+ """Create a new mock FileShareService instance."""
return MockFileShareService(
config={"fileShareName": "MOCK_FILESHARE"},
- parent=LocalExecService(parent=ConfigPersistenceService())
+ parent=LocalExecService(parent=ConfigPersistenceService()),
)
@pytest.fixture
-def local_fileshare_env(tunable_groups: TunableGroups,
- mock_fileshare_service: MockFileShareService) -> LocalFileShareEnv:
- """
- Create a LocalFileShareEnv instance.
- """
+def local_fileshare_env(
+ tunable_groups: TunableGroups,
+ mock_fileshare_service: MockFileShareService,
+) -> LocalFileShareEnv:
+ """Create a LocalFileShareEnv instance."""
env = LocalFileShareEnv(
name="TestLocalFileShareEnv",
config={
"const_args": {
"experiment_id": "EXP_ID", # Passed into "shell_env_params"
- "trial_id": 222, # NOT passed into "shell_env_params"
+ "trial_id": 222, # NOT passed into "shell_env_params"
},
"tunable_params": ["boot"],
"shell_env_params": [
- "trial_id", # From "const_arg"
- "idle", # From "tunable_params", == "halt"
+ "trial_id", # From "const_arg"
+ "idle", # From "tunable_params", == "halt"
],
"upload": [
{
@@ -56,9 +53,7 @@ def local_fileshare_env(tunable_groups: TunableGroups,
"to": "$experiment_id/$trial_id/input/data_$idle.csv",
},
],
- "run": [
- "echo No-op run"
- ],
+ "run": ["echo No-op run"],
"download": [
{
"from": "$experiment_id/$trial_id/$idle/data.csv",
@@ -72,12 +67,13 @@ def local_fileshare_env(tunable_groups: TunableGroups,
return env
-def test_local_fileshare_env(tunable_groups: TunableGroups,
- mock_fileshare_service: MockFileShareService,
- local_fileshare_env: LocalFileShareEnv) -> None:
- """
- Test that the LocalFileShareEnv correctly expands the `$VAR` variables
- in the upload and download sections of the config.
+def test_local_fileshare_env(
+ tunable_groups: TunableGroups,
+ mock_fileshare_service: MockFileShareService,
+ local_fileshare_env: LocalFileShareEnv,
+) -> None:
+ """Test that the LocalFileShareEnv correctly expands the `$VAR` variables in the
+ upload and download sections of the config.
"""
with local_fileshare_env as env_context:
assert env_context.setup(tunable_groups)
diff --git a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py
index 608edbf9ef..3a82d8dfd3 100644
--- a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py
+++ b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mock benchmark environment.
-"""
+"""Unit tests for mock benchmark environment."""
import pytest
from mlos_bench.environments.mock_env import MockEnv
@@ -12,9 +10,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> None:
- """
- Check the default values of the mock environment.
- """
+ """Check the default values of the mock environment."""
with mock_env as env_context:
assert env_context.setup(tunable_groups)
(status, _ts, data) = env_context.run()
@@ -29,9 +25,7 @@ def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> N
def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups) -> None:
- """
- Check the default values of the mock environment.
- """
+ """Check the default values of the mock environment."""
with mock_env_no_noise as env_context:
assert env_context.setup(tunable_groups)
for _ in range(10):
@@ -42,23 +36,26 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr
assert data["score"] == pytest.approx(75.0, 0.01)
-@pytest.mark.parametrize(('tunable_values', 'expected_score'), [
- ({
- "vmSize": "Standard_B2ms",
- "idle": "halt",
- "kernel_sched_migration_cost_ns": 250000
- }, 66.4),
- ({
- "vmSize": "Standard_B4ms",
- "idle": "halt",
- "kernel_sched_migration_cost_ns": 40000
- }, 74.06),
-])
-def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups,
- tunable_values: dict, expected_score: float) -> None:
- """
- Check the benchmark values of the mock environment after the assignment.
- """
+@pytest.mark.parametrize(
+ ("tunable_values", "expected_score"),
+ [
+ (
+ {"vmSize": "Standard_B2ms", "idle": "halt", "kernel_sched_migration_cost_ns": 250000},
+ 66.4,
+ ),
+ (
+ {"vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 40000},
+ 74.06,
+ ),
+ ],
+)
+def test_mock_env_assign(
+ mock_env: MockEnv,
+ tunable_groups: TunableGroups,
+ tunable_values: dict,
+ expected_score: float,
+) -> None:
+ """Check the benchmark values of the mock environment after the assignment."""
with mock_env as env_context:
tunable_groups.assign(tunable_values)
assert env_context.setup(tunable_groups)
@@ -68,23 +65,27 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups,
assert data["score"] == pytest.approx(expected_score, 0.01)
-@pytest.mark.parametrize(('tunable_values', 'expected_score'), [
- ({
- "vmSize": "Standard_B2ms",
- "idle": "halt",
- "kernel_sched_migration_cost_ns": 250000
- }, 67.5),
- ({
- "vmSize": "Standard_B4ms",
- "idle": "halt",
- "kernel_sched_migration_cost_ns": 40000
- }, 75.1),
-])
-def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv,
- tunable_groups: TunableGroups,
- tunable_values: dict, expected_score: float) -> None:
- """
- Check the benchmark values of the noiseless mock environment after the assignment.
+@pytest.mark.parametrize(
+ ("tunable_values", "expected_score"),
+ [
+ (
+ {"vmSize": "Standard_B2ms", "idle": "halt", "kernel_sched_migration_cost_ns": 250000},
+ 67.5,
+ ),
+ (
+ {"vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 40000},
+ 75.1,
+ ),
+ ],
+)
+def test_mock_env_no_noise_assign(
+ mock_env_no_noise: MockEnv,
+ tunable_groups: TunableGroups,
+ tunable_values: dict,
+ expected_score: float,
+) -> None:
+ """Check the benchmark values of the noiseless mock environment after the
+ assignment.
"""
with mock_env_no_noise as env_context:
tunable_groups.assign(tunable_values)
diff --git a/mlos_bench/mlos_bench/tests/environments/remote/__init__.py b/mlos_bench/mlos_bench/tests/environments/remote/__init__.py
index f8a576c536..a72cac05db 100644
--- a/mlos_bench/mlos_bench/tests/environments/remote/__init__.py
+++ b/mlos_bench/mlos_bench/tests/environments/remote/__init__.py
@@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Helpers for RemoteEnv tests.
-"""
+"""Helpers for RemoteEnv tests."""
diff --git a/mlos_bench/mlos_bench/tests/environments/remote/conftest.py b/mlos_bench/mlos_bench/tests/environments/remote/conftest.py
index 4e9e4197e8..257e37fa9e 100644
--- a/mlos_bench/mlos_bench/tests/environments/remote/conftest.py
+++ b/mlos_bench/mlos_bench/tests/environments/remote/conftest.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Fixtures for the RemoteEnv tests using SSH Services.
-"""
+"""Fixtures for the RemoteEnv tests using SSH Services."""
import mlos_bench.tests.services.remote.ssh.fixtures as ssh_fixtures
diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py
index 36ea7c324b..e3a12bd3ed 100644
--- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py
+++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py
@@ -2,26 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for RemoveEnv benchmark environment via local SSH test services.
-"""
-
-from typing import Dict
+"""Unit tests for RemoveEnv benchmark environment via local SSH test services."""
import os
import sys
+from typing import Dict
import numpy as np
-
import pytest
from mlos_bench.services.config_persistence import ConfigPersistenceService
-from mlos_bench.tunables.tunable import TunableValue
-from mlos_bench.tunables.tunable_groups import TunableGroups
-
from mlos_bench.tests import requires_docker
from mlos_bench.tests.environments import check_env_success
from mlos_bench.tests.services.remote.ssh import SshTestServerInfo
+from mlos_bench.tunables.tunable import TunableValue
+from mlos_bench.tunables.tunable_groups import TunableGroups
if sys.version_info < (3, 10):
from importlib_resources import files
@@ -31,9 +26,7 @@ else:
@requires_docker
def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None:
- """
- Produce benchmark and telemetry data in a local script and read it.
- """
+ """Produce benchmark and telemetry data in a local script and read it."""
global_config: Dict[str, TunableValue] = {
"ssh_hostname": ssh_test_server.hostname,
"ssh_port": ssh_test_server.get_port(),
@@ -41,25 +34,34 @@ def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None:
"ssh_priv_key_path": ssh_test_server.id_rsa_path,
}
- service = ConfigPersistenceService(config={"config_path": [str(files("mlos_bench.tests.config"))]})
+ service = ConfigPersistenceService(
+ config={"config_path": [str(files("mlos_bench.tests.config"))]}
+ )
config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc")
- env = service.load_environment(config_path, TunableGroups(), global_config=global_config, service=service)
+ env = service.load_environment(
+ config_path,
+ TunableGroups(),
+ global_config=global_config,
+ service=service,
+ )
check_env_success(
- env, env.tunable_params,
+ env,
+ env.tunable_params,
expected_results={
"hostname": ssh_test_server.service_name,
"username": ssh_test_server.username,
"score": 0.9,
- "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number"
+ "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number"
"test_param": "unset",
"FOO": "unset",
"ssh_username": "unset",
},
expected_telemetry=[],
)
- assert not os.path.exists(os.path.join(os.getcwd(), "output-downloaded.csv")), \
- "output-downloaded.csv should have been cleaned up by temp_dir context"
+ assert not os.path.exists(
+ os.path.join(os.getcwd(), "output-downloaded.csv")
+ ), "output-downloaded.csv should have been cleaned up by temp_dir context"
if __name__ == "__main__":
diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py
index 80b252f255..eb92c4c132 100644
--- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py
+++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py
@@ -2,21 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench.event_loop_context background thread logic.
-"""
+"""Tests for mlos_bench.event_loop_context background thread logic."""
import asyncio
import sys
import time
-
from asyncio import AbstractEventLoop
from threading import Thread
from types import TracebackType
from typing import Optional, Type
-from typing_extensions import Literal
import pytest
+from typing_extensions import Literal
from mlos_bench.event_loop_context import EventLoopContext
@@ -24,6 +21,7 @@ from mlos_bench.event_loop_context import EventLoopContext
class EventLoopContextCaller:
"""
Simple class to test the EventLoopContext.
+
See Also: SshService
"""
@@ -41,16 +39,21 @@ class EventLoopContextCaller:
self.EVENT_LOOP_CONTEXT.enter()
self._in_context = True
- def __exit__(self, ex_type: Optional[Type[BaseException]],
- ex_val: Optional[BaseException],
- ex_tb: Optional[TracebackType]) -> Literal[False]:
+ def __exit__(
+ self,
+ ex_type: Optional[Type[BaseException]],
+ ex_val: Optional[BaseException],
+ ex_tb: Optional[TracebackType],
+ ) -> Literal[False]:
assert self._in_context
self.EVENT_LOOP_CONTEXT.exit()
self._in_context = False
return False
-@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0")
+@pytest.mark.filterwarnings(
+ "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0"
+)
def test_event_loop_context() -> None:
"""Test event loop context background thread setup/cleanup handling."""
# pylint: disable=protected-access,too-many-statements
@@ -69,7 +72,9 @@ def test_event_loop_context() -> None:
# After we enter the instance context, we should have a background thread.
with event_loop_caller_instance_1:
assert event_loop_caller_instance_1._in_context
- assert isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread, Thread) # type: ignore[unreachable]
+ assert ( # type: ignore[unreachable]
+ isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread, Thread)
+ )
# Give the thread a chance to start.
# Mostly important on the underpowered Windows CI machines.
time.sleep(0.25)
@@ -88,12 +93,16 @@ def test_event_loop_context() -> None:
assert event_loop_caller_instance_1._in_context
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2
# We should only get one thread for all instances.
- assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread \
- is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread \
+ assert (
+ EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread
+ is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread
is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop_thread
- assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop \
- is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop \
+ )
+ assert (
+ EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop
+ is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop
is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop
+ )
assert not event_loop_caller_instance_2._in_context
@@ -105,30 +114,40 @@ def test_event_loop_context() -> None:
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running()
start = time.time()
- future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo'))
+ future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(
+ asyncio.sleep(0.1, result="foo")
+ )
assert 0.0 <= time.time() - start < 0.1
- assert future.result(timeout=0.2) == 'foo'
+ assert future.result(timeout=0.2) == "foo"
assert 0.1 <= time.time() - start <= 0.2
# Once we exit the last context, the background thread should be stopped
# and unusable for running co-routines.
- assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives)
+ assert ( # type: ignore[unreachable] # (false positives)
+ EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None
+ )
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 0
assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is event_loop is not None
assert not EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running()
# Check that the event loop has no more tasks.
- assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_ready')
+ assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_ready")
# Windows ProactorEventLoopPolicy adds a dummy task.
- if sys.platform == 'win32' and isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop):
+ if sys.platform == "win32" and isinstance(
+ EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop
+ ):
assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 1
else:
assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 0
- assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_scheduled')
+ assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_scheduled")
assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._scheduled) == 0
- with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"):
- future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo'))
+ with pytest.raises(
+ AssertionError
+ ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"):
+ future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(
+ asyncio.sleep(0.1, result="foo")
+ )
raise ValueError(f"Future should not have been available to wait on {future.result()}")
# Test that when re-entering the context we have the same event loop.
@@ -139,12 +158,14 @@ def test_event_loop_context() -> None:
# Test running again.
start = time.time()
- future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo'))
+ future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(
+ asyncio.sleep(0.1, result="foo")
+ )
assert 0.0 <= time.time() - start < 0.1
- assert future.result(timeout=0.2) == 'foo'
+ assert future.result(timeout=0.2) == "foo"
assert 0.1 <= time.time() - start <= 0.2
-if __name__ == '__main__':
+if __name__ == "__main__":
# For debugging in Windows which has issues with pytest detection in vscode.
pytest.main(["-n1", "--dist=no", "-k", "test_event_loop_context"])
diff --git a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py
index 90aa7e08f7..6fe340c9eb 100644
--- a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py
+++ b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests to check the launcher and the main optimization loop in-process.
-"""
+"""Unit tests to check the launcher and the main optimization loop in-process."""
from typing import List
@@ -14,24 +12,36 @@ from mlos_bench.run import _main
@pytest.mark.parametrize(
- ("argv", "expected_score"), [
- ([
- "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc",
- "--trial_config_repeat_count", "5",
- "--mock_env_seed", "-1", # Deterministic Mock Environment.
- ], 67.40329),
- ([
- "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc",
- "--trial_config_repeat_count", "3",
- "--max_suggestions", "3",
- "--mock_env_seed", "42", # Noisy Mock Environment.
- ], 64.53897),
- ]
+ ("argv", "expected_score"),
+ [
+ (
+ [
+ "--config",
+ "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc",
+ "--trial_config_repeat_count",
+ "5",
+ "--mock_env_seed",
+ "-1", # Deterministic Mock Environment.
+ ],
+ 67.40329,
+ ),
+ (
+ [
+ "--config",
+ "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc",
+ "--trial_config_repeat_count",
+ "3",
+ "--max_suggestions",
+ "3",
+ "--mock_env_seed",
+ "42", # Noisy Mock Environment.
+ ],
+ 64.53897,
+ ),
+ ],
)
def test_main_bench(argv: List[str], expected_score: float) -> None:
- """
- Run mlos_bench optimization loop with given config and check the results.
- """
+ """Run mlos_bench optimization loop with given config and check the results."""
(score, _config) = _main(argv)
assert score is not None
assert pytest.approx(score["score"], 1e-5) == expected_score
diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py
index 90e52bb880..f577f21526 100644
--- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py
+++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py
@@ -14,11 +14,10 @@ from typing import List
import pytest
-from mlos_bench.launcher import Launcher
-from mlos_bench.optimizers import OneShotOptimizer, MlosCoreOptimizer
-from mlos_bench.os_environ import environ
from mlos_bench.config.schemas import ConfigSchema
-from mlos_bench.util import path_join
+from mlos_bench.launcher import Launcher
+from mlos_bench.optimizers import MlosCoreOptimizer, OneShotOptimizer
+from mlos_bench.os_environ import environ
from mlos_bench.schedulers import SyncScheduler
from mlos_bench.services.types import (
SupportsAuth,
@@ -28,6 +27,7 @@ from mlos_bench.services.types import (
SupportsRemoteExec,
)
from mlos_bench.tests import check_class_name
+from mlos_bench.util import path_join
if sys.version_info < (3, 10):
from importlib_resources import files
@@ -48,15 +48,16 @@ def config_paths() -> List[str]:
"""
return [
path_join(os.getcwd(), abs_path=True),
- str(files('mlos_bench.config')),
- str(files('mlos_bench.tests.config')),
+ str(files("mlos_bench.config")),
+ str(files("mlos_bench.tests.config")),
]
def test_launcher_args_parse_1(config_paths: List[str]) -> None:
"""
- Test that using multiple --globals arguments works and that multiple space
- separated options to --config-paths works.
+ Test that using multiple --globals arguments works and that multiple space separated
+ options to --config-paths works.
+
Check $var expansion and Environment loading.
"""
# The VSCode pytest wrapper actually starts in a different directory before
@@ -64,20 +65,23 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None:
# variable so we use a separate variable.
# See global_test_config.jsonc for more details.
environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd()
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# Some env tweaks for platform compatibility.
- environ['USER'] = environ['USERNAME']
+ environ["USER"] = environ["USERNAME"]
# This is part of the minimal required args by the Launcher.
- env_conf_path = 'environments/mock/mock_env.jsonc'
- cli_args = '--config-paths ' + ' '.join(config_paths) + \
- ' --service services/remote/mock/mock_auth_service.jsonc' + \
- ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \
- ' --scheduler schedulers/sync_scheduler.jsonc' + \
- f' --environment {env_conf_path}' + \
- ' --globals globals/global_test_config.jsonc' + \
- ' --globals globals/global_test_extra_config.jsonc' \
- ' --test_global_value_2 from-args'
+ env_conf_path = "environments/mock/mock_env.jsonc"
+ cli_args = (
+ "--config-paths "
+ + " ".join(config_paths)
+ + " --service services/remote/mock/mock_auth_service.jsonc"
+ + " --service services/remote/mock/mock_remote_exec_service.jsonc"
+ + " --scheduler schedulers/sync_scheduler.jsonc"
+ + f" --environment {env_conf_path}"
+ + " --globals globals/global_test_config.jsonc"
+ + " --globals globals/global_test_extra_config.jsonc"
+ " --test_global_value_2 from-args"
+ )
launcher = Launcher(description=__name__, argv=cli_args.split())
# Check that the parent service
assert isinstance(launcher.service, SupportsAuth)
@@ -85,27 +89,28 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None:
assert isinstance(launcher.service, SupportsLocalExec)
assert isinstance(launcher.service, SupportsRemoteExec)
# Check that the first --globals file is loaded and $var expansion is handled.
- assert launcher.global_config['experiment_id'] == 'MockExperiment'
- assert launcher.global_config['testVmName'] == 'MockExperiment-vm'
+ assert launcher.global_config["experiment_id"] == "MockExperiment"
+ assert launcher.global_config["testVmName"] == "MockExperiment-vm"
# Check that secondary expansion also works.
- assert launcher.global_config['testVnetName'] == 'MockExperiment-vm-vnet'
+ assert launcher.global_config["testVnetName"] == "MockExperiment-vm-vnet"
# Check that the second --globals file is loaded.
- assert launcher.global_config['test_global_value'] == 'from-file'
+ assert launcher.global_config["test_global_value"] == "from-file"
# Check overriding values in a file from the command line.
- assert launcher.global_config['test_global_value_2'] == 'from-args'
+ assert launcher.global_config["test_global_value_2"] == "from-args"
# Check that we can expand a $var in a config file that references an environment variable.
- assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \
- == path_join(os.getcwd(), "foo", abs_path=True)
- assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}'
+ assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join(
+ os.getcwd(), "foo", abs_path=True
+ )
+ assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}"
assert launcher.teardown
# Check that the environment that got loaded looks to be of the right type.
env_config = launcher.config_loader.load_config(env_conf_path, ConfigSchema.ENVIRONMENT)
- assert check_class_name(launcher.environment, env_config['class'])
+ assert check_class_name(launcher.environment, env_config["class"])
# Check that the optimizer looks right.
assert isinstance(launcher.optimizer, OneShotOptimizer)
# Check that the optimizer got initialized with defaults.
assert launcher.optimizer.tunable_params.is_defaults()
- assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer
+ assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer
# Check that we pick up the right scheduler config:
assert isinstance(launcher.scheduler, SyncScheduler)
assert launcher.scheduler._trial_config_repeat_count == 3 # pylint: disable=protected-access
@@ -113,8 +118,7 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None:
def test_launcher_args_parse_2(config_paths: List[str]) -> None:
- """
- Test multiple --config-path instances, --config file vs --arg, --var=val
+ """Test multiple --config-path instances, --config file vs --arg, --var=val
overrides, $var templates, option args, --random-init, etc.
"""
# The VSCode pytest wrapper actually starts in a different directory before
@@ -122,23 +126,25 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None:
# variable so we use a separate variable.
# See global_test_config.jsonc for more details.
environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd()
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# Some env tweaks for platform compatibility.
- environ['USER'] = environ['USERNAME']
+ environ["USER"] = environ["USERNAME"]
- config_file = 'cli/test-cli-config.jsonc'
- globals_file = 'globals/global_test_config.jsonc'
- cli_args = ' '.join([f"--config-path {config_path}" for config_path in config_paths]) + \
- f' --config {config_file}' + \
- ' --service services/remote/mock/mock_auth_service.jsonc' + \
- ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \
- f' --globals {globals_file}' + \
- ' --experiment_id MockeryExperiment' + \
- ' --no-teardown' + \
- ' --random-init' + \
- ' --random-seed 1234' + \
- ' --trial-config-repeat-count 5' + \
- ' --max_trials 200'
+ config_file = "cli/test-cli-config.jsonc"
+ globals_file = "globals/global_test_config.jsonc"
+ cli_args = (
+ " ".join([f"--config-path {config_path}" for config_path in config_paths])
+ + f" --config {config_file}"
+ + " --service services/remote/mock/mock_auth_service.jsonc"
+ + " --service services/remote/mock/mock_remote_exec_service.jsonc"
+ + f" --globals {globals_file}"
+ + " --experiment_id MockeryExperiment"
+ + " --no-teardown"
+ + " --random-init"
+ + " --random-seed 1234"
+ + " --trial-config-repeat-count 5"
+ + " --max_trials 200"
+ )
launcher = Launcher(description=__name__, argv=cli_args.split())
# Check that the parent service
assert isinstance(launcher.service, SupportsAuth)
@@ -148,35 +154,42 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None:
assert isinstance(launcher.service, SupportsRemoteExec)
# Check that the --globals file is loaded and $var expansion is handled
# using the value provided on the CLI.
- assert launcher.global_config['experiment_id'] == 'MockeryExperiment'
- assert launcher.global_config['testVmName'] == 'MockeryExperiment-vm'
+ assert launcher.global_config["experiment_id"] == "MockeryExperiment"
+ assert launcher.global_config["testVmName"] == "MockeryExperiment-vm"
# Check that secondary expansion also works.
- assert launcher.global_config['testVnetName'] == 'MockeryExperiment-vm-vnet'
+ assert launcher.global_config["testVnetName"] == "MockeryExperiment-vm-vnet"
# Check that we can expand a $var in a config file that references an environment variable.
- assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \
- == path_join(os.getcwd(), "foo", abs_path=True)
- assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}'
+ assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join(
+ os.getcwd(), "foo", abs_path=True
+ )
+ assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}"
assert not launcher.teardown
config = launcher.config_loader.load_config(config_file, ConfigSchema.CLI)
- assert launcher.config_loader.config_paths == [path_join(path, abs_path=True) for path in config_paths + config['config_path']]
+ assert launcher.config_loader.config_paths == [
+ path_join(path, abs_path=True) for path in config_paths + config["config_path"]
+ ]
# Check that the environment that got loaded looks to be of the right type.
- env_config_file = config['environment']
+ env_config_file = config["environment"]
env_config = launcher.config_loader.load_config(env_config_file, ConfigSchema.ENVIRONMENT)
- assert check_class_name(launcher.environment, env_config['class'])
+ assert check_class_name(launcher.environment, env_config["class"])
# Check that the optimizer looks right.
assert isinstance(launcher.optimizer, MlosCoreOptimizer)
- opt_config_file = config['optimizer']
+ opt_config_file = config["optimizer"]
opt_config = launcher.config_loader.load_config(opt_config_file, ConfigSchema.OPTIMIZER)
globals_file_config = launcher.config_loader.load_config(globals_file, ConfigSchema.GLOBALS)
# The actual global_config gets overwritten as a part of processing, so to test
# this we read the original value out of the source files.
- orig_max_iters = globals_file_config.get('max_suggestions', opt_config.get('config', {}).get('max_suggestions', 100))
- assert launcher.optimizer.max_iterations \
- == orig_max_iters \
- == launcher.global_config['max_suggestions']
+ orig_max_iters = globals_file_config.get(
+ "max_suggestions", opt_config.get("config", {}).get("max_suggestions", 100)
+ )
+ assert (
+ launcher.optimizer.max_iterations
+ == orig_max_iters
+ == launcher.global_config["max_suggestions"]
+ )
# Check that the optimizer got initialized with random values instead of the defaults.
# Note: the environment doesn't get updated until suggest() is called to
@@ -193,12 +206,12 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None:
assert launcher.scheduler._max_trials == 200 # pylint: disable=protected-access
# Check that the value from the file is overridden by the CLI arg.
- assert config['random_seed'] == 42
+ assert config["random_seed"] == 42
# TODO: This isn't actually respected yet because the `--random-init` only
# applies to a temporary Optimizer used to populate the initial values via
# random sampling.
# assert launcher.optimizer.seed == 1234
-if __name__ == '__main__':
+if __name__ == "__main__":
pytest.main([__file__, "-n1"])
diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py
index d8caf7537e..1ae5af7e11 100644
--- a/mlos_bench/mlos_bench/tests/launcher_run_test.py
+++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py
@@ -2,17 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests to check the main CLI launcher.
-"""
+"""Unit tests to check the main CLI launcher."""
import os
import re
from typing import List
import pytest
-from mlos_bench.services.local.local_exec import LocalExecService
from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.services.local.local_exec import LocalExecService
from mlos_bench.util import path_join
# pylint: disable=redefined-outer-name
@@ -20,30 +18,33 @@ from mlos_bench.util import path_join
@pytest.fixture
def root_path() -> str:
- """
- Root path of mlos_bench project.
- """
+ """Root path of mlos_bench project."""
return path_join(os.path.dirname(__file__), "../../..", abs_path=True)
@pytest.fixture
def local_exec_service() -> LocalExecService:
- """
- Test fixture for LocalExecService.
- """
- return LocalExecService(parent=ConfigPersistenceService({
- "config_path": [
- "mlos_bench/config",
- "mlos_bench/examples",
- ]
- }))
+ """Test fixture for LocalExecService."""
+ return LocalExecService(
+ parent=ConfigPersistenceService(
+ {
+ "config_path": [
+ "mlos_bench/config",
+ "mlos_bench/examples",
+ ]
+ }
+ )
+ )
-def _launch_main_app(root_path: str, local_exec_service: LocalExecService,
- cli_config: str, re_expected: List[str]) -> None:
- """
- Run mlos_bench command-line application with given config
- and check the results in the log.
+def _launch_main_app(
+ root_path: str,
+ local_exec_service: LocalExecService,
+ cli_config: str,
+ re_expected: List[str],
+) -> None:
+ """Run mlos_bench command-line application with given config and check the results
+ in the log.
"""
with local_exec_service.temp_dir_context() as temp_dir:
@@ -52,10 +53,13 @@ def _launch_main_app(root_path: str, local_exec_service: LocalExecService,
# temp_dir = '/tmp'
log_path = path_join(temp_dir, "mock-test.log")
(return_code, _stdout, _stderr) = local_exec_service.local_exec(
- ["./mlos_bench/mlos_bench/run.py" +
- " --config_path ./mlos_bench/mlos_bench/tests/config/" +
- f" {cli_config} --log_file '{log_path}'"],
- cwd=root_path)
+ [
+ "./mlos_bench/mlos_bench/run.py"
+ + " --config_path ./mlos_bench/mlos_bench/tests/config/"
+ + f" {cli_config} --log_file '{log_path}'"
+ ],
+ cwd=root_path,
+ )
assert return_code == 0
try:
@@ -74,64 +78,63 @@ _RE_DATE = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}"
def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecService) -> None:
- """
- Run mlos_bench command-line application with mock benchmark config
- and default tunable values and check the results in the log.
+ """Run mlos_bench command-line application with mock benchmark config and default
+ tunable values and check the results in the log.
"""
_launch_main_app(
- root_path, local_exec_service,
- " --config cli/mock-bench.jsonc" +
- " --trial_config_repeat_count 5" +
- " --mock_env_seed -1", # Deterministic Mock Environment.
+ root_path,
+ local_exec_service,
+ " --config cli/mock-bench.jsonc"
+ + " --trial_config_repeat_count 5"
+ + " --mock_env_seed -1", # Deterministic Mock Environment.
[
- f"^{_RE_DATE} run\\.py:\\d+ " +
- r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$",
- ]
+ f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$",
+ ],
)
def test_launch_main_app_bench_values(
- root_path: str, local_exec_service: LocalExecService) -> None:
- """
- Run mlos_bench command-line application with mock benchmark config
- and user-specified tunable values and check the results in the log.
+ root_path: str,
+ local_exec_service: LocalExecService,
+) -> None:
+ """Run mlos_bench command-line application with mock benchmark config and user-
+ specified tunable values and check the results in the log.
"""
_launch_main_app(
- root_path, local_exec_service,
- " --config cli/mock-bench.jsonc" +
- " --tunable_values tunable-values/tunable-values-example.jsonc" +
- " --trial_config_repeat_count 5" +
- " --mock_env_seed -1", # Deterministic Mock Environment.
+ root_path,
+ local_exec_service,
+ " --config cli/mock-bench.jsonc"
+ + " --tunable_values tunable-values/tunable-values-example.jsonc"
+ + " --trial_config_repeat_count 5"
+ + " --mock_env_seed -1", # Deterministic Mock Environment.
[
- f"^{_RE_DATE} run\\.py:\\d+ " +
- r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$",
- ]
+ f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$",
+ ],
)
def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecService) -> None:
- """
- Run mlos_bench command-line application with mock optimization config
- and check the results in the log.
+ """Run mlos_bench command-line application with mock optimization config and check
+ the results in the log.
"""
_launch_main_app(
- root_path, local_exec_service,
- "--config cli/mock-opt.jsonc" +
- " --trial_config_repeat_count 3" +
- " --max_suggestions 3" +
- " --mock_env_seed 42", # Noisy Mock Environment.
+ root_path,
+ local_exec_service,
+ "--config cli/mock-opt.jsonc"
+ + " --trial_config_repeat_count 3"
+ + " --max_suggestions 3"
+ + " --mock_env_seed 42", # Noisy Mock Environment.
[
# Iteration 1: Expect first value to be the baseline
- f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " +
- r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$",
+ f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ "
+ + r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$",
# Iteration 2: The result may not always be deterministic
- f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " +
- r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$",
+ f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ "
+ + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$",
# Iteration 3: non-deterministic (depends on the optimizer)
- f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " +
- r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$",
+ f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ "
+ + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$",
# Final result: baseline is the optimum for the mock environment
- f"^{_RE_DATE} run\\.py:\\d+ " +
- r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$",
- ]
+ f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$",
+ ],
)
diff --git a/mlos_bench/mlos_bench/tests/optimizers/__init__.py b/mlos_bench/mlos_bench/tests/optimizers/__init__.py
index 509ecbd842..dbee44936d 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/__init__.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/__init__.py
@@ -4,5 +4,6 @@
#
"""
Tests for mlos_bench.optimizers.
+
Used to make mypy happy about multiple conftest.py modules.
"""
diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py
index 7149a79c93..6b660f7fea 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py
@@ -2,59 +2,52 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Test fixtures for mlos_bench optimizers.
-"""
+"""Test fixtures for mlos_bench optimizers."""
from typing import List
import pytest
-from mlos_bench.tunables.tunable_groups import TunableGroups
-from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
-
+from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.tests import SEED
+from mlos_bench.tunables.tunable_groups import TunableGroups
@pytest.fixture
def mock_configs() -> List[dict]:
- """
- Mock configurations of earlier experiments.
- """
+ """Mock configurations of earlier experiments."""
return [
{
- 'vmSize': 'Standard_B4ms',
- 'idle': 'halt',
- 'kernel_sched_migration_cost_ns': 50000,
- 'kernel_sched_latency_ns': 1000000,
+ "vmSize": "Standard_B4ms",
+ "idle": "halt",
+ "kernel_sched_migration_cost_ns": 50000,
+ "kernel_sched_latency_ns": 1000000,
},
{
- 'vmSize': 'Standard_B4ms',
- 'idle': 'halt',
- 'kernel_sched_migration_cost_ns': 40000,
- 'kernel_sched_latency_ns': 2000000,
+ "vmSize": "Standard_B4ms",
+ "idle": "halt",
+ "kernel_sched_migration_cost_ns": 40000,
+ "kernel_sched_latency_ns": 2000000,
},
{
- 'vmSize': 'Standard_B4ms',
- 'idle': 'mwait',
- 'kernel_sched_migration_cost_ns': -1, # Special value
- 'kernel_sched_latency_ns': 3000000,
+ "vmSize": "Standard_B4ms",
+ "idle": "mwait",
+ "kernel_sched_migration_cost_ns": -1, # Special value
+ "kernel_sched_latency_ns": 3000000,
},
{
- 'vmSize': 'Standard_B2s',
- 'idle': 'mwait',
- 'kernel_sched_migration_cost_ns': 200000,
- 'kernel_sched_latency_ns': 4000000,
- }
+ "vmSize": "Standard_B2s",
+ "idle": "mwait",
+ "kernel_sched_migration_cost_ns": 200000,
+ "kernel_sched_latency_ns": 4000000,
+ },
]
@pytest.fixture
def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer:
- """
- Test fixture for MockOptimizer that ignores the initial configuration.
- """
+ """Test fixture for MockOptimizer that ignores the initial configuration."""
return MockOptimizer(
tunables=tunable_groups,
service=None,
@@ -62,48 +55,34 @@ def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer:
"optimization_targets": {"score": "min"},
"max_suggestions": 5,
"start_with_defaults": False,
- "seed": SEED
+ "seed": SEED,
},
)
@pytest.fixture
def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer:
- """
- Test fixture for MockOptimizer.
- """
+ """Test fixture for MockOptimizer."""
return MockOptimizer(
tunables=tunable_groups,
service=None,
- config={
- "optimization_targets": {"score": "min"},
- "max_suggestions": 5,
- "seed": SEED
- },
+ config={"optimization_targets": {"score": "min"}, "max_suggestions": 5, "seed": SEED},
)
@pytest.fixture
def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer:
- """
- Test fixture for MockOptimizer.
- """
+ """Test fixture for MockOptimizer."""
return MockOptimizer(
tunables=tunable_groups,
service=None,
- config={
- "optimization_targets": {"score": "max"},
- "max_suggestions": 10,
- "seed": SEED
- },
+ config={"optimization_targets": {"score": "max"}, "max_suggestions": 10, "seed": SEED},
)
@pytest.fixture
def flaml_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
- """
- Test fixture for mlos_core FLAML optimizer.
- """
+ """Test fixture for mlos_core FLAML optimizer."""
return MlosCoreOptimizer(
tunables=tunable_groups,
service=None,
@@ -118,9 +97,7 @@ def flaml_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
@pytest.fixture
def flaml_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
- """
- Test fixture for mlos_core FLAML optimizer.
- """
+ """Test fixture for mlos_core FLAML optimizer."""
return MlosCoreOptimizer(
tunables=tunable_groups,
service=None,
@@ -143,9 +120,7 @@ SMAC_ITERATIONS = 10
@pytest.fixture
def smac_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
- """
- Test fixture for mlos_core SMAC optimizer.
- """
+ """Test fixture for mlos_core SMAC optimizer."""
return MlosCoreOptimizer(
tunables=tunable_groups,
service=None,
@@ -164,9 +139,7 @@ def smac_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
@pytest.fixture
def smac_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
- """
- Test fixture for mlos_core SMAC optimizer.
- """
+ """Test fixture for mlos_core SMAC optimizer."""
return MlosCoreOptimizer(
tunables=tunable_groups,
service=None,
diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py
index 9e43b3731e..769bf8859d 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py
@@ -2,15 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for grid search mlos_bench optimizer.
-"""
-
-from typing import Dict, List
+"""Unit tests for grid search mlos_bench optimizer."""
import itertools
import math
import random
+from typing import Dict, List
import pytest
@@ -19,14 +16,12 @@ from mlos_bench.optimizers.grid_search_optimizer import GridSearchOptimizer
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
-
# pylint: disable=redefined-outer-name
+
@pytest.fixture
def grid_search_tunables_config() -> dict:
- """
- Test fixture for grid search optimizer tunables config.
- """
+ """Test fixture for grid search optimizer tunables config."""
return {
"grid": {
"cost": 1,
@@ -53,46 +48,56 @@ def grid_search_tunables_config() -> dict:
@pytest.fixture
-def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]:
+def grid_search_tunables_grid(
+ grid_search_tunables: TunableGroups,
+) -> List[Dict[str, TunableValue]]:
"""
Test fixture for grid from tunable groups.
+
Used to check that the grids are the same (ignoring order).
"""
- tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None]
- tunable_names = tuple(tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None)
- return list(dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values))
+ tunables_params_values = [
+ tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None
+ ]
+ tunable_names = tuple(
+ tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None
+ )
+ return list(
+ dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)
+ )
@pytest.fixture
def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups:
- """
- Test fixture for grid search optimizer tunables.
- """
+ """Test fixture for grid search optimizer tunables."""
return TunableGroups(grid_search_tunables_config)
@pytest.fixture
-def grid_search_opt(grid_search_tunables: TunableGroups,
- grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer:
- """
- Test fixture for grid search optimizer.
- """
+def grid_search_opt(
+ grid_search_tunables: TunableGroups,
+ grid_search_tunables_grid: List[Dict[str, TunableValue]],
+) -> GridSearchOptimizer:
+ """Test fixture for grid search optimizer."""
assert len(grid_search_tunables) == 3
# Test the convergence logic by controlling the number of iterations to be not a
# multiple of the number of elements in the grid.
max_iterations = len(grid_search_tunables_grid) * 2 - 3
- return GridSearchOptimizer(tunables=grid_search_tunables, config={
- "max_suggestions": max_iterations,
- "optimization_targets": {"score": "max", "other_score": "min"},
- })
+ return GridSearchOptimizer(
+ tunables=grid_search_tunables,
+ config={
+ "max_suggestions": max_iterations,
+ "optimization_targets": {"score": "max", "other_score": "min"},
+ },
+ )
-def test_grid_search_grid(grid_search_opt: GridSearchOptimizer,
- grid_search_tunables: TunableGroups,
- grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None:
- """
- Make sure that grid search optimizer initializes and works correctly.
- """
+def test_grid_search_grid(
+ grid_search_opt: GridSearchOptimizer,
+ grid_search_tunables: TunableGroups,
+ grid_search_tunables_grid: List[Dict[str, TunableValue]],
+) -> None:
+ """Make sure that grid search optimizer initializes and works correctly."""
# Check the size.
expected_grid_size = math.prod(tunable.cardinality for tunable, _group in grid_search_tunables)
assert expected_grid_size > len(grid_search_tunables)
@@ -116,12 +121,12 @@ def test_grid_search_grid(grid_search_opt: GridSearchOptimizer,
# assert grid_search_opt.pending_configs == grid_search_tunables_grid
-def test_grid_search(grid_search_opt: GridSearchOptimizer,
- grid_search_tunables: TunableGroups,
- grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None:
- """
- Make sure that grid search optimizer initializes and works correctly.
- """
+def test_grid_search(
+ grid_search_opt: GridSearchOptimizer,
+ grid_search_tunables: TunableGroups,
+ grid_search_tunables_grid: List[Dict[str, TunableValue]],
+) -> None:
+ """Make sure that grid search optimizer initializes and works correctly."""
score: Dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0}
status = Status.SUCCEEDED
suggestion = grid_search_opt.suggest()
@@ -145,7 +150,9 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer,
grid_search_tunables_grid.remove(default_config)
assert default_config not in grid_search_opt.pending_configs
assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs)
- assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid)
+ assert all(
+ config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid
+ )
# The next suggestion should be a different element in the grid search.
suggestion = grid_search_opt.suggest()
@@ -159,7 +166,9 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer,
grid_search_tunables_grid.remove(suggestion.get_param_values())
assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs)
- assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid)
+ assert all(
+ config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid
+ )
# We consider not_converged as either having reached "max_suggestions" or an empty grid?
@@ -173,7 +182,8 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer,
assert not list(grid_search_opt.suggested_configs)
assert not grid_search_opt.not_converged()
- # But if we still have iterations left, we should be able to suggest again by refilling the grid.
+ # But if we still have iterations left, we should be able to suggest again by
+ # refilling the grid.
assert grid_search_opt.current_iteration < grid_search_opt.max_iterations
assert grid_search_opt.suggest()
assert list(grid_search_opt.pending_configs)
@@ -191,8 +201,7 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer,
def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None:
- """
- Make sure that grid search optimizer works correctly when suggest and register
+ """Make sure that grid search optimizer works correctly when suggest and register
are called out of order.
"""
# pylint: disable=too-many-locals
@@ -225,7 +234,7 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None:
assert best_suggestion_dict not in grid_search_opt.suggested_configs
best_suggestion_score: Dict[str, TunableValue] = {}
- for (opt_target, opt_dir) in grid_search_opt.targets.items():
+ for opt_target, opt_dir in grid_search_opt.targets.items():
val = score[opt_target]
assert isinstance(val, (int, float))
best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1
@@ -239,36 +248,53 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None:
# Check bulk register
suggested = [grid_search_opt.suggest() for _ in range(suggest_count)]
- assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested)
- assert all(suggestion.get_param_values() in grid_search_opt.suggested_configs for suggestion in suggested)
+ assert all(
+ suggestion.get_param_values() not in grid_search_opt.pending_configs
+ for suggestion in suggested
+ )
+ assert all(
+ suggestion.get_param_values() in grid_search_opt.suggested_configs
+ for suggestion in suggested
+ )
# Those new suggestions also shouldn't be in the set of previously suggested configs.
assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested)
- grid_search_opt.bulk_register([suggestion.get_param_values() for suggestion in suggested],
- [score] * len(suggested),
- [status] * len(suggested))
+ grid_search_opt.bulk_register(
+ [suggestion.get_param_values() for suggestion in suggested],
+ [score] * len(suggested),
+ [status] * len(suggested),
+ )
- assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested)
- assert all(suggestion.get_param_values() not in grid_search_opt.suggested_configs for suggestion in suggested)
+ assert all(
+ suggestion.get_param_values() not in grid_search_opt.pending_configs
+ for suggestion in suggested
+ )
+ assert all(
+ suggestion.get_param_values() not in grid_search_opt.suggested_configs
+ for suggestion in suggested
+ )
best_score, best_config = grid_search_opt.get_best_observation()
assert best_score == best_suggestion_score
assert best_config == best_suggestion
-def test_grid_search_register(grid_search_opt: GridSearchOptimizer,
- grid_search_tunables: TunableGroups) -> None:
- """
- Make sure that the `.register()` method adjusts the score signs correctly.
- """
+def test_grid_search_register(
+ grid_search_opt: GridSearchOptimizer,
+ grid_search_tunables: TunableGroups,
+) -> None:
+ """Make sure that the `.register()` method adjusts the score signs correctly."""
assert grid_search_opt.register(
- grid_search_tunables, Status.SUCCEEDED, {
+ grid_search_tunables,
+ Status.SUCCEEDED,
+ {
"score": 1.0,
"other_score": 2.0,
- }) == {
- "score": -1.0, # max
- "other_score": 2.0, # min
+ },
+ ) == {
+ "score": -1.0, # max
+ "other_score": 2.0, # min
}
assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == {
diff --git a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py
index d356466e58..4494cba3ef 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py
@@ -2,26 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mock mlos_bench optimizer.
-"""
+"""Unit tests for mock mlos_bench optimizer."""
import pytest
from mlos_bench.environments.status import Status
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
-
from mlos_bench.tests import SEED
+from mlos_bench.tunables.tunable_groups import TunableGroups
# pylint: disable=redefined-outer-name
@pytest.fixture
def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
- """
- Test fixture for mlos_core SMAC optimizer.
- """
+ """Test fixture for mlos_core SMAC optimizer."""
return MlosCoreOptimizer(
tunables=tunable_groups,
service=None,
@@ -35,21 +30,18 @@ def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
"optimizer_type": "SMAC",
"seed": SEED,
# "start_with_defaults": False,
- })
+ },
+ )
@pytest.fixture
def mock_scores() -> list:
- """
- A list of fake benchmark scores to test the optimizers.
- """
+ """A list of fake benchmark scores to test the optimizers."""
return [88.88, 66.66, 99.99]
def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None:
- """
- Make sure that llamatune+smac optimizer initializes and works correctly.
- """
+ """Make sure that llamatune+smac optimizer initializes and works correctly."""
for score in mock_scores:
assert llamatune_opt.not_converged()
tunables = llamatune_opt.suggest()
@@ -62,6 +54,6 @@ def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list
assert best_score["score"] == pytest.approx(66.66, 0.01)
-if __name__ == '__main__':
+if __name__ == "__main__":
# For attaching debugger debugging:
pytest.main(["-vv", "-n1", "-k", "test_llamatune_optimizer", __file__])
diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py
index f36e3c149c..043e457375 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for internal methods of the `MlosCoreOptimizer`.
-"""
+"""Unit tests for internal methods of the `MlosCoreOptimizer`."""
from typing import List
@@ -12,72 +10,67 @@ import pandas
import pytest
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
-from mlos_bench.tunables.tunable_groups import TunableGroups
-
from mlos_bench.tests import SEED
+from mlos_bench.tunables.tunable_groups import TunableGroups
# pylint: disable=redefined-outer-name, protected-access
@pytest.fixture
def mlos_core_optimizer(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
- """
- An instance of a mlos_core optimizer (FLAML-based).
- """
+ """An instance of a mlos_core optimizer (FLAML-based)."""
test_opt_config = {
- 'optimizer_type': 'FLAML',
- 'max_suggestions': 10,
- 'seed': SEED,
+ "optimizer_type": "FLAML",
+ "max_suggestions": 10,
+ "seed": SEED,
}
return MlosCoreOptimizer(tunable_groups, test_opt_config)
def test_df(mlos_core_optimizer: MlosCoreOptimizer, mock_configs: List[dict]) -> None:
- """
- Test `MlosCoreOptimizer._to_df()` method on tunables that have special values.
- """
+ """Test `MlosCoreOptimizer._to_df()` method on tunables that have special values."""
df_config = mlos_core_optimizer._to_df(mock_configs)
assert isinstance(df_config, pandas.DataFrame)
assert df_config.shape == (4, 6)
assert set(df_config.columns) == {
- 'kernel_sched_latency_ns',
- 'kernel_sched_migration_cost_ns',
- 'kernel_sched_migration_cost_ns!type',
- 'kernel_sched_migration_cost_ns!special',
- 'idle',
- 'vmSize',
+ "kernel_sched_latency_ns",
+ "kernel_sched_migration_cost_ns",
+ "kernel_sched_migration_cost_ns!type",
+ "kernel_sched_migration_cost_ns!special",
+ "idle",
+ "vmSize",
}
- assert df_config.to_dict(orient='records') == [
+ assert df_config.to_dict(orient="records") == [
{
- 'idle': 'halt',
- 'kernel_sched_latency_ns': 1000000,
- 'kernel_sched_migration_cost_ns': 50000,
- 'kernel_sched_migration_cost_ns!special': None,
- 'kernel_sched_migration_cost_ns!type': 'range',
- 'vmSize': 'Standard_B4ms',
+ "idle": "halt",
+ "kernel_sched_latency_ns": 1000000,
+ "kernel_sched_migration_cost_ns": 50000,
+ "kernel_sched_migration_cost_ns!special": None,
+ "kernel_sched_migration_cost_ns!type": "range",
+ "vmSize": "Standard_B4ms",
},
{
- 'idle': 'halt',
- 'kernel_sched_latency_ns': 2000000,
- 'kernel_sched_migration_cost_ns': 40000,
- 'kernel_sched_migration_cost_ns!special': None,
- 'kernel_sched_migration_cost_ns!type': 'range',
- 'vmSize': 'Standard_B4ms',
+ "idle": "halt",
+ "kernel_sched_latency_ns": 2000000,
+ "kernel_sched_migration_cost_ns": 40000,
+ "kernel_sched_migration_cost_ns!special": None,
+ "kernel_sched_migration_cost_ns!type": "range",
+ "vmSize": "Standard_B4ms",
},
{
- 'idle': 'mwait',
- 'kernel_sched_latency_ns': 3000000,
- 'kernel_sched_migration_cost_ns': None, # The value is special!
- 'kernel_sched_migration_cost_ns!special': -1,
- 'kernel_sched_migration_cost_ns!type': 'special',
- 'vmSize': 'Standard_B4ms',
+ "idle": "mwait",
+ "kernel_sched_latency_ns": 3000000,
+ "kernel_sched_migration_cost_ns": None, # The value is special!
+ "kernel_sched_migration_cost_ns!special": -1,
+ "kernel_sched_migration_cost_ns!type": "special",
+ "vmSize": "Standard_B4ms",
},
{
- 'idle': 'mwait',
- 'kernel_sched_latency_ns': 4000000,
- 'kernel_sched_migration_cost_ns': 200000,
- 'kernel_sched_migration_cost_ns!special': None,
- 'kernel_sched_migration_cost_ns!type': 'range',
- 'vmSize': 'Standard_B2s',
+ "idle": "mwait",
+ "kernel_sched_latency_ns": 4000000,
+ "kernel_sched_migration_cost_ns": 200000,
+ "kernel_sched_migration_cost_ns!special": None,
+ "kernel_sched_migration_cost_ns!type": "range",
+ "vmSize": "Standard_B2s",
},
]
diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py
index b10571095b..23aa56e48c 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py
@@ -2,36 +2,30 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mock mlos_bench optimizer.
-"""
+"""Unit tests for mock mlos_bench optimizer."""
import os
-import sys
import shutil
+import sys
import pytest
-from mlos_bench.util import path_join
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
-from mlos_bench.tunables.tunable_groups import TunableGroups
-
from mlos_bench.tests import SEED
-
+from mlos_bench.tunables.tunable_groups import TunableGroups
+from mlos_bench.util import path_join
from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer
-_OUTPUT_DIR_PATH_BASE = r'c:/temp' if sys.platform == 'win32' else '/tmp/'
-_OUTPUT_DIR = '_test_output_dir' # Will be deleted after the test.
+_OUTPUT_DIR_PATH_BASE = r"c:/temp" if sys.platform == "win32" else "/tmp/"
+_OUTPUT_DIR = "_test_output_dir" # Will be deleted after the test.
def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) -> None:
- """
- Test invalid max_trials initialization of mlos_core SMAC optimizer.
- """
+ """Test invalid max_trials initialization of mlos_core SMAC optimizer."""
test_opt_config = {
- 'optimizer_type': 'SMAC',
- 'max_trials': 10,
- 'max_suggestions': 11,
- 'seed': SEED,
+ "optimizer_type": "SMAC",
+ "max_trials": 10,
+ "max_suggestions": 11,
+ "seed": SEED,
}
with pytest.raises(AssertionError):
opt = MlosCoreOptimizer(tunable_groups, test_opt_config)
@@ -39,29 +33,27 @@ def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups)
def test_init_mlos_core_smac_opt_max_trials(tunable_groups: TunableGroups) -> None:
- """
- Test max_trials initialization of mlos_core SMAC optimizer.
- """
+ """Test max_trials initialization of mlos_core SMAC optimizer."""
test_opt_config = {
- 'optimizer_type': 'SMAC',
- 'max_suggestions': 123,
- 'seed': SEED,
+ "optimizer_type": "SMAC",
+ "max_suggestions": 123,
+ "seed": SEED,
}
opt = MlosCoreOptimizer(tunable_groups, test_opt_config)
# pylint: disable=protected-access
assert isinstance(opt._opt, SmacOptimizer)
- assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config['max_suggestions']
+ assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config["max_suggestions"]
def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGroups) -> None:
- """
- Test absolute path output directory initialization of mlos_core SMAC optimizer.
+ """Test absolute path output directory initialization of mlos_core SMAC
+ optimizer.
"""
output_dir = path_join(_OUTPUT_DIR_PATH_BASE, _OUTPUT_DIR)
test_opt_config = {
- 'optimizer_type': 'SMAC',
- 'output_directory': output_dir,
- 'seed': SEED,
+ "optimizer_type": "SMAC",
+ "output_directory": output_dir,
+ "seed": SEED,
}
opt = MlosCoreOptimizer(tunable_groups, test_opt_config)
assert isinstance(opt, MlosCoreOptimizer)
@@ -69,76 +61,88 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr
assert isinstance(opt._opt, SmacOptimizer)
# Final portions of the path are generated by SMAC when run_name is not specified.
assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith(
- str(test_opt_config['output_directory']))
+ str(test_opt_config["output_directory"])
+ )
shutil.rmtree(output_dir)
def test_init_mlos_core_smac_relative_output_directory(tunable_groups: TunableGroups) -> None:
- """
- Test relative path output directory initialization of mlos_core SMAC optimizer.
+ """Test relative path output directory initialization of mlos_core SMAC
+ optimizer.
"""
test_opt_config = {
- 'optimizer_type': 'SMAC',
- 'output_directory': _OUTPUT_DIR,
- 'seed': SEED,
+ "optimizer_type": "SMAC",
+ "output_directory": _OUTPUT_DIR,
+ "seed": SEED,
}
opt = MlosCoreOptimizer(tunable_groups, test_opt_config)
assert isinstance(opt, MlosCoreOptimizer)
# pylint: disable=protected-access
assert isinstance(opt._opt, SmacOptimizer)
assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith(
- path_join(os.getcwd(), str(test_opt_config['output_directory'])))
+ path_join(os.getcwd(), str(test_opt_config["output_directory"]))
+ )
shutil.rmtree(_OUTPUT_DIR)
-def test_init_mlos_core_smac_relative_output_directory_with_run_name(tunable_groups: TunableGroups) -> None:
- """
- Test relative path output directory initialization of mlos_core SMAC optimizer.
+def test_init_mlos_core_smac_relative_output_directory_with_run_name(
+ tunable_groups: TunableGroups,
+) -> None:
+ """Test relative path output directory initialization of mlos_core SMAC
+ optimizer.
"""
test_opt_config = {
- 'optimizer_type': 'SMAC',
- 'output_directory': _OUTPUT_DIR,
- 'run_name': 'test_run',
- 'seed': SEED,
+ "optimizer_type": "SMAC",
+ "output_directory": _OUTPUT_DIR,
+ "run_name": "test_run",
+ "seed": SEED,
}
opt = MlosCoreOptimizer(tunable_groups, test_opt_config)
assert isinstance(opt, MlosCoreOptimizer)
# pylint: disable=protected-access
assert isinstance(opt._opt, SmacOptimizer)
assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith(
- path_join(os.getcwd(), str(test_opt_config['output_directory']), str(test_opt_config['run_name'])))
+ path_join(
+ os.getcwd(), str(test_opt_config["output_directory"]), str(test_opt_config["run_name"])
+ )
+ )
shutil.rmtree(_OUTPUT_DIR)
-def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(tunable_groups: TunableGroups) -> None:
- """
- Test relative path output directory initialization of mlos_core SMAC optimizer.
+def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(
+ tunable_groups: TunableGroups,
+) -> None:
+ """Test relative path output directory initialization of mlos_core SMAC
+ optimizer.
"""
test_opt_config = {
- 'optimizer_type': 'SMAC',
- 'output_directory': _OUTPUT_DIR,
- 'seed': SEED,
+ "optimizer_type": "SMAC",
+ "output_directory": _OUTPUT_DIR,
+ "seed": SEED,
}
global_config = {
- 'experiment_id': 'experiment_id',
+ "experiment_id": "experiment_id",
}
opt = MlosCoreOptimizer(tunable_groups, test_opt_config, global_config)
assert isinstance(opt, MlosCoreOptimizer)
# pylint: disable=protected-access
assert isinstance(opt._opt, SmacOptimizer)
assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith(
- path_join(os.getcwd(), str(test_opt_config['output_directory']), global_config['experiment_id']))
+ path_join(
+ os.getcwd(),
+ str(test_opt_config["output_directory"]),
+ global_config["experiment_id"],
+ )
+ )
shutil.rmtree(_OUTPUT_DIR)
def test_init_mlos_core_smac_temp_output_directory(tunable_groups: TunableGroups) -> None:
- """
- Test random output directory initialization of mlos_core SMAC optimizer.
- """
+ """Test random output directory initialization of mlos_core SMAC optimizer."""
test_opt_config = {
- 'optimizer_type': 'SMAC',
- 'output_directory': None,
- 'seed': SEED,
+ "optimizer_type": "SMAC",
+ "output_directory": None,
+ "seed": SEED,
}
opt = MlosCoreOptimizer(tunable_groups, test_opt_config)
assert isinstance(opt, MlosCoreOptimizer)
diff --git a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py
index a94a315939..05305de50b 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mock mlos_bench optimizer.
-"""
+"""Unit tests for mock mlos_bench optimizer."""
import pytest
@@ -16,51 +14,57 @@ from mlos_bench.optimizers.mock_optimizer import MockOptimizer
@pytest.fixture
def mock_configurations_no_defaults() -> list:
- """
- A list of 2-tuples of (tunable_values, score) to test the optimizers.
- """
+ """A list of 2-tuples of (tunable_values, score) to test the optimizers."""
return [
- ({
- "vmSize": "Standard_B4ms",
- "idle": "halt",
- "kernel_sched_migration_cost_ns": 13112,
- "kernel_sched_latency_ns": 796233790,
- }, 88.88),
- ({
- "vmSize": "Standard_B2ms",
- "idle": "halt",
- "kernel_sched_migration_cost_ns": 117026,
- "kernel_sched_latency_ns": 149827706,
- }, 66.66),
- ({
- "vmSize": "Standard_B4ms",
- "idle": "halt",
- "kernel_sched_migration_cost_ns": 354785,
- "kernel_sched_latency_ns": 795285932,
- }, 99.99),
+ (
+ {
+ "vmSize": "Standard_B4ms",
+ "idle": "halt",
+ "kernel_sched_migration_cost_ns": 13112,
+ "kernel_sched_latency_ns": 796233790,
+ },
+ 88.88,
+ ),
+ (
+ {
+ "vmSize": "Standard_B2ms",
+ "idle": "halt",
+ "kernel_sched_migration_cost_ns": 117026,
+ "kernel_sched_latency_ns": 149827706,
+ },
+ 66.66,
+ ),
+ (
+ {
+ "vmSize": "Standard_B4ms",
+ "idle": "halt",
+ "kernel_sched_migration_cost_ns": 354785,
+ "kernel_sched_latency_ns": 795285932,
+ },
+ 99.99,
+ ),
]
@pytest.fixture
def mock_configurations(mock_configurations_no_defaults: list) -> list:
- """
- A list of 2-tuples of (tunable_values, score) to test the optimizers.
- """
+ """A list of 2-tuples of (tunable_values, score) to test the optimizers."""
return [
- ({
- "vmSize": "Standard_B4ms",
- "idle": "halt",
- "kernel_sched_migration_cost_ns": -1,
- "kernel_sched_latency_ns": 2000000,
- }, 88.88),
+ (
+ {
+ "vmSize": "Standard_B4ms",
+ "idle": "halt",
+ "kernel_sched_migration_cost_ns": -1,
+ "kernel_sched_latency_ns": 2000000,
+ },
+ 88.88,
+ ),
] + mock_configurations_no_defaults
def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float:
- """
- Run several iterations of the optimizer and return the best score.
- """
- for (tunable_values, score) in mock_configurations:
+ """Run several iterations of the optimizer and return the best score."""
+ for tunable_values, score in mock_configurations:
assert mock_opt.not_converged()
tunables = mock_opt.suggest()
assert tunables.get_param_values() == tunable_values
@@ -73,34 +77,28 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float:
def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> None:
- """
- Make sure that mock optimizer produces consistent suggestions.
- """
+ """Make sure that mock optimizer produces consistent suggestions."""
score = _optimize(mock_opt, mock_configurations)
assert score == pytest.approx(66.66, 0.01)
-def test_mock_optimizer_no_defaults(mock_opt_no_defaults: MockOptimizer,
- mock_configurations_no_defaults: list) -> None:
- """
- Make sure that mock optimizer produces consistent suggestions.
- """
+def test_mock_optimizer_no_defaults(
+ mock_opt_no_defaults: MockOptimizer,
+ mock_configurations_no_defaults: list,
+) -> None:
+ """Make sure that mock optimizer produces consistent suggestions."""
score = _optimize(mock_opt_no_defaults, mock_configurations_no_defaults)
assert score == pytest.approx(66.66, 0.01)
def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list) -> None:
- """
- Check the maximization mode of the mock optimizer.
- """
+ """Check the maximization mode of the mock optimizer."""
score = _optimize(mock_opt_max, mock_configurations)
assert score == pytest.approx(99.99, 0.01)
def test_mock_optimizer_register_fail(mock_opt: MockOptimizer) -> None:
- """
- Check the input acceptance conditions for Optimizer.register().
- """
+ """Check the input acceptance conditions for Optimizer.register()."""
tunables = mock_opt.suggest()
mock_opt.register(tunables, Status.SUCCEEDED, {"score": 10})
mock_opt.register(tunables, Status.FAILED)
diff --git a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py
index f2805e9322..cbbd2a627d 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mock mlos_bench optimizer.
-"""
+"""Unit tests for mock mlos_bench optimizer."""
from typing import Dict, List, Optional
@@ -12,8 +10,8 @@ import pytest
from mlos_bench.environments.status import Status
from mlos_bench.optimizers.base_optimizer import Optimizer
-from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
+from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.tunables.tunable import TunableValue
# pylint: disable=redefined-outer-name
@@ -23,19 +21,15 @@ from mlos_bench.tunables.tunable import TunableValue
def mock_configs_str(mock_configs: List[dict]) -> List[dict]:
"""
Same as `mock_config` above, but with all values converted to strings.
+
(This can happen when we retrieve the data from storage).
"""
- return [
- {key: str(val) for (key, val) in config.items()}
- for config in mock_configs
- ]
+ return [{key: str(val) for (key, val) in config.items()} for config in mock_configs]
@pytest.fixture
def mock_scores() -> List[Optional[Dict[str, TunableValue]]]:
- """
- Mock benchmark results from earlier experiments.
- """
+ """Mock benchmark results from earlier experiments."""
return [
None,
{"score": 88.88},
@@ -46,19 +40,17 @@ def mock_scores() -> List[Optional[Dict[str, TunableValue]]]:
@pytest.fixture
def mock_status() -> List[Status]:
- """
- Mock status values for earlier experiments.
- """
+ """Mock status values for earlier experiments."""
return [Status.FAILED, Status.SUCCEEDED, Status.SUCCEEDED, Status.SUCCEEDED]
-def _test_opt_update_min(opt: Optimizer,
- configs: List[dict],
- scores: List[Optional[Dict[str, TunableValue]]],
- status: Optional[List[Status]] = None) -> None:
- """
- Test the bulk update of the optimizer on the minimization problem.
- """
+def _test_opt_update_min(
+ opt: Optimizer,
+ configs: List[dict],
+ scores: List[Optional[Dict[str, TunableValue]]],
+ status: Optional[List[Status]] = None,
+) -> None:
+ """Test the bulk update of the optimizer on the minimization problem."""
opt.bulk_register(configs, scores, status)
(score, tunables) = opt.get_best_observation()
assert score is not None
@@ -68,17 +60,17 @@ def _test_opt_update_min(opt: Optimizer,
"vmSize": "Standard_B4ms",
"idle": "mwait",
"kernel_sched_migration_cost_ns": -1,
- 'kernel_sched_latency_ns': 3000000,
+ "kernel_sched_latency_ns": 3000000,
}
-def _test_opt_update_max(opt: Optimizer,
- configs: List[dict],
- scores: List[Optional[Dict[str, TunableValue]]],
- status: Optional[List[Status]] = None) -> None:
- """
- Test the bulk update of the optimizer on the maximization problem.
- """
+def _test_opt_update_max(
+ opt: Optimizer,
+ configs: List[dict],
+ scores: List[Optional[Dict[str, TunableValue]]],
+ status: Optional[List[Status]] = None,
+) -> None:
+ """Test the bulk update of the optimizer on the maximization problem."""
opt.bulk_register(configs, scores, status)
(score, tunables) = opt.get_best_observation()
assert score is not None
@@ -88,82 +80,82 @@ def _test_opt_update_max(opt: Optimizer,
"vmSize": "Standard_B2s",
"idle": "mwait",
"kernel_sched_migration_cost_ns": 200000,
- 'kernel_sched_latency_ns': 4000000,
+ "kernel_sched_latency_ns": 4000000,
}
-def test_update_mock_min(mock_opt: MockOptimizer,
- mock_configs: List[dict],
- mock_scores: List[Optional[Dict[str, TunableValue]]],
- mock_status: List[Status]) -> None:
- """
- Test the bulk update of the mock optimizer on the minimization problem.
- """
+def test_update_mock_min(
+ mock_opt: MockOptimizer,
+ mock_configs: List[dict],
+ mock_scores: List[Optional[Dict[str, TunableValue]]],
+ mock_status: List[Status],
+) -> None:
+ """Test the bulk update of the mock optimizer on the minimization problem."""
_test_opt_update_min(mock_opt, mock_configs, mock_scores, mock_status)
# make sure the first suggestion after bulk load is *NOT* the default config:
assert mock_opt.suggest().get_param_values() == {
"vmSize": "Standard_B4ms",
"idle": "halt",
"kernel_sched_migration_cost_ns": 13112,
- 'kernel_sched_latency_ns': 796233790,
+ "kernel_sched_latency_ns": 796233790,
}
-def test_update_mock_min_str(mock_opt: MockOptimizer,
- mock_configs_str: List[dict],
- mock_scores: List[Optional[Dict[str, TunableValue]]],
- mock_status: List[Status]) -> None:
- """
- Test the bulk update of the mock optimizer with all-strings data.
- """
+def test_update_mock_min_str(
+ mock_opt: MockOptimizer,
+ mock_configs_str: List[dict],
+ mock_scores: List[Optional[Dict[str, TunableValue]]],
+ mock_status: List[Status],
+) -> None:
+ """Test the bulk update of the mock optimizer with all-strings data."""
_test_opt_update_min(mock_opt, mock_configs_str, mock_scores, mock_status)
-def test_update_mock_max(mock_opt_max: MockOptimizer,
- mock_configs: List[dict],
- mock_scores: List[Optional[Dict[str, TunableValue]]],
- mock_status: List[Status]) -> None:
- """
- Test the bulk update of the mock optimizer on the maximization problem.
- """
+def test_update_mock_max(
+ mock_opt_max: MockOptimizer,
+ mock_configs: List[dict],
+ mock_scores: List[Optional[Dict[str, TunableValue]]],
+ mock_status: List[Status],
+) -> None:
+ """Test the bulk update of the mock optimizer on the maximization problem."""
_test_opt_update_max(mock_opt_max, mock_configs, mock_scores, mock_status)
-def test_update_flaml(flaml_opt: MlosCoreOptimizer,
- mock_configs: List[dict],
- mock_scores: List[Optional[Dict[str, TunableValue]]],
- mock_status: List[Status]) -> None:
- """
- Test the bulk update of the FLAML optimizer.
- """
+def test_update_flaml(
+ flaml_opt: MlosCoreOptimizer,
+ mock_configs: List[dict],
+ mock_scores: List[Optional[Dict[str, TunableValue]]],
+ mock_status: List[Status],
+) -> None:
+ """Test the bulk update of the FLAML optimizer."""
_test_opt_update_min(flaml_opt, mock_configs, mock_scores, mock_status)
-def test_update_flaml_max(flaml_opt_max: MlosCoreOptimizer,
- mock_configs: List[dict],
- mock_scores: List[Optional[Dict[str, TunableValue]]],
- mock_status: List[Status]) -> None:
- """
- Test the bulk update of the FLAML optimizer.
- """
+def test_update_flaml_max(
+ flaml_opt_max: MlosCoreOptimizer,
+ mock_configs: List[dict],
+ mock_scores: List[Optional[Dict[str, TunableValue]]],
+ mock_status: List[Status],
+) -> None:
+ """Test the bulk update of the FLAML optimizer."""
_test_opt_update_max(flaml_opt_max, mock_configs, mock_scores, mock_status)
-def test_update_smac(smac_opt: MlosCoreOptimizer,
- mock_configs: List[dict],
- mock_scores: List[Optional[Dict[str, TunableValue]]],
- mock_status: List[Status]) -> None:
- """
- Test the bulk update of the SMAC optimizer.
- """
+def test_update_smac(
+ smac_opt: MlosCoreOptimizer,
+ mock_configs: List[dict],
+ mock_scores: List[Optional[Dict[str, TunableValue]]],
+ mock_status: List[Status],
+) -> None:
+ """Test the bulk update of the SMAC optimizer."""
_test_opt_update_min(smac_opt, mock_configs, mock_scores, mock_status)
-def test_update_smac_max(smac_opt_max: MlosCoreOptimizer,
- mock_configs: List[dict],
- mock_scores: List[Optional[Dict[str, TunableValue]]],
- mock_status: List[Status]) -> None:
- """
- Test the bulk update of the SMAC optimizer.
- """
+def test_update_smac_max(
+ smac_opt_max: MlosCoreOptimizer,
+ mock_configs: List[dict],
+ mock_scores: List[Optional[Dict[str, TunableValue]]],
+ mock_status: List[Status],
+) -> None:
+ """Test the bulk update of the SMAC optimizer."""
_test_opt_update_max(smac_opt_max, mock_configs, mock_scores, mock_status)
diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py
index 183db1dc62..db46189e44 100644
--- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py
+++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py
@@ -2,27 +2,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Toy optimization loop to test the optimizers on mock benchmark environment.
-"""
-
-from typing import Tuple
+"""Toy optimization loop to test the optimizers on mock benchmark environment."""
import logging
+from typing import Tuple
import pytest
-from mlos_core.util import config_to_dataframe
-from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer
-from mlos_bench.optimizers.convert_configspace import tunable_values_to_configuration
-
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.mock_env import MockEnv
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.base_optimizer import Optimizer
-from mlos_bench.optimizers.mock_optimizer import MockOptimizer
+from mlos_bench.optimizers.convert_configspace import tunable_values_to_configuration
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
-
+from mlos_bench.optimizers.mock_optimizer import MockOptimizer
+from mlos_bench.tunables.tunable_groups import TunableGroups
+from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer
+from mlos_core.util import config_to_dataframe
# For debugging purposes output some warnings which are captured with failed tests.
DEBUG = True
@@ -32,9 +27,7 @@ if DEBUG:
def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]:
- """
- Toy optimization loop.
- """
+ """Toy optimization loop."""
assert opt.not_converged()
while opt.not_converged():
@@ -59,7 +52,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]:
(status, _ts, output) = env_context.run()
assert status.is_succeeded()
assert output is not None
- score = output['score']
+ score = output["score"]
assert isinstance(score, float)
assert 60 <= score <= 120
logger("score: %s", str(score))
@@ -72,11 +65,8 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]:
return (best_score["score"], best_tunables)
-def test_mock_optimization_loop(mock_env_no_noise: MockEnv,
- mock_opt: MockOptimizer) -> None:
- """
- Toy optimization loop with mock environment and optimizer.
- """
+def test_mock_optimization_loop(mock_env_no_noise: MockEnv, mock_opt: MockOptimizer) -> None:
+ """Toy optimization loop with mock environment and optimizer."""
(score, tunables) = _optimize(mock_env_no_noise, mock_opt)
assert score == pytest.approx(64.9, 0.01)
assert tunables.get_param_values() == {
@@ -87,11 +77,11 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv,
}
-def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv,
- mock_opt_no_defaults: MockOptimizer) -> None:
- """
- Toy optimization loop with mock environment and optimizer.
- """
+def test_mock_optimization_loop_no_defaults(
+ mock_env_no_noise: MockEnv,
+ mock_opt_no_defaults: MockOptimizer,
+) -> None:
+ """Toy optimization loop with mock environment and optimizer."""
(score, tunables) = _optimize(mock_env_no_noise, mock_opt_no_defaults)
assert score == pytest.approx(60.97, 0.01)
assert tunables.get_param_values() == {
@@ -102,11 +92,8 @@ def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv,
}
-def test_flaml_optimization_loop(mock_env_no_noise: MockEnv,
- flaml_opt: MlosCoreOptimizer) -> None:
- """
- Toy optimization loop with mock environment and FLAML optimizer.
- """
+def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, flaml_opt: MlosCoreOptimizer) -> None:
+ """Toy optimization loop with mock environment and FLAML optimizer."""
(score, tunables) = _optimize(mock_env_no_noise, flaml_opt)
assert score == pytest.approx(60.15, 0.01)
assert tunables.get_param_values() == {
@@ -118,11 +105,8 @@ def test_flaml_optimization_loop(mock_env_no_noise: MockEnv,
# @pytest.mark.skip(reason="SMAC is not deterministic")
-def test_smac_optimization_loop(mock_env_no_noise: MockEnv,
- smac_opt: MlosCoreOptimizer) -> None:
- """
- Toy optimization loop with mock environment and SMAC optimizer.
- """
+def test_smac_optimization_loop(mock_env_no_noise: MockEnv, smac_opt: MlosCoreOptimizer) -> None:
+ """Toy optimization loop with mock environment and SMAC optimizer."""
(score, tunables) = _optimize(mock_env_no_noise, smac_opt)
expected_score = 70.33
expected_tunable_values = {
diff --git a/mlos_bench/mlos_bench/tests/services/__init__.py b/mlos_bench/mlos_bench/tests/services/__init__.py
index 1971c01799..a0b56eeb03 100644
--- a/mlos_bench/mlos_bench/tests/services/__init__.py
+++ b/mlos_bench/mlos_bench/tests/services/__init__.py
@@ -4,6 +4,7 @@
#
"""
Tests for mlos_bench.services.
+
Used to make mypy happy about multiple conftest.py modules.
"""
@@ -11,8 +12,8 @@ from .local import MockLocalExecService
from .remote import MockFileShareService, MockRemoteExecService, MockVMService
__all__ = [
- 'MockLocalExecService',
- 'MockFileShareService',
- 'MockRemoteExecService',
- 'MockVMService',
+ "MockLocalExecService",
+ "MockFileShareService",
+ "MockRemoteExecService",
+ "MockVMService",
]
diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py
index 55dc15a8d4..0be2ac7749 100644
--- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py
+++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py
@@ -2,19 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for configuration persistence service.
-"""
+"""Unit tests for configuration persistence service."""
import os
import sys
+
import pytest
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.util import path_join
-
if sys.version_info < (3, 9):
from importlib_resources import files
else:
@@ -26,24 +24,24 @@ else:
@pytest.fixture
def config_persistence_service() -> ConfigPersistenceService:
- """
- Test fixture for ConfigPersistenceService.
- """
- return ConfigPersistenceService({
- "config_path": [
- "./non-existent-dir/test/foo/bar", # Non-existent config path
- ".", # cwd
- str(files("mlos_bench.tests.config").joinpath("")), # Test configs (relative to mlos_bench/tests)
- # Shouldn't be necessary since we automatically add this.
- # str(files("mlos_bench.config").joinpath("")), # Stock configs
- ]
- })
+ """Test fixture for ConfigPersistenceService."""
+ return ConfigPersistenceService(
+ {
+ "config_path": [
+ "./non-existent-dir/test/foo/bar", # Non-existent config path
+ ".", # cwd
+ str(
+ files("mlos_bench.tests.config").joinpath("")
+ ), # Test configs (relative to mlos_bench/tests)
+ # Shouldn't be necessary since we automatically add this.
+ # str(files("mlos_bench.config").joinpath("")), # Stock configs
+ ]
+ }
+ )
def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None:
- """
- Check that CWD is in the search path in the correct place.
- """
+ """Check that CWD is in the search path in the correct place."""
# pylint: disable=protected-access
assert config_persistence_service._config_path is not None
cwd = path_join(os.getcwd(), abs_path=True)
@@ -53,9 +51,7 @@ def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersisten
def test_cwd_in_default_search_path() -> None:
- """
- Checks that the CWD is prepended to the search path if not explicitly present.
- """
+ """Checks that the CWD is prepended to the search path if not explicitly present."""
# pylint: disable=protected-access
config_persistence_service = ConfigPersistenceService()
assert config_persistence_service._config_path is not None
@@ -66,9 +62,7 @@ def test_cwd_in_default_search_path() -> None:
def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService) -> None:
- """
- Check if we can actually find a file somewhere in `config_path`.
- """
+ """Check if we can actually find a file somewhere in `config_path`."""
# pylint: disable=protected-access
assert config_persistence_service._config_path is not None
assert ConfigPersistenceService.BUILTIN_CONFIG_PATH in config_persistence_service._config_path
@@ -78,14 +72,12 @@ def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService
assert os.path.exists(path)
assert os.path.samefile(
ConfigPersistenceService.BUILTIN_CONFIG_PATH,
- os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path])
+ os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]),
)
def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> None:
- """
- Check if we can actually find a file somewhere in `config_path`.
- """
+ """Check if we can actually find a file somewhere in `config_path`."""
file_path = "tunable-values/tunable-values-example.jsonc"
path = config_persistence_service.resolve_path(file_path)
assert path.endswith(file_path)
@@ -93,9 +85,7 @@ def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> N
def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None:
- """
- Check if non-existent file resolves without using `config_path`.
- """
+ """Check if non-existent file resolves without using `config_path`."""
file_path = "foo/non-existent-config.json"
path = config_persistence_service.resolve_path(file_path)
assert not os.path.exists(path)
@@ -103,11 +93,13 @@ def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService)
def test_load_config(config_persistence_service: ConfigPersistenceService) -> None:
+ """Check if we can successfully load a config file located relative to
+ `config_path`.
"""
- Check if we can successfully load a config file located relative to `config_path`.
- """
- tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc",
- ConfigSchema.TUNABLE_VALUES)
+ tunables_data = config_persistence_service.load_config(
+ "tunable-values/tunable-values-example.jsonc",
+ ConfigSchema.TUNABLE_VALUES,
+ )
assert tunables_data is not None
assert isinstance(tunables_data, dict)
assert len(tunables_data) >= 1
diff --git a/mlos_bench/mlos_bench/tests/services/local/__init__.py b/mlos_bench/mlos_bench/tests/services/local/__init__.py
index c6dbf7c021..79778d3c25 100644
--- a/mlos_bench/mlos_bench/tests/services/local/__init__.py
+++ b/mlos_bench/mlos_bench/tests/services/local/__init__.py
@@ -4,11 +4,12 @@
#
"""
Tests for mlos_bench.services.local.
+
Used to make mypy happy about multiple conftest.py modules.
"""
from .mock import MockLocalExecService
__all__ = [
- 'MockLocalExecService',
+ "MockLocalExecService",
]
diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py
index 6f8549aee7..e3890149bd 100644
--- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py
+++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py
@@ -2,19 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for LocalExecService to run Python scripts locally.
-"""
-
-from typing import Any, Dict
+"""Unit tests for LocalExecService to run Python scripts locally."""
import json
+from typing import Any, Dict
import pytest
-from mlos_bench.tunables.tunable import TunableValue
-from mlos_bench.services.local.local_exec import LocalExecService
from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.services.local.local_exec import LocalExecService
+from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.util import path_join
# pylint: disable=redefined-outer-name
@@ -22,16 +19,12 @@ from mlos_bench.util import path_join
@pytest.fixture
def local_exec_service() -> LocalExecService:
- """
- Test fixture for LocalExecService.
- """
+ """Test fixture for LocalExecService."""
return LocalExecService(parent=ConfigPersistenceService())
def test_run_python_script(local_exec_service: LocalExecService) -> None:
- """
- Run a Python script using a local_exec service.
- """
+ """Run a Python script using a local_exec service."""
input_file = "./input-params.json"
meta_file = "./input-params-meta.json"
output_file = "./config-kernel.sh"
@@ -57,11 +50,14 @@ def test_run_python_script(local_exec_service: LocalExecService) -> None:
json.dump(params_meta, fh_meta)
script_path = local_exec_service.config_loader_service.resolve_path(
- "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py")
+ "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py"
+ )
- (return_code, _stdout, stderr) = local_exec_service.local_exec([
- f"{script_path} {input_file} {meta_file} {output_file}"
- ], cwd=temp_dir, env=params)
+ (return_code, _stdout, stderr) = local_exec_service.local_exec(
+ [f"{script_path} {input_file} {meta_file} {output_file}"],
+ cwd=temp_dir,
+ env=params,
+ )
assert stderr.strip() == ""
assert return_code == 0
diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py
index 6e56b3bbe2..7165496f9d 100644
--- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py
+++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py
@@ -2,17 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for the service to run the scripts locally.
-"""
+"""Unit tests for the service to run the scripts locally."""
import sys
import tempfile
-import pytest
import pandas
+import pytest
-from mlos_bench.services.local.local_exec import LocalExecService, split_cmdline
from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.services.local.local_exec import LocalExecService, split_cmdline
from mlos_bench.util import path_join
# pylint: disable=redefined-outer-name
@@ -21,36 +19,34 @@ from mlos_bench.util import path_join
def test_split_cmdline() -> None:
- """
- Test splitting a commandline into subcommands.
- """
- cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)"
+ """Test splitting a commandline into subcommands."""
+ cmdline = (
+ ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)"
+ )
assert list(split_cmdline(cmdline)) == [
- ['.', 'env.sh'],
- ['&&'],
- ['('],
- ['echo', 'hello'],
- ['&&'],
- ['echo', 'world'],
- ['|'],
- ['tee'],
- ['>'],
- ['/tmp/test'],
- ['||'],
- ['echo', 'foo'],
- ['&&'],
- ['echo', '$var'],
- [';'],
- ['true'],
- [')'],
+ [".", "env.sh"],
+ ["&&"],
+ ["("],
+ ["echo", "hello"],
+ ["&&"],
+ ["echo", "world"],
+ ["|"],
+ ["tee"],
+ [">"],
+ ["/tmp/test"],
+ ["||"],
+ ["echo", "foo"],
+ ["&&"],
+ ["echo", "$var"],
+ [";"],
+ ["true"],
+ [")"],
]
@pytest.fixture
def local_exec_service() -> LocalExecService:
- """
- Test fixture for LocalExecService.
- """
+ """Test fixture for LocalExecService."""
config = {
"abort_on_error": True,
}
@@ -58,25 +54,24 @@ def local_exec_service() -> LocalExecService:
def test_resolve_script(local_exec_service: LocalExecService) -> None:
- """
- Test local script resolution logic with complex subcommand names.
- """
+ """Test local script resolution logic with complex subcommand names."""
script = "os/linux/runtime/scripts/local/generate_kernel_config_script.py"
script_abspath = local_exec_service.config_loader_service.resolve_path(script)
orig_cmdline = f". env.sh && {script} --input foo"
expected_cmdline = f". env.sh && {script_abspath} --input foo"
subcmds_tokens = split_cmdline(orig_cmdline)
# pylint: disable=protected-access
- subcmds_tokens = [local_exec_service._resolve_cmdline_script_path(subcmd_tokens) for subcmd_tokens in subcmds_tokens]
+ subcmds_tokens = [
+ local_exec_service._resolve_cmdline_script_path(subcmd_tokens)
+ for subcmd_tokens in subcmds_tokens
+ ]
cmdline_tokens = [token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens]
expanded_cmdline = " ".join(cmdline_tokens)
assert expanded_cmdline == expected_cmdline
def test_run_script(local_exec_service: LocalExecService) -> None:
- """
- Run a script locally and check the results.
- """
+ """Run a script locally and check the results."""
# `echo` should work on all platforms
(return_code, stdout, stderr) = local_exec_service.local_exec(["echo hello"])
assert return_code == 0
@@ -85,30 +80,23 @@ def test_run_script(local_exec_service: LocalExecService) -> None:
def test_run_script_multiline(local_exec_service: LocalExecService) -> None:
- """
- Run a multiline script locally and check the results.
- """
+ """Run a multiline script locally and check the results."""
# `echo` should work on all platforms
- (return_code, stdout, stderr) = local_exec_service.local_exec([
- "echo hello",
- "echo world"
- ])
+ (return_code, stdout, stderr) = local_exec_service.local_exec(["echo hello", "echo world"])
assert return_code == 0
assert stdout.strip().split() == ["hello", "world"]
assert stderr.strip() == ""
def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None:
- """
- Run a multiline script locally and pass the environment variables to it.
- """
+ """Run a multiline script locally and pass the environment variables to it."""
# `echo` should work on all platforms
- (return_code, stdout, stderr) = local_exec_service.local_exec([
- r"echo $var", # Unix shell
- r"echo %var%" # Windows cmd
- ], env={"var": "VALUE", "int_var": 10})
+ (return_code, stdout, stderr) = local_exec_service.local_exec(
+ [r"echo $var", r"echo %var%"], # Unix shell # Windows cmd
+ env={"var": "VALUE", "int_var": 10},
+ )
assert return_code == 0
- if sys.platform == 'win32':
+ if sys.platform == "win32":
assert stdout.strip().split() == ["$var", "VALUE"]
else:
assert stdout.strip().split() == ["VALUE", "%var%"]
@@ -116,46 +104,48 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None:
def test_run_script_read_csv(local_exec_service: LocalExecService) -> None:
- """
- Run a script locally and read the resulting CSV file.
- """
+ """Run a script locally and read the resulting CSV file."""
with local_exec_service.temp_dir_context() as temp_dir:
- (return_code, stdout, stderr) = local_exec_service.local_exec([
- "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows
- "echo '111,222' >> output.csv",
- "echo '333,444' >> output.csv",
- ], cwd=temp_dir)
+ (return_code, stdout, stderr) = local_exec_service.local_exec(
+ [
+ "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows
+ "echo '111,222' >> output.csv",
+ "echo '333,444' >> output.csv",
+ ],
+ cwd=temp_dir,
+ )
assert return_code == 0
assert stdout.strip() == ""
assert stderr.strip() == ""
data = pandas.read_csv(path_join(temp_dir, "output.csv"))
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# Workaround for Python's subprocess module on Windows adding a
# space inbetween the col1,col2 arg and the redirect symbol which
# cmd poorly interprets as being part of the original string arg.
# Without this, we get "col2 " as the second column name.
- data.rename(str.rstrip, axis='columns', inplace=True)
+ data.rename(str.rstrip, axis="columns", inplace=True)
assert all(data.col1 == [111, 333])
assert all(data.col2 == [222, 444])
def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None:
- """
- Write data a temp location and run a script that updates it there.
- """
+ """Write data a temp location and run a script that updates it there."""
with local_exec_service.temp_dir_context() as temp_dir:
input_file = "input.txt"
with open(path_join(temp_dir, input_file), "wt", encoding="utf-8") as fh_input:
fh_input.write("hello\n")
- (return_code, stdout, stderr) = local_exec_service.local_exec([
- f"echo 'world' >> {input_file}",
- f"echo 'test' >> {input_file}",
- ], cwd=temp_dir)
+ (return_code, stdout, stderr) = local_exec_service.local_exec(
+ [
+ f"echo 'world' >> {input_file}",
+ f"echo 'test' >> {input_file}",
+ ],
+ cwd=temp_dir,
+ )
assert return_code == 0
assert stdout.strip() == ""
@@ -166,37 +156,35 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None
def test_run_script_fail(local_exec_service: LocalExecService) -> None:
- """
- Try to run a non-existent command.
- """
+ """Try to run a non-existent command."""
(return_code, stdout, _stderr) = local_exec_service.local_exec(["foo_bar_baz hello"])
assert return_code != 0
assert stdout.strip() == ""
def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> None:
- """
- Try to run a series of commands, one of which fails, and abort early.
- """
- (return_code, stdout, _stderr) = local_exec_service.local_exec([
- "echo hello",
- "cmd /c 'exit 1'" if sys.platform == 'win32' else "false",
- "echo world",
- ])
+ """Try to run a series of commands, one of which fails, and abort early."""
+ (return_code, stdout, _stderr) = local_exec_service.local_exec(
+ [
+ "echo hello",
+ "cmd /c 'exit 1'" if sys.platform == "win32" else "false",
+ "echo world",
+ ]
+ )
assert return_code != 0
assert stdout.strip() == "hello"
def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> None:
- """
- Try to run a series of commands, one of which fails, but let it pass.
- """
+ """Try to run a series of commands, one of which fails, but let it pass."""
local_exec_service.abort_on_error = False
- (return_code, stdout, _stderr) = local_exec_service.local_exec([
- "echo hello",
- "cmd /c 'exit 1'" if sys.platform == 'win32' else "false",
- "echo world",
- ])
+ (return_code, stdout, _stderr) = local_exec_service.local_exec(
+ [
+ "echo hello",
+ "cmd /c 'exit 1'" if sys.platform == "win32" else "false",
+ "echo world",
+ ]
+ )
assert return_code == 0
assert stdout.splitlines() == [
"hello",
@@ -205,22 +193,26 @@ def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> No
def test_temp_dir_path_expansion() -> None:
- """
- Test that we can control the temp_dir path using globals expansion.
- """
+ """Test that we can control the temp_dir path using globals expansion."""
# Create a temp dir for the test.
# Normally this would be a real path set on the CLI or in a global config,
# but for test purposes we still want it to be dynamic and cleaned up after
# the fact.
with tempfile.TemporaryDirectory() as temp_dir:
global_config = {
- "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench"
+ "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench"
}
config = {
# The temp_dir for the LocalExecService should get expanded via workdir global config.
"temp_dir": "$workdir/temp",
}
- local_exec_service = LocalExecService(config, global_config, parent=ConfigPersistenceService())
+ local_exec_service = LocalExecService(
+ config, global_config, parent=ConfigPersistenceService()
+ )
# pylint: disable=protected-access
assert isinstance(local_exec_service._temp_dir, str)
- assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join(temp_dir, "temp", abs_path=True)
+ assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join(
+ temp_dir,
+ "temp",
+ abs_path=True,
+ )
diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py
index eede9383bc..2bae6d8dbd 100644
--- a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py
+++ b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py
@@ -2,12 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Mock local services for testing purposes.
-"""
+"""Mock local services for testing purposes."""
from .mock_local_exec_service import MockLocalExecService
__all__ = [
- 'MockLocalExecService',
+ "MockLocalExecService",
]
diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py
index 588b94d8ea..39934c40e8 100644
--- a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py
+++ b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py
@@ -2,13 +2,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for mocking local exec.
-"""
+"""A collection Service functions for mocking local exec."""
import logging
from typing import (
- Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
)
from mlos_bench.services.base_service import Service
@@ -22,20 +29,23 @@ _LOG = logging.getLogger(__name__)
class MockLocalExecService(TempDirContextService, SupportsLocalExec):
- """
- Mock methods for LocalExecService testing.
- """
+ """Mock methods for LocalExecService testing."""
- def __init__(self, config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [self.local_exec])
+ config, global_config, parent, self.merge_methods(methods, [self.local_exec])
)
- def local_exec(self, script_lines: Iterable[str],
- env: Optional[Mapping[str, "TunableValue"]] = None,
- cwd: Optional[str] = None) -> Tuple[int, str, str]:
+ def local_exec(
+ self,
+ script_lines: Iterable[str],
+ env: Optional[Mapping[str, "TunableValue"]] = None,
+ cwd: Optional[str] = None,
+ ) -> Tuple[int, str, str]:
return (0, "", "")
diff --git a/mlos_bench/mlos_bench/tests/services/mock_service.py b/mlos_bench/mlos_bench/tests/services/mock_service.py
index 835738015b..cebea96912 100644
--- a/mlos_bench/mlos_bench/tests/services/mock_service.py
+++ b/mlos_bench/mlos_bench/tests/services/mock_service.py
@@ -15,39 +15,44 @@ from mlos_bench.services.base_service import Service
@runtime_checkable
class SupportsSomeMethod(Protocol):
- """Protocol for some_method"""
+ """Protocol for some_method."""
def some_method(self) -> str:
- """some_method"""
+ """some_method."""
def some_other_method(self) -> str:
- """some_other_method"""
+ """some_other_method."""
class MockServiceBase(Service, SupportsSomeMethod):
"""A base service class for testing."""
def __init__(
- self,
- config: Optional[dict] = None,
- global_config: Optional[dict] = None,
- parent: Optional[Service] = None,
- methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None) -> None:
+ self,
+ config: Optional[dict] = None,
+ global_config: Optional[dict] = None,
+ parent: Optional[Service] = None,
+ methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None,
+ ) -> None:
super().__init__(
config,
global_config,
parent,
- self.merge_methods(methods, [
- self.some_method,
- self.some_other_method,
- ]))
+ self.merge_methods(
+ methods,
+ [
+ self.some_method,
+ self.some_other_method,
+ ],
+ ),
+ )
def some_method(self) -> str:
- """some_method"""
+ """some_method."""
return f"{self}: base.some_method"
def some_other_method(self) -> str:
- """some_other_method"""
+ """some_other_method."""
return f"{self}: base.some_other_method"
@@ -57,5 +62,5 @@ class MockServiceChild(MockServiceBase, SupportsSomeMethod):
# Intentionally includes no constructor.
def some_method(self) -> str:
- """some_method"""
+ """some_method."""
return f"{self}: child.some_method"
diff --git a/mlos_bench/mlos_bench/tests/services/remote/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/__init__.py
index e8a87ab684..b486afdb7c 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/__init__.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/__init__.py
@@ -4,6 +4,7 @@
#
"""
Tests for mlos_bench.services.remote.
+
Used to make mypy happy about multiple conftest.py modules.
"""
@@ -12,7 +13,7 @@ from .mock.mock_remote_exec_service import MockRemoteExecService
from .mock.mock_vm_service import MockVMService
__all__ = [
- 'MockFileShareService',
- 'MockRemoteExecService',
- 'MockVMService',
+ "MockFileShareService",
+ "MockRemoteExecService",
+ "MockVMService",
]
diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py
index dc3b5469be..d45db2383e 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py
@@ -2,19 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests helpers for mlos_bench.services.remote.azure.
-"""
+"""Tests helpers for mlos_bench.services.remote.azure."""
+import json
from io import BytesIO
-import json
import urllib3
def make_httplib_json_response(status: int, json_data: dict) -> urllib3.HTTPResponse:
- """
- Prepare a json response object for use with urllib3
- """
+ """Prepare a json response object for use with urllib3."""
data = json.dumps(json_data).encode("utf-8")
response = urllib3.HTTPResponse(
status=status,
diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py
index 199c42f1fb..79090a2f5f 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py
@@ -2,12 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench.services.remote.azure.azure_fileshare
-"""
+"""Tests for mlos_bench.services.remote.azure.azure_fileshare."""
import os
-from unittest.mock import MagicMock, Mock, patch, call
+from unittest.mock import MagicMock, Mock, call, patch
from mlos_bench.services.remote.azure.azure_fileshare import AzureFileShareService
@@ -18,7 +16,11 @@ from mlos_bench.services.remote.azure.azure_fileshare import AzureFileShareServi
@patch("mlos_bench.services.remote.azure.azure_fileshare.open")
@patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs")
-def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
+def test_download_file(
+ mock_makedirs: MagicMock,
+ mock_open: MagicMock,
+ azure_fileshare: AzureFileShareService,
+) -> None:
filename = "test.csv"
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
@@ -26,8 +28,9 @@ def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fil
local_path = f"{local_folder}/{filename}"
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}
- with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \
- patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client:
+ with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, patch.object(
+ mock_share_client, "get_directory_client"
+ ) as mock_get_directory_client:
mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False))
azure_fileshare.download(config, remote_path, local_path)
@@ -47,38 +50,43 @@ def make_dir_client_returns(remote_folder: str) -> dict:
return {
remote_folder: Mock(
exists=Mock(return_value=True),
- list_directories_and_files=Mock(return_value=[
- {"name": "a_folder", "is_directory": True},
- {"name": "a_file_1.csv", "is_directory": False},
- ])
+ list_directories_and_files=Mock(
+ return_value=[
+ {"name": "a_folder", "is_directory": True},
+ {"name": "a_file_1.csv", "is_directory": False},
+ ]
+ ),
),
f"{remote_folder}/a_folder": Mock(
exists=Mock(return_value=True),
- list_directories_and_files=Mock(return_value=[
- {"name": "a_file_2.csv", "is_directory": False},
- ])
- ),
- f"{remote_folder}/a_file_1.csv": Mock(
- exists=Mock(return_value=False)
- ),
- f"{remote_folder}/a_folder/a_file_2.csv": Mock(
- exists=Mock(return_value=False)
+ list_directories_and_files=Mock(
+ return_value=[
+ {"name": "a_file_2.csv", "is_directory": False},
+ ]
+ ),
),
+ f"{remote_folder}/a_file_1.csv": Mock(exists=Mock(return_value=False)),
+ f"{remote_folder}/a_folder/a_file_2.csv": Mock(exists=Mock(return_value=False)),
}
@patch("mlos_bench.services.remote.azure.azure_fileshare.open")
@patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs")
-def test_download_folder_non_recursive(mock_makedirs: MagicMock,
- mock_open: MagicMock,
- azure_fileshare: AzureFileShareService) -> None:
+def test_download_folder_non_recursive(
+ mock_makedirs: MagicMock,
+ mock_open: MagicMock,
+ azure_fileshare: AzureFileShareService,
+) -> None:
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
dir_client_returns = make_dir_client_returns(remote_folder)
- mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
+ mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}
- with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \
- patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
+ with patch.object(
+ mock_share_client, "get_directory_client"
+ ) as mock_get_directory_client, patch.object(
+ mock_share_client, "get_file_client"
+ ) as mock_get_file_client:
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]
@@ -87,47 +95,67 @@ def test_download_folder_non_recursive(mock_makedirs: MagicMock,
mock_get_file_client.assert_called_with(
f"{remote_folder}/a_file_1.csv",
)
- mock_get_directory_client.assert_has_calls([
- call(remote_folder),
- call(f"{remote_folder}/a_file_1.csv"),
- ], any_order=True)
+ mock_get_directory_client.assert_has_calls(
+ [
+ call(remote_folder),
+ call(f"{remote_folder}/a_file_1.csv"),
+ ],
+ any_order=True,
+ )
@patch("mlos_bench.services.remote.azure.azure_fileshare.open")
@patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs")
-def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
+def test_download_folder_recursive(
+ mock_makedirs: MagicMock,
+ mock_open: MagicMock,
+ azure_fileshare: AzureFileShareService,
+) -> None:
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
dir_client_returns = make_dir_client_returns(remote_folder)
- mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
+ mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}
- with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \
- patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
+ with patch.object(
+ mock_share_client, "get_directory_client"
+ ) as mock_get_directory_client, patch.object(
+ mock_share_client, "get_file_client"
+ ) as mock_get_file_client:
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]
azure_fileshare.download(config, remote_folder, local_folder, recursive=True)
- mock_get_file_client.assert_has_calls([
- call(f"{remote_folder}/a_file_1.csv"),
- call(f"{remote_folder}/a_folder/a_file_2.csv"),
- ], any_order=True)
- mock_get_directory_client.assert_has_calls([
- call(remote_folder),
- call(f"{remote_folder}/a_file_1.csv"),
- call(f"{remote_folder}/a_folder"),
- call(f"{remote_folder}/a_folder/a_file_2.csv"),
- ], any_order=True)
+ mock_get_file_client.assert_has_calls(
+ [
+ call(f"{remote_folder}/a_file_1.csv"),
+ call(f"{remote_folder}/a_folder/a_file_2.csv"),
+ ],
+ any_order=True,
+ )
+ mock_get_directory_client.assert_has_calls(
+ [
+ call(remote_folder),
+ call(f"{remote_folder}/a_file_1.csv"),
+ call(f"{remote_folder}/a_folder"),
+ call(f"{remote_folder}/a_folder/a_file_2.csv"),
+ ],
+ any_order=True,
+ )
@patch("mlos_bench.services.remote.azure.azure_fileshare.open")
@patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir")
-def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None:
+def test_upload_file(
+ mock_isdir: MagicMock,
+ mock_open: MagicMock,
+ azure_fileshare: AzureFileShareService,
+) -> None:
filename = "test.csv"
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
remote_path = f"{remote_folder}/{filename}"
local_path = f"{local_folder}/{filename}"
- mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
+ mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
mock_isdir.return_value = False
config: dict = {}
@@ -142,7 +170,8 @@ def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshar
class MyDirEntry:
# pylint: disable=too-few-public-methods
- """Dummy class for os.DirEntry"""
+ """Dummy class for os.DirEntry."""
+
def __init__(self, name: str, is_a_dir: bool):
self.name = name
self.is_a_dir = is_a_dir
@@ -176,7 +205,7 @@ def process_paths(input_path: str) -> str:
skip_prefix = os.getcwd()
# Remove prefix from os.path.abspath if there
if input_path == os.path.abspath(input_path):
- result = input_path[len(skip_prefix) + 1:]
+ result = input_path[(len(skip_prefix) + 1) :]
else:
result = input_path
# Change file seps to unix-style
@@ -186,17 +215,19 @@ def process_paths(input_path: str) -> str:
@patch("mlos_bench.services.remote.azure.azure_fileshare.open")
@patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir")
@patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir")
-def test_upload_directory_non_recursive(mock_scandir: MagicMock,
- mock_isdir: MagicMock,
- mock_open: MagicMock,
- azure_fileshare: AzureFileShareService) -> None:
+def test_upload_directory_non_recursive(
+ mock_scandir: MagicMock,
+ mock_isdir: MagicMock,
+ mock_open: MagicMock,
+ azure_fileshare: AzureFileShareService,
+) -> None:
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
scandir_returns = make_scandir_returns(local_folder)
isdir_returns = make_isdir_returns(local_folder)
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
- mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
+ mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
@@ -208,23 +239,28 @@ def test_upload_directory_non_recursive(mock_scandir: MagicMock,
@patch("mlos_bench.services.remote.azure.azure_fileshare.open")
@patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir")
@patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir")
-def test_upload_directory_recursive(mock_scandir: MagicMock,
- mock_isdir: MagicMock,
- mock_open: MagicMock,
- azure_fileshare: AzureFileShareService) -> None:
+def test_upload_directory_recursive(
+ mock_scandir: MagicMock,
+ mock_isdir: MagicMock,
+ mock_open: MagicMock,
+ azure_fileshare: AzureFileShareService,
+) -> None:
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
scandir_returns = make_scandir_returns(local_folder)
isdir_returns = make_isdir_returns(local_folder)
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
- mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
+ mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
azure_fileshare.upload(config, local_folder, remote_folder, recursive=True)
- mock_get_file_client.assert_has_calls([
- call(f"{remote_folder}/a_file_1.csv"),
- call(f"{remote_folder}/a_folder/a_file_2.csv"),
- ], any_order=True)
+ mock_get_file_client.assert_has_calls(
+ [
+ call(f"{remote_folder}/a_file_1.csv"),
+ call(f"{remote_folder}/a_folder/a_file_2.csv"),
+ ],
+ any_order=True,
+ )
diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py
index 22fec74c74..87dd78fd5a 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench.services.remote.azure.azure_network_services
-"""
+"""Tests for mlos_bench.services.remote.azure.azure_network_services."""
from unittest.mock import MagicMock, patch
@@ -12,33 +10,37 @@ import pytest
import requests.exceptions as requests_ex
from mlos_bench.environments.status import Status
-
from mlos_bench.services.remote.azure.azure_auth import AzureAuthService
from mlos_bench.services.remote.azure.azure_network_services import AzureNetworkService
-
from mlos_bench.tests.services.remote.azure import make_httplib_json_response
@pytest.mark.parametrize(
- ("total_retries", "operation_status"), [
+ ("total_retries", "operation_status"),
+ [
(2, Status.SUCCEEDED),
(1, Status.FAILED),
(0, Status.FAILED),
- ])
+ ],
+)
@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn")
-def test_wait_network_deployment_retry(mock_getconn: MagicMock,
- total_retries: int,
- operation_status: Status,
- azure_network_service: AzureNetworkService) -> None:
- """
- Test retries of the network deployment operation.
- """
+def test_wait_network_deployment_retry(
+ mock_getconn: MagicMock,
+ total_retries: int,
+ operation_status: Status,
+ azure_network_service: AzureNetworkService,
+) -> None:
+ """Test retries of the network deployment operation."""
# Simulate intermittent connection issues with multiple connection errors
# Sufficient retry attempts should result in success, otherwise a graceful failure state
mock_getconn.return_value.getresponse.side_effect = [
make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}),
- requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")),
- requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")),
+ requests_ex.ConnectionError(
+ "Connection aborted", OSError(107, "Transport endpoint is not connected")
+ ),
+ requests_ex.ConnectionError(
+ "Connection aborted", OSError(107, "Transport endpoint is not connected")
+ ),
make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}),
make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}),
]
@@ -51,33 +53,38 @@ def test_wait_network_deployment_retry(mock_getconn: MagicMock,
"subscription": "TEST_SUB1",
"resourceGroup": "TEST_RG1",
},
- is_setup=True)
+ is_setup=True,
+ )
assert status == operation_status
@pytest.mark.parametrize(
- ("operation_name", "accepts_params"), [
+ ("operation_name", "accepts_params"),
+ [
("deprovision_network", True),
- ])
+ ],
+)
@pytest.mark.parametrize(
- ("http_status_code", "operation_status"), [
+ ("http_status_code", "operation_status"),
+ [
(200, Status.SUCCEEDED),
(202, Status.PENDING),
# These should succeed since we set ignore_errors=True by default
(401, Status.SUCCEEDED),
(404, Status.SUCCEEDED),
- ])
+ ],
+)
@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests")
# pylint: disable=too-many-arguments
-def test_network_operation_status(mock_requests: MagicMock,
- azure_network_service: AzureNetworkService,
- operation_name: str,
- accepts_params: bool,
- http_status_code: int,
- operation_status: Status) -> None:
- """
- Test network operation status.
- """
+def test_network_operation_status(
+ mock_requests: MagicMock,
+ azure_network_service: AzureNetworkService,
+ operation_name: str,
+ accepts_params: bool,
+ http_status_code: int,
+ operation_status: Status,
+) -> None:
+ """Test network operation status."""
mock_response = MagicMock()
mock_response.status_code = http_status_code
mock_requests.post.return_value = mock_response
@@ -91,22 +98,28 @@ def test_network_operation_status(mock_requests: MagicMock,
@pytest.fixture
-def test_azure_network_service_no_deployment_template(azure_auth_service: AzureAuthService) -> None:
- """
- Tests creating a network services without a deployment template (should fail).
- """
+def test_azure_network_service_no_deployment_template(
+ azure_auth_service: AzureAuthService,
+) -> None:
+ """Tests creating a network services without a deployment template (should fail)."""
with pytest.raises(ValueError):
- _ = AzureNetworkService(config={
- "deploymentTemplatePath": None,
- "deploymentTemplateParameters": {
- "location": "westus2",
+ _ = AzureNetworkService(
+ config={
+ "deploymentTemplatePath": None,
+ "deploymentTemplateParameters": {
+ "location": "westus2",
+ },
},
- }, parent=azure_auth_service)
+ parent=azure_auth_service,
+ )
with pytest.raises(ValueError):
- _ = AzureNetworkService(config={
- # "deploymentTemplatePath": None,
- "deploymentTemplateParameters": {
- "location": "westus2",
+ _ = AzureNetworkService(
+ config={
+ # "deploymentTemplatePath": None,
+ "deploymentTemplateParameters": {
+ "location": "westus2",
+ },
},
- }, parent=azure_auth_service)
+ parent=azure_auth_service,
+ )
diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py
index 97bf904a56..6418da01a9 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench.services.remote.azure.azure_vm_services
-"""
+"""Tests for mlos_bench.services.remote.azure.azure_vm_services."""
from copy import deepcopy
from unittest.mock import MagicMock, patch
@@ -13,33 +11,39 @@ import pytest
import requests.exceptions as requests_ex
from mlos_bench.environments.status import Status
-
from mlos_bench.services.remote.azure.azure_auth import AzureAuthService
from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService
-
from mlos_bench.tests.services.remote.azure import make_httplib_json_response
@pytest.mark.parametrize(
- ("total_retries", "operation_status"), [
+ ("total_retries", "operation_status"),
+ [
(2, Status.SUCCEEDED),
(1, Status.FAILED),
(0, Status.FAILED),
- ])
+ ],
+)
@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn")
-def test_wait_host_deployment_retry(mock_getconn: MagicMock,
- total_retries: int,
- operation_status: Status,
- azure_vm_service: AzureVMService) -> None:
- """
- Test retries of the host deployment operation.
- """
+def test_wait_host_deployment_retry(
+ mock_getconn: MagicMock,
+ total_retries: int,
+ operation_status: Status,
+ azure_vm_service: AzureVMService,
+) -> None:
+ """Test retries of the host deployment operation."""
# Simulate intermittent connection issues with multiple connection errors
# Sufficient retry attempts should result in success, otherwise a graceful failure state
mock_getconn.return_value.getresponse.side_effect = [
make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}),
- requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")),
- requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")),
+ requests_ex.ConnectionError(
+ "Connection aborted",
+ OSError(107, "Transport endpoint is not connected"),
+ ),
+ requests_ex.ConnectionError(
+ "Connection aborted",
+ OSError(107, "Transport endpoint is not connected"),
+ ),
make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}),
make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}),
]
@@ -52,16 +56,17 @@ def test_wait_host_deployment_retry(mock_getconn: MagicMock,
"subscription": "TEST_SUB1",
"resourceGroup": "TEST_RG1",
},
- is_setup=True)
+ is_setup=True,
+ )
assert status == operation_status
def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None:
- """
- Test expanding template params recursively.
- """
+ """Test expanding template params recursively."""
config = {
- "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc",
+ "deploymentTemplatePath": (
+ "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc"
+ ),
"subscription": "TEST_SUB1",
"resourceGroup": "TEST_RG1",
"deploymentTemplateParameters": {
@@ -77,17 +82,23 @@ def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAut
}
azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service)
assert azure_vm_service.deploy_params["location"] == global_config["location"]
- assert azure_vm_service.deploy_params["vmMeta"] == f'{global_config["vmName"]}-{global_config["location"]}'
- assert azure_vm_service.deploy_params["vmNsg"] == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg'
+ assert (
+ azure_vm_service.deploy_params["vmMeta"]
+ == f'{global_config["vmName"]}-{global_config["location"]}'
+ )
+ assert (
+ azure_vm_service.deploy_params["vmNsg"]
+ == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg'
+ )
def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None:
- """
- Test loading custom data from a file.
- """
+ """Test loading custom data from a file."""
config = {
"customDataFile": "services/remote/azure/cloud-init/alt-ssh.yml",
- "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc",
+ "deploymentTemplatePath": (
+ "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc"
+ ),
"subscription": "TEST_SUB1",
"resourceGroup": "TEST_RG1",
"deploymentTemplateParameters": {
@@ -100,14 +111,15 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N
}
with pytest.raises(ValueError):
config_with_custom_data = deepcopy(config)
- config_with_custom_data['deploymentTemplateParameters']['customData'] = "DUMMY_CUSTOM_DATA" # type: ignore[index]
+ config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] # pylint: disable=line-too-long # noqa
AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service)
azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service)
- assert azure_vm_service.deploy_params['customData']
+ assert azure_vm_service.deploy_params["customData"]
@pytest.mark.parametrize(
- ("operation_name", "accepts_params"), [
+ ("operation_name", "accepts_params"),
+ [
("start_host", True),
("stop_host", True),
("shutdown", True),
@@ -115,25 +127,28 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N
("deallocate_host", True),
("restart_host", True),
("reboot", True),
- ])
+ ],
+)
@pytest.mark.parametrize(
- ("http_status_code", "operation_status"), [
+ ("http_status_code", "operation_status"),
+ [
(200, Status.SUCCEEDED),
(202, Status.PENDING),
(401, Status.FAILED),
(404, Status.FAILED),
- ])
+ ],
+)
@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests")
# pylint: disable=too-many-arguments
-def test_vm_operation_status(mock_requests: MagicMock,
- azure_vm_service: AzureVMService,
- operation_name: str,
- accepts_params: bool,
- http_status_code: int,
- operation_status: Status) -> None:
- """
- Test VM operation status.
- """
+def test_vm_operation_status(
+ mock_requests: MagicMock,
+ azure_vm_service: AzureVMService,
+ operation_name: str,
+ accepts_params: bool,
+ http_status_code: int,
+ operation_status: Status,
+) -> None:
+ """Test VM operation status."""
mock_response = MagicMock()
mock_response.status_code = http_status_code
mock_requests.post.return_value = mock_response
@@ -147,15 +162,17 @@ def test_vm_operation_status(mock_requests: MagicMock,
@pytest.mark.parametrize(
- ("operation_name", "accepts_params"), [
+ ("operation_name", "accepts_params"),
+ [
("provision_host", True),
- ])
-def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService,
- operation_name: str,
- accepts_params: bool) -> None:
- """
- Test VM operation status for an incomplete service config.
- """
+ ],
+)
+def test_vm_operation_invalid(
+ azure_vm_service_remote_exec_only: AzureVMService,
+ operation_name: str,
+ accepts_params: bool,
+) -> None:
+ """Test VM operation status for an incomplete service config."""
operation = getattr(azure_vm_service_remote_exec_only, operation_name)
with pytest.raises(ValueError):
(_, _) = operation({"vmName": "test-vm"}) if accepts_params else operation()
@@ -163,11 +180,12 @@ def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService,
@patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep")
@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session")
-def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock,
- azure_vm_service: AzureVMService) -> None:
- """
- Test waiting for the completion of the remote VM operation.
- """
+def test_wait_vm_operation_ready(
+ mock_session: MagicMock,
+ mock_sleep: MagicMock,
+ azure_vm_service: AzureVMService,
+) -> None:
+ """Test waiting for the completion of the remote VM operation."""
# Mock response header
async_url = "DUMMY_ASYNC_URL"
retry_after = 12345
@@ -185,23 +203,19 @@ def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock,
status, _ = azure_vm_service.wait_host_operation(params)
- assert (async_url, ) == mock_session.return_value.get.call_args[0]
- assert (retry_after, ) == mock_sleep.call_args[0]
+ assert (async_url,) == mock_session.return_value.get.call_args[0]
+ assert (retry_after,) == mock_sleep.call_args[0]
assert status.is_succeeded()
@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session")
-def test_wait_vm_operation_timeout(mock_session: MagicMock,
- azure_vm_service: AzureVMService) -> None:
- """
- Test the time out of the remote VM operation.
- """
+def test_wait_vm_operation_timeout(
+ mock_session: MagicMock,
+ azure_vm_service: AzureVMService,
+) -> None:
+ """Test the time out of the remote VM operation."""
# Mock response header
- params = {
- "asyncResultsUrl": "DUMMY_ASYNC_URL",
- "vmName": "test-vm",
- "pollInterval": 1
- }
+ params = {"asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", "pollInterval": 1}
mock_status_response = MagicMock(status_code=200)
mock_status_response.json.return_value = {
@@ -214,25 +228,33 @@ def test_wait_vm_operation_timeout(mock_session: MagicMock,
@pytest.mark.parametrize(
- ("total_retries", "operation_status"), [
+ ("total_retries", "operation_status"),
+ [
(2, Status.SUCCEEDED),
(1, Status.FAILED),
(0, Status.FAILED),
- ])
+ ],
+)
@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn")
-def test_wait_vm_operation_retry(mock_getconn: MagicMock,
- total_retries: int,
- operation_status: Status,
- azure_vm_service: AzureVMService) -> None:
- """
- Test the retries of the remote VM operation.
- """
+def test_wait_vm_operation_retry(
+ mock_getconn: MagicMock,
+ total_retries: int,
+ operation_status: Status,
+ azure_vm_service: AzureVMService,
+) -> None:
+ """Test the retries of the remote VM operation."""
# Simulate intermittent connection issues with multiple connection errors
# Sufficient retry attempts should result in success, otherwise a graceful failure state
mock_getconn.return_value.getresponse.side_effect = [
make_httplib_json_response(200, {"status": "InProgress"}),
- requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")),
- requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")),
+ requests_ex.ConnectionError(
+ "Connection aborted",
+ OSError(107, "Transport endpoint is not connected"),
+ ),
+ requests_ex.ConnectionError(
+ "Connection aborted",
+ OSError(107, "Transport endpoint is not connected"),
+ ),
make_httplib_json_response(200, {"status": "InProgress"}),
make_httplib_json_response(200, {"status": "Succeeded"}),
]
@@ -243,61 +265,76 @@ def test_wait_vm_operation_retry(mock_getconn: MagicMock,
"requestTotalRetries": total_retries,
"asyncResultsUrl": "https://DUMMY_ASYNC_URL",
"vmName": "test-vm",
- })
+ }
+ )
assert status == operation_status
@pytest.mark.parametrize(
- ("http_status_code", "operation_status"), [
+ ("http_status_code", "operation_status"),
+ [
(200, Status.SUCCEEDED),
(202, Status.PENDING),
(401, Status.FAILED),
(404, Status.FAILED),
- ])
+ ],
+)
@patch("mlos_bench.services.remote.azure.azure_vm_services.requests")
-def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService,
- http_status_code: int, operation_status: Status) -> None:
- """
- Test waiting for completion of the remote execution on Azure.
- """
+def test_remote_exec_status(
+ mock_requests: MagicMock,
+ azure_vm_service_remote_exec_only: AzureVMService,
+ http_status_code: int,
+ operation_status: Status,
+) -> None:
+ """Test waiting for completion of the remote execution on Azure."""
script = ["command_1", "command_2"]
mock_response = MagicMock()
mock_response.status_code = http_status_code
- mock_response.json = MagicMock(return_value={
- "fake response": "body as json to dict",
- })
+ mock_response.json = MagicMock(
+ return_value={
+ "fake response": "body as json to dict",
+ }
+ )
mock_requests.post.return_value = mock_response
- status, _ = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={})
+ status, _ = azure_vm_service_remote_exec_only.remote_exec(
+ script,
+ config={"vmName": "test-vm"},
+ env_params={},
+ )
assert status == operation_status
@patch("mlos_bench.services.remote.azure.azure_vm_services.requests")
-def test_remote_exec_headers_output(mock_requests: MagicMock,
- azure_vm_service_remote_exec_only: AzureVMService) -> None:
- """
- Check if HTTP headers from the remote execution on Azure are correct.
- """
+def test_remote_exec_headers_output(
+ mock_requests: MagicMock,
+ azure_vm_service_remote_exec_only: AzureVMService,
+) -> None:
+ """Check if HTTP headers from the remote execution on Azure are correct."""
async_url_key = "asyncResultsUrl"
async_url_value = "DUMMY_ASYNC_URL"
script = ["command_1", "command_2"]
mock_response = MagicMock()
mock_response.status_code = 202
- mock_response.headers = {
- "Azure-AsyncOperation": async_url_value
- }
- mock_response.json = MagicMock(return_value={
- "fake response": "body as json to dict",
- })
+ mock_response.headers = {"Azure-AsyncOperation": async_url_value}
+ mock_response.json = MagicMock(
+ return_value={
+ "fake response": "body as json to dict",
+ }
+ )
mock_requests.post.return_value = mock_response
- _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={
- "param_1": 123,
- "param_2": "abc",
- })
+ _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(
+ script,
+ config={"vmName": "test-vm"},
+ env_params={
+ "param_1": 123,
+ "param_2": "abc",
+ },
+ )
assert async_url_key in cmd_output
assert cmd_output[async_url_key] == async_url_value
@@ -305,15 +342,13 @@ def test_remote_exec_headers_output(mock_requests: MagicMock,
assert mock_requests.post.call_args[1]["json"] == {
"commandId": "RunShellScript",
"script": script,
- "parameters": [
- {"name": "param_1", "value": 123},
- {"name": "param_2", "value": "abc"}
- ]
+ "parameters": [{"name": "param_1", "value": 123}, {"name": "param_2", "value": "abc"}],
}
@pytest.mark.parametrize(
- ("operation_status", "wait_output", "results_output"), [
+ ("operation_status", "wait_output", "results_output"),
+ [
(
Status.SUCCEEDED,
{
@@ -325,16 +360,19 @@ def test_remote_exec_headers_output(mock_requests: MagicMock,
}
}
},
- {"stdout": "DUMMY_STDOUT_STDERR"}
+ {"stdout": "DUMMY_STDOUT_STDERR"},
),
(Status.PENDING, {}, {}),
(Status.FAILED, {}, {}),
- ])
-def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMService, operation_status: Status,
- wait_output: dict, results_output: dict) -> None:
- """
- Test getting the results of the remote execution on Azure.
- """
+ ],
+)
+def test_get_remote_exec_results(
+ azure_vm_service_remote_exec_only: AzureVMService,
+ operation_status: Status,
+ wait_output: dict,
+ results_output: dict,
+) -> None:
+ """Test getting the results of the remote execution on Azure."""
params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"}
mock_wait_host_operation = MagicMock()
diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py
index 45feb1aa50..ad7bae26ee 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Configuration test fixtures for azure_vm_services in mlos_bench.
-"""
+"""Configuration test fixtures for azure_vm_services in mlos_bench."""
from unittest.mock import patch
@@ -13,9 +11,9 @@ import pytest
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.services.remote.azure import (
AzureAuthService,
+ AzureFileShareService,
AzureNetworkService,
AzureVMService,
- AzureFileShareService,
)
# pylint: disable=redefined-outer-name
@@ -23,18 +21,16 @@ from mlos_bench.services.remote.azure import (
@pytest.fixture
def config_persistence_service() -> ConfigPersistenceService:
- """
- Test fixture for ConfigPersistenceService.
- """
+ """Test fixture for ConfigPersistenceService."""
return ConfigPersistenceService()
@pytest.fixture
-def azure_auth_service(config_persistence_service: ConfigPersistenceService,
- monkeypatch: pytest.MonkeyPatch) -> AzureAuthService:
- """
- Creates a dummy AzureAuthService for tests that require it.
- """
+def azure_auth_service(
+ config_persistence_service: ConfigPersistenceService,
+ monkeypatch: pytest.MonkeyPatch,
+) -> AzureAuthService:
+ """Creates a dummy AzureAuthService for tests that require it."""
auth = AzureAuthService(config={}, global_config={}, parent=config_persistence_service)
monkeypatch.setattr(auth, "get_access_token", lambda: "TEST_TOKEN")
return auth
@@ -42,67 +38,79 @@ def azure_auth_service(config_persistence_service: ConfigPersistenceService,
@pytest.fixture
def azure_network_service(azure_auth_service: AzureAuthService) -> AzureNetworkService:
- """
- Creates a dummy Azure VM service for tests that require it.
- """
- return AzureNetworkService(config={
- "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc",
- "subscription": "TEST_SUB",
- "resourceGroup": "TEST_RG",
- "deploymentTemplateParameters": {
- "location": "westus2",
+ """Creates a dummy Azure VM service for tests that require it."""
+ return AzureNetworkService(
+ config={
+ "deploymentTemplatePath": (
+ "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc"
+ ),
+ "subscription": "TEST_SUB",
+ "resourceGroup": "TEST_RG",
+ "deploymentTemplateParameters": {
+ "location": "westus2",
+ },
+ "pollInterval": 1,
+ "pollTimeout": 2,
},
- "pollInterval": 1,
- "pollTimeout": 2
- }, global_config={
- "deploymentName": "TEST_DEPLOYMENT-VNET",
- "vnetName": "test-vnet", # Should come from the upper-level config
- }, parent=azure_auth_service)
+ global_config={
+ "deploymentName": "TEST_DEPLOYMENT-VNET",
+ "vnetName": "test-vnet", # Should come from the upper-level config
+ },
+ parent=azure_auth_service,
+ )
@pytest.fixture
def azure_vm_service(azure_auth_service: AzureAuthService) -> AzureVMService:
- """
- Creates a dummy Azure VM service for tests that require it.
- """
- return AzureVMService(config={
- "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc",
- "subscription": "TEST_SUB",
- "resourceGroup": "TEST_RG",
- "deploymentTemplateParameters": {
- "location": "westus2",
+ """Creates a dummy Azure VM service for tests that require it."""
+ return AzureVMService(
+ config={
+ "deploymentTemplatePath": (
+ "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc"
+ ),
+ "subscription": "TEST_SUB",
+ "resourceGroup": "TEST_RG",
+ "deploymentTemplateParameters": {
+ "location": "westus2",
+ },
+ "pollInterval": 1,
+ "pollTimeout": 2,
},
- "pollInterval": 1,
- "pollTimeout": 2
- }, global_config={
- "deploymentName": "TEST_DEPLOYMENT-VM",
- "vmName": "test-vm", # Should come from the upper-level config
- }, parent=azure_auth_service)
+ global_config={
+ "deploymentName": "TEST_DEPLOYMENT-VM",
+ "vmName": "test-vm", # Should come from the upper-level config
+ },
+ parent=azure_auth_service,
+ )
@pytest.fixture
def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> AzureVMService:
- """
- Creates a dummy Azure VM service with no deployment template.
- """
- return AzureVMService(config={
- "subscription": "TEST_SUB",
- "resourceGroup": "TEST_RG",
- "pollInterval": 1,
- "pollTimeout": 2,
- }, global_config={
- "vmName": "test-vm", # Should come from the upper-level config
- }, parent=azure_auth_service)
+ """Creates a dummy Azure VM service with no deployment template."""
+ return AzureVMService(
+ config={
+ "subscription": "TEST_SUB",
+ "resourceGroup": "TEST_RG",
+ "pollInterval": 1,
+ "pollTimeout": 2,
+ },
+ global_config={
+ "vmName": "test-vm", # Should come from the upper-level config
+ },
+ parent=azure_auth_service,
+ )
@pytest.fixture
def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService:
- """
- Creates a dummy AzureFileShareService for tests that require it.
- """
+ """Creates a dummy AzureFileShareService for tests that require it."""
with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"):
- return AzureFileShareService(config={
- "storageAccountName": "TEST_ACCOUNT_NAME",
- "storageFileShareName": "TEST_FS_NAME",
- "storageAccountKey": "TEST_ACCOUNT_KEY"
- }, global_config={}, parent=config_persistence_service)
+ return AzureFileShareService(
+ config={
+ "storageAccountName": "TEST_ACCOUNT_NAME",
+ "storageFileShareName": "TEST_FS_NAME",
+ "storageAccountKey": "TEST_ACCOUNT_KEY",
+ },
+ global_config={},
+ parent=config_persistence_service,
+ )
diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/mock/__init__.py
index d86cbdf2a3..a12bde8a23 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/mock/__init__.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/mock/__init__.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Mock remote services for testing purposes.
-"""
+"""Mock remote services for testing purposes."""
from typing import Any, Tuple
diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py
index b9474f0709..482f9ee2a9 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for mocking authentication.
-"""
+"""A collection Service functions for mocking authentication."""
import logging
from typing import Any, Callable, Dict, List, Optional, Union
@@ -16,20 +14,26 @@ _LOG = logging.getLogger(__name__)
class MockAuthService(Service, SupportsAuth):
- """
- A collection Service functions for mocking authentication ops.
- """
+ """A collection Service functions for mocking authentication ops."""
- def __init__(self, config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [
- self.get_access_token,
- self.get_auth_headers,
- ])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ [
+ self.get_access_token,
+ self.get_auth_headers,
+ ],
+ ),
)
def get_access_token(self) -> str:
diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py
index c09b31b299..2d227e635e 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py
@@ -2,50 +2,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for mocking file share ops.
-"""
+"""A collection Service functions for mocking file share ops."""
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-from mlos_bench.services.base_service import Service
from mlos_bench.services.base_fileshare import FileShareService
+from mlos_bench.services.base_service import Service
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
_LOG = logging.getLogger(__name__)
class MockFileShareService(FileShareService, SupportsFileShareOps):
- """
- A collection Service functions for mocking file share ops.
- """
+ """A collection Service functions for mocking file share ops."""
- def __init__(self, config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, [self.upload, self.download])
+ config,
+ global_config,
+ parent,
+ self.merge_methods(methods, [self.upload, self.download]),
)
self._upload: List[Tuple[str, str]] = []
self._download: List[Tuple[str, str]] = []
- def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
+ def upload(
+ self,
+ params: dict,
+ local_path: str,
+ remote_path: str,
+ recursive: bool = True,
+ ) -> None:
self._upload.append((local_path, remote_path))
- def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
+ def download(
+ self,
+ params: dict,
+ remote_path: str,
+ local_path: str,
+ recursive: bool = True,
+ ) -> None:
self._download.append((remote_path, local_path))
def get_upload(self) -> List[Tuple[str, str]]:
- """
- Get the list of files that were uploaded.
- """
+ """Get the list of files that were uploaded."""
return self._upload
def get_download(self) -> List[Tuple[str, str]]:
- """
- Get the list of files that were downloaded.
- """
+ """Get the list of files that were downloaded."""
return self._download
diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py
index 6d2bd058b9..a483432023 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py
@@ -2,26 +2,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for mocking managing (Virtual) Networks.
-"""
+"""A collection Service functions for mocking managing (Virtual) Networks."""
from typing import Any, Callable, Dict, List, Optional, Union
from mlos_bench.services.base_service import Service
-from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
+from mlos_bench.services.types.network_provisioner_type import (
+ SupportsNetworkProvisioning,
+)
from mlos_bench.tests.services.remote.mock import mock_operation
class MockNetworkService(Service, SupportsNetworkProvisioning):
- """
- Mock Network service for testing.
- """
+ """Mock Network service for testing."""
- def __init__(self, config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of mock network services proxy.
@@ -36,13 +37,19 @@ class MockNetworkService(Service, SupportsNetworkProvisioning):
Parent service that can provide mixin functions.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, {
- name: mock_operation for name in (
- # SupportsNetworkProvisioning:
- "provision_network",
- "deprovision_network",
- "wait_network_deployment",
- )
- })
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ {
+ name: mock_operation
+ for name in (
+ # SupportsNetworkProvisioning:
+ "provision_network",
+ "deprovision_network",
+ "wait_network_deployment",
+ )
+ },
+ ),
)
diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py
index ee99251c64..57f90ccd4d 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for mocking remote script execution.
-"""
+"""A collection Service functions for mocking remote script execution."""
from typing import Any, Callable, Dict, List, Optional, Union
@@ -14,14 +12,15 @@ from mlos_bench.tests.services.remote.mock import mock_operation
class MockRemoteExecService(Service, SupportsRemoteExec):
- """
- Mock remote script execution service.
- """
+ """Mock remote script execution service."""
- def __init__(self, config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of mock remote exec service.
@@ -36,9 +35,14 @@ class MockRemoteExecService(Service, SupportsRemoteExec):
Parent service that can provide mixin functions.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, {
- "remote_exec": mock_operation,
- "get_remote_exec_results": mock_operation,
- })
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ {
+ "remote_exec": mock_operation,
+ "get_remote_exec_results": mock_operation,
+ },
+ ),
)
diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py
index 88896c3f16..0d093df48f 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py
@@ -2,28 +2,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-A collection Service functions for mocking managing VMs.
-"""
+"""A collection Service functions for mocking managing VMs."""
from typing import Any, Callable, Dict, List, Optional, Union
from mlos_bench.services.base_service import Service
-from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
from mlos_bench.services.types.host_ops_type import SupportsHostOps
+from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
from mlos_bench.services.types.os_ops_type import SupportsOSOps
from mlos_bench.tests.services.remote.mock import mock_operation
class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps):
- """
- Mock VM service for testing.
- """
+ """Mock VM service for testing."""
- def __init__(self, config: Optional[Dict[str, Any]] = None,
- global_config: Optional[Dict[str, Any]] = None,
- parent: Optional[Service] = None,
- methods: Union[Dict[str, Callable], List[Callable], None] = None):
+ def __init__(
+ self,
+ config: Optional[Dict[str, Any]] = None,
+ global_config: Optional[Dict[str, Any]] = None,
+ parent: Optional[Service] = None,
+ methods: Union[Dict[str, Callable], List[Callable], None] = None,
+ ):
"""
Create a new instance of mock VM services proxy.
@@ -38,23 +37,29 @@ class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, Supports
Parent service that can provide mixin functions.
"""
super().__init__(
- config, global_config, parent,
- self.merge_methods(methods, {
- name: mock_operation for name in (
- # SupportsHostProvisioning:
- "wait_host_deployment",
- "provision_host",
- "deprovision_host",
- "deallocate_host",
- # SupportsHostOps:
- "start_host",
- "stop_host",
- "restart_host",
- "wait_host_operation",
- # SupportsOsOps:
- "shutdown",
- "reboot",
- "wait_os_operation",
- )
- })
+ config,
+ global_config,
+ parent,
+ self.merge_methods(
+ methods,
+ {
+ name: mock_operation
+ for name in (
+ # SupportsHostProvisioning:
+ "wait_host_deployment",
+ "provision_host",
+ "deprovision_host",
+ "deallocate_host",
+ # SupportsHostOps:
+ "start_host",
+ "stop_host",
+ "restart_host",
+ "wait_host_operation",
+ # SupportsOsOps:
+ "shutdown",
+ "reboot",
+ "wait_os_operation",
+ )
+ },
+ ),
)
diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py
index eb285ffc7d..78bd4b1bab 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Common data classes for the SSH service tests.
-"""
+"""Common data classes for the SSH service tests."""
from dataclasses import dataclass
from subprocess import run
@@ -14,20 +12,17 @@ from pytest_docker.plugin import Services as DockerServices
from mlos_bench.tests import check_socket
-
# The SSH test server port and name.
# See Also: docker-compose.yml
SSH_TEST_SERVER_PORT = 2254
-SSH_TEST_SERVER_NAME = 'ssh-server'
-ALT_TEST_SERVER_NAME = 'alt-server'
-REBOOT_TEST_SERVER_NAME = 'reboot-server'
+SSH_TEST_SERVER_NAME = "ssh-server"
+ALT_TEST_SERVER_NAME = "alt-server"
+REBOOT_TEST_SERVER_NAME = "reboot-server"
@dataclass
class SshTestServerInfo:
- """
- A data class for SshTestServerInfo.
- """
+ """A data class for SshTestServerInfo."""
compose_project_name: str
service_name: str
@@ -40,11 +35,19 @@ class SshTestServerInfo:
"""
Gets the port that the SSH test server is listening on.
- Note: this value can change when the service restarts so we can't rely on the DockerServices.
+ Note: this value can change when the service restarts so we can't rely on
+ the DockerServices.
"""
if self._port is None or uncached:
- port_cmd = run(f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}",
- shell=True, check=True, capture_output=True)
+ port_cmd = run(
+ (
+ f"docker compose -p {self.compose_project_name} "
+ f"port {self.service_name} {SSH_TEST_SERVER_PORT}"
+ ),
+ shell=True,
+ check=True,
+ capture_output=True,
+ )
self._port = int(port_cmd.stdout.decode().strip().split(":")[1])
return self._port
@@ -60,6 +63,7 @@ class SshTestServerInfo:
def to_connect_params(self, uncached: bool = False) -> dict:
"""
Convert to a connect_params dict for SshClient.
+
See Also: mlos_bench.services.remote.ssh.ssh_service.SshService._get_connect_params()
"""
return {
diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py
index 1bb910ed77..34006985af 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Fixtures for the SSH service tests.
-"""
+"""Fixtures for the SSH service tests."""
import mlos_bench.tests.services.remote.ssh.fixtures as ssh_fixtures
diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py
index 1706a42969..f4042cf62f 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py
@@ -8,57 +8,58 @@ Fixtures for the SSH service tests.
Note: these are not in the conftest.py file because they are also used by remote_ssh_env_test.py
"""
-from typing import Generator
-from subprocess import run
-
import os
import sys
import tempfile
+from subprocess import run
+from typing import Generator
import pytest
from pytest_docker.plugin import Services as DockerServices
-from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService
-
+from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
from mlos_bench.tests import resolve_host_name
-from mlos_bench.tests.services.remote.ssh import (SshTestServerInfo,
- ALT_TEST_SERVER_NAME,
- REBOOT_TEST_SERVER_NAME,
- SSH_TEST_SERVER_NAME,
- wait_docker_service_socket)
+from mlos_bench.tests.services.remote.ssh import (
+ ALT_TEST_SERVER_NAME,
+ REBOOT_TEST_SERVER_NAME,
+ SSH_TEST_SERVER_NAME,
+ SshTestServerInfo,
+ wait_docker_service_socket,
+)
# pylint: disable=redefined-outer-name
-HOST_DOCKER_NAME = 'host.docker.internal'
+HOST_DOCKER_NAME = "host.docker.internal"
@pytest.fixture(scope="session")
def ssh_test_server_hostname() -> str:
"""Returns the local hostname to use to connect to the test ssh server."""
- if sys.platform != 'win32' and resolve_host_name(HOST_DOCKER_NAME):
+ if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME):
# On Linux, if we're running in a docker container, we can use the
# --add-host (extra_hosts in docker-compose.yml) to refer to the host IP.
return HOST_DOCKER_NAME
# Docker (Desktop) for Windows (WSL2) uses a special networking magic
# to refer to the host machine as `localhost` when exposing ports.
# In all other cases, assume we're executing directly inside conda on the host.
- return 'localhost'
+ return "localhost"
@pytest.fixture(scope="session")
-def ssh_test_server(ssh_test_server_hostname: str,
- docker_compose_project_name: str,
- locked_docker_services: DockerServices) -> Generator[SshTestServerInfo, None, None]:
+def ssh_test_server(
+ ssh_test_server_hostname: str,
+ docker_compose_project_name: str,
+ locked_docker_services: DockerServices,
+) -> Generator[SshTestServerInfo, None, None]:
"""
- Fixture for getting the ssh test server services setup via docker-compose
- using pytest-docker.
+ Fixture for getting the ssh test server services setup via docker-compose using
+ pytest-docker.
Yields the (hostname, port, username, id_rsa_path) of the test server.
- Once the session is over, the docker containers are torn down, and the
- temporary file holding the dynamically generated private key of the test
- server is deleted.
+ Once the session is over, the docker containers are torn down, and the temporary
+ file holding the dynamically generated private key of the test server is deleted.
"""
# Get a copy of the ssh id_rsa key from the test ssh server.
with tempfile.NamedTemporaryFile() as id_rsa_file:
@@ -66,25 +67,42 @@ def ssh_test_server(ssh_test_server_hostname: str,
compose_project_name=docker_compose_project_name,
service_name=SSH_TEST_SERVER_NAME,
hostname=ssh_test_server_hostname,
- username='root',
- id_rsa_path=id_rsa_file.name)
- wait_docker_service_socket(locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port())
+ username="root",
+ id_rsa_path=id_rsa_file.name,
+ )
+ wait_docker_service_socket(
+ locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port()
+ )
id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa"
- docker_cp_cmd = f"docker compose -p {docker_compose_project_name} cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}"
- cmd = run(docker_cp_cmd.split(), check=True, cwd=os.path.dirname(__file__), capture_output=True, text=True)
+ docker_cp_cmd = (
+ f"docker compose -p {docker_compose_project_name} "
+ f"cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}"
+ )
+ cmd = run(
+ docker_cp_cmd.split(),
+ check=True,
+ cwd=os.path.dirname(__file__),
+ capture_output=True,
+ text=True,
+ )
if cmd.returncode != 0:
- raise RuntimeError(f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container "
- + f"[return={cmd.returncode}]: {str(cmd.stderr)}")
+ raise RuntimeError(
+ f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container "
+ + f"[return={cmd.returncode}]: {str(cmd.stderr)}"
+ )
os.chmod(id_rsa_file.name, 0o600)
yield ssh_test_server_info
# NamedTempFile deleted on context exit
@pytest.fixture(scope="session")
-def alt_test_server(ssh_test_server: SshTestServerInfo,
- locked_docker_services: DockerServices) -> SshTestServerInfo:
+def alt_test_server(
+ ssh_test_server: SshTestServerInfo,
+ locked_docker_services: DockerServices,
+) -> SshTestServerInfo:
"""
Fixture for getting the second ssh test server info from the docker-compose.yml.
+
See additional notes in the ssh_test_server fixture above.
"""
# Note: The alt-server uses the same image as the ssh-server container, so
@@ -95,16 +113,22 @@ def alt_test_server(ssh_test_server: SshTestServerInfo,
service_name=ALT_TEST_SERVER_NAME,
hostname=ssh_test_server.hostname,
username=ssh_test_server.username,
- id_rsa_path=ssh_test_server.id_rsa_path)
- wait_docker_service_socket(locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port())
+ id_rsa_path=ssh_test_server.id_rsa_path,
+ )
+ wait_docker_service_socket(
+ locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port()
+ )
return alt_test_server_info
@pytest.fixture(scope="session")
-def reboot_test_server(ssh_test_server: SshTestServerInfo,
- locked_docker_services: DockerServices) -> SshTestServerInfo:
+def reboot_test_server(
+ ssh_test_server: SshTestServerInfo,
+ locked_docker_services: DockerServices,
+) -> SshTestServerInfo:
"""
Fixture for getting the third ssh test server info from the docker-compose.yml.
+
See additional notes in the ssh_test_server fixture above.
"""
# Note: The reboot-server uses the same image as the ssh-server container, so
@@ -115,8 +139,13 @@ def reboot_test_server(ssh_test_server: SshTestServerInfo,
service_name=REBOOT_TEST_SERVER_NAME,
hostname=ssh_test_server.hostname,
username=ssh_test_server.username,
- id_rsa_path=ssh_test_server.id_rsa_path)
- wait_docker_service_socket(locked_docker_services, reboot_test_server_info.hostname, reboot_test_server_info.get_port())
+ id_rsa_path=ssh_test_server.id_rsa_path,
+ )
+ wait_docker_service_socket(
+ locked_docker_services,
+ reboot_test_server_info.hostname,
+ reboot_test_server_info.get_port(),
+ )
return reboot_test_server_info
diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py
index ab25093fd0..c0bb730a1e 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py
@@ -2,34 +2,30 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench.services.remote.ssh.ssh_services
-"""
+"""Tests for mlos_bench.services.remote.ssh.ssh_services."""
+import os
+import tempfile
from contextlib import contextmanager
from os.path import basename
from pathlib import Path
from tempfile import _TemporaryFileWrapper # pylint: disable=import-private-name
from typing import Any, Dict, Generator, List
-import os
-import tempfile
-
import pytest
-from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService
-from mlos_bench.util import path_join
-
+from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
from mlos_bench.tests import are_dir_trees_equal, requires_docker
from mlos_bench.tests.services.remote.ssh import SshTestServerInfo
+from mlos_bench.util import path_join
@contextmanager
def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]:
"""
- Provides a context manager for a temporary file that can be closed and
- still unlinked.
+ Provides a context manager for a temporary file that can be closed and still
+ unlinked.
Since Windows doesn't allow us to reopen the file while it's still open we
need to handle deletion ourselves separately.
@@ -54,8 +50,10 @@ def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None,
@requires_docker
-def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo,
- ssh_fileshare_service: SshFileShareService) -> None:
+def test_ssh_fileshare_single_file(
+ ssh_test_server: SshTestServerInfo,
+ ssh_fileshare_service: SshFileShareService,
+) -> None:
"""Test the SshFileShareService single file download/upload."""
with ssh_fileshare_service:
config = ssh_test_server.to_ssh_service_config()
@@ -68,7 +66,7 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo,
lines = [line + "\n" for line in lines]
# 1. Write a local file and upload it.
- with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file:
+ with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file:
temp_file.writelines(lines)
temp_file.flush()
temp_file.close()
@@ -80,7 +78,7 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo,
)
# 2. Download the remote file and compare the contents.
- with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file:
+ with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file:
temp_file.close()
ssh_fileshare_service.download(
params=config,
@@ -88,14 +86,16 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo,
local_path=temp_file.name,
)
# Download will replace the inode at that name, so we need to reopen the file.
- with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h:
+ with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h:
read_lines = temp_file_h.readlines()
assert read_lines == lines
@requires_docker
-def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo,
- ssh_fileshare_service: SshFileShareService) -> None:
+def test_ssh_fileshare_recursive(
+ ssh_test_server: SshTestServerInfo,
+ ssh_fileshare_service: SshFileShareService,
+) -> None:
"""Test the SshFileShareService recursive download/upload."""
with ssh_fileshare_service:
config = ssh_test_server.to_ssh_service_config()
@@ -115,14 +115,16 @@ def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo,
"bar",
],
}
- files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()}
+ files_lines = {
+ path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()
+ }
with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2:
# Setup the directory structure.
- for (file_path, lines) in files_lines.items():
+ for file_path, lines in files_lines.items():
path = Path(tempdir1, file_path)
path.parent.mkdir(parents=True, exist_ok=True)
- with open(path, mode='w+t', encoding='utf-8') as temp_file:
+ with open(path, mode="w+t", encoding="utf-8") as temp_file:
temp_file.writelines(lines)
temp_file.flush()
assert os.path.getsize(path) > 0
@@ -149,15 +151,17 @@ def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo,
@requires_docker
-def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo,
- ssh_fileshare_service: SshFileShareService) -> None:
+def test_ssh_fileshare_download_file_dne(
+ ssh_test_server: SshTestServerInfo,
+ ssh_fileshare_service: SshFileShareService,
+) -> None:
"""Test the SshFileShareService single file download that doesn't exist."""
with ssh_fileshare_service:
config = ssh_test_server.to_ssh_service_config()
canary_str = "canary"
- with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file:
+ with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file:
temp_file.writelines([canary_str])
temp_file.flush()
temp_file.close()
@@ -168,20 +172,22 @@ def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo,
remote_path="/tmp/file-dne.txt",
local_path=temp_file.name,
)
- with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h:
+ with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h:
read_lines = temp_file_h.readlines()
assert read_lines == [canary_str]
@requires_docker
-def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo,
- ssh_host_service: SshHostService,
- ssh_fileshare_service: SshFileShareService) -> None:
+def test_ssh_fileshare_upload_file_dne(
+ ssh_test_server: SshTestServerInfo,
+ ssh_host_service: SshHostService,
+ ssh_fileshare_service: SshFileShareService,
+) -> None:
"""Test the SshFileShareService single file upload that doesn't exist."""
with ssh_host_service, ssh_fileshare_service:
config = ssh_test_server.to_ssh_service_config()
- path = '/tmp/upload-file-src-dne.txt'
+ path = "/tmp/upload-file-src-dne.txt"
with pytest.raises(OSError):
ssh_fileshare_service.upload(
params=config,
diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py
index 03b7eb56a8..003a8e6433 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py
@@ -2,39 +2,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench.services.remote.ssh.ssh_host_service
-"""
-
-from subprocess import CalledProcessError, run
+"""Tests for mlos_bench.services.remote.ssh.ssh_host_service."""
import logging
import time
+from subprocess import CalledProcessError, run
from pytest_docker.plugin import Services as DockerServices
from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
from mlos_bench.services.remote.ssh.ssh_service import SshClient
-
from mlos_bench.tests import requires_docker
-from mlos_bench.tests.services.remote.ssh import (SshTestServerInfo,
- ALT_TEST_SERVER_NAME,
- REBOOT_TEST_SERVER_NAME,
- SSH_TEST_SERVER_NAME,
- wait_docker_service_socket)
+from mlos_bench.tests.services.remote.ssh import (
+ ALT_TEST_SERVER_NAME,
+ REBOOT_TEST_SERVER_NAME,
+ SSH_TEST_SERVER_NAME,
+ SshTestServerInfo,
+ wait_docker_service_socket,
+)
_LOG = logging.getLogger(__name__)
@requires_docker
-def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo,
- alt_test_server: SshTestServerInfo,
- ssh_host_service: SshHostService) -> None:
+def test_ssh_service_remote_exec(
+ ssh_test_server: SshTestServerInfo,
+ alt_test_server: SshTestServerInfo,
+ ssh_host_service: SshHostService,
+) -> None:
"""
Test the SshHostService remote_exec.
- This checks state of the service across multiple invocations and states to
- check for internal cache handling logic as well.
+ This checks state of the service across multiple invocations and states to check for
+ internal cache handling logic as well.
"""
# pylint: disable=protected-access
with ssh_host_service:
@@ -42,7 +42,9 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo,
connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params())
assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None
- connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get(connection_id)
+ connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get(
+ connection_id
+ )
assert connection_client is None
(status, results_info) = ssh_host_service.remote_exec(
@@ -57,7 +59,9 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo,
assert results["stdout"].strip() == SSH_TEST_SERVER_NAME
# Check that the client caching is behaving as expected.
- connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id]
+ connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[
+ connection_id
+ ]
assert connection is not None
assert connection._username == ssh_test_server.username
assert connection._host == ssh_test_server.hostname
@@ -72,7 +76,8 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo,
script=["hostname"],
config=alt_test_server.to_ssh_service_config(),
env_params={
- "UNUSED": "unused", # unused, making sure it doesn't carry over with cached connections
+ # unused, making sure it doesn't carry over with cached connections
+ "UNUSED": "unused",
},
)
assert status.is_pending()
@@ -91,13 +96,15 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo,
},
)
status, results = ssh_host_service.get_remote_exec_results(results_info)
- assert status.is_failed() # should retain exit code from "false"
+ assert status.is_failed() # should retain exit code from "false"
stdout = str(results["stdout"])
assert stdout.splitlines() == [
"BAR=bar",
"UNUSED=",
]
- connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id]
+ connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[
+ connection_id
+ ]
assert connection._local_port == local_port
# Close the connection (gracefully)
@@ -114,7 +121,7 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo,
config=config,
# Also test interacting with environment_variables.
env_params={
- 'FOO': 'foo',
+ "FOO": "foo",
},
)
status, results = ssh_host_service.get_remote_exec_results(results_info)
@@ -127,20 +134,22 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo,
"BAZ=",
]
# Make sure it looks like we reconnected.
- connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id]
+ connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[
+ connection_id
+ ]
assert connection._local_port != local_port
# Make sure the cache is cleaned up on context exit.
assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0
-def check_ssh_service_reboot(docker_services: DockerServices,
- reboot_test_server: SshTestServerInfo,
- ssh_host_service: SshHostService,
- graceful: bool) -> None:
- """
- Check the SshHostService reboot operation.
- """
+def check_ssh_service_reboot(
+ docker_services: DockerServices,
+ reboot_test_server: SshTestServerInfo,
+ ssh_host_service: SshHostService,
+ graceful: bool,
+) -> None:
+ """Check the SshHostService reboot operation."""
# Note: rebooting changes the port number unfortunately, but makes it
# easier to check for success.
# Also, it may cause issues with other parallel unit tests, so we run it as
@@ -148,11 +157,7 @@ def check_ssh_service_reboot(docker_services: DockerServices,
with ssh_host_service:
reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config(uncached=True)
(status, results_info) = ssh_host_service.remote_exec(
- script=[
- 'echo "sleeping..."',
- 'sleep 30',
- 'echo "should not reach this point"'
- ],
+ script=['echo "sleeping..."', "sleep 30", 'echo "should not reach this point"'],
config=reboot_test_srv_ssh_svc_conf,
env_params={},
)
@@ -161,8 +166,10 @@ def check_ssh_service_reboot(docker_services: DockerServices,
time.sleep(1)
# Now try to restart the server.
- (status, reboot_results_info) = ssh_host_service.reboot(params=reboot_test_srv_ssh_svc_conf,
- force=not graceful)
+ (status, reboot_results_info) = ssh_host_service.reboot(
+ params=reboot_test_srv_ssh_svc_conf,
+ force=not graceful,
+ )
assert status.is_pending()
(status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info)
@@ -183,19 +190,34 @@ def check_ssh_service_reboot(docker_services: DockerServices,
time.sleep(1)
# try to reconnect and see if the port changed
try:
- run_res = run("docker ps | grep mlos_bench-test- | grep reboot", shell=True, capture_output=True, check=False)
+ run_res = run(
+ "docker ps | grep mlos_bench-test- | grep reboot",
+ shell=True,
+ capture_output=True,
+ check=False,
+ )
print(run_res.stdout.decode())
print(run_res.stderr.decode())
- reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config(uncached=True)
- if reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"]:
+ reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config(
+ uncached=True
+ )
+ if (
+ reboot_test_srv_ssh_svc_conf_new["ssh_port"]
+ != reboot_test_srv_ssh_svc_conf["ssh_port"]
+ ):
break
except CalledProcessError as ex:
_LOG.info("Failed to check port for reboot test server: %s", ex)
- assert reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"]
+ assert (
+ reboot_test_srv_ssh_svc_conf_new["ssh_port"]
+ != reboot_test_srv_ssh_svc_conf["ssh_port"]
+ )
- wait_docker_service_socket(docker_services,
- reboot_test_server.hostname,
- reboot_test_srv_ssh_svc_conf_new["ssh_port"])
+ wait_docker_service_socket(
+ docker_services,
+ reboot_test_server.hostname,
+ reboot_test_srv_ssh_svc_conf_new["ssh_port"],
+ )
(status, results_info) = ssh_host_service.remote_exec(
script=["hostname"],
@@ -208,12 +230,22 @@ def check_ssh_service_reboot(docker_services: DockerServices,
@requires_docker
-def test_ssh_service_reboot(locked_docker_services: DockerServices,
- reboot_test_server: SshTestServerInfo,
- ssh_host_service: SshHostService) -> None:
- """
- Test the SshHostService reboot operation.
- """
+def test_ssh_service_reboot(
+ locked_docker_services: DockerServices,
+ reboot_test_server: SshTestServerInfo,
+ ssh_host_service: SshHostService,
+) -> None:
+ """Test the SshHostService reboot operation."""
# Grouped together to avoid parallel runner interactions.
- check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=True)
- check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=False)
+ check_ssh_service_reboot(
+ locked_docker_services,
+ reboot_test_server,
+ ssh_host_service,
+ graceful=True,
+ )
+ check_ssh_service_reboot(
+ locked_docker_services,
+ reboot_test_server,
+ ssh_host_service,
+ graceful=False,
+ )
diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py
index fd0804ba15..5b335477a9 100644
--- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py
+++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py
@@ -2,34 +2,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench.services.remote.ssh.SshService base class.
-"""
+"""Tests for mlos_bench.services.remote.ssh.SshService base class."""
import asyncio
-from importlib.metadata import version, PackageNotFoundError
import time
-
+from importlib.metadata import PackageNotFoundError, version
from subprocess import run
from threading import Thread
import pytest
from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture
-from mlos_bench.services.remote.ssh.ssh_service import SshService
-from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService
-
-from mlos_bench.tests import requires_docker, requires_ssh, check_socket, resolve_host_name
-from mlos_bench.tests.services.remote.ssh import SshTestServerInfo, ALT_TEST_SERVER_NAME, SSH_TEST_SERVER_NAME
-
+from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
+from mlos_bench.services.remote.ssh.ssh_service import SshService
+from mlos_bench.tests import (
+ check_socket,
+ requires_docker,
+ requires_ssh,
+ resolve_host_name,
+)
+from mlos_bench.tests.services.remote.ssh import (
+ ALT_TEST_SERVER_NAME,
+ SSH_TEST_SERVER_NAME,
+ SshTestServerInfo,
+)
if version("pytest") >= "8.0.0":
try:
# We replaced pytest-lazy-fixture with pytest-lazy-fixtures:
# https://github.com/TvoroG/pytest-lazy-fixture/issues/65
if version("pytest-lazy-fixture"):
- raise UserWarning("pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it.")
+ raise UserWarning(
+ "pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it."
+ )
except PackageNotFoundError:
# OK: pytest-lazy-fixture not installed
pass
@@ -37,12 +43,14 @@ if version("pytest") >= "8.0.0":
@requires_docker
@requires_ssh
-@pytest.mark.parametrize(["ssh_test_server_info", "server_name"], [
- (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME),
- (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME),
-])
-def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo,
- server_name: str) -> None:
+@pytest.mark.parametrize(
+ ["ssh_test_server_info", "server_name"],
+ [
+ (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME),
+ (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME),
+ ],
+)
+def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, server_name: str) -> None:
"""Check for the pytest-docker ssh test infra."""
assert ssh_test_server_info.service_name == server_name
@@ -51,20 +59,22 @@ def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo,
local_port = ssh_test_server_info.get_port()
assert check_socket(ip_addr, local_port)
- ssh_cmd = "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " \
- + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " \
+ ssh_cmd = (
+ "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new "
+ + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} "
+ f"-p {local_port} {ssh_test_server_info.hostname} hostname"
- cmd = run(ssh_cmd.split(),
- capture_output=True,
- text=True,
- check=True)
+ )
+ cmd = run(ssh_cmd.split(), capture_output=True, text=True, check=True)
assert cmd.stdout.strip() == server_name
-@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0")
+@pytest.mark.filterwarnings(
+ "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0"
+)
def test_ssh_service_context_handler() -> None:
"""
Test the SSH service context manager handling.
+
See Also: test_event_loop_context
"""
# pylint: disable=protected-access
@@ -81,7 +91,9 @@ def test_ssh_service_context_handler() -> None:
# After we enter the SshService instance context, we should have a background thread.
with ssh_host_service:
assert ssh_host_service._in_context
- assert isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) # type: ignore[unreachable]
+ assert ( # type: ignore[unreachable]
+ isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread)
+ )
# Give the thread a chance to start.
# Mostly important on the underpowered Windows CI machines.
time.sleep(0.25)
@@ -94,17 +106,23 @@ def test_ssh_service_context_handler() -> None:
with ssh_fileshare_service:
assert ssh_fileshare_service._in_context
assert ssh_host_service._in_context
- assert SshService._EVENT_LOOP_CONTEXT._event_loop_thread \
- is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread \
+ assert (
+ SshService._EVENT_LOOP_CONTEXT._event_loop_thread
+ is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread
is ssh_fileshare_service._EVENT_LOOP_CONTEXT._event_loop_thread
- assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \
- is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \
+ )
+ assert (
+ SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE
+ is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE
is ssh_fileshare_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE
+ )
assert not ssh_fileshare_service._in_context
# And that instance should be unusable after we are outside the context.
- with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"):
- future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result='foo'))
+ with pytest.raises(
+ AssertionError
+ ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"):
+ future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result="foo"))
raise ValueError(f"Future should not have been available to wait on {future.result()}")
# The background thread should remain running since we have another context still open.
@@ -112,6 +130,6 @@ def test_ssh_service_context_handler() -> None:
assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None
-if __name__ == '__main__':
+if __name__ == "__main__":
# For debugging in Windows which has issues with pytest detection in vscode.
pytest.main(["-n1", "--dist=no", "-k", "test_ssh_service_background_thread"])
diff --git a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py
index 736a3d5ef2..088223279b 100644
--- a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py
+++ b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py
@@ -2,21 +2,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for Service method registering.
-"""
+"""Unit tests for Service method registering."""
import pytest
from mlos_bench.services.base_service import Service
-
-from mlos_bench.tests.services.mock_service import SupportsSomeMethod, MockServiceBase, MockServiceChild
+from mlos_bench.tests.services.mock_service import (
+ MockServiceBase,
+ MockServiceChild,
+ SupportsSomeMethod,
+)
def test_service_method_register_without_constructor() -> None:
- """
- Test registering a method without a constructor.
- """
+ """Test registering a method without a constructor."""
# pylint: disable=protected-access
some_base_service = MockServiceBase()
some_child_service = MockServiceChild()
diff --git a/mlos_bench/mlos_bench/tests/storage/__init__.py b/mlos_bench/mlos_bench/tests/storage/__init__.py
index ca5a3b33dd..c3b294cae1 100644
--- a/mlos_bench/mlos_bench/tests/storage/__init__.py
+++ b/mlos_bench/mlos_bench/tests/storage/__init__.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench storage.
-"""
+"""Tests for mlos_bench storage."""
CONFIG_COUNT = 10
CONFIG_TRIAL_REPEAT_COUNT = 3
diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py
index 2c16df65c4..52b0fdcd53 100644
--- a/mlos_bench/mlos_bench/tests/storage/conftest.py
+++ b/mlos_bench/mlos_bench/tests/storage/conftest.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Export test fixtures for mlos_bench storage.
-"""
+"""Export test fixtures for mlos_bench storage."""
import mlos_bench.tests.storage.sql.fixtures as sql_storage_fixtures
@@ -19,7 +17,9 @@ exp_no_tunables_storage = sql_storage_fixtures.exp_no_tunables_storage
mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage
exp_storage_with_trials = sql_storage_fixtures.exp_storage_with_trials
exp_no_tunables_storage_with_trials = sql_storage_fixtures.exp_no_tunables_storage_with_trials
-mixed_numerics_exp_storage_with_trials = sql_storage_fixtures.mixed_numerics_exp_storage_with_trials
+mixed_numerics_exp_storage_with_trials = (
+ sql_storage_fixtures.mixed_numerics_exp_storage_with_trials
+)
exp_data = sql_storage_fixtures.exp_data
exp_no_tunables_data = sql_storage_fixtures.exp_no_tunables_data
mixed_numerics_exp_data = sql_storage_fixtures.mixed_numerics_exp_data
diff --git a/mlos_bench/mlos_bench/tests/storage/exp_context_test.py b/mlos_bench/mlos_bench/tests/storage/exp_context_test.py
index e2b1d7c26b..f0bfa1d127 100644
--- a/mlos_bench/mlos_bench/tests/storage/exp_context_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/exp_context_test.py
@@ -2,16 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for the storage subsystem.
-"""
+"""Unit tests for the storage subsystem."""
from mlos_bench.storage.base_storage import Storage
def test_exp_context(exp_storage: Storage.Experiment) -> None:
- """
- Try to retrieve old experimental data from the empty storage.
- """
+ """Try to retrieve old experimental data from the empty storage."""
# pylint: disable=protected-access
assert exp_storage._in_context
diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py
index c37dc433b0..256c5d3b38 100644
--- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py
@@ -2,44 +2,49 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for loading the experiment metadata.
-"""
+"""Unit tests for loading the experiment metadata."""
-from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.base_experiment_data import ExperimentData
-from mlos_bench.tunables.tunable_groups import TunableGroups
-
+from mlos_bench.storage.base_storage import Storage
from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT
+from mlos_bench.tunables.tunable_groups import TunableGroups
def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) -> None:
- """
- Try to retrieve old experimental data from the empty storage.
- """
+ """Try to retrieve old experimental data from the empty storage."""
exp = storage.experiments[exp_storage.experiment_id]
assert exp.experiment_id == exp_storage.experiment_id
assert exp.description == exp_storage.description
assert exp.objectives == exp_storage.opt_targets
-def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: ExperimentData) -> None:
- """Tests the root_env_config property of ExperimentData"""
+def test_exp_data_root_env_config(
+ exp_storage: Storage.Experiment,
+ exp_data: ExperimentData,
+) -> None:
+ """Tests the root_env_config property of ExperimentData."""
# pylint: disable=protected-access
- assert exp_data.root_env_config == (exp_storage._root_env_config, exp_storage._git_repo, exp_storage._git_commit)
+ assert exp_data.root_env_config == (
+ exp_storage._root_env_config,
+ exp_storage._git_repo,
+ exp_storage._git_commit,
+ )
-def test_exp_trial_data_objectives(storage: Storage,
- exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups) -> None:
- """
- Start a new trial and check the storage for the trial data.
- """
+def test_exp_trial_data_objectives(
+ storage: Storage,
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+) -> None:
+ """Start a new trial and check the storage for the trial data."""
- trial_opt_new = exp_storage.new_trial(tunable_groups, config={
- "opt_target": "some-other-target",
- "opt_direction": "max",
- })
+ trial_opt_new = exp_storage.new_trial(
+ tunable_groups,
+ config={
+ "opt_target": "some-other-target",
+ "opt_direction": "max",
+ },
+ )
assert trial_opt_new.config() == {
"experiment_id": exp_storage.experiment_id,
"trial_id": trial_opt_new.trial_id,
@@ -47,10 +52,13 @@ def test_exp_trial_data_objectives(storage: Storage,
"opt_direction": "max",
}
- trial_opt_old = exp_storage.new_trial(tunable_groups, config={
- "opt_target": "back-compat",
- # "opt_direction": "max", # missing
- })
+ trial_opt_old = exp_storage.new_trial(
+ tunable_groups,
+ config={
+ "opt_target": "back-compat",
+ # "opt_direction": "max", # missing
+ },
+ )
assert trial_opt_old.config() == {
"experiment_id": exp_storage.experiment_id,
"trial_id": trial_opt_old.trial_id,
@@ -68,21 +76,26 @@ def test_exp_trial_data_objectives(storage: Storage,
def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None:
- """Tests the results_df property of ExperimentData"""
+ """Tests the results_df property of ExperimentData."""
results_df = exp_data.results_df
expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT
assert len(results_df) == expected_trials_count
assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT
assert len(results_df["trial_id"].unique()) == expected_trials_count
obj_target = next(iter(exp_data.objectives))
- assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count
+ assert (
+ len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count
+ )
(tunable, _covariant_group) = next(iter(tunable_groups))
- assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_trials_count
+ assert (
+ len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name])
+ == expected_trials_count
+ )
def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None:
"""
- Tests the tunable_config_trial_group_id property of ExperimentData.results_df
+ Tests the tunable_config_trial_group_id property of ExperimentData.results_df.
See Also: test_exp_trial_data_tunable_config_trial_group_id()
"""
@@ -110,32 +123,34 @@ def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: Experime
def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None:
"""
- Tests the tunable_config_trial_groups property of ExperimentData
+ Tests the tunable_config_trial_groups property of ExperimentData.
This tests bulk loading of the tunable_config_trial_groups.
"""
# Should be keyed by config_id.
assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1))
# Which should match the objects.
- assert [config_trial_group.tunable_config_id
- for config_trial_group in exp_data.tunable_config_trial_groups.values()
- ] == list(range(1, CONFIG_COUNT + 1))
+ assert [
+ config_trial_group.tunable_config_id
+ for config_trial_group in exp_data.tunable_config_trial_groups.values()
+ ] == list(range(1, CONFIG_COUNT + 1))
# And the tunable_config_trial_group_id should also match the minimum trial_id.
- assert [config_trial_group.tunable_config_trial_group_id
- for config_trial_group in exp_data.tunable_config_trial_groups.values()
- ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT))
+ assert [
+ config_trial_group.tunable_config_trial_group_id
+ for config_trial_group in exp_data.tunable_config_trial_groups.values()
+ ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT))
def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None:
- """Tests the tunable_configs property of ExperimentData"""
+ """Tests the tunable_configs property of ExperimentData."""
# Should be keyed by config_id.
assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1))
# Which should match the objects.
- assert [config.tunable_config_id
- for config in exp_data.tunable_configs.values()
- ] == list(range(1, CONFIG_COUNT + 1))
+ assert [config.tunable_config_id for config in exp_data.tunable_configs.values()] == list(
+ range(1, CONFIG_COUNT + 1)
+ )
def test_exp_data_default_config_id(exp_data: ExperimentData) -> None:
- """Tests the default_tunable_config_id property of ExperimentData"""
+ """Tests the default_tunable_config_id property of ExperimentData."""
assert exp_data.default_tunable_config_id == 1
diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py
index cd6b17be74..6a7d05bb2a 100644
--- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py
@@ -2,26 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for the storage subsystem.
-"""
+"""Unit tests for the storage subsystem."""
from datetime import datetime, tzinfo
from typing import Optional
+import pytest
from pytz import UTC
-import pytest
-
from mlos_bench.environments.status import Status
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_storage import Storage
from mlos_bench.tests import ZONE_INFO
+from mlos_bench.tunables.tunable_groups import TunableGroups
def test_exp_load_empty(exp_storage: Storage.Experiment) -> None:
- """
- Try to retrieve old experimental data from the empty storage.
- """
+ """Try to retrieve old experimental data from the empty storage."""
(trial_ids, configs, scores, status) = exp_storage.load()
assert not trial_ids
assert not configs
@@ -30,20 +25,18 @@ def test_exp_load_empty(exp_storage: Storage.Experiment) -> None:
def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None:
- """
- Try to retrieve pending experiments from the empty storage.
- """
+ """Try to retrieve pending experiments from the empty storage."""
trials = list(exp_storage.pending_trials(datetime.now(UTC), running=True))
assert not trials
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_exp_trial_pending(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- zone_info: Optional[tzinfo]) -> None:
- """
- Start a trial and check that it is pending.
- """
+def test_exp_trial_pending(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
+ """Start a trial and check that it is pending."""
trial = exp_storage.new_trial(tunable_groups)
(pending,) = list(exp_storage.pending_trials(datetime.now(zone_info), running=True))
assert pending.trial_id == trial.trial_id
@@ -51,14 +44,14 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment,
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_exp_trial_pending_many(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- zone_info: Optional[tzinfo]) -> None:
- """
- Start THREE trials and check that both are pending.
- """
- config1 = tunable_groups.copy().assign({'idle': 'mwait'})
- config2 = tunable_groups.copy().assign({'idle': 'noidle'})
+def test_exp_trial_pending_many(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
+ """Start THREE trials and check that both are pending."""
+ config1 = tunable_groups.copy().assign({"idle": "mwait"})
+ config2 = tunable_groups.copy().assign({"idle": "noidle"})
trial_ids = {
exp_storage.new_trial(config1).trial_id,
exp_storage.new_trial(config2).trial_id,
@@ -73,12 +66,12 @@ def test_exp_trial_pending_many(exp_storage: Storage.Experiment,
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_exp_trial_pending_fail(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- zone_info: Optional[tzinfo]) -> None:
- """
- Start a trial, fail it, and and check that it is NOT pending.
- """
+def test_exp_trial_pending_fail(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
+ """Start a trial, fail it, and and check that it is NOT pending."""
trial = exp_storage.new_trial(tunable_groups)
trial.update(Status.FAILED, datetime.now(zone_info))
trials = list(exp_storage.pending_trials(datetime.now(zone_info), running=True))
@@ -86,12 +79,12 @@ def test_exp_trial_pending_fail(exp_storage: Storage.Experiment,
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_exp_trial_success(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- zone_info: Optional[tzinfo]) -> None:
- """
- Start a trial, finish it successfully, and and check that it is NOT pending.
- """
+def test_exp_trial_success(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
+ """Start a trial, finish it successfully, and and check that it is NOT pending."""
trial = exp_storage.new_trial(tunable_groups)
trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9})
trials = list(exp_storage.pending_trials(datetime.now(zone_info), running=True))
@@ -99,34 +92,36 @@ def test_exp_trial_success(exp_storage: Storage.Experiment,
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_exp_trial_update_categ(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- zone_info: Optional[tzinfo]) -> None:
- """
- Update the trial with multiple metrics, some of which are categorical.
- """
+def test_exp_trial_update_categ(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
+ """Update the trial with multiple metrics, some of which are categorical."""
trial = exp_storage.new_trial(tunable_groups)
trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"})
assert exp_storage.load() == (
[trial.trial_id],
- [{
- 'idle': 'halt',
- 'kernel_sched_latency_ns': '2000000',
- 'kernel_sched_migration_cost_ns': '-1',
- 'vmSize': 'Standard_B4ms'
- }],
+ [
+ {
+ "idle": "halt",
+ "kernel_sched_latency_ns": "2000000",
+ "kernel_sched_migration_cost_ns": "-1",
+ "vmSize": "Standard_B4ms",
+ }
+ ],
[{"score": "99.9", "benchmark": "test"}],
- [Status.SUCCEEDED]
+ [Status.SUCCEEDED],
)
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_exp_trial_update_twice(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- zone_info: Optional[tzinfo]) -> None:
- """
- Update the trial status twice and receive an error.
- """
+def test_exp_trial_update_twice(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
+ """Update the trial status twice and receive an error."""
trial = exp_storage.new_trial(tunable_groups)
trial.update(Status.FAILED, datetime.now(zone_info))
with pytest.raises(RuntimeError):
@@ -134,11 +129,14 @@ def test_exp_trial_update_twice(exp_storage: Storage.Experiment,
@pytest.mark.parametrize(("zone_info"), ZONE_INFO)
-def test_exp_trial_pending_3(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- zone_info: Optional[tzinfo]) -> None:
+def test_exp_trial_pending_3(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ zone_info: Optional[tzinfo],
+) -> None:
"""
Start THREE trials, let one succeed, another one fail and keep one not updated.
+
Check that one is still pending another one can be loaded into the optimizer.
"""
score = 99.9
diff --git a/mlos_bench/mlos_bench/tests/storage/sql/__init__.py b/mlos_bench/mlos_bench/tests/storage/sql/__init__.py
index 61b8ec8df4..d17a448b5e 100644
--- a/mlos_bench/mlos_bench/tests/storage/sql/__init__.py
+++ b/mlos_bench/mlos_bench/tests/storage/sql/__init__.py
@@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_bench sql storage.
-"""
+"""Tests for mlos_bench sql storage."""
diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
index b86ebe6c18..8a9065e436 100644
--- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
+++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
@@ -2,42 +2,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Test fixtures for mlos_bench storage.
-"""
+"""Test fixtures for mlos_bench storage."""
from datetime import datetime
-from random import random, seed as rand_seed
+from random import random
+from random import seed as rand_seed
from typing import Generator, Optional
+import pytest
from pytz import UTC
-import pytest
-
from mlos_bench.environments.status import Status
+from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.storage.sql.storage import SqlStorage
-from mlos_bench.optimizers.mock_optimizer import MockOptimizer
-from mlos_bench.tunables.tunable_groups import TunableGroups
-
from mlos_bench.tests import SEED
from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT
+from mlos_bench.tunables.tunable_groups import TunableGroups
# pylint: disable=redefined-outer-name
@pytest.fixture
def storage() -> SqlStorage:
- """
- Test fixture for in-memory SQLite3 storage.
- """
+ """Test fixture for in-memory SQLite3 storage."""
return SqlStorage(
service=None,
config={
"drivername": "sqlite",
"database": ":memory:",
# "database": "mlos_bench.pytest.db",
- }
+ },
)
@@ -48,6 +43,7 @@ def exp_storage(
) -> Generator[SqlStorage.Experiment, None, None]:
"""
Test fixture for Experiment using in-memory SQLite3 storage.
+
Note: It has already entered the context upon return.
"""
with storage.experiment(
@@ -69,6 +65,7 @@ def exp_no_tunables_storage(
) -> Generator[SqlStorage.Experiment, None, None]:
"""
Test fixture for Experiment using in-memory SQLite3 storage.
+
Note: It has already entered the context upon return.
"""
empty_config: dict = {}
@@ -91,7 +88,9 @@ def mixed_numerics_exp_storage(
mixed_numerics_tunable_groups: TunableGroups,
) -> Generator[SqlStorage.Experiment, None, None]:
"""
- Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3 storage.
+ Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3
+ storage.
+
Note: It has already entered the context upon return.
"""
with storage.experiment(
@@ -107,10 +106,11 @@ def mixed_numerics_exp_storage(
assert not exp._in_context
-def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> SqlStorage.Experiment:
- """
- Generates data by doing a simulated run of the given experiment.
- """
+def _dummy_run_exp(
+ exp: SqlStorage.Experiment,
+ tunable_name: Optional[str],
+) -> SqlStorage.Experiment:
+ """Generates data by doing a simulated run of the given experiment."""
# Add some trials to that experiment.
# Note: we're just fabricating some made up function for the ML libraries to try and learn.
base_score = 10.0
@@ -120,24 +120,31 @@ def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> S
(tunable_min, tunable_max) = tunable.range
tunable_range = tunable_max - tunable_min
rand_seed(SEED)
- opt = MockOptimizer(tunables=exp.tunables, config={
- "seed": SEED,
- # This should be the default, so we leave it omitted for now to test the default.
- # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params)
- # "start_with_defaults": True,
- })
+ opt = MockOptimizer(
+ tunables=exp.tunables,
+ config={
+ "seed": SEED,
+ # This should be the default, so we leave it omitted for now to test the default.
+ # But the test logic relies on this (e.g., trial 1 is config 1 is the
+ # default values for the tunable params)
+ # "start_with_defaults": True,
+ },
+ )
assert opt.start_with_defaults
for config_i in range(CONFIG_COUNT):
tunables = opt.suggest()
for repeat_j in range(CONFIG_TRIAL_REPEAT_COUNT):
- trial = exp.new_trial(tunables=tunables.copy(), config={
- "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1,
- **{
- f"opt_{key}_{i}": val
- for (i, opt_target) in enumerate(exp.opt_targets.items())
- for (key, val) in zip(["target", "direction"], opt_target)
- }
- })
+ trial = exp.new_trial(
+ tunables=tunables.copy(),
+ config={
+ "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1,
+ **{
+ f"opt_{key}_{i}": val
+ for (i, opt_target) in enumerate(exp.opt_targets.items())
+ for (key, val) in zip(["target", "direction"], opt_target)
+ },
+ },
+ )
if exp.tunables:
assert trial.tunable_config_id == config_i + 1
else:
@@ -148,62 +155,72 @@ def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> S
else:
tunable_value_norm = 0
timestamp = datetime.now(UTC)
- trial.update_telemetry(status=Status.RUNNING, timestamp=timestamp, metrics=[
- (timestamp, "some-metric", tunable_value_norm + random() / 100),
- ])
- trial.update(Status.SUCCEEDED, timestamp, metrics={
- # Give some variance on the score.
- # And some influence from the tunable value.
- "score": tunable_value_norm + random() / 100
- })
+ trial.update_telemetry(
+ status=Status.RUNNING,
+ timestamp=timestamp,
+ metrics=[
+ (timestamp, "some-metric", tunable_value_norm + random() / 100),
+ ],
+ )
+ trial.update(
+ Status.SUCCEEDED,
+ timestamp,
+ metrics={
+ # Give some variance on the score.
+ # And some influence from the tunable value.
+ "score": tunable_value_norm
+ + random() / 100
+ },
+ )
return exp
@pytest.fixture
def exp_storage_with_trials(exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment:
- """
- Test fixture for Experiment using in-memory SQLite3 storage.
- """
+ """Test fixture for Experiment using in-memory SQLite3 storage."""
return _dummy_run_exp(exp_storage, tunable_name="kernel_sched_latency_ns")
@pytest.fixture
-def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Experiment) -> SqlStorage.Experiment:
- """
- Test fixture for Experiment using in-memory SQLite3 storage.
- """
+def exp_no_tunables_storage_with_trials(
+ exp_no_tunables_storage: SqlStorage.Experiment,
+) -> SqlStorage.Experiment:
+ """Test fixture for Experiment using in-memory SQLite3 storage."""
assert not exp_no_tunables_storage.tunables
return _dummy_run_exp(exp_no_tunables_storage, tunable_name=None)
@pytest.fixture
-def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment:
- """
- Test fixture for Experiment using in-memory SQLite3 storage.
- """
+def mixed_numerics_exp_storage_with_trials(
+ mixed_numerics_exp_storage: SqlStorage.Experiment,
+) -> SqlStorage.Experiment:
+ """Test fixture for Experiment using in-memory SQLite3 storage."""
tunable = next(iter(mixed_numerics_exp_storage.tunables))[0]
return _dummy_run_exp(mixed_numerics_exp_storage, tunable_name=tunable.name)
@pytest.fixture
-def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData:
- """
- Test fixture for ExperimentData.
- """
+def exp_data(
+ storage: SqlStorage,
+ exp_storage_with_trials: SqlStorage.Experiment,
+) -> ExperimentData:
+ """Test fixture for ExperimentData."""
return storage.experiments[exp_storage_with_trials.experiment_id]
@pytest.fixture
-def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment) -> ExperimentData:
- """
- Test fixture for ExperimentData with no tunable configs.
- """
+def exp_no_tunables_data(
+ storage: SqlStorage,
+ exp_no_tunables_storage_with_trials: SqlStorage.Experiment,
+) -> ExperimentData:
+ """Test fixture for ExperimentData with no tunable configs."""
return storage.experiments[exp_no_tunables_storage_with_trials.experiment_id]
@pytest.fixture
-def mixed_numerics_exp_data(storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData:
- """
- Test fixture for ExperimentData with mixed numerical tunable types.
- """
+def mixed_numerics_exp_data(
+ storage: SqlStorage,
+ mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment,
+) -> ExperimentData:
+ """Test fixture for ExperimentData with mixed numerical tunable types."""
return storage.experiments[mixed_numerics_exp_storage_with_trials.experiment_id]
diff --git a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py
index ba965ed3c6..b5f4778a74 100644
--- a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for saving and retrieving additional parameters of pending trials.
-"""
+"""Unit tests for saving and retrieving additional parameters of pending trials."""
from datetime import datetime
from pytz import UTC
@@ -13,11 +11,8 @@ from mlos_bench.storage.base_storage import Storage
from mlos_bench.tunables.tunable_groups import TunableGroups
-def test_exp_trial_pending(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups) -> None:
- """
- Schedule a trial and check that it is pending and has the right configuration.
- """
+def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None:
+ """Schedule a trial and check that it is pending and has the right configuration."""
config = {"location": "westus2", "num_repeats": 100}
trial = exp_storage.new_trial(tunable_groups, config=config)
(pending,) = list(exp_storage.pending_trials(datetime.now(UTC), running=True))
@@ -31,13 +26,11 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment,
}
-def test_exp_trial_configs(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups) -> None:
+def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None:
+ """Start multiple trials with two different configs and check that we store only two
+ config objects in the DB.
"""
- Start multiple trials with two different configs and check that
- we store only two config objects in the DB.
- """
- config1 = tunable_groups.copy().assign({'idle': 'mwait'})
+ config1 = tunable_groups.copy().assign({"idle": "mwait"})
trials1 = [
exp_storage.new_trial(config1),
exp_storage.new_trial(config1),
@@ -46,7 +39,7 @@ def test_exp_trial_configs(exp_storage: Storage.Experiment,
assert trials1[0].tunable_config_id == trials1[1].tunable_config_id
assert trials1[0].tunable_config_id == trials1[2].tunable_config_id
- config2 = tunable_groups.copy().assign({'idle': 'halt'})
+ config2 = tunable_groups.copy().assign({"idle": "halt"})
trials2 = [
exp_storage.new_trial(config2),
exp_storage.new_trial(config2),
@@ -67,9 +60,7 @@ def test_exp_trial_configs(exp_storage: Storage.Experiment,
def test_exp_trial_no_config(exp_no_tunables_storage: Storage.Experiment) -> None:
- """
- Schedule a trial that has an empty tunable groups config.
- """
+ """Schedule a trial that has an empty tunable groups config."""
empty_config: dict = {}
tunable_groups = TunableGroups(config=empty_config)
trial = exp_no_tunables_storage.new_trial(tunable_groups, config=empty_config)
diff --git a/mlos_bench/mlos_bench/tests/storage/trial_data_test.py b/mlos_bench/mlos_bench/tests/storage/trial_data_test.py
index c3703c9a13..9fe59b426b 100644
--- a/mlos_bench/mlos_bench/tests/storage/trial_data_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/trial_data_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for loading the trial metadata.
-"""
+"""Unit tests for loading the trial metadata."""
from datetime import datetime
@@ -15,9 +13,7 @@ from mlos_bench.storage.base_experiment_data import ExperimentData
def test_exp_trial_data(exp_data: ExperimentData) -> None:
- """
- Check expected return values for TrialData.
- """
+ """Check expected return values for TrialData."""
trial_id = 1
expected_config_id = 1
trial = exp_data.trials[trial_id]
diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py
index 21f857ae45..0a4d72480d 100644
--- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py
@@ -2,11 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for scheduling trials for some future time.
-"""
+"""Unit tests for scheduling trials for some future time."""
from datetime import datetime, timedelta
-
from typing import Iterator, Set
from pytz import UTC
@@ -17,16 +14,13 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]:
- """
- Extract trial IDs from a list of trials.
- """
+ """Extract trial IDs from a list of trials."""
return set(t.trial_id for t in trials)
-def test_schedule_trial(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups) -> None:
- """
- Schedule several trials for future execution and retrieve them later at certain timestamps.
+def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None:
+ """Schedule several trials for future execution and retrieve them later at certain
+ timestamps.
"""
timestamp = datetime.now(UTC)
timedelta_1min = timedelta(minutes=1)
@@ -45,16 +39,14 @@ def test_schedule_trial(exp_storage: Storage.Experiment,
# Scheduler side: get trials ready to run at certain timestamps:
# Pretend 1 minute has passed, get trials scheduled to run:
- pending_ids = _trial_ids(
- exp_storage.pending_trials(timestamp + timedelta_1min, running=False))
+ pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False))
assert pending_ids == {
trial_now1.trial_id,
trial_now2.trial_id,
}
# Get trials scheduled to run within the next 1 hour:
- pending_ids = _trial_ids(
- exp_storage.pending_trials(timestamp + timedelta_1hr, running=False))
+ pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False))
assert pending_ids == {
trial_now1.trial_id,
trial_now2.trial_id,
@@ -63,7 +55,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment,
# Get trials scheduled to run within the next 3 hours:
pending_ids = _trial_ids(
- exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False))
+ exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)
+ )
assert pending_ids == {
trial_now1.trial_id,
trial_now2.trial_id,
@@ -85,7 +78,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment,
# Get trials scheduled to run within the next 3 hours:
pending_ids = _trial_ids(
- exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False))
+ exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)
+ )
assert pending_ids == {
trial_1h.trial_id,
trial_2h.trial_id,
@@ -93,7 +87,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment,
# Get trials scheduled to run OR running within the next 3 hours:
pending_ids = _trial_ids(
- exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True))
+ exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)
+ )
assert pending_ids == {
trial_now1.trial_id,
trial_now2.trial_id,
@@ -115,7 +110,9 @@ def test_schedule_trial(exp_storage: Storage.Experiment,
assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED]
# Get only trials completed after trial_now2:
- (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(last_trial_id=trial_now2.trial_id)
+ (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(
+ last_trial_id=trial_now2.trial_id
+ )
assert trial_ids == [trial_1h.trial_id]
assert len(trial_configs) == len(trial_scores) == 1
assert trial_status == [Status.SUCCEEDED]
diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py
index deea02128f..72f73724db 100644
--- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py
@@ -2,22 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for saving and restoring the telemetry data.
-"""
+"""Unit tests for saving and restoring the telemetry data."""
from datetime import datetime, timedelta, tzinfo
from typing import Any, List, Optional, Tuple
+import pytest
from pytz import UTC
-import pytest
-
from mlos_bench.environments.status import Status
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_storage import Storage
-from mlos_bench.util import nullable
-
from mlos_bench.tests import ZONE_INFO
+from mlos_bench.tunables.tunable_groups import TunableGroups
+from mlos_bench.util import nullable
# pylint: disable=redefined-outer-name
@@ -33,33 +29,34 @@ def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, st
"""
timestamp1 = datetime.now(zone_info)
timestamp2 = timestamp1 + timedelta(seconds=1)
- return sorted([
- (timestamp1, "cpu_load", 10.1),
- (timestamp1, "memory", 20),
- (timestamp1, "setup", "prod"),
- (timestamp2, "cpu_load", 30.1),
- (timestamp2, "memory", 40),
- (timestamp2, "setup", "prod"),
- ])
+ return sorted(
+ [
+ (timestamp1, "cpu_load", 10.1),
+ (timestamp1, "memory", 20),
+ (timestamp1, "setup", "prod"),
+ (timestamp2, "cpu_load", 30.1),
+ (timestamp2, "memory", 40),
+ (timestamp2, "setup", "prod"),
+ ]
+ )
-def _telemetry_str(data: List[Tuple[datetime, str, Any]]
- ) -> List[Tuple[datetime, str, Optional[str]]]:
- """
- Convert telemetry values to strings.
- """
+def _telemetry_str(
+ data: List[Tuple[datetime, str, Any]],
+) -> List[Tuple[datetime, str, Optional[str]]]:
+ """Convert telemetry values to strings."""
# All retrieved timestamps should have been converted to UTC.
return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data]
@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO)
-def test_update_telemetry(storage: Storage,
- exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- origin_zone_info: Optional[tzinfo]) -> None:
- """
- Make sure update_telemetry() and load_telemetry() methods work.
- """
+def test_update_telemetry(
+ storage: Storage,
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ origin_zone_info: Optional[tzinfo],
+) -> None:
+ """Make sure update_telemetry() and load_telemetry() methods work."""
telemetry_data = zoned_telemetry_data(origin_zone_info)
trial = exp_storage.new_trial(tunable_groups)
assert exp_storage.load_telemetry(trial.trial_id) == []
@@ -75,12 +72,12 @@ def test_update_telemetry(storage: Storage,
@pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO)
-def test_update_telemetry_twice(exp_storage: Storage.Experiment,
- tunable_groups: TunableGroups,
- origin_zone_info: Optional[tzinfo]) -> None:
- """
- Make sure update_telemetry() call is idempotent.
- """
+def test_update_telemetry_twice(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+ origin_zone_info: Optional[tzinfo],
+) -> None:
+ """Make sure update_telemetry() call is idempotent."""
telemetry_data = zoned_telemetry_data(origin_zone_info)
trial = exp_storage.new_trial(tunable_groups)
timestamp = datetime.now(origin_zone_info)
diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py
index 3b57222822..755fc0205a 100644
--- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py
@@ -2,19 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for loading the TunableConfigData.
-"""
+"""Unit tests for loading the TunableConfigData."""
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.tunables.tunable_groups import TunableGroups
-def test_trial_data_tunable_config_data(exp_data: ExperimentData,
- tunable_groups: TunableGroups) -> None:
- """
- Check expected return values for TunableConfigData.
- """
+def test_trial_data_tunable_config_data(
+ exp_data: ExperimentData,
+ tunable_groups: TunableGroups,
+) -> None:
+ """Check expected return values for TunableConfigData."""
trial_id = 1
expected_config_id = 1
trial = exp_data.trials[trial_id]
@@ -26,35 +24,31 @@ def test_trial_data_tunable_config_data(exp_data: ExperimentData,
def test_trial_metadata(exp_data: ExperimentData) -> None:
- """
- Check expected return values for TunableConfigData metadata.
- """
- assert exp_data.objectives == {'score': 'min'}
- for (trial_id, trial) in exp_data.trials.items():
+ """Check expected return values for TunableConfigData metadata."""
+ assert exp_data.objectives == {"score": "min"}
+ for trial_id, trial in exp_data.trials.items():
assert trial.metadata_dict == {
- 'opt_target_0': 'score',
- 'opt_direction_0': 'min',
- 'trial_number': trial_id,
+ "opt_target_0": "score",
+ "opt_direction_0": "min",
+ "trial_number": trial_id,
}
def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData) -> None:
- """
- Check expected return values for TunableConfigData.
- """
+ """Check expected return values for TunableConfigData."""
empty_config: dict = {}
for _trial_id, trial in exp_no_tunables_data.trials.items():
assert trial.tunable_config.config_dict == empty_config
def test_mixed_numerics_exp_trial_data(
- mixed_numerics_exp_data: ExperimentData,
- mixed_numerics_tunable_groups: TunableGroups) -> None:
- """
- Tests that data type conversions are retained when loading experiment data with
+ mixed_numerics_exp_data: ExperimentData,
+ mixed_numerics_tunable_groups: TunableGroups,
+) -> None:
+ """Tests that data type conversions are retained when loading experiment data with
mixed numeric tunable types.
"""
trial = next(iter(mixed_numerics_exp_data.trials.values()))
config = trial.tunable_config.config_dict
- for (tunable, _group) in mixed_numerics_tunable_groups:
+ for tunable, _group in mixed_numerics_tunable_groups:
assert isinstance(config[tunable.name], tunable.dtype)
diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py
index 736621a3fd..faa61e5286 100644
--- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py
@@ -2,14 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for loading the TunableConfigTrialGroupData.
-"""
+"""Unit tests for loading the TunableConfigTrialGroupData."""
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_experiment_data import ExperimentData
-
from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT
+from mlos_bench.tunables.tunable_groups import TunableGroups
def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None:
@@ -17,10 +14,15 @@ def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None:
trial_id = 1
trial = exp_data.trials[trial_id]
tunable_config_trial_group = trial.tunable_config_trial_group
- assert tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id
+ assert (
+ tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id
+ )
assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id
assert tunable_config_trial_group.tunable_config == trial.tunable_config
- assert tunable_config_trial_group == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group
+ assert (
+ tunable_config_trial_group
+ == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group
+ )
def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None:
@@ -50,7 +52,10 @@ def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData)
# And so on ...
-def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None:
+def test_tunable_config_trial_group_results_df(
+ exp_data: ExperimentData,
+ tunable_groups: TunableGroups,
+) -> None:
"""Tests the results_df property of the TunableConfigTrialGroup."""
tunable_config_id = 2
expected_group_id = 4
@@ -59,9 +64,14 @@ def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable
# We shouldn't have the results for the other configs, just this one.
expected_count = CONFIG_TRIAL_REPEAT_COUNT
assert len(results_df) == expected_count
- assert len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count
+ assert (
+ len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count
+ )
assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0
- assert len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) == expected_count
+ assert (
+ len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)])
+ == expected_count
+ )
assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0
assert len(results_df["trial_id"].unique()) == expected_count
obj_target = next(iter(exp_data.objectives))
@@ -77,8 +87,14 @@ def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None:
tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id]
trials = tunable_config_trial_group.trials
assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT
- assert all(trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id
- for trial in trials.values())
- assert all(trial.tunable_config_id == tunable_config_id
- for trial in tunable_config_trial_group.trials.values())
- assert exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id]
+ assert all(
+ trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id
+ for trial in trials.values()
+ )
+ assert all(
+ trial.tunable_config_id == tunable_config_id
+ for trial in tunable_config_trial_group.trials.values()
+ )
+ assert (
+ exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id]
+ )
diff --git a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py
index cd279f4e5d..87a4dcb0ba 100644
--- a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py
+++ b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py
@@ -2,20 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests various other test scenarios with alternative default (un-named) TZ info.
-"""
+"""Tests various other test scenarios with alternative default (un-named) TZ info."""
-from subprocess import run
import os
import sys
+from subprocess import run
from typing import Optional
import pytest
from mlos_bench.tests import ZONE_NAMES
-
DIRNAME = os.path.dirname(__file__)
TZ_TEST_FILES = [
DIRNAME + "/environments/local/composite_local_env_test.py",
@@ -25,13 +22,11 @@ TZ_TEST_FILES = [
]
-@pytest.mark.skipif(sys.platform == 'win32', reason="TZ environment variable is a UNIXism")
+@pytest.mark.skipif(sys.platform == "win32", reason="TZ environment variable is a UNIXism")
@pytest.mark.parametrize(("tz_name"), ZONE_NAMES)
@pytest.mark.parametrize(("test_file"), TZ_TEST_FILES)
def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None:
- """
- Run the tests under alternative default (un-named) TZ info.
- """
+ """Run the tests under alternative default (un-named) TZ info."""
env = os.environ.copy()
if tz_name is None:
env.pop("TZ", None)
@@ -46,4 +41,6 @@ def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None:
if cmd.returncode != 0:
print(cmd.stdout.decode())
print(cmd.stderr.decode())
- raise AssertionError(f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'")
+ raise AssertionError(
+ f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'"
+ )
diff --git a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py
index 0350fff3bb..5f31eaef23 100644
--- a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py
+++ b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py
@@ -2,15 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Common fixtures for mock TunableGroups.
-"""
+"""Common fixtures for mock TunableGroups."""
from typing import Any, Dict
-import pytest
-
import json5 as json
+import pytest
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.tunables.covariant_group import CovariantTunableGroup
@@ -73,9 +70,7 @@ TUNABLE_GROUPS_JSON = """
@pytest.fixture
def tunable_groups_config() -> Dict[str, Any]:
- """
- Fixture to get the JSON string for the tunable groups.
- """
+ """Fixture to get the JSON string for the tunable groups."""
conf = json.loads(TUNABLE_GROUPS_JSON)
assert isinstance(conf, dict)
ConfigSchema.TUNABLE_PARAMS.validate(conf)
@@ -120,24 +115,26 @@ def mixed_numerics_tunable_groups() -> TunableGroups:
tunable_groups : TunableGroups
A new TunableGroups object for testing.
"""
- tunables = TunableGroups({
- "mix-numerics": {
- "cost": 1,
- "params": {
- "int": {
- "description": "An integer",
- "type": "int",
- "default": 0,
- "range": [0, 100],
+ tunables = TunableGroups(
+ {
+ "mix-numerics": {
+ "cost": 1,
+ "params": {
+ "int": {
+ "description": "An integer",
+ "type": "int",
+ "default": 0,
+ "range": [0, 100],
+ },
+ "float": {
+ "description": "A float",
+ "type": "float",
+ "default": 0,
+ "range": [0, 1],
+ },
},
- "float": {
- "description": "A float",
- "type": "float",
- "default": 0,
- "range": [0, 1],
- },
- }
- },
- })
+ },
+ }
+ )
tunables.reset()
return tunables
diff --git a/mlos_bench/mlos_bench/tests/tunables/__init__.py b/mlos_bench/mlos_bench/tests/tunables/__init__.py
index 83c046e575..69ef4a9204 100644
--- a/mlos_bench/mlos_bench/tests/tunables/__init__.py
+++ b/mlos_bench/mlos_bench/tests/tunables/__init__.py
@@ -4,5 +4,6 @@
#
"""
Tests for mlos_bench.tunables.
+
Used to make mypy happy about multiple conftest.py modules.
"""
diff --git a/mlos_bench/mlos_bench/tests/tunables/conftest.py b/mlos_bench/mlos_bench/tests/tunables/conftest.py
index 95de20d9b8..054e8c7d87 100644
--- a/mlos_bench/mlos_bench/tests/tunables/conftest.py
+++ b/mlos_bench/mlos_bench/tests/tunables/conftest.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Test fixtures for individual Tunable objects.
-"""
+"""Test fixtures for individual Tunable objects."""
import pytest
@@ -25,12 +23,15 @@ def tunable_categorical() -> Tunable:
tunable : Tunable
An instance of a categorical Tunable.
"""
- return Tunable("vmSize", {
- "description": "Azure VM size",
- "type": "categorical",
- "default": "Standard_B4ms",
- "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"]
- })
+ return Tunable(
+ "vmSize",
+ {
+ "description": "Azure VM size",
+ "type": "categorical",
+ "default": "Standard_B4ms",
+ "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"],
+ },
+ )
@pytest.fixture
@@ -43,13 +44,16 @@ def tunable_int() -> Tunable:
tunable : Tunable
An instance of an integer Tunable.
"""
- return Tunable("kernel_sched_migration_cost_ns", {
- "description": "Cost of migrating the thread to another core",
- "type": "int",
- "default": 40000,
- "range": [0, 500000],
- "special": [-1] # Special value outside of the range
- })
+ return Tunable(
+ "kernel_sched_migration_cost_ns",
+ {
+ "description": "Cost of migrating the thread to another core",
+ "type": "int",
+ "default": 40000,
+ "range": [0, 500000],
+ "special": [-1], # Special value outside of the range
+ },
+ )
@pytest.fixture
@@ -62,9 +66,12 @@ def tunable_float() -> Tunable:
tunable : Tunable
An instance of a float Tunable.
"""
- return Tunable("chaos_monkey_prob", {
- "description": "Probability of spontaneous VM shutdown",
- "type": "float",
- "default": 0.01,
- "range": [0, 1]
- })
+ return Tunable(
+ "chaos_monkey_prob",
+ {
+ "description": "Probability of spontaneous VM shutdown",
+ "type": "float",
+ "default": 0.01,
+ "range": [0, 1],
+ },
+ )
diff --git a/mlos_bench/mlos_bench/tests/tunables/test_empty_tunable_group.py b/mlos_bench/mlos_bench/tests/tunables/test_empty_tunable_group.py
index 0b3e124779..50e9061222 100644
--- a/mlos_bench/mlos_bench/tests/tunables/test_empty_tunable_group.py
+++ b/mlos_bench/mlos_bench/tests/tunables/test_empty_tunable_group.py
@@ -2,23 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for empty tunable groups.
-"""
+"""Unit tests for empty tunable groups."""
from mlos_bench.tunables.tunable_groups import TunableGroups
def test_empty_tunable_group() -> None:
- """
- Test __nonzero__ property of tunable groups.
- """
+ """Test __nonzero__ property of tunable groups."""
tunable_groups = TunableGroups(config={})
assert not tunable_groups
def test_non_empty_tunable_group(tunable_groups: TunableGroups) -> None:
- """
- Test __nonzero__ property of tunable groups.
- """
+ """Test __nonzero__ property of tunable groups."""
assert tunable_groups
diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py
index 0e910f3761..5ec31743bd 100644
--- a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py
+++ b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py
@@ -2,17 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for deep copy of tunable objects and groups.
-"""
+"""Unit tests for deep copy of tunable objects and groups."""
from mlos_bench.tunables.tunable_groups import TunableGroups
def test_tunable_categorical_types() -> None:
- """
- Check if we accept tunable categoricals as ints as well as strings and
- convert both to strings.
+ """Check if we accept tunable categoricals as ints as well as strings and convert
+ both to strings.
"""
tunable_params = {
"test-group": {
@@ -38,7 +35,7 @@ def test_tunable_categorical_types() -> None:
"values": ["a", "b", "c"],
"default": "a",
},
- }
+ },
}
}
tunable_groups = TunableGroups(tunable_params)
diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py
index de966536c4..fcbca29ed9 100644
--- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py
+++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py
@@ -2,28 +2,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for checking tunable size properties.
-"""
+"""Unit tests for checking tunable size properties."""
import numpy as np
import pytest
from mlos_bench.tunables.tunable import Tunable
-
# Note: these test do *not* check the ConfigSpace conversions for those same Tunables.
# That is checked indirectly via grid_search_optimizer_test.py
+
def test_tunable_int_size_props() -> None:
- """Test tunable int size properties"""
+ """Test tunable int size properties."""
tunable = Tunable(
name="test",
config={
"type": "int",
"range": [1, 5],
"default": 3,
- })
+ },
+ )
assert tunable.span == 4
assert tunable.cardinality == 5
expected = [1, 2, 3, 4, 5]
@@ -32,14 +31,15 @@ def test_tunable_int_size_props() -> None:
def test_tunable_float_size_props() -> None:
- """Test tunable float size properties"""
+ """Test tunable float size properties."""
tunable = Tunable(
name="test",
config={
"type": "float",
"range": [1.5, 5],
"default": 3,
- })
+ },
+ )
assert tunable.span == 3.5
assert tunable.cardinality == np.inf
assert tunable.quantized_values is None
@@ -47,14 +47,15 @@ def test_tunable_float_size_props() -> None:
def test_tunable_categorical_size_props() -> None:
- """Test tunable categorical size properties"""
+ """Test tunable categorical size properties."""
tunable = Tunable(
name="test",
config={
"type": "categorical",
"values": ["a", "b", "c"],
"default": "a",
- })
+ },
+ )
with pytest.raises(AssertionError):
_ = tunable.span
assert tunable.cardinality == 3
@@ -64,15 +65,11 @@ def test_tunable_categorical_size_props() -> None:
def test_tunable_quantized_int_size_props() -> None:
- """Test quantized tunable int size properties"""
+ """Test quantized tunable int size properties."""
tunable = Tunable(
name="test",
- config={
- "type": "int",
- "range": [100, 1000],
- "default": 100,
- "quantization": 100
- })
+ config={"type": "int", "range": [100, 1000], "default": 100, "quantization": 100},
+ )
assert tunable.span == 900
assert tunable.cardinality == 10
expected = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
@@ -81,15 +78,11 @@ def test_tunable_quantized_int_size_props() -> None:
def test_tunable_quantized_float_size_props() -> None:
- """Test quantized tunable float size properties"""
+ """Test quantized tunable float size properties."""
tunable = Tunable(
name="test",
- config={
- "type": "float",
- "range": [0, 1],
- "default": 0,
- "quantization": .1
- })
+ config={"type": "float", "range": [0, 1], "default": 0, "quantization": 0.1},
+ )
assert tunable.span == 1
assert tunable.cardinality == 11
expected = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_accessors_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_accessors_test.py
index fcf0d5b9e5..8aa888ebca 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_accessors_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_accessors_test.py
@@ -2,8 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for accessing values to the individual parameters within tunable groups.
+"""Unit tests for accessing values to the individual parameters within tunable
+groups.
"""
import pytest
@@ -13,9 +13,7 @@ from mlos_bench.tunables.tunable import Tunable
def test_categorical_access_to_numerical_tunable(tunable_int: Tunable) -> None:
- """
- Make sure we throw an error on accessing a numerical tunable as a categorical.
- """
+ """Make sure we throw an error on accessing a numerical tunable as a categorical."""
with pytest.raises(ValueError):
print(tunable_int.category)
with pytest.raises(AssertionError):
@@ -23,9 +21,7 @@ def test_categorical_access_to_numerical_tunable(tunable_int: Tunable) -> None:
def test_numerical_access_to_categorical_tunable(tunable_categorical: Tunable) -> None:
- """
- Make sure we throw an error on accessing a numerical tunable as a categorical.
- """
+ """Make sure we throw an error on accessing a numerical tunable as a categorical."""
with pytest.raises(ValueError):
print(tunable_categorical.numerical_value)
with pytest.raises(AssertionError):
@@ -33,15 +29,11 @@ def test_numerical_access_to_categorical_tunable(tunable_categorical: Tunable) -
def test_covariant_group_repr(covariant_group: CovariantTunableGroup) -> None:
- """
- Tests that the covariant group representation works as expected.
- """
+ """Tests that the covariant group representation works as expected."""
assert repr(covariant_group).startswith(f"{covariant_group.name}:")
def test_covariant_group_tunables(covariant_group: CovariantTunableGroup) -> None:
- """
- Tests that we can access the tunables in the covariant group.
- """
+ """Tests that we can access the tunables in the covariant group."""
for tunable in covariant_group.get_tunables():
assert isinstance(tunable, Tunable)
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py
index 6a91b14016..ccf76d07c8 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for checking tunable comparisons.
-"""
+"""Unit tests for checking tunable comparisons."""
import pytest
@@ -14,9 +12,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
def test_tunable_int_value_lt(tunable_int: Tunable) -> None:
- """
- Tests that the __lt__ operator works as expected.
- """
+ """Tests that the __lt__ operator works as expected."""
tunable_int_2 = tunable_int.copy()
tunable_int_2.numerical_value += 1
assert tunable_int.numerical_value < tunable_int_2.numerical_value
@@ -24,21 +20,18 @@ def test_tunable_int_value_lt(tunable_int: Tunable) -> None:
def test_tunable_int_name_lt(tunable_int: Tunable) -> None:
- """
- Tests that the __lt__ operator works as expected.
- """
+ """Tests that the __lt__ operator works as expected."""
tunable_int_2 = tunable_int.copy()
- tunable_int_2._name = "aaa" # pylint: disable=protected-access
+ tunable_int_2._name = "aaa" # pylint: disable=protected-access
assert tunable_int_2 < tunable_int
def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None:
- """
- Tests that the __lt__ operator works as expected.
- """
+ """Tests that the __lt__ operator works as expected."""
tunable_categorical_2 = tunable_categorical.copy()
new_value = [
- x for x in tunable_categorical.categories
+ x
+ for x in tunable_categorical.categories
if x != tunable_categorical.category and x is not None
][0]
assert tunable_categorical.category is not None
@@ -50,16 +43,14 @@ def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None:
def test_tunable_categorical_lt_null() -> None:
- """
- Tests that the __lt__ operator works as expected.
- """
+ """Tests that the __lt__ operator works as expected."""
tunable_cat = Tunable(
name="same-name",
config={
"type": "categorical",
"values": ["floof", "fuzz"],
"default": "floof",
- }
+ },
)
tunable_dog = Tunable(
name="same-name",
@@ -67,22 +58,20 @@ def test_tunable_categorical_lt_null() -> None:
"type": "categorical",
"values": [None, "doggo"],
"default": None,
- }
+ },
)
assert tunable_dog < tunable_cat
def test_tunable_lt_same_name_different_type() -> None:
- """
- Tests that the __lt__ operator works as expected.
- """
+ """Tests that the __lt__ operator works as expected."""
tunable_cat = Tunable(
name="same-name",
config={
"type": "categorical",
"values": ["floof", "fuzz"],
"default": "floof",
- }
+ },
)
tunable_int = Tunable(
name="same-name",
@@ -90,29 +79,23 @@ def test_tunable_lt_same_name_different_type() -> None:
"type": "int",
"range": [1, 3],
"default": 2,
- }
+ },
)
assert tunable_cat < tunable_int
def test_tunable_lt_different_object(tunable_int: Tunable) -> None:
- """
- Tests that the __lt__ operator works as expected.
- """
+ """Tests that the __lt__ operator works as expected."""
assert (tunable_int < "foo") is False
with pytest.raises(TypeError):
- assert "foo" < tunable_int # type: ignore[operator]
+ assert "foo" < tunable_int # type: ignore[operator]
def test_tunable_group_ne_object(tunable_groups: TunableGroups) -> None:
- """
- Tests that the __eq__ operator works as expected with other objects.
- """
+ """Tests that the __eq__ operator works as expected with other objects."""
assert tunable_groups != "foo"
def test_covariant_group_ne_object(covariant_group: CovariantTunableGroup) -> None:
- """
- Tests that the __eq__ operator works as expected with other objects.
- """
+ """Tests that the __eq__ operator works as expected with other objects."""
assert covariant_group != "foo"
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py
index f2da3ba60e..7403841f8d 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for checking tunable definition rules.
-"""
+"""Unit tests for checking tunable definition rules."""
import json5 as json
import pytest
@@ -13,18 +11,14 @@ from mlos_bench.tunables.tunable import Tunable, TunableValueTypeName
def test_tunable_name() -> None:
- """
- Check that tunable name is valid.
- """
+ """Check that tunable name is valid."""
with pytest.raises(ValueError):
# ! characters are currently disallowed in tunable names
- Tunable(name='test!tunable', config={"type": "float", "range": [0, 1], "default": 0})
+ Tunable(name="test!tunable", config={"type": "float", "range": [0, 1], "default": 0})
def test_categorical_required_params() -> None:
- """
- Check that required parameters are present for categorical tunables.
- """
+ """Check that required parameters are present for categorical tunables."""
json_config = """
{
"type": "categorical",
@@ -34,13 +28,11 @@ def test_categorical_required_params() -> None:
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
def test_categorical_weights() -> None:
- """
- Instantiate a categorical tunable with weights.
- """
+ """Instantiate a categorical tunable with weights."""
json_config = """
{
"type": "categorical",
@@ -50,14 +42,12 @@ def test_categorical_weights() -> None:
}
"""
config = json.loads(json_config)
- tunable = Tunable(name='test', config=config)
+ tunable = Tunable(name="test", config=config)
assert tunable.weights == [25, 25, 50]
def test_categorical_weights_wrong_count() -> None:
- """
- Try to instantiate a categorical tunable with incorrect number of weights.
- """
+ """Try to instantiate a categorical tunable with incorrect number of weights."""
json_config = """
{
"type": "categorical",
@@ -68,13 +58,11 @@ def test_categorical_weights_wrong_count() -> None:
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
def test_categorical_weights_wrong_values() -> None:
- """
- Try to instantiate a categorical tunable with invalid weights.
- """
+ """Try to instantiate a categorical tunable with invalid weights."""
json_config = """
{
"type": "categorical",
@@ -85,13 +73,11 @@ def test_categorical_weights_wrong_values() -> None:
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
def test_categorical_wrong_params() -> None:
- """
- Disallow range param for categorical tunables.
- """
+ """Disallow range param for categorical tunables."""
json_config = """
{
"type": "categorical",
@@ -102,13 +88,11 @@ def test_categorical_wrong_params() -> None:
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
def test_categorical_disallow_special_values() -> None:
- """
- Disallow special values for categorical values.
- """
+ """Disallow special values for categorical values."""
json_config = """
{
"type": "categorical",
@@ -119,66 +103,68 @@ def test_categorical_disallow_special_values() -> None:
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
def test_categorical_tunable_disallow_repeats() -> None:
- """
- Disallow duplicate values in categorical tunables.
- """
+ """Disallow duplicate values in categorical tunables."""
with pytest.raises(ValueError):
- Tunable(name='test', config={
- "type": "categorical",
- "values": ["foo", "bar", "foo"],
- "default": "foo",
- })
+ Tunable(
+ name="test",
+ config={
+ "type": "categorical",
+ "values": ["foo", "bar", "foo"],
+ "default": "foo",
+ },
+ )
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeName) -> None:
- """
- Disallow null values as default for numerical tunables.
- """
+ """Disallow null values as default for numerical tunables."""
with pytest.raises(ValueError):
- Tunable(name=f'test_{tunable_type}', config={
- "type": tunable_type,
- "range": [0, 10],
- "default": None,
- })
+ Tunable(
+ name=f"test_{tunable_type}",
+ config={
+ "type": tunable_type,
+ "range": [0, 10],
+ "default": None,
+ },
+ )
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeName) -> None:
- """
- Disallow out of range values as default for numerical tunables.
- """
+ """Disallow out of range values as default for numerical tunables."""
with pytest.raises(ValueError):
- Tunable(name=f'test_{tunable_type}', config={
- "type": tunable_type,
- "range": [0, 10],
- "default": 11,
- })
+ Tunable(
+ name=f"test_{tunable_type}",
+ config={
+ "type": tunable_type,
+ "range": [0, 10],
+ "default": 11,
+ },
+ )
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> None:
- """
- Disallow values param for numerical tunables.
- """
+ """Disallow values param for numerical tunables."""
with pytest.raises(ValueError):
- Tunable(name=f'test_{tunable_type}', config={
- "type": tunable_type,
- "range": [0, 10],
- "values": ["foo", "bar"],
- "default": 0,
- })
+ Tunable(
+ name=f"test_{tunable_type}",
+ config={
+ "type": tunable_type,
+ "range": [0, 10],
+ "values": ["foo", "bar"],
+ "default": 0,
+ },
+ )
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) -> None:
- """
- Disallow null values param for numerical tunables.
- """
+ """Disallow null values param for numerical tunables."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -188,14 +174,12 @@ def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) -
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name=f'test_{tunable_type}', config=config)
+ Tunable(name=f"test_{tunable_type}", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> None:
- """
- Disallow invalid range param for numerical tunables.
- """
+ """Disallow invalid range param for numerical tunables."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -205,14 +189,12 @@ def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) ->
"""
config = json.loads(json_config)
with pytest.raises(AssertionError):
- Tunable(name=f'test_{tunable_type}', config=config)
+ Tunable(name=f"test_{tunable_type}", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> None:
- """
- Disallow reverse range param for numerical tunables.
- """
+ """Disallow reverse range param for numerical tunables."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -222,14 +204,12 @@ def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) ->
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name=f'test_{tunable_type}', config=config)
+ Tunable(name=f"test_{tunable_type}", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights(tunable_type: TunableValueTypeName) -> None:
- """
- Instantiate a numerical tunable with weighted special values.
- """
+ """Instantiate a numerical tunable with weighted special values."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -241,7 +221,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None:
}}
"""
config = json.loads(json_config)
- tunable = Tunable(name='test', config=config)
+ tunable = Tunable(name="test", config=config)
assert tunable.special == [0]
assert tunable.weights == [0.1]
assert tunable.range_weight == 0.9
@@ -249,9 +229,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None:
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None:
- """
- Instantiate a numerical tunable with quantization.
- """
+ """Instantiate a numerical tunable with quantization."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -261,16 +239,14 @@ def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None:
}}
"""
config = json.loads(json_config)
- tunable = Tunable(name='test', config=config)
+ tunable = Tunable(name="test", config=config)
assert tunable.quantization == 10
assert not tunable.is_log
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_log(tunable_type: TunableValueTypeName) -> None:
- """
- Instantiate a numerical tunable with log scale.
- """
+ """Instantiate a numerical tunable with log scale."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -280,15 +256,13 @@ def test_numerical_log(tunable_type: TunableValueTypeName) -> None:
}}
"""
config = json.loads(json_config)
- tunable = Tunable(name='test', config=config)
+ tunable = Tunable(name="test", config=config)
assert tunable.is_log
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> None:
- """
- Raise an error if special_weights are specified but no special values.
- """
+ """Raise an error if special_weights are specified but no special values."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -299,14 +273,13 @@ def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> No
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> None:
- """
- Instantiate a numerical tunable with non-normalized weights
- of the special values.
+ """Instantiate a numerical tunable with non-normalized weights of the special
+ values.
"""
json_config = f"""
{{
@@ -319,7 +292,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) ->
}}
"""
config = json.loads(json_config)
- tunable = Tunable(name='test', config=config)
+ tunable = Tunable(name="test", config=config)
assert tunable.special == [-1, 0]
assert tunable.weights == [0, 10] # Zero weights are ok
assert tunable.range_weight == 90
@@ -327,9 +300,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) ->
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> None:
- """
- Try to instantiate a numerical tunable with incorrect number of weights.
- """
+ """Try to instantiate a numerical tunable with incorrect number of weights."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -342,14 +313,12 @@ def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> No
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) -> None:
- """
- Try to instantiate a numerical tunable with weights but no range_weight.
- """
+ """Try to instantiate a numerical tunable with weights but no range_weight."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -361,14 +330,12 @@ def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) -
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) -> None:
- """
- Try to instantiate a numerical tunable with specials but no range_weight.
- """
+ """Try to instantiate a numerical tunable with specials but no range_weight."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -380,14 +347,12 @@ def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) -
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) -> None:
- """
- Try to instantiate a numerical tunable with specials but no range_weight.
- """
+ """Try to instantiate a numerical tunable with specials but no range_weight."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -398,14 +363,12 @@ def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName)
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> None:
- """
- Try to instantiate a numerical tunable with incorrect number of weights.
- """
+ """Try to instantiate a numerical tunable with incorrect number of weights."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -418,14 +381,12 @@ def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> N
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> None:
- """
- Instantiate a numerical tunable with invalid number of quantization points.
- """
+ """Instantiate a numerical tunable with invalid number of quantization points."""
json_config = f"""
{{
"type": "{tunable_type}",
@@ -436,13 +397,11 @@ def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> Non
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test', config=config)
+ Tunable(name="test", config=config)
def test_bad_type() -> None:
- """
- Disallow bad types.
- """
+ """Disallow bad types."""
json_config = """
{
"type": "foo",
@@ -452,4 +411,4 @@ def test_bad_type() -> None:
"""
config = json.loads(json_config)
with pytest.raises(ValueError):
- Tunable(name='test_bad_type', config=config)
+ Tunable(name="test_bad_type", config=config)
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py
index deffcb6a46..54f08e1709 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for checking tunable parameters' distributions.
-"""
+"""Unit tests for checking tunable parameters' distributions."""
import json5 as json
import pytest
@@ -13,33 +11,31 @@ from mlos_bench.tunables.tunable import Tunable, TunableValueTypeName
def test_categorical_distribution() -> None:
- """
- Try to instantiate a categorical tunable with distribution specified.
- """
+ """Try to instantiate a categorical tunable with distribution specified."""
with pytest.raises(ValueError):
- Tunable(name='test', config={
- "type": "categorical",
- "values": ["foo", "bar", "baz"],
- "distribution": {
- "type": "uniform"
+ Tunable(
+ name="test",
+ config={
+ "type": "categorical",
+ "values": ["foo", "bar", "baz"],
+ "distribution": {"type": "uniform"},
+ "default": "foo",
},
- "default": "foo"
- })
+ )
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> None:
- """
- Create a numeric Tunable with explicit uniform distribution.
- """
- tunable = Tunable(name="test", config={
- "type": tunable_type,
- "range": [0, 10],
- "distribution": {
- "type": "uniform"
+ """Create a numeric Tunable with explicit uniform distribution."""
+ tunable = Tunable(
+ name="test",
+ config={
+ "type": tunable_type,
+ "range": [0, 10],
+ "distribution": {"type": "uniform"},
+ "default": 0,
},
- "default": 0
- })
+ )
assert tunable.is_numerical
assert tunable.distribution == "uniform"
assert not tunable.distribution_params
@@ -47,51 +43,39 @@ def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> N
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> None:
- """
- Create a numeric Tunable with explicit Gaussian distribution specified.
- """
- tunable = Tunable(name="test", config={
- "type": tunable_type,
- "range": [0, 10],
- "distribution": {
- "type": "normal",
- "params": {
- "mu": 0,
- "sigma": 1.0
- }
+ """Create a numeric Tunable with explicit Gaussian distribution specified."""
+ tunable = Tunable(
+ name="test",
+ config={
+ "type": tunable_type,
+ "range": [0, 10],
+ "distribution": {"type": "normal", "params": {"mu": 0, "sigma": 1.0}},
+ "default": 0,
},
- "default": 0
- })
+ )
assert tunable.distribution == "normal"
assert tunable.distribution_params == {"mu": 0, "sigma": 1.0}
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None:
- """
- Create a numeric Tunable with explicit Beta distribution specified.
- """
- tunable = Tunable(name="test", config={
- "type": tunable_type,
- "range": [0, 10],
- "distribution": {
- "type": "beta",
- "params": {
- "alpha": 2,
- "beta": 5
- }
+ """Create a numeric Tunable with explicit Beta distribution specified."""
+ tunable = Tunable(
+ name="test",
+ config={
+ "type": tunable_type,
+ "range": [0, 10],
+ "distribution": {"type": "beta", "params": {"alpha": 2, "beta": 5}},
+ "default": 0,
},
- "default": 0
- })
+ )
assert tunable.distribution == "beta"
assert tunable.distribution_params == {"alpha": 2, "beta": 5}
@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_unsupported(tunable_type: str) -> None:
- """
- Create a numeric Tunable with unsupported distribution.
- """
+ """Create a numeric Tunable with unsupported distribution."""
json_config = f"""
{{
"type": "{tunable_type}",
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py
index c6fb5670f0..ae22094baa 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py
@@ -2,18 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for checking the indexing rules for tunable groups.
-"""
+"""Tests for checking the indexing rules for tunable groups."""
from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable_groups import TunableGroups
-def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categorical: Tunable) -> None:
- """
- Check that various types of indexing work for the tunable group.
- """
+def test_tunable_group_indexing(
+ tunable_groups: TunableGroups,
+ tunable_categorical: Tunable,
+) -> None:
+ """Check that various types of indexing work for the tunable group."""
# Check that the "in" operator works.
assert tunable_categorical in tunable_groups
assert tunable_categorical.name in tunable_groups
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py
index 55a485e951..c44fbfc866 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py
@@ -2,16 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for `TunableGroup.subgroup()` method.
-"""
+"""Tests for `TunableGroup.subgroup()` method."""
from mlos_bench.tunables.tunable_groups import TunableGroups
def test_tunable_group_subgroup(tunable_groups: TunableGroups) -> None:
- """
- Check that the subgroup() method returns only a selection of tunable parameters.
+ """Check that the subgroup() method returns only a selection of tunable
+ parameters.
"""
tunables = tunable_groups.subgroup(["provision"])
- assert tunables.get_param_values() == {'vmSize': 'Standard_B4ms'}
+ assert tunables.get_param_values() == {"vmSize": "Standard_B4ms"}
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_update_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_update_test.py
index 8a9fba6d86..21f9de84d5 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_update_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_update_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for checking the is_updated flag for tunable groups.
-"""
+"""Tests for checking the is_updated flag for tunable groups."""
from mlos_bench.tunables.tunable_groups import TunableGroups
@@ -15,16 +13,14 @@ _TUNABLE_VALUES = {
def test_tunable_group_update(tunable_groups: TunableGroups) -> None:
- """
- Test that updating a tunable group raises the is_updated flag.
- """
+ """Test that updating a tunable group raises the is_updated flag."""
tunable_groups.assign(_TUNABLE_VALUES)
assert tunable_groups.is_updated()
def test_tunable_group_update_twice(tunable_groups: TunableGroups) -> None:
- """
- Test that updating a tunable group with the same values do *NOT* raises the is_updated flag.
+ """Test that updating a tunable group with the same values do *NOT* raises the
+ is_updated flag.
"""
tunable_groups.assign(_TUNABLE_VALUES)
assert tunable_groups.is_updated()
@@ -37,9 +33,7 @@ def test_tunable_group_update_twice(tunable_groups: TunableGroups) -> None:
def test_tunable_group_update_kernel(tunable_groups: TunableGroups) -> None:
- """
- Test that the is_updated flag is set only for the affected covariant group.
- """
+ """Test that the is_updated flag is set only for the affected covariant group."""
tunable_groups.assign(_TUNABLE_VALUES)
assert tunable_groups.is_updated()
assert tunable_groups.is_updated(["kernel"])
@@ -47,9 +41,7 @@ def test_tunable_group_update_kernel(tunable_groups: TunableGroups) -> None:
def test_tunable_group_update_boot(tunable_groups: TunableGroups) -> None:
- """
- Test that the is_updated flag is set only for the affected covariant group.
- """
+ """Test that the is_updated flag is set only for the affected covariant group."""
tunable_groups.assign(_TUNABLE_VALUES)
assert tunable_groups.is_updated()
assert not tunable_groups.is_updated(["boot"])
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py
index 8d195dd5cf..9d267d4e16 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for unique references to tunables when they're loaded multiple times.
-"""
+"""Unit tests for unique references to tunables when they're loaded multiple times."""
import json5 as json
import pytest
@@ -13,9 +11,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
def test_duplicate_merging_tunable_groups(tunable_groups_config: dict) -> None:
- """
- Check that the merging logic of tunable groups works as expected.
- """
+ """Check that the merging logic of tunable groups works as expected."""
parent_tunables = TunableGroups(tunable_groups_config)
# Pretend we loaded this one from disk another time.
@@ -63,9 +59,7 @@ def test_duplicate_merging_tunable_groups(tunable_groups_config: dict) -> None:
def test_overlapping_group_merge_tunable_groups(tunable_groups_config: dict) -> None:
- """
- Check that the merging logic of tunable groups works as expected.
- """
+ """Check that the merging logic of tunable groups works as expected."""
parent_tunables = TunableGroups(tunable_groups_config)
# This config should overlap with the parent config.
@@ -94,9 +88,7 @@ def test_overlapping_group_merge_tunable_groups(tunable_groups_config: dict) ->
def test_bad_extended_merge_tunable_group(tunable_groups_config: dict) -> None:
- """
- Check that the merging logic of tunable groups works as expected.
- """
+ """Check that the merging logic of tunable groups works as expected."""
parent_tunables = TunableGroups(tunable_groups_config)
# This config should overlap with the parent config.
@@ -125,9 +117,7 @@ def test_bad_extended_merge_tunable_group(tunable_groups_config: dict) -> None:
def test_good_extended_merge_tunable_group(tunable_groups_config: dict) -> None:
- """
- Check that the merging logic of tunable groups works as expected.
- """
+ """Check that the merging logic of tunable groups works as expected."""
parent_tunables = TunableGroups(tunable_groups_config)
# This config should overlap with the parent config.
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py
index 1024ba992b..91e387f92b 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py
@@ -2,30 +2,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for converting tunable parameters with explicitly
-specified distributions to ConfigSpace.
+"""Unit tests for converting tunable parameters with explicitly specified distributions
+to ConfigSpace.
"""
import pytest
-
from ConfigSpace import (
- CategoricalHyperparameter,
BetaFloatHyperparameter,
BetaIntegerHyperparameter,
+ CategoricalHyperparameter,
NormalFloatHyperparameter,
NormalIntegerHyperparameter,
UniformFloatHyperparameter,
UniformIntegerHyperparameter,
)
-from mlos_bench.tunables.tunable import DistributionName
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.convert_configspace import (
special_param_names,
tunable_groups_to_configspace,
)
-
+from mlos_bench.tunables.tunable import DistributionName
+from mlos_bench.tunables.tunable_groups import TunableGroups
_CS_HYPERPARAMETER = {
("float", "beta"): BetaFloatHyperparameter,
@@ -38,37 +35,39 @@ _CS_HYPERPARAMETER = {
@pytest.mark.parametrize("param_type", ["int", "float"])
-@pytest.mark.parametrize("distr_name,distr_params", [
- ("normal", {"mu": 0.0, "sigma": 1.0}),
- ("beta", {"alpha": 2, "beta": 5}),
- ("uniform", {}),
-])
-def test_convert_numerical_distributions(param_type: str,
- distr_name: DistributionName,
- distr_params: dict) -> None:
- """
- Convert a numerical Tunable with explicit distribution to ConfigSpace.
- """
+@pytest.mark.parametrize(
+ "distr_name,distr_params",
+ [
+ ("normal", {"mu": 0.0, "sigma": 1.0}),
+ ("beta", {"alpha": 2, "beta": 5}),
+ ("uniform", {}),
+ ],
+)
+def test_convert_numerical_distributions(
+ param_type: str,
+ distr_name: DistributionName,
+ distr_params: dict,
+) -> None:
+ """Convert a numerical Tunable with explicit distribution to ConfigSpace."""
tunable_name = "x"
- tunable_groups = TunableGroups({
- "tunable_group": {
- "cost": 1,
- "params": {
- tunable_name: {
- "type": param_type,
- "range": [0, 100],
- "special": [-1, 0],
- "special_weights": [0.1, 0.2],
- "range_weight": 0.7,
- "distribution": {
- "type": distr_name,
- "params": distr_params
- },
- "default": 0
- }
+ tunable_groups = TunableGroups(
+ {
+ "tunable_group": {
+ "cost": 1,
+ "params": {
+ tunable_name: {
+ "type": param_type,
+ "range": [0, 100],
+ "special": [-1, 0],
+ "special_weights": [0.1, 0.2],
+ "range_weight": 0.7,
+ "distribution": {"type": distr_name, "params": distr_params},
+ "default": 0,
+ }
+ },
}
}
- })
+ )
(tunable, _group) = tunable_groups.get_tunable(tunable_name)
assert tunable.distribution == distr_name
@@ -84,5 +83,5 @@ def test_convert_numerical_distributions(param_type: str,
cs_param = space[tunable_name]
assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name])
- for (key, val) in distr_params.items():
+ for key, val in distr_params.items():
assert getattr(cs_param, key) == val
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py
index 42b24dd51e..dce3e366a6 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py
@@ -2,12 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for Tunable to ConfigSpace conversion.
-"""
+"""Unit tests for Tunable to ConfigSpace conversion."""
import pytest
-
from ConfigSpace import (
CategoricalHyperparameter,
ConfigurationSpace,
@@ -16,14 +13,14 @@ from ConfigSpace import (
UniformIntegerHyperparameter,
)
-from mlos_bench.tunables.tunable import Tunable
-from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.convert_configspace import (
TunableValueKind,
_tunable_to_configspace,
special_param_names,
tunable_groups_to_configspace,
)
+from mlos_bench.tunables.tunable import Tunable
+from mlos_bench.tunables.tunable_groups import TunableGroups
# pylint: disable=redefined-outer-name
@@ -31,25 +28,31 @@ from mlos_bench.optimizers.convert_configspace import (
@pytest.fixture
def configuration_space() -> ConfigurationSpace:
"""
- A test fixture that produces a mock ConfigurationSpace object
- matching the tunable_groups fixture.
+ A test fixture that produces a mock ConfigurationSpace object matching the
+ tunable_groups fixture.
Returns
-------
configuration_space : ConfigurationSpace
A new ConfigurationSpace object for testing.
"""
- (kernel_sched_migration_cost_ns_special,
- kernel_sched_migration_cost_ns_type) = special_param_names("kernel_sched_migration_cost_ns")
+ (kernel_sched_migration_cost_ns_special, kernel_sched_migration_cost_ns_type) = (
+ special_param_names("kernel_sched_migration_cost_ns")
+ )
- spaces = ConfigurationSpace(space={
- "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"],
- "idle": ["halt", "mwait", "noidle"],
- "kernel_sched_migration_cost_ns": (0, 500000),
- kernel_sched_migration_cost_ns_special: [-1, 0],
- kernel_sched_migration_cost_ns_type: [TunableValueKind.SPECIAL, TunableValueKind.RANGE],
- "kernel_sched_latency_ns": (0, 1000000000),
- })
+ spaces = ConfigurationSpace(
+ space={
+ "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"],
+ "idle": ["halt", "mwait", "noidle"],
+ "kernel_sched_migration_cost_ns": (0, 500000),
+ kernel_sched_migration_cost_ns_special: [-1, 0],
+ kernel_sched_migration_cost_ns_type: [
+ TunableValueKind.SPECIAL,
+ TunableValueKind.RANGE,
+ ],
+ "kernel_sched_latency_ns": (0, 1000000000),
+ }
+ )
# NOTE: FLAML requires distribution to be uniform
spaces["vmSize"].default_value = "Standard_B4ms"
@@ -61,32 +64,34 @@ def configuration_space() -> ConfigurationSpace:
spaces[kernel_sched_migration_cost_ns_type].probabilities = (0.5, 0.5)
spaces["kernel_sched_latency_ns"].default_value = 2000000
- spaces.add_condition(EqualsCondition(
- spaces[kernel_sched_migration_cost_ns_special],
- spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.SPECIAL))
- spaces.add_condition(EqualsCondition(
- spaces["kernel_sched_migration_cost_ns"],
- spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.RANGE))
+ spaces.add_condition(
+ EqualsCondition(
+ spaces[kernel_sched_migration_cost_ns_special],
+ spaces[kernel_sched_migration_cost_ns_type],
+ TunableValueKind.SPECIAL,
+ )
+ )
+ spaces.add_condition(
+ EqualsCondition(
+ spaces["kernel_sched_migration_cost_ns"],
+ spaces[kernel_sched_migration_cost_ns_type],
+ TunableValueKind.RANGE,
+ )
+ )
return spaces
-def _cmp_tunable_hyperparameter_categorical(
- tunable: Tunable, space: ConfigurationSpace) -> None:
- """
- Check if categorical Tunable and ConfigSpace Hyperparameter actually match.
- """
+def _cmp_tunable_hyperparameter_categorical(tunable: Tunable, space: ConfigurationSpace) -> None:
+ """Check if categorical Tunable and ConfigSpace Hyperparameter actually match."""
param = space[tunable.name]
assert isinstance(param, CategoricalHyperparameter)
assert set(param.choices) == set(tunable.categories)
assert param.default_value == tunable.value
-def _cmp_tunable_hyperparameter_numerical(
- tunable: Tunable, space: ConfigurationSpace) -> None:
- """
- Check if integer Tunable and ConfigSpace Hyperparameter actually match.
- """
+def _cmp_tunable_hyperparameter_numerical(tunable: Tunable, space: ConfigurationSpace) -> None:
+ """Check if integer Tunable and ConfigSpace Hyperparameter actually match."""
param = space[tunable.name]
assert isinstance(param, (UniformIntegerHyperparameter, UniformFloatHyperparameter))
assert (param.lower, param.upper) == tuple(tunable.range)
@@ -95,25 +100,19 @@ def _cmp_tunable_hyperparameter_numerical(
def test_tunable_to_configspace_categorical(tunable_categorical: Tunable) -> None:
- """
- Check the conversion of Tunable to CategoricalHyperparameter.
- """
+ """Check the conversion of Tunable to CategoricalHyperparameter."""
cs_param = _tunable_to_configspace(tunable_categorical)
_cmp_tunable_hyperparameter_categorical(tunable_categorical, cs_param)
def test_tunable_to_configspace_int(tunable_int: Tunable) -> None:
- """
- Check the conversion of Tunable to UniformIntegerHyperparameter.
- """
+ """Check the conversion of Tunable to UniformIntegerHyperparameter."""
cs_param = _tunable_to_configspace(tunable_int)
_cmp_tunable_hyperparameter_numerical(tunable_int, cs_param)
def test_tunable_to_configspace_float(tunable_float: Tunable) -> None:
- """
- Check the conversion of Tunable to UniformFloatHyperparameter.
- """
+ """Check the conversion of Tunable to UniformFloatHyperparameter."""
cs_param = _tunable_to_configspace(tunable_float)
_cmp_tunable_hyperparameter_numerical(tunable_float, cs_param)
@@ -128,18 +127,20 @@ _CMP_FUNC = {
def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> None:
"""
Check the conversion of TunableGroups to ConfigurationSpace.
+
Make sure that the corresponding Tunable and Hyperparameter objects match.
"""
space = tunable_groups_to_configspace(tunable_groups)
- for (tunable, _group) in tunable_groups:
+ for tunable, _group in tunable_groups:
_CMP_FUNC[tunable.type](tunable, space)
def test_tunable_groups_to_configspace(
- tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None:
- """
- Check the conversion of the entire TunableGroups collection
- to a single ConfigurationSpace object.
+ tunable_groups: TunableGroups,
+ configuration_space: ConfigurationSpace,
+) -> None:
+ """Check the conversion of the entire TunableGroups collection to a single
+ ConfigurationSpace object.
"""
space = tunable_groups_to_configspace(tunable_groups)
assert space == configuration_space
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py
index cbccd6bfe1..05f29a9064 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py
@@ -2,8 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for assigning values to the individual parameters within tunable groups.
+"""Unit tests for assigning values to the individual parameters within tunable
+groups.
"""
import json5 as json
@@ -14,132 +14,105 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None:
- """
- Make sure that bulk assignment fails for parameters
- that don't exist in the TunableGroups object.
+ """Make sure that bulk assignment fails for parameters that don't exist in the
+ TunableGroups object.
"""
with pytest.raises(KeyError):
- tunable_groups.assign({
- "vmSize": "Standard_B2ms",
- "idle": "mwait",
- "UnknownParam_1": 1,
- "UnknownParam_2": "invalid-value"
- })
+ tunable_groups.assign(
+ {
+ "vmSize": "Standard_B2ms",
+ "idle": "mwait",
+ "UnknownParam_1": 1,
+ "UnknownParam_2": "invalid-value",
+ }
+ )
def test_tunables_assign_categorical(tunable_categorical: Tunable) -> None:
- """
- Regular assignment for categorical tunable.
- """
+ """Regular assignment for categorical tunable."""
# Must be one of: {"Standard_B2s", "Standard_B2ms", "Standard_B4ms"}
tunable_categorical.value = "Standard_B4ms"
assert not tunable_categorical.is_special
def test_tunables_assign_invalid_categorical(tunable_groups: TunableGroups) -> None:
- """
- Check parameter validation for categorical tunables.
- """
+ """Check parameter validation for categorical tunables."""
with pytest.raises(ValueError):
tunable_groups.assign({"vmSize": "InvalidSize"})
def test_tunables_assign_invalid_range(tunable_groups: TunableGroups) -> None:
- """
- Check parameter out-of-range validation for numerical tunables.
- """
+ """Check parameter out-of-range validation for numerical tunables."""
with pytest.raises(ValueError):
tunable_groups.assign({"kernel_sched_migration_cost_ns": -2})
def test_tunables_assign_coerce_str(tunable_groups: TunableGroups) -> None:
- """
- Check the conversion from strings when assigning to an integer parameter.
- """
+ """Check the conversion from strings when assigning to an integer parameter."""
tunable_groups.assign({"kernel_sched_migration_cost_ns": "10000"})
def test_tunables_assign_coerce_str_range_check(tunable_groups: TunableGroups) -> None:
- """
- Check the range when assigning to an integer tunable.
- """
+ """Check the range when assigning to an integer tunable."""
with pytest.raises(ValueError):
tunable_groups.assign({"kernel_sched_migration_cost_ns": "5500000"})
def test_tunables_assign_coerce_str_invalid(tunable_groups: TunableGroups) -> None:
- """
- Make sure we fail when assigning an invalid string to an integer tunable.
- """
+ """Make sure we fail when assigning an invalid string to an integer tunable."""
with pytest.raises(ValueError):
tunable_groups.assign({"kernel_sched_migration_cost_ns": "1.1"})
def test_tunable_assign_str_to_numerical(tunable_int: Tunable) -> None:
- """
- Check str to int coercion.
- """
+ """Check str to int coercion."""
with pytest.raises(ValueError):
tunable_int.numerical_value = "foo" # type: ignore[assignment]
def test_tunable_assign_int_to_numerical_value(tunable_int: Tunable) -> None:
- """
- Check numerical value assignment.
- """
+ """Check numerical value assignment."""
tunable_int.numerical_value = 10.0
assert tunable_int.numerical_value == 10
assert not tunable_int.is_special
def test_tunable_assign_float_to_numerical_value(tunable_float: Tunable) -> None:
- """
- Check numerical value assignment.
- """
+ """Check numerical value assignment."""
tunable_float.numerical_value = 0.1
assert tunable_float.numerical_value == 0.1
assert not tunable_float.is_special
def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None:
- """
- Check str to int coercion.
- """
+ """Check str to int coercion."""
tunable_int.value = "10"
- assert tunable_int.value == 10 # type: ignore[comparison-overlap]
+ assert tunable_int.value == 10 # type: ignore[comparison-overlap]
assert not tunable_int.is_special
def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None:
- """
- Check str to float coercion.
- """
+ """Check str to float coercion."""
tunable_float.value = "0.5"
- assert tunable_float.value == 0.5 # type: ignore[comparison-overlap]
+ assert tunable_float.value == 0.5 # type: ignore[comparison-overlap]
assert not tunable_float.is_special
def test_tunable_assign_float_to_int(tunable_int: Tunable) -> None:
- """
- Check float to int coercion.
- """
+ """Check float to int coercion."""
tunable_int.value = 10.0
assert tunable_int.value == 10
assert not tunable_int.is_special
def test_tunable_assign_float_to_int_fail(tunable_int: Tunable) -> None:
- """
- Check the invalid float to int coercion.
- """
+ """Check the invalid float to int coercion."""
with pytest.raises(ValueError):
tunable_int.value = 10.1
def test_tunable_assign_null_to_categorical() -> None:
- """
- Checks that we can use null/None in categorical tunables.
- """
+ """Checks that we can use null/None in categorical tunables."""
json_config = """
{
"name": "categorical_test",
@@ -149,38 +122,34 @@ def test_tunable_assign_null_to_categorical() -> None:
}
"""
config = json.loads(json_config)
- categorical_tunable = Tunable(name='categorical_test', config=config)
+ categorical_tunable = Tunable(name="categorical_test", config=config)
assert categorical_tunable
assert categorical_tunable.category == "foo"
categorical_tunable.value = None
assert categorical_tunable.value is None
- assert categorical_tunable.value != 'None'
+ assert categorical_tunable.value != "None"
assert categorical_tunable.category is None
def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None:
- """
- Checks that we can't use null/None in integer tunables.
- """
+ """Checks that we can't use null/None in integer tunables."""
with pytest.raises((TypeError, AssertionError)):
tunable_int.value = None
with pytest.raises((TypeError, AssertionError)):
- tunable_int.numerical_value = None # type: ignore[assignment]
+ tunable_int.numerical_value = None # type: ignore[assignment]
def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None:
- """
- Checks that we can't use null/None in float tunables.
- """
+ """Checks that we can't use null/None in float tunables."""
with pytest.raises((TypeError, AssertionError)):
tunable_float.value = None
with pytest.raises((TypeError, AssertionError)):
- tunable_float.numerical_value = None # type: ignore[assignment]
+ tunable_float.numerical_value = None # type: ignore[assignment]
def test_tunable_assign_special(tunable_int: Tunable) -> None:
- """
- Check the assignment of a special value outside of the range (but declared `special`).
+ """Check the assignment of a special value outside of the range (but declared
+ `special`).
"""
tunable_int.numerical_value = -1
assert tunable_int.numerical_value == -1
@@ -188,16 +157,16 @@ def test_tunable_assign_special(tunable_int: Tunable) -> None:
def test_tunable_assign_special_fail(tunable_int: Tunable) -> None:
- """
- Assign a value that is neither special nor in range and fail.
- """
+ """Assign a value that is neither special nor in range and fail."""
with pytest.raises(ValueError):
tunable_int.numerical_value = -2
def test_tunable_assign_special_with_coercion(tunable_int: Tunable) -> None:
"""
- Check the assignment of a special value outside of the range (but declared `special`).
+ Check the assignment of a special value outside of the range (but declared
+ `special`).
+
Check coercion from float to int.
"""
tunable_int.numerical_value = -1.0
@@ -207,7 +176,9 @@ def test_tunable_assign_special_with_coercion(tunable_int: Tunable) -> None:
def test_tunable_assign_special_with_coercion_str(tunable_int: Tunable) -> None:
"""
- Check the assignment of a special value outside of the range (but declared `special`).
+ Check the assignment of a special value outside of the range (but declared
+ `special`).
+
Check coercion from string to int.
"""
tunable_int.value = "-1"
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py
index 16bb42500c..c5395fcb16 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for deep copy of tunable objects and groups.
-"""
+"""Unit tests for deep copy of tunable objects and groups."""
from mlos_bench.tunables.covariant_group import CovariantTunableGroup
from mlos_bench.tunables.tunable import Tunable, TunableValue
@@ -12,9 +10,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
def test_copy_tunable_int(tunable_int: Tunable) -> None:
- """
- Check if deep copy works for Tunable object.
- """
+ """Check if deep copy works for Tunable object."""
tunable_copy = tunable_int.copy()
assert tunable_int == tunable_copy
tunable_copy.numerical_value += 200
@@ -22,9 +18,7 @@ def test_copy_tunable_int(tunable_int: Tunable) -> None:
def test_copy_tunable_groups(tunable_groups: TunableGroups) -> None:
- """
- Check if deep copy works for TunableGroups object.
- """
+ """Check if deep copy works for TunableGroups object."""
tunable_groups_copy = tunable_groups.copy()
assert tunable_groups == tunable_groups_copy
tunable_groups_copy["vmSize"] = "Standard_B2ms"
@@ -34,9 +28,7 @@ def test_copy_tunable_groups(tunable_groups: TunableGroups) -> None:
def test_copy_covariant_group(covariant_group: CovariantTunableGroup) -> None:
- """
- Check if deep copy works for TunableGroups object.
- """
+ """Check if deep copy works for TunableGroups object."""
covariant_group_copy = covariant_group.copy()
assert covariant_group == covariant_group_copy
tunable = next(iter(covariant_group.get_tunables()))
diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py
index 672b16ab73..61514c605b 100644
--- a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py
+++ b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py
@@ -2,57 +2,57 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests to make sure we always produce a string representation
-of a TunableGroup in canonical form.
+"""Unit tests to make sure we always produce a string representation of a TunableGroup
+in canonical form.
"""
from mlos_bench.tunables.tunable_groups import TunableGroups
def test_tunable_groups_str(tunable_groups: TunableGroups) -> None:
- """
- Check that we produce the same string representation of TunableGroups,
- regardless of the order in which we declare the covariant groups and
- tunables within each covariant group.
+ """Check that we produce the same string representation of TunableGroups, regardless
+ of the order in which we declare the covariant groups and tunables within each
+ covariant group.
"""
# Same as `tunable_groups` (defined in the `conftest.py` file), but in different order:
- tunables_other = TunableGroups({
- "kernel": {
- "cost": 1,
- "params": {
- "kernel_sched_latency_ns": {
- "type": "int",
- "default": 2000000,
- "range": [0, 1000000000]
+ tunables_other = TunableGroups(
+ {
+ "kernel": {
+ "cost": 1,
+ "params": {
+ "kernel_sched_latency_ns": {
+ "type": "int",
+ "default": 2000000,
+ "range": [0, 1000000000],
+ },
+ "kernel_sched_migration_cost_ns": {
+ "type": "int",
+ "default": -1,
+ "range": [0, 500000],
+ "special": [-1],
+ },
},
- "kernel_sched_migration_cost_ns": {
- "type": "int",
- "default": -1,
- "range": [0, 500000],
- "special": [-1]
- }
- }
- },
- "boot": {
- "cost": 300,
- "params": {
- "idle": {
- "type": "categorical",
- "default": "halt",
- "values": ["halt", "mwait", "noidle"]
- }
- }
- },
- "provision": {
- "cost": 1000,
- "params": {
- "vmSize": {
- "type": "categorical",
- "default": "Standard_B4ms",
- "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"]
- }
- }
- },
- })
+ },
+ "boot": {
+ "cost": 300,
+ "params": {
+ "idle": {
+ "type": "categorical",
+ "default": "halt",
+ "values": ["halt", "mwait", "noidle"],
+ }
+ },
+ },
+ "provision": {
+ "cost": 1000,
+ "params": {
+ "vmSize": {
+ "type": "categorical",
+ "default": "Standard_B4ms",
+ "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"],
+ }
+ },
+ },
+ }
+ )
assert str(tunable_groups) == str(tunables_other)
diff --git a/mlos_bench/mlos_bench/tests/util_git_test.py b/mlos_bench/mlos_bench/tests/util_git_test.py
index 54946fca6e..77fd2779c7 100644
--- a/mlos_bench/mlos_bench/tests/util_git_test.py
+++ b/mlos_bench/mlos_bench/tests/util_git_test.py
@@ -2,18 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for get_git_info utility function.
-"""
+"""Unit tests for get_git_info utility function."""
import re
from mlos_bench.util import get_git_info
def test_get_git_info() -> None:
- """
- Check that we can retrieve git info about the current repository correctly.
- """
+ """Check that we can retrieve git info about the current repository correctly."""
(git_repo, git_commit, rel_path) = get_git_info(__file__)
assert "mlos" in git_repo.lower()
assert re.match(r"[0-9a-f]{40}", git_commit) is not None
diff --git a/mlos_bench/mlos_bench/tests/util_nullable_test.py b/mlos_bench/mlos_bench/tests/util_nullable_test.py
index 28ed7fc92c..f0ca82eb6e 100644
--- a/mlos_bench/mlos_bench/tests/util_nullable_test.py
+++ b/mlos_bench/mlos_bench/tests/util_nullable_test.py
@@ -2,18 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for `nullable` utility function.
-"""
+"""Unit tests for `nullable` utility function."""
import pytest
from mlos_bench.util import nullable
def test_nullable_str() -> None:
- """
- Check that the `nullable` function works properly for `str`.
- """
+ """Check that the `nullable` function works properly for `str`."""
assert nullable(str, None) is None
assert nullable(str, "") is not None
assert nullable(str, "") == ""
@@ -22,9 +18,7 @@ def test_nullable_str() -> None:
def test_nullable_int() -> None:
- """
- Check that the `nullable` function works properly for `int`.
- """
+ """Check that the `nullable` function works properly for `int`."""
assert nullable(int, None) is None
assert nullable(int, 10) is not None
assert nullable(int, 10) == 10
@@ -32,9 +26,7 @@ def test_nullable_int() -> None:
def test_nullable_func() -> None:
- """
- Check that the `nullable` function works properly with `list.pop()` function.
- """
+ """Check that the `nullable` function works properly with `list.pop()` function."""
assert nullable(list.pop, None) is None
assert nullable(list.pop, [1, 2, 3]) == 3
diff --git a/mlos_bench/mlos_bench/tests/util_try_parse_test.py b/mlos_bench/mlos_bench/tests/util_try_parse_test.py
index b613c19694..d97acd0b8c 100644
--- a/mlos_bench/mlos_bench/tests/util_try_parse_test.py
+++ b/mlos_bench/mlos_bench/tests/util_try_parse_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for try_parse_val utility function.
-"""
+"""Unit tests for try_parse_val utility function."""
import math
@@ -12,9 +10,7 @@ from mlos_bench.util import try_parse_val
def test_try_parse_val() -> None:
- """
- Check that we can retrieve git info about the current repository correctly.
- """
+ """Check that we can retrieve git info about the current repository correctly."""
assert try_parse_val(None) is None
assert try_parse_val("1") == int(1)
assert try_parse_val("1.1") == float(1.1)
diff --git a/mlos_bench/mlos_bench/tunables/__init__.py b/mlos_bench/mlos_bench/tunables/__init__.py
index 4191f37d89..c5f49e9202 100644
--- a/mlos_bench/mlos_bench/tunables/__init__.py
+++ b/mlos_bench/mlos_bench/tunables/__init__.py
@@ -2,15 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tunables classes for Environments in mlos_bench.
-"""
+"""Tunables classes for Environments in mlos_bench."""
from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
__all__ = [
- 'Tunable',
- 'TunableValue',
- 'TunableGroups',
+ "Tunable",
+ "TunableValue",
+ "TunableGroups",
]
diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py
index beb8db15ec..b30c879d8f 100644
--- a/mlos_bench/mlos_bench/tunables/covariant_group.py
+++ b/mlos_bench/mlos_bench/tunables/covariant_group.py
@@ -2,11 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tunable parameter definition.
-"""
+"""Tunable parameter definition."""
import copy
-
from typing import Dict, Iterable, Union
from mlos_bench.tunables.tunable import Tunable, TunableValue
@@ -15,6 +12,7 @@ from mlos_bench.tunables.tunable import Tunable, TunableValue
class CovariantTunableGroup:
"""
A collection of tunable parameters.
+
Changing any of the parameters in the group incurs the same cost of the experiment.
"""
@@ -53,9 +51,9 @@ class CovariantTunableGroup:
@property
def cost(self) -> int:
"""
- Get the cost of changing the values in the covariant group.
- This value is a constant. Use `get_current_cost()` to get
- the cost given the group update status.
+ Get the cost of changing the values in the covariant group. This value is a
+ constant. Use `get_current_cost()` to get the cost given the group update
+ status.
Returns
-------
@@ -94,15 +92,17 @@ class CovariantTunableGroup:
return False
# TODO: May need to provide logic to relax the equality check on the
# tunables (e.g. "compatible" vs. "equal").
- return (self._name == other._name and
- self._cost == other._cost and
- self._is_updated == other._is_updated and
- self._tunables == other._tunables)
+ return (
+ self._name == other._name
+ and self._cost == other._cost
+ and self._is_updated == other._is_updated
+ and self._tunables == other._tunables
+ )
def equals_defaults(self, other: "CovariantTunableGroup") -> bool:
"""
- Checks to see if the other CovariantTunableGroup is the same, ignoring
- the current values of the two groups' Tunables.
+ Checks to see if the other CovariantTunableGroup is the same, ignoring the
+ current values of the two groups' Tunables.
Parameters
----------
@@ -127,7 +127,8 @@ class CovariantTunableGroup:
def is_defaults(self) -> bool:
"""
- Checks whether the currently assigned values of all tunables are at their defaults.
+ Checks whether the currently assigned values of all tunables are at their
+ defaults.
Returns
-------
@@ -136,9 +137,7 @@ class CovariantTunableGroup:
return all(tunable.is_default() for tunable in self._tunables.values())
def restore_defaults(self) -> None:
- """
- Restore all tunable parameters to their default values.
- """
+ """Restore all tunable parameters to their default values."""
for tunable in self._tunables.values():
if tunable.value != tunable.default:
self._is_updated = True
@@ -146,8 +145,10 @@ class CovariantTunableGroup:
def reset_is_updated(self) -> None:
"""
- Clear the update flag. That is, state that running an experiment with the
- current values of the tunables in this group has no extra cost.
+ Clear the update flag.
+
+ That is, state that running an experiment with the current values of the
+ tunables in this group has no extra cost.
"""
self._is_updated = False
@@ -174,9 +175,7 @@ class CovariantTunableGroup:
return self._cost if self._is_updated else 0
def get_names(self) -> Iterable[str]:
- """
- Get the names of all tunables in the group.
- """
+ """Get the names of all tunables in the group."""
return self._tunables.keys()
def get_tunable_values_dict(self) -> Dict[str, TunableValue]:
@@ -191,8 +190,8 @@ class CovariantTunableGroup:
def __repr__(self) -> str:
"""
- Produce a human-readable version of the CovariantTunableGroup
- (mostly for logging).
+ Produce a human-readable version of the CovariantTunableGroup (mostly for
+ logging).
Returns
-------
@@ -203,8 +202,8 @@ class CovariantTunableGroup:
def get_tunable(self, tunable: Union[str, Tunable]) -> Tunable:
"""
- Access the entire Tunable in a group (not just its value).
- Throw KeyError if the tunable is not in the group.
+ Access the entire Tunable in a group (not just its value). Throw KeyError if the
+ tunable is not in the group.
Parameters
----------
@@ -220,7 +219,8 @@ class CovariantTunableGroup:
return self._tunables[name]
def get_tunables(self) -> Iterable[Tunable]:
- """Gets the set of tunables for this CovariantTunableGroup.
+ """
+ Gets the set of tunables for this CovariantTunableGroup.
Returns
-------
@@ -235,7 +235,13 @@ class CovariantTunableGroup:
def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue:
return self.get_tunable(tunable).value
- def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue:
- value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
+ def __setitem__(
+ self,
+ tunable: Union[str, Tunable],
+ tunable_value: Union[TunableValue, Tunable],
+ ) -> TunableValue:
+ value: TunableValue = (
+ tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
+ )
self._is_updated |= self.get_tunable(tunable).update(value)
return value
diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py
index 26eb719866..8f9bb48bff 100644
--- a/mlos_bench/mlos_bench/tunables/tunable.py
+++ b/mlos_bench/mlos_bench/tunables/tunable.py
@@ -2,48 +2,49 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tunable parameter definition.
-"""
-import copy
+"""Tunable parameter definition."""
import collections
+import copy
import logging
-
-from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Type, TypedDict, Union
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypedDict,
+ Union,
+)
import numpy as np
from mlos_bench.util import nullable
_LOG = logging.getLogger(__name__)
-
-
"""A tunable parameter value type alias."""
TunableValue = Union[int, float, Optional[str]]
-
"""Tunable value type."""
TunableValueType = Union[Type[int], Type[float], Type[str]]
-
"""
Tunable value type tuple.
+
For checking with isinstance()
"""
TunableValueTypeTuple = (int, float, str, type(None))
-
"""The string name of a tunable value type."""
TunableValueTypeName = Literal["int", "float", "categorical"]
-
-"""Tunable values dictionary type"""
+"""Tunable values dictionary type."""
TunableValuesDict = Dict[str, TunableValue]
-
-"""Tunable value distribution type"""
+"""Tunable value distribution type."""
DistributionName = Literal["uniform", "normal", "beta"]
class DistributionDict(TypedDict, total=False):
- """
- A typed dict for tunable parameters' distributions.
- """
+ """A typed dict for tunable parameters' distributions."""
type: DistributionName
params: Optional[Dict[str, float]]
@@ -74,9 +75,7 @@ class TunableDict(TypedDict, total=False):
class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods
- """
- A tunable parameter definition and its current value.
- """
+ """A tunable parameter definition and its current value."""
# Maps tunable types to their corresponding Python types by name.
_DTYPE: Dict[TunableValueTypeName, TunableValueType] = {
@@ -96,7 +95,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
config : dict
Python dict that represents a Tunable (e.g., deserialized from JSON)
"""
- if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema
+ if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema
raise ValueError(f"Invalid name of the tunable: {name}")
self._name = name
self._type: TunableValueTypeName = config["type"] # required
@@ -133,8 +132,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
self.value = self._default
def _sanity_check(self) -> None:
- """
- Check if the status of the Tunable is valid, and throw ValueError if it is not.
+ """Check if the status of the Tunable is valid, and throw ValueError if it is
+ not.
"""
if self.is_categorical:
self._sanity_check_categorical()
@@ -146,8 +145,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
raise ValueError(f"Invalid default value for tunable {self}: {self.default}")
def _sanity_check_categorical(self) -> None:
- """
- Check if the status of the categorical Tunable is valid, and throw ValueError if it is not.
+ """Check if the status of the categorical Tunable is valid, and throw ValueError
+ if it is not.
"""
# pylint: disable=too-complex
assert self.is_categorical
@@ -174,8 +173,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
raise ValueError(f"All weights must be non-negative: {self}")
def _sanity_check_numerical(self) -> None:
- """
- Check if the status of the numerical Tunable is valid, and throw ValueError if it is not.
+ """Check if the status of the numerical Tunable is valid, and throw ValueError
+ if it is not.
"""
# pylint: disable=too-complex,too-many-branches
assert self.is_numerical
@@ -191,10 +190,16 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
raise ValueError(f"Number of quantization points is <= 1: {self}")
if self.dtype == float:
if not isinstance(self._quantization, (float, int)):
- raise ValueError(f"Quantization of a float param should be a float or int: {self}")
+ raise ValueError(
+ f"Quantization of a float param should be a float or int: {self}"
+ )
if self._quantization <= 0:
raise ValueError(f"Number of quantization points is <= 0: {self}")
- if self._distribution is not None and self._distribution not in {"uniform", "normal", "beta"}:
+ if self._distribution is not None and self._distribution not in {
+ "uniform",
+ "normal",
+ "beta",
+ }:
raise ValueError(f"Invalid distribution: {self}")
if self._distribution_params and self._distribution is None:
raise ValueError(f"Must specify the distribution: {self}")
@@ -219,7 +224,9 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
"""
# TODO? Add weights, specials, quantization, distribution?
if self.is_categorical:
- return f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}"
+ return (
+ f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}"
+ )
return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}"
def __eq__(self, other: object) -> bool:
@@ -240,12 +247,12 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
if not isinstance(other, Tunable):
return False
return bool(
- self._name == other._name and
- self._type == other._type and
- self._current_value == other._current_value
+ self._name == other._name
+ and self._type == other._type
+ and self._current_value == other._current_value
)
- def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements
+ def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements
"""
Compare the two Tunable objects. We mostly need this to create a canonical list
of tunable objects when hashing a TunableGroup.
@@ -292,29 +299,23 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property
def default(self) -> TunableValue:
- """
- Get the default value of the tunable.
- """
+ """Get the default value of the tunable."""
return self._default
def is_default(self) -> TunableValue:
- """
- Checks whether the currently assigned value of the tunable is at its default.
+ """Checks whether the currently assigned value of the tunable is at its
+ default.
"""
return self._default == self._current_value
@property
def value(self) -> TunableValue:
- """
- Get the current value of the tunable.
- """
+ """Get the current value of the tunable."""
return self._current_value
@value.setter
def value(self, value: TunableValue) -> TunableValue:
- """
- Set the current value of the tunable.
- """
+ """Set the current value of the tunable."""
# We need this coercion for the values produced by some optimizers
# (e.g., scikit-optimize) and for data restored from certain storage
# systems (where values can be strings).
@@ -325,18 +326,33 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
assert value is not None
coerced_value = self.dtype(value)
except Exception:
- _LOG.error("Impossible conversion: %s %s <- %s %s",
- self._type, self._name, type(value), value)
+ _LOG.error(
+ "Impossible conversion: %s %s <- %s %s",
+ self._type,
+ self._name,
+ type(value),
+ value,
+ )
raise
if self._type == "int" and isinstance(value, float) and value != coerced_value:
- _LOG.error("Loss of precision: %s %s <- %s %s",
- self._type, self._name, type(value), value)
+ _LOG.error(
+ "Loss of precision: %s %s <- %s %s",
+ self._type,
+ self._name,
+ type(value),
+ value,
+ )
raise ValueError(f"Loss of precision: {self._name}={value}")
if not self.is_valid(coerced_value):
- _LOG.error("Invalid assignment: %s %s <- %s %s",
- self._type, self._name, type(value), value)
+ _LOG.error(
+ "Invalid assignment: %s %s <- %s %s",
+ self._type,
+ self._name,
+ type(value),
+ value,
+ )
raise ValueError(f"Invalid value for the tunable: {self._name}={value}")
self._current_value = coerced_value
@@ -344,7 +360,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
def update(self, value: TunableValue) -> bool:
"""
- Assign the value to the tunable. Return True if it is a new value, False otherwise.
+ Assign the value to the tunable. Return True if it is a new value, False
+ otherwise.
Parameters
----------
@@ -388,21 +405,20 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
def in_range(self, value: Union[int, float, str, None]) -> bool:
"""
Check if the value is within the range of the tunable.
- Do *NOT* check for special values.
- Return False if the tunable or value is categorical or None.
+
+ Do *NOT* check for special values. Return False if the tunable or value is
+ categorical or None.
"""
return (
- isinstance(value, (float, int)) and
- self.is_numerical and
- self._range is not None and
- bool(self._range[0] <= value <= self._range[1])
+ isinstance(value, (float, int))
+ and self.is_numerical
+ and self._range is not None
+ and bool(self._range[0] <= value <= self._range[1])
)
@property
def category(self) -> Optional[str]:
- """
- Get the current value of the tunable as a number.
- """
+ """Get the current value of the tunable as a number."""
if self.is_categorical:
return nullable(str, self._current_value)
else:
@@ -410,9 +426,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@category.setter
def category(self, new_value: Optional[str]) -> Optional[str]:
- """
- Set the current value of the tunable.
- """
+ """Set the current value of the tunable."""
assert self.is_categorical
assert isinstance(new_value, (str, type(None)))
self.value = new_value
@@ -420,9 +434,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property
def numerical_value(self) -> Union[int, float]:
- """
- Get the current value of the tunable as a number.
- """
+ """Get the current value of the tunable as a number."""
assert self._current_value is not None
if self._type == "int":
return int(self._current_value)
@@ -433,9 +445,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@numerical_value.setter
def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]:
- """
- Set the current numerical value of the tunable.
- """
+ """Set the current numerical value of the tunable."""
# We need this coercion for the values produced by some optimizers
# (e.g., scikit-optimize) and for data restored from certain storage
# systems (where values can be strings).
@@ -445,9 +455,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property
def name(self) -> str:
- """
- Get the name / string ID of the tunable.
- """
+ """Get the name / string ID of the tunable."""
return self._name
@property
@@ -477,8 +485,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property
def weights(self) -> Optional[List[float]]:
"""
- Get the weights of the categories or special values of the tunable.
- Return None if there are none.
+ Get the weights of the categories or special values of the tunable. Return None
+ if there are none.
Returns
-------
@@ -490,8 +498,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property
def range_weight(self) -> Optional[float]:
"""
- Get weight of the range of the numeric tunable.
- Return None if there are no weights or a tunable is categorical.
+ Get weight of the range of the numeric tunable. Return None if there are no
+ weights or a tunable is categorical.
Returns
-------
@@ -615,10 +623,15 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
# Be sure to return python types instead of numpy types.
cardinality = self.cardinality
assert isinstance(cardinality, int)
- return (float(x) for x in np.linspace(start=num_range[0],
- stop=num_range[1],
- num=cardinality,
- endpoint=True))
+ return (
+ float(x)
+ for x in np.linspace(
+ start=num_range[0],
+ stop=num_range[1],
+ num=cardinality,
+ endpoint=True,
+ )
+ )
assert self.type == "int", f"Unhandled tunable type: {self}"
return range(int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1))
@@ -682,8 +695,8 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property
def categories(self) -> List[Optional[str]]:
"""
- Get the list of all possible values of a categorical tunable.
- Return None if the tunable is not categorical.
+ Get the list of all possible values of a categorical tunable. Return None if the
+ tunable is not categorical.
Returns
-------
@@ -712,7 +725,9 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property
def meta(self) -> Dict[str, Any]:
"""
- Get the tunable's metadata. This is a free-form dictionary that can be used to
- store any additional information about the tunable (e.g., the unit information).
+ Get the tunable's metadata.
+
+ This is a free-form dictionary that can be used to store any additional
+ information about the tunable (e.g., the unit information).
"""
return self._meta
diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py
index f97bf9de7d..da56eb79ac 100644
--- a/mlos_bench/mlos_bench/tunables/tunable_groups.py
+++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py
@@ -2,22 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-TunableGroups definition.
-"""
+"""TunableGroups definition."""
import copy
-
from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union
from mlos_bench.config.schemas import ConfigSchema
-from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.covariant_group import CovariantTunableGroup
+from mlos_bench.tunables.tunable import Tunable, TunableValue
class TunableGroups:
- """
- A collection of covariant groups of tunable parameters.
- """
+ """A collection of covariant groups of tunable parameters."""
def __init__(self, config: Optional[dict] = None):
"""
@@ -31,9 +26,10 @@ class TunableGroups:
if config is None:
config = {}
ConfigSchema.TUNABLE_PARAMS.validate(config)
- self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup)
+ # Index (Tunable id -> CovariantTunableGroup)
+ self._index: Dict[str, CovariantTunableGroup] = {}
self._tunable_groups: Dict[str, CovariantTunableGroup] = {}
- for (name, group_config) in config.items():
+ for name, group_config in config.items():
self._add_group(CovariantTunableGroup(name, group_config))
def __bool__(self) -> bool:
@@ -82,11 +78,15 @@ class TunableGroups:
----------
group : CovariantTunableGroup
"""
- assert group.name not in self._tunable_groups, f"Duplicate covariant tunable group name {group.name} in {self}"
+ assert (
+ group.name not in self._tunable_groups
+ ), f"Duplicate covariant tunable group name {group.name} in {self}"
self._tunable_groups[group.name] = group
for tunable in group.get_tunables():
if tunable.name in self._index:
- raise ValueError(f"Duplicate Tunable {tunable.name} from group {group.name} in {self}")
+ raise ValueError(
+ f"Duplicate Tunable {tunable.name} from group {group.name} in {self}"
+ )
self._index[tunable.name] = group
def merge(self, tunables: "TunableGroups") -> "TunableGroups":
@@ -120,8 +120,10 @@ class TunableGroups:
# Check that there's no overlap in the tunables.
# But allow for differing current values.
if not self._tunable_groups[group.name].equals_defaults(group):
- raise ValueError(f"Overlapping covariant tunable group name {group.name} " +
- "in {self._tunable_groups[group.name]} and {tunables}")
+ raise ValueError(
+ f"Overlapping covariant tunable group name {group.name} "
+ "in {self._tunable_groups[group.name]} and {tunables}"
+ )
return self
def __repr__(self) -> str:
@@ -133,32 +135,37 @@ class TunableGroups:
string : str
A human-readable version of the TunableGroups.
"""
- return "{ " + ", ".join(
- f"{group.name}::{tunable}"
- for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name))
- for tunable in sorted(group._tunables.values())) + " }"
+ return (
+ "{ "
+ + ", ".join(
+ f"{group.name}::{tunable}"
+ for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name))
+ for tunable in sorted(group._tunables.values())
+ )
+ + " }"
+ )
def __contains__(self, tunable: Union[str, Tunable]) -> bool:
- """
- Checks if the given name/tunable is in this tunable group.
- """
+ """Checks if the given name/tunable is in this tunable group."""
name: str = tunable.name if isinstance(tunable, Tunable) else tunable
return name in self._index
def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue:
- """
- Get the current value of a single tunable parameter.
- """
+ """Get the current value of a single tunable parameter."""
name: str = tunable.name if isinstance(tunable, Tunable) else tunable
return self._index[name][name]
- def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue:
- """
- Update the current value of a single tunable parameter.
- """
+ def __setitem__(
+ self,
+ tunable: Union[str, Tunable],
+ tunable_value: Union[TunableValue, Tunable],
+ ) -> TunableValue:
+ """Update the current value of a single tunable parameter."""
# Use double index to make sure we set the is_updated flag of the group
name: str = tunable.name if isinstance(tunable, Tunable) else tunable
- value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
+ value: TunableValue = (
+ tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
+ )
self._index[name][name] = value
return self._index[name][name]
@@ -176,8 +183,8 @@ class TunableGroups:
def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]:
"""
- Access the entire Tunable (not just its value) and its covariant group.
- Throw KeyError if the tunable is not found.
+ Access the entire Tunable (not just its value) and its covariant group. Throw
+ KeyError if the tunable is not found.
Parameters
----------
@@ -206,8 +213,8 @@ class TunableGroups:
def subgroup(self, group_names: Iterable[str]) -> "TunableGroups":
"""
- Select the covariance groups from the current set and create a new
- TunableGroups object that consists of those covariance groups.
+ Select the covariance groups from the current set and create a new TunableGroups
+ object that consists of those covariance groups.
Note: The new TunableGroup will include *references* (not copies) to
original ones, so each will get updated together.
@@ -233,10 +240,14 @@ class TunableGroups:
tunables._add_group(self._tunable_groups[name])
return tunables
- def get_param_values(self, group_names: Optional[Iterable[str]] = None,
- into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]:
+ def get_param_values(
+ self,
+ group_names: Optional[Iterable[str]] = None,
+ into_params: Optional[Dict[str, TunableValue]] = None,
+ ) -> Dict[str, TunableValue]:
"""
- Get the current values of the tunables that belong to the specified covariance groups.
+ Get the current values of the tunables that belong to the specified covariance
+ groups.
Parameters
----------
@@ -273,12 +284,15 @@ class TunableGroups:
is_updated : bool
True if any of the specified tunable groups has been updated, False otherwise.
"""
- return any(self._tunable_groups[name].is_updated()
- for name in (group_names or self.get_covariant_group_names()))
+ return any(
+ self._tunable_groups[name].is_updated()
+ for name in (group_names or self.get_covariant_group_names())
+ )
def is_defaults(self) -> bool:
"""
- Checks whether the currently assigned values of all tunables are at their defaults.
+ Checks whether the currently assigned values of all tunables are at their
+ defaults.
Returns
-------
@@ -300,7 +314,7 @@ class TunableGroups:
self : TunableGroups
Self-reference for chaining.
"""
- for name in (group_names or self.get_covariant_group_names()):
+ for name in group_names or self.get_covariant_group_names():
self._tunable_groups[name].restore_defaults()
return self
@@ -318,14 +332,14 @@ class TunableGroups:
self : TunableGroups
Self-reference for chaining.
"""
- for name in (group_names or self.get_covariant_group_names()):
+ for name in group_names or self.get_covariant_group_names():
self._tunable_groups[name].reset_is_updated()
return self
def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
"""
- In-place update the values of the tunables from the dictionary
- of (key, value) pairs.
+ In-place update the values of the tunables from the dictionary of (key, value)
+ pairs.
Parameters
----------
diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py
index eb7dd3990d..64d3600966 100644
--- a/mlos_bench/mlos_bench/util.py
+++ b/mlos_bench/mlos_bench/util.py
@@ -2,28 +2,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Various helper functions for mlos_bench.
-"""
+"""Various helper functions for mlos_bench."""
# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports.
-from datetime import datetime
-import os
+import importlib
import json
import logging
-import importlib
+import os
import subprocess
-
+from datetime import datetime
from typing import (
- Any, Callable, Dict, Iterable, Literal, Mapping, Optional,
- Tuple, Type, TypeVar, TYPE_CHECKING, Union,
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Literal,
+ Mapping,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
)
import pandas
import pytz
-
_LOG = logging.getLogger(__name__)
if TYPE_CHECKING:
@@ -40,8 +46,8 @@ BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"]
def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> dict:
"""
- Replaces all $name values in the destination config with the corresponding
- value from the source config.
+ Replaces all $name values in the destination config with the corresponding value
+ from the source config.
Parameters
----------
@@ -63,12 +69,15 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) ->
return dest
-def merge_parameters(*, dest: dict, source: Optional[dict] = None,
- required_keys: Optional[Iterable[str]] = None) -> dict:
+def merge_parameters(
+ *,
+ dest: dict,
+ source: Optional[dict] = None,
+ required_keys: Optional[Iterable[str]] = None,
+) -> dict:
"""
- Merge the source config dict into the destination config.
- Pick from the source configs *ONLY* the keys that are already present
- in the destination config.
+ Merge the source config dict into the destination config. Pick from the source
+ configs *ONLY* the keys that are already present in the destination config.
Parameters
----------
@@ -124,8 +133,10 @@ def path_join(*args: str, abs_path: bool = False) -> str:
return os.path.normpath(path).replace("\\", "/")
-def prepare_class_load(config: dict,
- global_config: Optional[Dict[str, Any]] = None) -> Tuple[str, Dict[str, Any]]:
+def prepare_class_load(
+ config: dict,
+ global_config: Optional[Dict[str, Any]] = None,
+) -> Tuple[str, Dict[str, Any]]:
"""
Extract the class instantiation parameters from the configuration.
@@ -147,8 +158,9 @@ def prepare_class_load(config: dict,
merge_parameters(dest=class_config, source=global_config)
if _LOG.isEnabledFor(logging.DEBUG):
- _LOG.debug("Instantiating: %s with config:\n%s",
- class_name, json.dumps(class_config, indent=2))
+ _LOG.debug(
+ "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2)
+ )
return (class_name, class_config)
@@ -179,8 +191,12 @@ def get_class_from_name(class_name: str) -> type:
# FIXME: Technically, this should return a type "class_name" derived from "base_class".
-def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str,
- *args: Any, **kwargs: Any) -> BaseTypeVar:
+def instantiate_from_config(
+ base_class: Type[BaseTypeVar],
+ class_name: str,
+ *args: Any,
+ **kwargs: Any,
+) -> BaseTypeVar:
"""
Factory method for a new class instantiated from config.
@@ -214,8 +230,8 @@ def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str,
def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None:
"""
- Check if all required parameters are present in the configuration.
- Raise ValueError if any of the parameters are missing.
+ Check if all required parameters are present in the configuration. Raise ValueError
+ if any of the parameters are missing.
Parameters
----------
@@ -230,7 +246,8 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s
if missing_params:
raise ValueError(
"The following parameters must be provided in the configuration"
- + f" or as command line arguments: {missing_params}")
+ + f" or as command line arguments: {missing_params}"
+ )
def get_git_info(path: str = __file__) -> Tuple[str, str, str]:
@@ -249,11 +266,14 @@ def get_git_info(path: str = __file__) -> Tuple[str, str, str]:
"""
dirname = os.path.dirname(path)
git_repo = subprocess.check_output(
- ["git", "-C", dirname, "remote", "get-url", "origin"], text=True).strip()
+ ["git", "-C", dirname, "remote", "get-url", "origin"], text=True
+ ).strip()
git_commit = subprocess.check_output(
- ["git", "-C", dirname, "rev-parse", "HEAD"], text=True).strip()
+ ["git", "-C", dirname, "rev-parse", "HEAD"], text=True
+ ).strip()
git_root = subprocess.check_output(
- ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True).strip()
+ ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True
+ ).strip()
_LOG.debug("Current git branch: %s %s", git_repo, git_commit)
rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root))
return (git_repo, git_commit, rel_path.replace("\\", "/"))
@@ -347,10 +367,12 @@ def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) ->
raise ValueError(f"Invalid origin: {origin}")
-def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]:
- """
- A nullable version of utcify_timestamp.
- """
+def utcify_nullable_timestamp(
+ timestamp: Optional[datetime],
+ *,
+ origin: Literal["utc", "local"],
+) -> Optional[datetime]:
+ """A nullable version of utcify_timestamp."""
return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None
@@ -359,7 +381,11 @@ def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal[
_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
-def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "local"]) -> pandas.Series:
+def datetime_parser(
+ datetime_col: pandas.Series,
+ *,
+ origin: Literal["utc", "local"],
+) -> pandas.Series:
"""
Attempt to convert a pandas column to a datetime format.
@@ -393,7 +419,7 @@ def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "loca
new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo)
assert new_datetime_col.dt.tz is not None
# And convert it to UTC.
- new_datetime_col = new_datetime_col.dt.tz_convert('UTC')
+ new_datetime_col = new_datetime_col.dt.tz_convert("UTC")
if new_datetime_col.isna().any():
raise ValueError(f"Invalid date format in the data: {datetime_col}")
if new_datetime_col.le(_MIN_TS).any():
diff --git a/mlos_bench/mlos_bench/version.py b/mlos_bench/mlos_bench/version.py
index 96d3d2b6bf..ab6ab85d2d 100644
--- a/mlos_bench/mlos_bench/version.py
+++ b/mlos_bench/mlos_bench/version.py
@@ -2,12 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Version number for the mlos_bench package.
-"""
+"""Version number for the mlos_bench package."""
# NOTE: This should be managed by bumpversion.
-VERSION = '0.5.1'
+VERSION = "0.5.1"
if __name__ == "__main__":
print(VERSION)
diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py
index 2a5a7fe538..cb3d975d92 100644
--- a/mlos_bench/setup.py
+++ b/mlos_bench/setup.py
@@ -2,18 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Setup instructions for the mlos_bench package.
-"""
+"""Setup instructions for the mlos_bench package."""
# pylint: disable=duplicate-code
-from logging import warning
-from itertools import chain
-from typing import Dict, List
-
import os
import re
+from itertools import chain
+from logging import warning
+from typing import Dict, List
from setuptools import setup
@@ -22,15 +19,16 @@ PKG_NAME = "mlos_bench"
try:
ns: Dict[str, str] = {}
with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file:
- exec(version_file.read(), ns) # pylint: disable=exec-used
- VERSION = ns['VERSION']
+ exec(version_file.read(), ns) # pylint: disable=exec-used
+ VERSION = ns["VERSION"]
except OSError:
VERSION = "0.0.1-dev"
warning(f"version.py not found, using dummy VERSION={VERSION}")
try:
from setuptools_scm import get_version
- version = get_version(root='..', relative_to=__file__, fallback_version=VERSION)
+
+ version = get_version(root="..", relative_to=__file__, fallback_version=VERSION)
if version is not None:
VERSION = version
except ImportError:
@@ -48,62 +46,67 @@ except LookupError as e:
# be duplicated for now.
def _get_long_desc_from_readme(base_url: str) -> dict:
pkg_dir = os.path.dirname(__file__)
- readme_path = os.path.join(pkg_dir, 'README.md')
+ readme_path = os.path.join(pkg_dir, "README.md")
if not os.path.isfile(readme_path):
return {
- 'long_description': 'missing',
+ "long_description": "missing",
}
- jsonc_re = re.compile(r'```jsonc')
- link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)')
- with open(readme_path, mode='r', encoding='utf-8') as readme_fh:
+ jsonc_re = re.compile(r"```jsonc")
+ link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)")
+ with open(readme_path, mode="r", encoding="utf-8") as readme_fh:
lines = readme_fh.readlines()
# Tweak the lexers for local expansion by pygments instead of github's.
- lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines]
+ lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines]
# Tweak source source code links.
- lines = [jsonc_re.sub(r'```json', line) for line in lines]
+ lines = [jsonc_re.sub(r"```json", line) for line in lines]
return {
- 'long_description': ''.join(lines),
- 'long_description_content_type': 'text/markdown',
+ "long_description": "".join(lines),
+ "long_description_content_type": "text/markdown",
}
-extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass
+extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass
# Additional tools for extra functionality.
- 'azure': ['azure-storage-file-share', 'azure-identity', 'azure-keyvault'],
- 'ssh': ['asyncssh<2.15.0'], # FIXME: asyncssh 2.15.0 has a bug that breaks the tests
- 'storage-sql-duckdb': ['sqlalchemy', 'duckdb_engine'],
- 'storage-sql-mysql': ['sqlalchemy', 'mysql-connector-python'],
- 'storage-sql-postgres': ['sqlalchemy', 'psycopg2'],
- 'storage-sql-sqlite': ['sqlalchemy'], # sqlite3 comes with python, so we don't need to install it.
+ "azure": ["azure-storage-file-share", "azure-identity", "azure-keyvault"],
+ "ssh": ["asyncssh<2.15.0"], # FIXME: asyncssh 2.15.0 has a bug that breaks the tests
+ "storage-sql-duckdb": ["sqlalchemy", "duckdb_engine"],
+ "storage-sql-mysql": ["sqlalchemy", "mysql-connector-python"],
+ "storage-sql-postgres": ["sqlalchemy", "psycopg2"],
+ # sqlite3 comes with python, so we don't need to install it.
+ "storage-sql-sqlite": ["sqlalchemy"],
# Transitive extra_requires from mlos-core.
- 'flaml': ['flaml[blendsearch]'],
- 'smac': ['smac'],
+ "flaml": ["flaml[blendsearch]"],
+ "smac": ["smac"],
}
# construct special 'full' extra that adds requirements for all built-in
# backend integrations and additional extra features.
-extra_requires['full'] = list(set(chain(*extra_requires.values())))
+extra_requires["full"] = list(set(chain(*extra_requires.values())))
-extra_requires['full-tests'] = extra_requires['full'] + [
- 'pytest',
- 'pytest-forked',
- 'pytest-xdist',
- 'pytest-cov',
- 'pytest-local-badge',
- 'pytest-lazy-fixtures',
- 'pytest-docker',
- 'fasteners',
+extra_requires["full-tests"] = extra_requires["full"] + [
+ "pytest",
+ "pytest-forked",
+ "pytest-xdist",
+ "pytest-cov",
+ "pytest-local-badge",
+ "pytest-lazy-fixtures",
+ "pytest-docker",
+ "fasteners",
]
setup(
version=VERSION,
install_requires=[
- 'mlos-core==' + VERSION,
- 'requests',
- 'json5',
- 'jsonschema>=4.18.0', 'referencing>=0.29.1',
+ "mlos-core==" + VERSION,
+ "requests",
+ "json5",
+ "jsonschema>=4.18.0",
+ "referencing>=0.29.1",
'importlib_resources;python_version<"3.10"',
- ] + extra_requires['storage-sql-sqlite'], # NOTE: For now sqlite is a fallback storage backend, so we always install it.
+ ]
+ + extra_requires[
+ "storage-sql-sqlite"
+ ], # NOTE: For now sqlite is a fallback storage backend, so we always install it.
extras_require=extra_requires,
- **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_bench'),
+ **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_bench"),
)
diff --git a/mlos_core/mlos_core/__init__.py b/mlos_core/mlos_core/__init__.py
index 3d816eb916..41d24af928 100644
--- a/mlos_core/mlos_core/__init__.py
+++ b/mlos_core/mlos_core/__init__.py
@@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Basic initializer module for the mlos_core package.
-"""
+"""Basic initializer module for the mlos_core package."""
diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py
index b00a9e8eb1..396bd5e212 100644
--- a/mlos_core/mlos_core/optimizers/__init__.py
+++ b/mlos_core/mlos_core/optimizers/__init__.py
@@ -2,28 +2,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Basic initializer module for the mlos_core optimizers.
-"""
+"""Basic initializer module for the mlos_core optimizers."""
from enum import Enum
from typing import List, Optional, TypeVar
import ConfigSpace
-from mlos_core.optimizers.optimizer import BaseOptimizer
-from mlos_core.optimizers.random_optimizer import RandomOptimizer
from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer
from mlos_core.optimizers.flaml_optimizer import FlamlOptimizer
-from mlos_core.spaces.adapters import SpaceAdapterType, SpaceAdapterFactory
+from mlos_core.optimizers.optimizer import BaseOptimizer
+from mlos_core.optimizers.random_optimizer import RandomOptimizer
+from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType
__all__ = [
- 'SpaceAdapterType',
- 'OptimizerFactory',
- 'BaseOptimizer',
- 'RandomOptimizer',
- 'FlamlOptimizer',
- 'SmacOptimizer',
+ "SpaceAdapterType",
+ "OptimizerFactory",
+ "BaseOptimizer",
+ "RandomOptimizer",
+ "FlamlOptimizer",
+ "SmacOptimizer",
]
@@ -31,13 +29,13 @@ class OptimizerType(Enum):
"""Enumerate supported MlosCore optimizers."""
RANDOM = RandomOptimizer
- """An instance of RandomOptimizer class will be used"""
+ """An instance of RandomOptimizer class will be used."""
FLAML = FlamlOptimizer
- """An instance of FlamlOptimizer class will be used"""
+ """An instance of FlamlOptimizer class will be used."""
SMAC = SmacOptimizer
- """An instance of SmacOptimizer class will be used"""
+ """An instance of SmacOptimizer class will be used."""
# To make mypy happy, we need to define a type variable for each optimizer type.
@@ -45,7 +43,7 @@ class OptimizerType(Enum):
# ConcreteOptimizer = TypeVar('ConcreteOptimizer', *[member.value for member in OptimizerType])
# To address this, we add a test for complete coverage of the enum.
ConcreteOptimizer = TypeVar(
- 'ConcreteOptimizer',
+ "ConcreteOptimizer",
RandomOptimizer,
FlamlOptimizer,
SmacOptimizer,
@@ -55,21 +53,23 @@ DEFAULT_OPTIMIZER_TYPE = OptimizerType.FLAML
class OptimizerFactory:
- """Simple factory class for creating BaseOptimizer-derived objects"""
+ """Simple factory class for creating BaseOptimizer-derived objects."""
# pylint: disable=too-few-public-methods
@staticmethod
- def create(*,
- parameter_space: ConfigSpace.ConfigurationSpace,
- optimization_targets: List[str],
- optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE,
- optimizer_kwargs: Optional[dict] = None,
- space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY,
- space_adapter_kwargs: Optional[dict] = None) -> ConcreteOptimizer: # type: ignore[type-var]
+ def create(
+ *,
+ parameter_space: ConfigSpace.ConfigurationSpace,
+ optimization_targets: List[str],
+ optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE,
+ optimizer_kwargs: Optional[dict] = None,
+ space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY,
+ space_adapter_kwargs: Optional[dict] = None,
+ ) -> ConcreteOptimizer: # type: ignore[type-var]
"""
- Create a new optimizer instance, given the parameter space, optimizer type,
- and potential optimizer options.
+ Create a new optimizer instance, given the parameter space, optimizer type, and
+ potential optimizer options.
Parameters
----------
@@ -107,7 +107,7 @@ class OptimizerFactory:
parameter_space=parameter_space,
optimization_targets=optimization_targets,
space_adapter=space_adapter,
- **optimizer_kwargs
+ **optimizer_kwargs,
)
return optimizer
diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py
index 55f0aa09eb..1a4fea7188 100644
--- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py
+++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py
@@ -2,15 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Basic initializer module for the mlos_core Bayesian optimizers.
-"""
+"""Basic initializer module for the mlos_core Bayesian optimizers."""
-from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import BaseBayesianOptimizer
+from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import (
+ BaseBayesianOptimizer,
+)
from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer
-
__all__ = [
- 'BaseBayesianOptimizer',
- 'SmacOptimizer',
+ "BaseBayesianOptimizer",
+ "SmacOptimizer",
]
diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py
index 2de01637f8..a39a5516e8 100644
--- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py
+++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py
@@ -2,16 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Contains the wrapper classes for base Bayesian optimizers.
-"""
+"""Contains the wrapper classes for base Bayesian optimizers."""
from abc import ABCMeta, abstractmethod
-
from typing import Optional
-import pandas as pd
import numpy.typing as npt
+import pandas as pd
from mlos_core.optimizers.optimizer import BaseOptimizer
@@ -20,31 +17,45 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta):
"""Abstract base class defining the interface for Bayesian optimization."""
@abstractmethod
- def surrogate_predict(self, *, configs: pd.DataFrame,
- context: Optional[pd.DataFrame] = None) -> npt.NDArray:
- """Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s).
+ def surrogate_predict(
+ self,
+ *,
+ configs: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ ) -> npt.NDArray:
+ """
+ Obtain a prediction from this Bayesian optimizer's surrogate model for the given
+ configuration(s).
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
context : pd.DataFrame
Not Yet Implemented.
"""
- pass # pylint: disable=unnecessary-pass # pragma: no cover
+ pass # pylint: disable=unnecessary-pass # pragma: no cover
@abstractmethod
- def acquisition_function(self, *, configs: pd.DataFrame,
- context: Optional[pd.DataFrame] = None) -> npt.NDArray:
- """Invokes the acquisition function from this Bayesian optimizer for the given configuration.
+ def acquisition_function(
+ self,
+ *,
+ configs: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ ) -> npt.NDArray:
+ """
+ Invokes the acquisition function from this Bayesian optimizer for the given
+ configuration.
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
context : pd.DataFrame
Not Yet Implemented.
"""
- pass # pylint: disable=unnecessary-pass # pragma: no cover
+ pass # pylint: disable=unnecessary-pass # pragma: no cover
diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py
index aa948b8125..611dc04044 100644
--- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py
+++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py
@@ -4,42 +4,46 @@
#
"""
Contains the wrapper class for SMAC Bayesian optimizers.
+
See Also:
"""
from logging import warning
from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from tempfile import TemporaryDirectory
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from warnings import warn
import ConfigSpace
import numpy.typing as npt
import pandas as pd
-from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import BaseBayesianOptimizer
+from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import (
+ BaseBayesianOptimizer,
+)
from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter
from mlos_core.spaces.adapters.identity_adapter import IdentityAdapter
class SmacOptimizer(BaseBayesianOptimizer):
- """
- Wrapper class for SMAC based Bayesian optimization.
- """
+ """Wrapper class for SMAC based Bayesian optimization."""
- def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments
- parameter_space: ConfigSpace.ConfigurationSpace,
- optimization_targets: List[str],
- objective_weights: Optional[List[float]] = None,
- space_adapter: Optional[BaseSpaceAdapter] = None,
- seed: Optional[int] = 0,
- run_name: Optional[str] = None,
- output_directory: Optional[str] = None,
- max_trials: int = 100,
- n_random_init: Optional[int] = None,
- max_ratio: Optional[float] = None,
- use_default_config: bool = False,
- n_random_probability: float = 0.1):
+ def __init__(
+ self,
+ *, # pylint: disable=too-many-locals,too-many-arguments
+ parameter_space: ConfigSpace.ConfigurationSpace,
+ optimization_targets: List[str],
+ objective_weights: Optional[List[float]] = None,
+ space_adapter: Optional[BaseSpaceAdapter] = None,
+ seed: Optional[int] = 0,
+ run_name: Optional[str] = None,
+ output_directory: Optional[str] = None,
+ max_trials: int = 100,
+ n_random_init: Optional[int] = None,
+ max_ratio: Optional[float] = None,
+ use_default_config: bool = False,
+ n_random_probability: float = 0.1,
+ ):
"""
Instantiate a new SMAC optimizer wrapper.
@@ -59,18 +63,21 @@ class SmacOptimizer(BaseBayesianOptimizer):
seed : Optional[int]
By default SMAC uses a known seed (0) to keep results reproducible.
- However, if a `None` seed is explicitly provided, we let a random seed be produced by SMAC.
+ However, if a `None` seed is explicitly provided, we let a random seed
+ be produced by SMAC.
run_name : Optional[str]
Name of this run. This is used to easily distinguish across different runs.
If set to `None` (default), SMAC will generate a hash from metadata.
output_directory : Optional[str]
- The directory where SMAC output will saved. If set to `None` (default), a temporary dir will be used.
+ The directory where SMAC output will saved. If set to `None` (default),
+ a temporary dir will be used.
max_trials : int
Maximum number of trials (i.e., function evaluations) to be run. Defaults to 100.
- Note that modifying this value directly affects the value of `n_random_init`, if latter is set to `None`.
+ Note that modifying this value directly affects the value of
+ `n_random_init`, if latter is set to `None`.
n_random_init : Optional[int]
Number of points evaluated at start to bootstrap the optimizer.
@@ -114,7 +121,8 @@ class SmacOptimizer(BaseBayesianOptimizer):
self.trial_info_map: Dict[ConfigSpace.Configuration, TrialInfo] = {}
# The default when not specified is to use a known seed (0) to keep results reproducible.
- # However, if a `None` seed is explicitly provided, we let a random seed be produced by SMAC.
+ # However, if a `None` seed is explicitly provided, we let a random seed be
+ # produced by SMAC.
# https://automl.github.io/SMAC3/main/api/smac.scenario.html#smac.scenario.Scenario
seed = -1 if seed is None else seed
@@ -122,7 +130,8 @@ class SmacOptimizer(BaseBayesianOptimizer):
if output_directory is None:
# pylint: disable=consider-using-with
try:
- self._temp_output_directory = TemporaryDirectory(ignore_cleanup_errors=True) # Argument added in Python 3.10
+ # Argument added in Python 3.10
+ self._temp_output_directory = TemporaryDirectory(ignore_cleanup_errors=True)
except TypeError:
self._temp_output_directory = TemporaryDirectory()
output_directory = self._temp_output_directory.name
@@ -144,8 +153,14 @@ class SmacOptimizer(BaseBayesianOptimizer):
seed=seed or -1, # if -1, SMAC will generate a random seed internally
n_workers=1, # Use a single thread for evaluating trials
)
- intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier(scenario, max_config_calls=1)
- config_selector: ConfigSelector = Optimizer_Smac.get_config_selector(scenario, retrain_after=1)
+ intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier(
+ scenario,
+ max_config_calls=1,
+ )
+ config_selector: ConfigSelector = Optimizer_Smac.get_config_selector(
+ scenario,
+ retrain_after=1,
+ )
# TODO: When bulk registering prior configs to rewarm the optimizer,
# there is a way to inform SMAC's initial design that we have
@@ -156,39 +171,46 @@ class SmacOptimizer(BaseBayesianOptimizer):
# See Also: #488
initial_design_args: Dict[str, Union[list, int, float, Scenario]] = {
- 'scenario': scenario,
+ "scenario": scenario,
# Workaround a bug in SMAC that sets a default arg to a mutable
# value that can cause issues when multiple optimizers are
# instantiated with the use_default_config option within the same
# process that use different ConfigSpaces so that the second
# receives the default config from both as an additional config.
- 'additional_configs': []
+ "additional_configs": [],
}
if n_random_init is not None:
- initial_design_args['n_configs'] = n_random_init
+ initial_design_args["n_configs"] = n_random_init
if n_random_init > 0.25 * max_trials and max_ratio is None:
warning(
- 'Number of random initial configs (%d) is ' +
- 'greater than 25%% of max_trials (%d). ' +
- 'Consider setting max_ratio to avoid SMAC overriding n_random_init.',
+ "Number of random initial configs (%d) is "
+ + "greater than 25%% of max_trials (%d). "
+ + "Consider setting max_ratio to avoid SMAC overriding n_random_init.",
n_random_init,
max_trials,
)
if max_ratio is not None:
assert isinstance(max_ratio, float) and 0.0 <= max_ratio <= 1.0
- initial_design_args['max_ratio'] = max_ratio
+ initial_design_args["max_ratio"] = max_ratio
# Use the default InitialDesign from SMAC.
# (currently SBOL instead of LatinHypercube due to better uniformity
# for initial sampling which results in lower overall samples required)
- initial_design = Optimizer_Smac.get_initial_design(**initial_design_args) # type: ignore[arg-type]
- # initial_design = LatinHypercubeInitialDesign(**initial_design_args) # type: ignore[arg-type]
+ initial_design = Optimizer_Smac.get_initial_design(
+ **initial_design_args, # type: ignore[arg-type]
+ )
+ # initial_design = LatinHypercubeInitialDesign(
+ # **initial_design_args, # type: ignore[arg-type]
+ # )
# Workaround a bug in SMAC that doesn't pass the seed to the random
# design when generated a random_design for itself via the
# get_random_design static method when random_design is None.
assert isinstance(n_random_probability, float) and n_random_probability >= 0
- random_design = ProbabilityRandomDesign(probability=n_random_probability, seed=scenario.seed)
+ random_design = ProbabilityRandomDesign(
+ probability=n_random_probability,
+ seed=scenario.seed,
+ )
self.base_optimizer = Optimizer_Smac(
scenario,
@@ -198,7 +220,9 @@ class SmacOptimizer(BaseBayesianOptimizer):
random_design=random_design,
config_selector=config_selector,
multi_objective_algorithm=Optimizer_Smac.get_multi_objective_algorithm(
- scenario, objective_weights=self._objective_weights),
+ scenario,
+ objective_weights=self._objective_weights,
+ ),
overwrite=True,
logging_level=False, # Use the existing logger
)
@@ -210,9 +234,11 @@ class SmacOptimizer(BaseBayesianOptimizer):
@property
def n_random_init(self) -> int:
"""
- Gets the number of random samples to use to initialize the optimizer's search space sampling.
+ Gets the number of random samples to use to initialize the optimizer's search
+ space sampling.
- Note: This may not be equal to the value passed to the initializer, due to logic present in the SMAC.
+ Note: This may not be equal to the value passed to the initializer, due to
+ logic present in the SMAC.
See Also: max_ratio
Returns
@@ -225,7 +251,8 @@ class SmacOptimizer(BaseBayesianOptimizer):
@staticmethod
def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None:
- """Dummy target function for SMAC optimizer.
+ """
+ Dummy target function for SMAC optimizer.
Since we only use the ask-and-tell interface, this is never called.
@@ -237,21 +264,31 @@ class SmacOptimizer(BaseBayesianOptimizer):
seed : int
Random seed to use for the target function. Not actually used.
"""
- # NOTE: Providing a target function when using the ask-and-tell interface is an imperfection of the API
- # -- this planned to be fixed in some future release: https://github.com/automl/SMAC3/issues/946
- raise RuntimeError('This function should never be called.')
+ # NOTE: Providing a target function when using the ask-and-tell interface is
+ # an imperfection of the API -- this is planned to be fixed in some future
+ # release: https://github.com/automl/SMAC3/issues/946
+ raise RuntimeError("This function should never be called.")
- def _register(self, *, configs: pd.DataFrame,
- scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None:
- """Registers the given configs and scores.
+ def _register(
+ self,
+ *,
+ configs: pd.DataFrame,
+ scores: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
+ """
+ Registers the given configs and scores.
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
scores : pd.DataFrame
- Scores from running the configs. The index is the same as the index of the configs.
+ Scores from running the configs. The index is the same as the index of
+ the configs.
context : pd.DataFrame
Not Yet Implemented.
@@ -259,24 +296,38 @@ class SmacOptimizer(BaseBayesianOptimizer):
metadata: pd.DataFrame
Not Yet Implemented.
"""
- from smac.runhistory import StatusType, TrialInfo, TrialValue # pylint: disable=import-outside-toplevel
+ from smac.runhistory import ( # pylint: disable=import-outside-toplevel
+ StatusType,
+ TrialInfo,
+ TrialValue,
+ )
if context is not None:
warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning)
# Register each trial (one-by-one)
- for (config, (_i, score)) in zip(self._to_configspace_configs(configs=configs), scores.iterrows()):
- # Retrieve previously generated TrialInfo (returned by .ask()) or create new TrialInfo instance
+ for config, (_i, score) in zip(
+ self._to_configspace_configs(configs=configs), scores.iterrows()
+ ):
+ # Retrieve previously generated TrialInfo (returned by .ask()) or create
+ # new TrialInfo instance
info: TrialInfo = self.trial_info_map.get(
- config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed))
+ config,
+ TrialInfo(config=config, seed=self.base_optimizer.scenario.seed),
+ )
value = TrialValue(cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS)
self.base_optimizer.tell(info, value, save=False)
# Save optimizer once we register all configs
self.base_optimizer.optimizer.save()
- def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
- """Suggests a new configuration.
+ def _suggest(
+ self,
+ *,
+ context: Optional[pd.DataFrame] = None,
+ ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
+ """
+ Suggests a new configuration.
Parameters
----------
@@ -292,7 +343,8 @@ class SmacOptimizer(BaseBayesianOptimizer):
Not yet implemented.
"""
if TYPE_CHECKING:
- from smac.runhistory import TrialInfo # pylint: disable=import-outside-toplevel,unused-import
+ # pylint: disable=import-outside-toplevel,unused-import
+ from smac.runhistory import TrialInfo
if context is not None:
warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning)
@@ -302,16 +354,28 @@ class SmacOptimizer(BaseBayesianOptimizer):
self.optimizer_parameter_space.check_configuration(trial.config)
assert trial.config.config_space == self.optimizer_parameter_space
self.trial_info_map[trial.config] = trial
- config_df = pd.DataFrame([trial.config], columns=list(self.optimizer_parameter_space.keys()))
+ config_df = pd.DataFrame(
+ [trial.config], columns=list(self.optimizer_parameter_space.keys())
+ )
return config_df, None
- def register_pending(self, *, configs: pd.DataFrame,
- context: Optional[pd.DataFrame] = None,
- metadata: Optional[pd.DataFrame] = None) -> None:
+ def register_pending(
+ self,
+ *,
+ configs: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
raise NotImplementedError()
- def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray:
- from smac.utils.configspace import convert_configurations_to_array # pylint: disable=import-outside-toplevel
+ def surrogate_predict(
+ self,
+ *,
+ configs: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ ) -> npt.NDArray:
+ # pylint: disable=import-outside-toplevel
+ from smac.utils.configspace import convert_configurations_to_array
if context is not None:
warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning)
@@ -321,16 +385,27 @@ class SmacOptimizer(BaseBayesianOptimizer):
# pylint: disable=protected-access
if len(self._observations) <= self.base_optimizer._initial_design._n_configs:
raise RuntimeError(
- 'Surrogate model can make predictions *only* after all initial points have been evaluated ' +
- f'{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}')
+ "Surrogate model can make predictions *only* after "
+ "all initial points have been evaluated "
+ f"{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}"
+ )
if self.base_optimizer._config_selector._model is None:
- raise RuntimeError('Surrogate model is not yet trained')
+ raise RuntimeError("Surrogate model is not yet trained")
- config_array: npt.NDArray = convert_configurations_to_array(self._to_configspace_configs(configs=configs))
+ config_array: npt.NDArray = convert_configurations_to_array(
+ self._to_configspace_configs(configs=configs)
+ )
mean_predictions, _ = self.base_optimizer._config_selector._model.predict(config_array)
- return mean_predictions.reshape(-1,)
+ return mean_predictions.reshape(
+ -1,
+ )
- def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray:
+ def acquisition_function(
+ self,
+ *,
+ configs: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ ) -> npt.NDArray:
if context is not None:
warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning)
if self._space_adapter:
@@ -338,23 +413,27 @@ class SmacOptimizer(BaseBayesianOptimizer):
# pylint: disable=protected-access
if self.base_optimizer._config_selector._acquisition_function is None:
- raise RuntimeError('Acquisition function is not yet initialized')
+ raise RuntimeError("Acquisition function is not yet initialized")
cs_configs: list = self._to_configspace_configs(configs=configs)
- return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape(-1,)
+ return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape(
+ -1,
+ )
def cleanup(self) -> None:
- if hasattr(self, '_temp_output_directory') and self._temp_output_directory is not None:
+ if hasattr(self, "_temp_output_directory") and self._temp_output_directory is not None:
self._temp_output_directory.cleanup()
self._temp_output_directory = None
def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace.Configuration]:
- """Convert a dataframe of configs to a list of ConfigSpace configs.
+ """
+ Convert a dataframe of configs to a list of ConfigSpace configs.
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
Returns
-------
@@ -363,5 +442,5 @@ class SmacOptimizer(BaseBayesianOptimizer):
"""
return [
ConfigSpace.Configuration(self.optimizer_parameter_space, values=config.to_dict())
- for (_, config) in configs.astype('O').iterrows()
+ for (_, config) in configs.astype("O").iterrows()
]
diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py
index 4f478db2bf..50def8bc80 100644
--- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py
+++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Contains the FlamlOptimizer class.
-"""
+"""Contains the FlamlOptimizer class."""
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
from warnings import warn
@@ -13,9 +11,9 @@ import ConfigSpace
import numpy as np
import pandas as pd
-from mlos_core.util import normalize_config
from mlos_core.optimizers.optimizer import BaseOptimizer
from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter
+from mlos_core.util import normalize_config
class EvaluatedSample(NamedTuple):
@@ -26,20 +24,22 @@ class EvaluatedSample(NamedTuple):
class FlamlOptimizer(BaseOptimizer):
- """
- Wrapper class for FLAML Optimizer: A fast library for AutoML and tuning.
- """
+ """Wrapper class for FLAML Optimizer: A fast library for AutoML and tuning."""
- # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics.
+ # The name of an internal objective attribute that is calculated as a weighted
+ # average of the user provided objective metrics.
_METRIC_NAME = "FLAML_score"
- def __init__(self, *, # pylint: disable=too-many-arguments
- parameter_space: ConfigSpace.ConfigurationSpace,
- optimization_targets: List[str],
- objective_weights: Optional[List[float]] = None,
- space_adapter: Optional[BaseSpaceAdapter] = None,
- low_cost_partial_config: Optional[dict] = None,
- seed: Optional[int] = None):
+ def __init__(
+ self,
+ *, # pylint: disable=too-many-arguments
+ parameter_space: ConfigSpace.ConfigurationSpace,
+ optimization_targets: List[str],
+ objective_weights: Optional[List[float]] = None,
+ space_adapter: Optional[BaseSpaceAdapter] = None,
+ low_cost_partial_config: Optional[dict] = None,
+ seed: Optional[int] = None,
+ ):
"""
Create an MLOS wrapper for FLAML.
@@ -59,10 +59,12 @@ class FlamlOptimizer(BaseOptimizer):
low_cost_partial_config : dict
A dictionary from a subset of controlled dimensions to the initial low-cost values.
- More info: https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune
+ More info:
+ https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune
seed : Optional[int]
- If provided, calls np.random.seed() with the provided value to set the seed globally at init.
+ If provided, calls np.random.seed() with the provided value to set the
+ seed globally at init.
"""
super().__init__(
parameter_space=parameter_space,
@@ -77,22 +79,35 @@ class FlamlOptimizer(BaseOptimizer):
np.random.seed(seed)
# pylint: disable=import-outside-toplevel
- from mlos_core.spaces.converters.flaml import configspace_to_flaml_space, FlamlDomain
+ from mlos_core.spaces.converters.flaml import (
+ FlamlDomain,
+ configspace_to_flaml_space,
+ )
- self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space(self.optimizer_parameter_space)
+ self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space(
+ self.optimizer_parameter_space
+ )
self.low_cost_partial_config = low_cost_partial_config
self.evaluated_samples: Dict[ConfigSpace.Configuration, EvaluatedSample] = {}
self._suggested_config: Optional[dict]
- def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame,
- context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None:
- """Registers the given configs and scores.
+ def _register(
+ self,
+ *,
+ configs: pd.DataFrame,
+ scores: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
+ """
+ Registers the given configs and scores.
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
scores : pd.DataFrame
Scores from running the configs. The index is the same as the index of the configs.
@@ -108,9 +123,10 @@ class FlamlOptimizer(BaseOptimizer):
if metadata is not None:
warn(f"Not Implemented: Ignoring metadata {list(metadata.columns)}", UserWarning)
- for (_, config), (_, score) in zip(configs.astype('O').iterrows(), scores.iterrows()):
+ for (_, config), (_, score) in zip(configs.astype("O").iterrows(), scores.iterrows()):
cs_config: ConfigSpace.Configuration = ConfigSpace.Configuration(
- self.optimizer_parameter_space, values=config.to_dict())
+ self.optimizer_parameter_space, values=config.to_dict()
+ )
if cs_config in self.evaluated_samples:
warn(f"Configuration {config} was already registered", UserWarning)
self.evaluated_samples[cs_config] = EvaluatedSample(
@@ -118,8 +134,13 @@ class FlamlOptimizer(BaseOptimizer):
score=float(np.average(score.astype(float), weights=self._objective_weights)),
)
- def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
- """Suggests a new configuration.
+ def _suggest(
+ self,
+ *,
+ context: Optional[pd.DataFrame] = None,
+ ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
+ """
+ Suggests a new configuration.
Sampled at random using ConfigSpace.
@@ -141,26 +162,35 @@ class FlamlOptimizer(BaseOptimizer):
config: dict = self._get_next_config()
return pd.DataFrame(config, index=[0]), None
- def register_pending(self, *, configs: pd.DataFrame,
- context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None:
+ def register_pending(
+ self,
+ *,
+ configs: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
raise NotImplementedError()
def _target_function(self, config: dict) -> Union[dict, None]:
- """Configuration evaluation function called by FLAML optimizer.
+ """
+ Configuration evaluation function called by FLAML optimizer.
- FLAML may suggest the same configuration multiple times (due to its warm-start mechanism).
- Once FLAML suggests an unseen configuration, we store it, and stop the optimization process.
+ FLAML may suggest the same configuration multiple times (due to its
+ warm-start mechanism). Once FLAML suggests an unseen configuration, we
+ store it, and stop the optimization process.
Parameters
----------
config: dict
Next configuration to be evaluated, as suggested by FLAML.
- This config is stored internally and is returned to user, via `.suggest()` method.
+ This config is stored internally and is returned to user, via
+ `.suggest()` method.
Returns
-------
result: Union[dict, None]
- Dictionary with a single key, `FLAML_score`, if config already evaluated; `None` otherwise.
+ Dictionary with a single key, `FLAML_score`, if config already
+ evaluated; `None` otherwise.
"""
cs_config = normalize_config(self.optimizer_parameter_space, config)
if cs_config in self.evaluated_samples:
@@ -170,12 +200,17 @@ class FlamlOptimizer(BaseOptimizer):
return None # Returning None stops the process
def _get_next_config(self) -> dict:
- """Warm-starts a new instance of FLAML, and returns a recommended, unseen new configuration.
+ """
+ Warm-starts a new instance of FLAML, and returns a recommended, unseen new
+ configuration.
- Since FLAML does not provide an ask-and-tell interface, we need to create a new instance of FLAML
- each time we get asked for a new suggestion. This is suboptimal performance-wise, but works.
- To do so, we use any previously evaluated configs to bootstrap FLAML (i.e., warm-start).
- For more info: https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function#warm-start
+ Since FLAML does not provide an ask-and-tell interface, we need to create a
+ new instance of FLAML each time we get asked for a new suggestion. This is
+ suboptimal performance-wise, but works.
+ To do so, we use any previously evaluated configs to bootstrap FLAML (i.e.,
+ warm-start).
+ For more info:
+ https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function#warm-start
Returns
-------
@@ -197,16 +232,14 @@ class FlamlOptimizer(BaseOptimizer):
dict(normalize_config(self.optimizer_parameter_space, conf))
for conf in self.evaluated_samples
]
- evaluated_rewards = [
- s.score for s in self.evaluated_samples.values()
- ]
+ evaluated_rewards = [s.score for s in self.evaluated_samples.values()]
# Warm start FLAML optimizer
self._suggested_config = None
tune.run(
self._target_function,
config=self.flaml_parameter_space,
- mode='min',
+ mode="min",
metric=self._METRIC_NAME,
points_to_evaluate=points_to_evaluate,
evaluated_rewards=evaluated_rewards,
@@ -215,6 +248,6 @@ class FlamlOptimizer(BaseOptimizer):
verbose=0,
)
if self._suggested_config is None:
- raise RuntimeError('FLAML did not produce a suggestion')
+ raise RuntimeError("FLAML did not produce a suggestion")
return self._suggested_config # type: ignore[unreachable]
diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py
index 8fcf592a6c..4152e3c4c0 100644
--- a/mlos_core/mlos_core/optimizers/optimizer.py
+++ b/mlos_core/mlos_core/optimizers/optimizer.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Contains the BaseOptimizer abstract class.
-"""
+"""Contains the BaseOptimizer abstract class."""
import collections
from abc import ABCMeta, abstractmethod
@@ -15,20 +13,21 @@ import numpy as np
import numpy.typing as npt
import pandas as pd
-from mlos_core.util import config_to_dataframe
from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter
+from mlos_core.util import config_to_dataframe
class BaseOptimizer(metaclass=ABCMeta):
- """
- Optimizer abstract base class defining the basic interface.
- """
+ """Optimizer abstract base class defining the basic interface."""
- def __init__(self, *,
- parameter_space: ConfigSpace.ConfigurationSpace,
- optimization_targets: List[str],
- objective_weights: Optional[List[float]] = None,
- space_adapter: Optional[BaseSpaceAdapter] = None):
+ def __init__(
+ self,
+ *,
+ parameter_space: ConfigSpace.ConfigurationSpace,
+ optimization_targets: List[str],
+ objective_weights: Optional[List[float]] = None,
+ space_adapter: Optional[BaseSpaceAdapter] = None,
+ ):
"""
Create a new instance of the base optimizer.
@@ -44,8 +43,9 @@ class BaseOptimizer(metaclass=ABCMeta):
The space adapter class to employ for parameter space transformations.
"""
self.parameter_space: ConfigSpace.ConfigurationSpace = parameter_space
- self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = \
+ self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = (
parameter_space if space_adapter is None else space_adapter.target_parameter_space
+ )
if space_adapter is not None and space_adapter.orig_parameter_space != parameter_space:
raise ValueError("Given parameter space differs from the one given to space adapter")
@@ -68,14 +68,23 @@ class BaseOptimizer(metaclass=ABCMeta):
"""Get the space adapter instance (if any)."""
return self._space_adapter
- def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame,
- context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None:
- """Wrapper method, which employs the space adapter (if any), before registering the configs and scores.
+ def register(
+ self,
+ *,
+ configs: pd.DataFrame,
+ scores: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
+ """
+ Wrapper method, which employs the space adapter (if any), before registering the
+ configs and scores.
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
scores : pd.DataFrame
Scores from running the configs. The index is the same as the index of the configs.
@@ -87,47 +96,62 @@ class BaseOptimizer(metaclass=ABCMeta):
"""
# Do some input validation.
assert metadata is None or isinstance(metadata, pd.DataFrame)
- assert set(scores.columns) == set(self._optimization_targets), \
- "Mismatched optimization targets."
- assert self._has_context is None or self._has_context ^ (context is None), \
- "Context must always be added or never be added."
- assert len(configs) == len(scores), \
- "Mismatched number of configs and scores."
+ assert set(scores.columns) == set(
+ self._optimization_targets
+ ), "Mismatched optimization targets."
+ assert self._has_context is None or self._has_context ^ (
+ context is None
+ ), "Context must always be added or never be added."
+ assert len(configs) == len(scores), "Mismatched number of configs and scores."
if context is not None:
- assert len(configs) == len(context), \
- "Mismatched number of configs and context."
- assert configs.shape[1] == len(self.parameter_space.values()), \
- "Mismatched configuration shape."
+ assert len(configs) == len(context), "Mismatched number of configs and context."
+ assert configs.shape[1] == len(
+ self.parameter_space.values()
+ ), "Mismatched configuration shape."
self._observations.append((configs, scores, context))
self._has_context = context is not None
if self._space_adapter:
configs = self._space_adapter.inverse_transform(configs)
- assert configs.shape[1] == len(self.optimizer_parameter_space.values()), \
- "Mismatched configuration shape after inverse transform."
+ assert configs.shape[1] == len(
+ self.optimizer_parameter_space.values()
+ ), "Mismatched configuration shape after inverse transform."
return self._register(configs=configs, scores=scores, context=context)
@abstractmethod
- def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame,
- context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None:
- """Registers the given configs and scores.
+ def _register(
+ self,
+ *,
+ configs: pd.DataFrame,
+ scores: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
+ """
+ Registers the given configs and scores.
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
scores : pd.DataFrame
Scores from running the configs. The index is the same as the index of the configs.
context : pd.DataFrame
Not Yet Implemented.
"""
- pass # pylint: disable=unnecessary-pass # pragma: no cover
+ pass # pylint: disable=unnecessary-pass # pragma: no cover
- def suggest(self, *, context: Optional[pd.DataFrame] = None,
- defaults: bool = False) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
+ def suggest(
+ self,
+ *,
+ context: Optional[pd.DataFrame] = None,
+ defaults: bool = False,
+ ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
"""
- Wrapper method, which employs the space adapter (if any), after suggesting a new configuration.
+ Wrapper method, which employs the space adapter (if any), after suggesting a new
+ configuration.
Parameters
----------
@@ -149,19 +173,27 @@ class BaseOptimizer(metaclass=ABCMeta):
configuration = self.space_adapter.inverse_transform(configuration)
else:
configuration, metadata = self._suggest(context=context)
- assert len(configuration) == 1, \
- "Suggest must return a single configuration."
- assert set(configuration.columns).issubset(set(self.optimizer_parameter_space)), \
- "Optimizer suggested a configuration that does not match the expected parameter space."
+ assert len(configuration) == 1, "Suggest must return a single configuration."
+ assert set(configuration.columns).issubset(set(self.optimizer_parameter_space)), (
+ "Optimizer suggested a configuration that does "
+ "not match the expected parameter space."
+ )
if self._space_adapter:
configuration = self._space_adapter.transform(configuration)
- assert set(configuration.columns).issubset(set(self.parameter_space)), \
- "Space adapter produced a configuration that does not match the expected parameter space."
+ assert set(configuration.columns).issubset(set(self.parameter_space)), (
+ "Space adapter produced a configuration that does "
+ "not match the expected parameter space."
+ )
return configuration, metadata
@abstractmethod
- def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
- """Suggests a new configuration.
+ def _suggest(
+ self,
+ *,
+ context: Optional[pd.DataFrame] = None,
+ ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
+ """
+ Suggests a new configuration.
Parameters
----------
@@ -176,26 +208,32 @@ class BaseOptimizer(metaclass=ABCMeta):
metadata : Optional[pd.DataFrame]
The metadata associated with the given configuration used for evaluations.
"""
- pass # pylint: disable=unnecessary-pass # pragma: no cover
+ pass # pylint: disable=unnecessary-pass # pragma: no cover
@abstractmethod
- def register_pending(self, *, configs: pd.DataFrame,
- context: Optional[pd.DataFrame] = None,
- metadata: Optional[pd.DataFrame] = None) -> None:
- """Registers the given configs as "pending".
- That is it say, it has been suggested by the optimizer, and an experiment trial has been started.
- This can be useful for executing multiple trials in parallel, retry logic, etc.
+ def register_pending(
+ self,
+ *,
+ configs: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
+ """
+ Registers the given configs as "pending". That is it say, it has been suggested
+ by the optimizer, and an experiment trial has been started. This can be useful
+ for executing multiple trials in parallel, retry logic, etc.
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
context : pd.DataFrame
Not Yet Implemented.
metadata : Optional[pd.DataFrame]
Not Yet Implemented.
"""
- pass # pylint: disable=unnecessary-pass # pragma: no cover
+ pass # pylint: disable=unnecessary-pass # pragma: no cover
def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]:
"""
@@ -210,15 +248,24 @@ class BaseOptimizer(metaclass=ABCMeta):
raise ValueError("No observations registered yet.")
configs = pd.concat([config for config, _, _ in self._observations]).reset_index(drop=True)
scores = pd.concat([score for _, score, _ in self._observations]).reset_index(drop=True)
- contexts = pd.concat([pd.DataFrame() if context is None else context
- for _, _, context in self._observations]).reset_index(drop=True)
+ contexts = pd.concat(
+ [
+ pd.DataFrame() if context is None else context
+ for _, _, context in self._observations
+ ]
+ ).reset_index(drop=True)
return (configs, scores, contexts if len(contexts.columns) > 0 else None)
- def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]:
+ def get_best_observations(
+ self,
+ *,
+ n_max: int = 1,
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]:
"""
- Get the N best observations so far as a triplet of DataFrames (config, score, context).
- Default is N=1. The columns are ordered in ASCENDING order of the optimization targets.
- The function uses `pandas.DataFrame.nsmallest(..., keep="first")` method under the hood.
+ Get the N best observations so far as a triplet of DataFrames (config, score,
+ context). Default is N=1. The columns are ordered in ASCENDING order of the
+ optimization targets. The function uses `pandas.DataFrame.nsmallest(...,
+ keep="first")` method under the hood.
Parameters
----------
@@ -234,26 +281,26 @@ class BaseOptimizer(metaclass=ABCMeta):
raise ValueError("No observations registered yet.")
(configs, scores, contexts) = self.get_observations()
idx = scores.nsmallest(n_max, columns=self._optimization_targets, keep="first").index
- return (configs.loc[idx], scores.loc[idx],
- None if contexts is None else contexts.loc[idx])
+ return (configs.loc[idx], scores.loc[idx], None if contexts is None else contexts.loc[idx])
def cleanup(self) -> None:
"""
- Remove temp files, release resources, etc. after use. Default is no-op.
- Redefine this method in optimizers that require cleanup.
+ Remove temp files, release resources, etc.
+
+ after use. Default is no-op. Redefine this method in optimizers that require
+ cleanup.
"""
def _from_1hot(self, *, config: npt.NDArray) -> pd.DataFrame:
- """
- Convert numpy array from one-hot encoding to a DataFrame
- with categoricals and ints in proper columns.
+ """Convert numpy array from one-hot encoding to a DataFrame with categoricals
+ and ints in proper columns.
"""
df_dict = collections.defaultdict(list)
for i in range(config.shape[0]):
j = 0
for param in self.optimizer_parameter_space.values():
if isinstance(param, ConfigSpace.CategoricalHyperparameter):
- for (offset, val) in enumerate(param.choices):
+ for offset, val in enumerate(param.choices):
if config[i][j + offset] == 1:
df_dict[param.name].append(val)
break
@@ -267,9 +314,7 @@ class BaseOptimizer(metaclass=ABCMeta):
return pd.DataFrame(df_dict)
def _to_1hot(self, *, config: Union[pd.DataFrame, pd.Series]) -> npt.NDArray:
- """
- Convert pandas DataFrame to one-hot-encoded numpy array.
- """
+ """Convert pandas DataFrame to one-hot-encoded numpy array."""
n_cols = 0
n_rows = config.shape[0] if config.ndim > 1 else 1
for param in self.optimizer_parameter_space.values():
diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py
index 0af785ef20..661a48a373 100644
--- a/mlos_core/mlos_core/optimizers/random_optimizer.py
+++ b/mlos_core/mlos_core/optimizers/random_optimizer.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Contains the RandomOptimizer class.
-"""
+"""Contains the RandomOptimizer class."""
from typing import Optional, Tuple
from warnings import warn
@@ -15,8 +13,9 @@ from mlos_core.optimizers.optimizer import BaseOptimizer
class RandomOptimizer(BaseOptimizer):
- """Optimizer class that produces random suggestions.
- Useful for baseline comparison against Bayesian optimizers.
+ """
+ Optimizer class that produces random suggestions. Useful for baseline comparison
+ against Bayesian optimizers.
Parameters
----------
@@ -24,16 +23,24 @@ class RandomOptimizer(BaseOptimizer):
The parameter space to optimize.
"""
- def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame,
- context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None:
- """Registers the given configs and scores.
+ def _register(
+ self,
+ *,
+ configs: pd.DataFrame,
+ scores: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
+ """
+ Registers the given configs and scores.
Doesn't do anything on the RandomOptimizer except storing configs for logging.
Parameters
----------
configs : pd.DataFrame
- Dataframe of configs / parameters. The columns are parameter names and the rows are the configs.
+ Dataframe of configs / parameters. The columns are parameter names and
+ the rows are the configs.
scores : pd.DataFrame
Scores from running the configs. The index is the same as the index of the configs.
@@ -50,8 +57,13 @@ class RandomOptimizer(BaseOptimizer):
warn(f"Not Implemented: Ignoring context {list(metadata.columns)}", UserWarning)
# should we pop them from self.pending_observations?
- def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
- """Suggests a new configuration.
+ def _suggest(
+ self,
+ *,
+ context: Optional[pd.DataFrame] = None,
+ ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
+ """
+ Suggests a new configuration.
Sampled at random using ConfigSpace.
@@ -71,9 +83,17 @@ class RandomOptimizer(BaseOptimizer):
if context is not None:
# not sure how that works here?
warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning)
- return pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), None
+ return (
+ pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]),
+ None,
+ )
- def register_pending(self, *, configs: pd.DataFrame,
- context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None:
+ def register_pending(
+ self,
+ *,
+ configs: pd.DataFrame,
+ context: Optional[pd.DataFrame] = None,
+ metadata: Optional[pd.DataFrame] = None,
+ ) -> None:
raise NotImplementedError()
# self._pending_observations.append((configs, context))
diff --git a/mlos_core/mlos_core/spaces/__init__.py b/mlos_core/mlos_core/spaces/__init__.py
index d2a636ff1a..8de6887783 100644
--- a/mlos_core/mlos_core/spaces/__init__.py
+++ b/mlos_core/mlos_core/spaces/__init__.py
@@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Space adapters and converters init file.
-"""
+"""Space adapters and converters init file."""
diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py
index 2e2f585590..1645ac9cb4 100644
--- a/mlos_core/mlos_core/spaces/adapters/__init__.py
+++ b/mlos_core/mlos_core/spaces/adapters/__init__.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Basic initializer module for the mlos_core space adapters.
-"""
+"""Basic initializer module for the mlos_core space adapters."""
from enum import Enum
from typing import Optional, TypeVar
@@ -15,8 +13,8 @@ from mlos_core.spaces.adapters.identity_adapter import IdentityAdapter
from mlos_core.spaces.adapters.llamatune import LlamaTuneAdapter
__all__ = [
- 'IdentityAdapter',
- 'LlamaTuneAdapter',
+ "IdentityAdapter",
+ "LlamaTuneAdapter",
]
@@ -24,33 +22,38 @@ class SpaceAdapterType(Enum):
"""Enumerate supported MlosCore space adapters."""
IDENTITY = IdentityAdapter
- """A no-op adapter will be used"""
+ """A no-op adapter will be used."""
LLAMATUNE = LlamaTuneAdapter
- """An instance of LlamaTuneAdapter class will be used"""
+ """An instance of LlamaTuneAdapter class will be used."""
# To make mypy happy, we need to define a type variable for each optimizer type.
# https://github.com/python/mypy/issues/12952
-# ConcreteSpaceAdapter = TypeVar('ConcreteSpaceAdapter', *[member.value for member in SpaceAdapterType])
+# ConcreteSpaceAdapter = TypeVar(
+# "ConcreteSpaceAdapter",
+# *[member.value for member in SpaceAdapterType],
+# )
# To address this, we add a test for complete coverage of the enum.
ConcreteSpaceAdapter = TypeVar(
- 'ConcreteSpaceAdapter',
+ "ConcreteSpaceAdapter",
IdentityAdapter,
LlamaTuneAdapter,
)
class SpaceAdapterFactory:
- """Simple factory class for creating BaseSpaceAdapter-derived objects"""
+ """Simple factory class for creating BaseSpaceAdapter-derived objects."""
# pylint: disable=too-few-public-methods
@staticmethod
- def create(*,
- parameter_space: ConfigSpace.ConfigurationSpace,
- space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY,
- space_adapter_kwargs: Optional[dict] = None) -> ConcreteSpaceAdapter: # type: ignore[type-var]
+ def create(
+ *,
+ parameter_space: ConfigSpace.ConfigurationSpace,
+ space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY,
+ space_adapter_kwargs: Optional[dict] = None,
+ ) -> ConcreteSpaceAdapter: # type: ignore[type-var]
"""
Create a new space adapter instance, given the parameter space and potential
space adapter options.
@@ -76,7 +79,7 @@ class SpaceAdapterFactory:
space_adapter: ConcreteSpaceAdapter = space_adapter_type.value(
orig_parameter_space=parameter_space,
- **space_adapter_kwargs
+ **space_adapter_kwargs,
)
return space_adapter
diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py
index 6c3a86fc8a..2d48a14c31 100644
--- a/mlos_core/mlos_core/spaces/adapters/adapter.py
+++ b/mlos_core/mlos_core/spaces/adapters/adapter.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Contains the BaseSpaceAdapter abstract class.
-"""
+"""Contains the BaseSpaceAdapter abstract class."""
from abc import ABCMeta, abstractmethod
@@ -13,7 +11,8 @@ import pandas as pd
class BaseSpaceAdapter(metaclass=ABCMeta):
- """SpaceAdapter abstract class defining the basic interface.
+ """
+ SpaceAdapter abstract class defining the basic interface.
Parameters
----------
@@ -35,28 +34,27 @@ class BaseSpaceAdapter(metaclass=ABCMeta):
@property
def orig_parameter_space(self) -> ConfigSpace.ConfigurationSpace:
- """
- Original (user-provided) parameter space to explore.
- """
+ """Original (user-provided) parameter space to explore."""
return self._orig_parameter_space
@property
@abstractmethod
def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace:
- """
- Target parameter space that is fed to the underlying optimizer.
- """
- pass # pylint: disable=unnecessary-pass # pragma: no cover
+ """Target parameter space that is fed to the underlying optimizer."""
+ pass # pylint: disable=unnecessary-pass # pragma: no cover
@abstractmethod
def transform(self, configuration: pd.DataFrame) -> pd.DataFrame:
- """Translates a configuration, which belongs to the target parameter space, to the original parameter space.
- This method is called by the `suggest` method of the `BaseOptimizer` class.
+ """
+ Translates a configuration, which belongs to the target parameter space, to the
+ original parameter space. This method is called by the `suggest` method of the
+ `BaseOptimizer` class.
Parameters
----------
configuration : pd.DataFrame
- Pandas dataframe with a single row. Column names are the parameter names of the target parameter space.
+ Pandas dataframe with a single row. Column names are the parameter names
+ of the target parameter space.
Returns
-------
@@ -64,24 +62,28 @@ class BaseSpaceAdapter(metaclass=ABCMeta):
Pandas dataframe with a single row, containing the translated configuration.
Column names are the parameter names of the original parameter space.
"""
- pass # pylint: disable=unnecessary-pass # pragma: no cover
+ pass # pylint: disable=unnecessary-pass # pragma: no cover
@abstractmethod
def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame:
- """Translates a configuration, which belongs to the original parameter space, to the target parameter space.
- This method is called by the `register` method of the `BaseOptimizer` class, and performs the inverse operation
- of `BaseSpaceAdapter.transform` method.
+ """
+ Translates a configuration, which belongs to the original parameter space, to
+ the target parameter space. This method is called by the `register` method of
+ the `BaseOptimizer` class, and performs the inverse operation of
+ `BaseSpaceAdapter.transform` method.
Parameters
----------
configurations : pd.DataFrame
Dataframe of configurations / parameters, which belong to the original parameter space.
- The columns are the parameter names the original parameter space and the rows are the configurations.
+ The columns are the parameter names the original parameter space and the
+ rows are the configurations.
Returns
-------
configurations : pd.DataFrame
Dataframe of the translated configurations / parameters.
- The columns are the parameter names of the target parameter space and the rows are the configurations.
+ The columns are the parameter names of the target parameter space and
+ the rows are the configurations.
"""
- pass # pylint: disable=unnecessary-pass # pragma: no cover
+ pass # pylint: disable=unnecessary-pass # pragma: no cover
diff --git a/mlos_core/mlos_core/spaces/adapters/identity_adapter.py b/mlos_core/mlos_core/spaces/adapters/identity_adapter.py
index ad79fa21c9..1e552110a2 100644
--- a/mlos_core/mlos_core/spaces/adapters/identity_adapter.py
+++ b/mlos_core/mlos_core/spaces/adapters/identity_adapter.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Contains the Identity (no-op) Space Adapter class.
-"""
+"""Contains the Identity (no-op) Space Adapter class."""
import ConfigSpace
import pandas as pd
@@ -13,7 +11,8 @@ from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter
class IdentityAdapter(BaseSpaceAdapter):
- """Identity (no-op) SpaceAdapter class.
+ """
+ Identity (no-op) SpaceAdapter class.
Parameters
----------
diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py
index 554b1169f5..38d973a27f 100644
--- a/mlos_core/mlos_core/spaces/adapters/llamatune.py
+++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py
@@ -2,44 +2,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Implementation of LlamaTune space adapter.
-"""
+"""Implementation of LlamaTune space adapter."""
from typing import Dict, Optional
from warnings import warn
import ConfigSpace
-from ConfigSpace.hyperparameters import NumericalHyperparameter
import numpy as np
import numpy.typing as npt
import pandas as pd
+from ConfigSpace.hyperparameters import NumericalHyperparameter
from sklearn.preprocessing import MinMaxScaler
-from mlos_core.util import normalize_config
from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter
+from mlos_core.util import normalize_config
-class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes
- """
- Implementation of LlamaTune, a set of parameter space transformation techniques,
+class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes
+ """Implementation of LlamaTune, a set of parameter space transformation techniques,
aimed at improving the sample-efficiency of the underlying optimizer.
"""
DEFAULT_NUM_LOW_DIMS = 16
- """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection"""
+ """Default number of dimensions in the low-dimensional search space, generated by
+ HeSBO projection.
+ """
- DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = .2
- """Default percentage of bias for each special parameter value"""
+ DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = 0.2
+ """Default percentage of bias for each special parameter value."""
DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000
- """Default number of (max) unique values of each parameter, when space discretization is used"""
+ """Default number of (max) unique values of each parameter, when space
+ discretization is used.
+ """
- def __init__(self, *,
- orig_parameter_space: ConfigSpace.ConfigurationSpace,
- num_low_dims: int = DEFAULT_NUM_LOW_DIMS,
- special_param_values: Optional[dict] = None,
- max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM,
- use_approximate_reverse_mapping: bool = False):
+ def __init__(
+ self,
+ *,
+ orig_parameter_space: ConfigSpace.ConfigurationSpace,
+ num_low_dims: int = DEFAULT_NUM_LOW_DIMS,
+ special_param_values: Optional[dict] = None,
+ max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM,
+ use_approximate_reverse_mapping: bool = False,
+ ):
"""
Create a space adapter that employs LlamaTune's techniques.
@@ -58,7 +62,10 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
super().__init__(orig_parameter_space=orig_parameter_space)
if num_low_dims >= len(orig_parameter_space):
- raise ValueError("Number of target config space dimensions should be less than those of original config space.")
+ raise ValueError(
+ "Number of target config space dimensions should be "
+ "less than those of original config space."
+ )
# Validate input special param values dict
special_param_values = special_param_values or {}
@@ -90,26 +97,36 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame:
target_configurations = []
- for (_, config) in configurations.astype('O').iterrows():
+ for _, config in configurations.astype("O").iterrows():
configuration = ConfigSpace.Configuration(
- self.orig_parameter_space, values=config.to_dict())
+ self.orig_parameter_space,
+ values=config.to_dict(),
+ )
target_config = self._suggested_configs.get(configuration, None)
- # NOTE: HeSBO is a non-linear projection method, and does not inherently support inverse projection
- # To (partly) support this operation, we keep track of the suggested low-dim point(s) along with the
- # respective high-dim point; this way we can retrieve the low-dim point, from its high-dim counterpart.
+ # NOTE: HeSBO is a non-linear projection method, and does not inherently
+ # support inverse projection
+ # To (partly) support this operation, we keep track of the suggested
+ # low-dim point(s) along with the respective high-dim point; this way we
+ # can retrieve the low-dim point, from its high-dim counterpart.
if target_config is None:
- # Inherently it is not supported to register points, which were not suggested by the optimizer.
+ # Inherently it is not supported to register points, which were not
+ # suggested by the optimizer.
if configuration == self.orig_parameter_space.get_default_configuration():
# Default configuration should always be registerable.
pass
elif not self._use_approximate_reverse_mapping:
- raise ValueError(f"{repr(configuration)}\n" "The above configuration was not suggested by the optimizer. "
- "Approximate reverse mapping is currently disabled; thus *only* configurations suggested "
- "previously by the optimizer can be registered.")
+ raise ValueError(
+ f"{repr(configuration)}\n"
+ "The above configuration was not suggested by the optimizer. "
+ "Approximate reverse mapping is currently disabled; "
+ "thus *only* configurations suggested "
+ "previously by the optimizer can be registered."
+ )
- # ...yet, we try to support that by implementing an approximate reverse mapping using pseudo-inverse matrix.
- if getattr(self, '_pinv_matrix', None) is None:
+ # ...yet, we try to support that by implementing an approximate
+ # reverse mapping using pseudo-inverse matrix.
+ if getattr(self, "_pinv_matrix", None) is None:
self._try_generate_approx_inverse_mapping()
# Replace NaNs with zeros for inactive hyperparameters
@@ -118,19 +135,29 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
# NOTE: applying special value biasing is not possible
vector = self._config_scaler.inverse_transform([config_vector])[0]
target_config_vector = self._pinv_matrix.dot(vector)
- target_config = ConfigSpace.Configuration(self.target_parameter_space, vector=target_config_vector)
+ target_config = ConfigSpace.Configuration(
+ self.target_parameter_space,
+ vector=target_config_vector,
+ )
target_configurations.append(target_config)
- return pd.DataFrame(target_configurations, columns=list(self.target_parameter_space.keys()))
+ return pd.DataFrame(
+ target_configurations, columns=list(self.target_parameter_space.keys())
+ )
def transform(self, configuration: pd.DataFrame) -> pd.DataFrame:
if len(configuration) != 1:
- raise ValueError("Configuration dataframe must contain exactly 1 row. "
- f"Found {len(configuration)} rows.")
+ raise ValueError(
+ "Configuration dataframe must contain exactly 1 row. "
+ f"Found {len(configuration)} rows."
+ )
target_values_dict = configuration.iloc[0].to_dict()
- target_configuration = ConfigSpace.Configuration(self.target_parameter_space, values=target_values_dict)
+ target_configuration = ConfigSpace.Configuration(
+ self.target_parameter_space,
+ values=target_values_dict,
+ )
orig_values_dict = self._transform(target_values_dict)
orig_configuration = normalize_config(self.orig_parameter_space, orig_values_dict)
@@ -138,10 +165,17 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
# Add to inverse dictionary -- needed for registering the performance later
self._suggested_configs[orig_configuration] = target_configuration
- return pd.DataFrame([list(orig_configuration.values())], columns=list(orig_configuration.keys()))
+ return pd.DataFrame(
+ [list(orig_configuration.values())], columns=list(orig_configuration.keys())
+ )
- def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_param: Optional[int]) -> None:
- """Constructs the low-dimensional parameter (potentially discretized) search space.
+ def _construct_low_dim_space(
+ self,
+ num_low_dims: int,
+ max_unique_values_per_param: Optional[int],
+ ) -> None:
+ """
+ Constructs the low-dimensional parameter (potentially discretized) search space.
Parameters
----------
@@ -156,19 +190,27 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
q_scaler = None
if max_unique_values_per_param is None:
hyperparameters = [
- ConfigSpace.UniformFloatHyperparameter(name=f'dim_{idx}', lower=-1, upper=1)
+ ConfigSpace.UniformFloatHyperparameter(name=f"dim_{idx}", lower=-1, upper=1)
for idx in range(num_low_dims)
]
else:
- # Currently supported optimizers do not support defining a discretized space (like ConfigSpace does using `q` kwarg).
- # Thus, to support space discretization, we define the low-dimensional space using integer hyperparameters.
- # We also employ a scaler, which scales suggested values to [-1, 1] range, used by HeSBO projection.
+ # Currently supported optimizers do not support defining a discretized
+ # space (like ConfigSpace does using `q` kwarg).
+ # Thus, to support space discretization, we define the low-dimensional
+ # space using integer hyperparameters.
+ # We also employ a scaler, which scales suggested values to [-1, 1]
+ # range, used by HeSBO projection.
hyperparameters = [
- ConfigSpace.UniformIntegerHyperparameter(name=f'dim_{idx}', lower=1, upper=max_unique_values_per_param)
+ ConfigSpace.UniformIntegerHyperparameter(
+ name=f"dim_{idx}",
+ lower=1,
+ upper=max_unique_values_per_param,
+ )
for idx in range(num_low_dims)
]
- # Initialize quantized values scaler: from [0, max_unique_values_per_param] to (-1, 1) range
+ # Initialize quantized values scaler:
+ # from [0, max_unique_values_per_param] to (-1, 1) range
q_scaler = MinMaxScaler(feature_range=(-1, 1))
ones_vector = np.ones(num_low_dims)
max_value_vector = ones_vector * max_unique_values_per_param
@@ -178,13 +220,16 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
# Construct low-dimensional parameter search space
config_space = ConfigSpace.ConfigurationSpace(name=self.orig_parameter_space.name)
- config_space.random = self._random_state # use same random state as in original parameter space
+ # use same random state as in original parameter space
+ config_space.random = self._random_state
config_space.add_hyperparameters(hyperparameters)
self._target_config_space = config_space
def _transform(self, configuration: dict) -> dict:
- """Projects a low-dimensional point (configuration) to the high-dimensional original parameter space,
- and then biases the resulting parameter values towards their special value(s) (if any).
+ """
+ Projects a low-dimensional point (configuration) to the high-dimensional
+ original parameter space, and then biases the resulting parameter values towards
+ their special value(s) (if any).
Parameters
----------
@@ -216,10 +261,10 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
# Clip value to force it to fall in [0, 1]
# NOTE: HeSBO projection ensures that theoretically but due to
# floating point ops nuances this is not always guaranteed
- value = max(0., min(1., norm_value)) # pylint: disable=redefined-loop-name
+ value = max(0.0, min(1.0, norm_value)) # pylint: disable=redefined-loop-name
if isinstance(param, ConfigSpace.CategoricalHyperparameter):
- index = int(value * len(param.choices)) # truncate integer part
+ index = int(value * len(param.choices)) # truncate integer part
index = max(0, min(len(param.choices) - 1, index))
# NOTE: potential rounding here would be unfair to first & last values
orig_value = param.choices[index]
@@ -227,17 +272,25 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
if param.name in self._special_param_values_dict:
value = self._special_param_value_scaler(param, value)
- orig_value = param._transform(value) # pylint: disable=protected-access
+ orig_value = param._transform(value) # pylint: disable=protected-access
orig_value = max(param.lower, min(param.upper, orig_value))
else:
- raise NotImplementedError("Only Categorical, Integer, and Float hyperparameters are currently supported.")
+ raise NotImplementedError(
+ "Only Categorical, Integer, and Float hyperparameters are currently supported."
+ )
original_config[param.name] = orig_value
return original_config
- def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float) -> float:
- """Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those.
+ def _special_param_value_scaler(
+ self,
+ param: ConfigSpace.UniformIntegerHyperparameter,
+ input_value: float,
+ ) -> float:
+ """
+ Biases the special value(s) of this parameter, by shifting the normalized
+ `input_value` towards those.
Parameters
----------
@@ -255,7 +308,7 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
special_values_list = self._special_param_values_dict[param.name]
# Check if input value corresponds to some special value
- perc_sum = 0.
+ perc_sum = 0.0
ret: float
for special_value, biasing_perc in special_values_list:
perc_sum += biasing_perc
@@ -264,14 +317,17 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
return ret
# Scale input value uniformly to non-special values
- ret = param._inverse_transform( # pylint: disable=protected-access
- param._transform_scalar((input_value - perc_sum) / (1 - perc_sum))) # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ ret = param._inverse_transform(
+ param._transform_scalar((input_value - perc_sum) / (1 - perc_sum))
+ )
return ret
# pylint: disable=too-complex,too-many-branches
def _validate_special_param_values(self, special_param_values_dict: dict) -> None:
- """Checks that the user-provided dict of special parameter values is valid.
- And assigns it to the corresponding attribute.
+ """
+ Checks that the user-provided dict of special parameter values is valid. And
+ assigns it to the corresponding attribute.
Parameters
----------
@@ -294,8 +350,10 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
hyperparameter = self.orig_parameter_space[param]
if not isinstance(hyperparameter, ConfigSpace.UniformIntegerHyperparameter):
- raise NotImplementedError(error_prefix + f"Parameter '{param}' is not supported. "
- "Only Integer Hyperparameters are currently supported.")
+ raise NotImplementedError(
+ error_prefix + f"Parameter '{param}' is not supported. "
+ "Only Integer Hyperparameters are currently supported."
+ )
if isinstance(value, int):
# User specifies a single special value -- default biasing percentage is used
@@ -306,55 +364,98 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
elif isinstance(value, list) and value:
if all(isinstance(t, int) for t in value):
# User specifies list of special values
- tuple_list = [(v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value]
- elif all(isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value):
- # User specifies list of tuples; each tuple defines the special value and the biasing percentage
+ tuple_list = [
+ (v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value
+ ]
+ elif all(
+ isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value
+ ):
+ # User specifies list of tuples; each tuple defines the special
+ # value and the biasing percentage
tuple_list = value
else:
- raise ValueError(error_prefix + f"Invalid format in value list for parameter '{param}'. "
- f"Special value list should contain either integers, or (special value, biasing %) tuples.")
+ raise ValueError(
+ error_prefix + f"Invalid format in value list for parameter '{param}'. "
+ f"Special value list should contain either integers, "
+ "or (special value, biasing %) tuples."
+ )
else:
- raise ValueError(error_prefix + f"Invalid format for parameter '{param}'. Dict value should be "
- "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples.")
+ raise ValueError(
+ error_prefix + f"Invalid format for parameter '{param}'. Dict value should be "
+ "an int, a (int, float) tuple, a list of integers, "
+ "or a list of (int, float) tuples."
+ )
# Are user-specified special values valid?
if not all(hyperparameter.lower <= v <= hyperparameter.upper for v, _ in tuple_list):
- raise ValueError(error_prefix + f"One (or more) special values are outside of parameter '{param}' value domain.")
+ raise ValueError(
+ error_prefix
+ + "One (or more) special values are outside of parameter "
+ + f"'{param}' value domain."
+ )
# Are user-provided special values unique?
if len(set(v for v, _ in tuple_list)) != len(tuple_list):
- raise ValueError(error_prefix + f"One (or more) special values are defined more than once for parameter '{param}'.")
+ raise ValueError(
+ error_prefix
+ + "One (or more) special values are defined more than once "
+ + f"for parameter '{param}'."
+ )
# Are biasing percentages valid?
if not all(0 < perc < 1 for _, perc in tuple_list):
- raise ValueError(error_prefix + f"One (or more) biasing percentages for parameter '{param}' are invalid: "
- "i.e., fall outside (0, 1) range.")
+ raise ValueError(
+ error_prefix
+ + f"One (or more) biasing percentages for parameter '{param}' are invalid: "
+ "i.e., fall outside (0, 1) range."
+ )
total_percentage = sum(perc for _, perc in tuple_list)
- if total_percentage >= 1.:
- raise ValueError(error_prefix + f"Total special values percentage for parameter '{param}' surpass 100%.")
+ if total_percentage >= 1.0:
+ raise ValueError(
+ error_prefix
+ + f"Total special values percentage for parameter '{param}' surpass 100%."
+ )
# ... and reasonable?
if total_percentage >= 0.5:
- warn(f"Total special values percentage for parameter '{param}' exceeds 50%.", UserWarning)
+ warn(
+ f"Total special values percentage for parameter '{param}' exceeds 50%.",
+ UserWarning,
+ )
sanitized_dict[param] = tuple_list
self._special_param_values_dict = sanitized_dict
def _try_generate_approx_inverse_mapping(self) -> None:
- """Tries to generate an approximate reverse mapping: i.e., from high-dimensional space to the low-dimensional one.
- Reverse mapping is generated using the pseudo-inverse matrix, of original HeSBO projection matrix.
- This mapping can be potentially used to register configurations that were *not* previously suggested by the optimizer.
+ """Tries to generate an approximate reverse mapping:
+ i.e., from high-dimensional space to the low-dimensional one.
- NOTE: This method is experimental, and there is currently no guarantee that it works as expected.
+ Reverse mapping is generated using the pseudo-inverse matrix, of original
+ HeSBO projection matrix.
+ This mapping can be potentially used to register configurations that were
+ *not* previously suggested by the optimizer.
+
+ NOTE: This method is experimental, and there is currently no guarantee that
+ it works as expected.
Raises
------
RuntimeError: if reverse mapping computation fails.
"""
- from scipy.linalg import pinv, LinAlgError # pylint: disable=import-outside-toplevel
+ from scipy.linalg import ( # pylint: disable=import-outside-toplevel
+ LinAlgError,
+ pinv,
+ )
- warn("Trying to register a configuration that was not previously suggested by the optimizer. " +
- "This inverse configuration transformation is typically not supported. " +
- "However, we will try to register this configuration using an *experimental* method.", UserWarning)
+ warn(
+ (
+ "Trying to register a configuration that was not "
+ "previously suggested by the optimizer.\n"
+ "This inverse configuration transformation is typically not supported.\n"
+ "However, we will try to register this configuration "
+ "using an *experimental* method."
+ ),
+ UserWarning,
+ )
orig_space_num_dims = len(list(self.orig_parameter_space.values()))
target_space_num_dims = len(list(self.target_parameter_space.values()))
@@ -368,5 +469,7 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-
try:
self._pinv_matrix = pinv(proj_matrix)
except LinAlgError as err:
- raise RuntimeError(f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}") from err
+ raise RuntimeError(
+ f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}"
+ ) from err
assert self._pinv_matrix.shape == (target_space_num_dims, orig_space_num_dims)
diff --git a/mlos_core/mlos_core/spaces/converters/__init__.py b/mlos_core/mlos_core/spaces/converters/__init__.py
index 8385a4938d..2360bda24f 100644
--- a/mlos_core/mlos_core/spaces/converters/__init__.py
+++ b/mlos_core/mlos_core/spaces/converters/__init__.py
@@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Space converters init file.
-"""
+"""Space converters init file."""
diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py
index 3935dbef6c..71370853e4 100644
--- a/mlos_core/mlos_core/spaces/converters/flaml.py
+++ b/mlos_core/mlos_core/spaces/converters/flaml.py
@@ -2,19 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Contains space converters for FLAML.
-"""
-
-from typing import Dict, TYPE_CHECKING
+"""Contains space converters for FLAML."""
import sys
+from typing import TYPE_CHECKING, Dict
import ConfigSpace
-import numpy as np
-
import flaml.tune
import flaml.tune.sample
+import numpy as np
if TYPE_CHECKING:
from ConfigSpace.hyperparameters import Hyperparameter
@@ -29,8 +25,11 @@ FlamlDomain: TypeAlias = flaml.tune.sample.Domain
FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain]
-def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> Dict[str, FlamlDomain]:
- """Converts a ConfigSpace.ConfigurationSpace to dict.
+def configspace_to_flaml_space(
+ config_space: ConfigSpace.ConfigurationSpace,
+) -> Dict[str, FlamlDomain]:
+ """
+ Converts a ConfigSpace.ConfigurationSpace to dict.
Parameters
----------
@@ -52,13 +51,21 @@ def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) ->
def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain:
if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter):
# FIXME: upper isn't included in the range
- return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper)
+ return flaml_numeric_type[(type(parameter), parameter.log)](
+ parameter.lower,
+ parameter.upper,
+ )
elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter):
- return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper + 1)
+ return flaml_numeric_type[(type(parameter), parameter.log)](
+ parameter.lower,
+ parameter.upper + 1,
+ )
elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter):
if len(np.unique(parameter.probabilities)) > 1:
- raise ValueError("FLAML doesn't support categorical parameters with non-uniform probabilities.")
- return flaml.tune.choice(parameter.choices) # TODO: set order?
+ raise ValueError(
+ "FLAML doesn't support categorical parameters with non-uniform probabilities."
+ )
+ return flaml.tune.choice(parameter.choices) # TODO: set order?
raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.")
return {param.name: _one_parameter_convert(param) for param in config_space.values()}
diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py
index 6f74147ae9..cff9016da7 100644
--- a/mlos_core/mlos_core/tests/__init__.py
+++ b/mlos_core/mlos_core/tests/__init__.py
@@ -2,12 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Common functions for mlos_core Optimizer tests.
-"""
+"""Common functions for mlos_core Optimizer tests."""
import sys
-
from importlib import import_module
from pkgutil import walk_packages
from typing import List, Optional, Set, Type, TypeVar
@@ -22,16 +19,19 @@ else:
from typing_extensions import TypeAlias
-T = TypeVar('T')
+T = TypeVar("T")
def get_all_submodules(pkg: TypeAlias) -> List[str]:
"""
Imports all submodules for a package and returns their names.
+
Useful for dynamically enumerating subclasses.
"""
submodules = []
- for _, submodule_name, _ in walk_packages(pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None):
+ for _, submodule_name, _ in walk_packages(
+ pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None
+ ):
submodules.append(submodule_name)
return submodules
@@ -39,16 +39,18 @@ def get_all_submodules(pkg: TypeAlias) -> List[str]:
def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]:
"""
Gets the set of all of the subclasses of the given class.
+
Useful for dynamically enumerating expected test cases.
"""
return set(cls.__subclasses__()).union(
- s for c in cls.__subclasses__() for s in _get_all_subclasses(c))
+ s for c in cls.__subclasses__() for s in _get_all_subclasses(c)
+ )
def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]:
"""
- Gets a sorted list of all of the concrete subclasses of the given class.
- Useful for dynamically enumerating expected test cases.
+ Gets a sorted list of all of the concrete subclasses of the given class. Useful for
+ dynamically enumerating expected test cases.
Note: For abstract types, mypy will complain at the call site.
Use "# type: ignore[type-abstract]" to suppress the warning.
@@ -58,5 +60,11 @@ def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) ->
pkg = import_module(pkg_name)
submodules = get_all_submodules(pkg)
assert submodules
- return sorted([subclass for subclass in _get_all_subclasses(cls) if not getattr(subclass, "__abstractmethods__", None)],
- key=lambda c: (c.__module__, c.__name__))
+ return sorted(
+ [
+ subclass
+ for subclass in _get_all_subclasses(cls)
+ if not getattr(subclass, "__abstractmethods__", None)
+ ],
+ key=lambda c: (c.__module__, c.__name__),
+ )
diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py
index c1aaa710ac..65f0d9ab92 100644
--- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py
+++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py
@@ -2,40 +2,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for Bayesian Optimizers.
-"""
+"""Tests for Bayesian Optimizers."""
from typing import Optional, Type
-import pytest
-
-import pandas as pd
import ConfigSpace as CS
+import pandas as pd
+import pytest
from mlos_core.optimizers import BaseOptimizer, OptimizerType
from mlos_core.optimizers.bayesian_optimizers import BaseBayesianOptimizer
@pytest.mark.filterwarnings("error:Not Implemented")
-@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [
- *[(member.value, {}) for member in OptimizerType],
-])
-def test_context_not_implemented_warning(configuration_space: CS.ConfigurationSpace,
- optimizer_class: Type[BaseOptimizer],
- kwargs: Optional[dict]) -> None:
- """
- Make sure we raise warnings for the functionality that has not been implemented yet.
+@pytest.mark.parametrize(
+ ("optimizer_class", "kwargs"),
+ [
+ *[(member.value, {}) for member in OptimizerType],
+ ],
+)
+def test_context_not_implemented_warning(
+ configuration_space: CS.ConfigurationSpace,
+ optimizer_class: Type[BaseOptimizer],
+ kwargs: Optional[dict],
+) -> None:
+ """Make sure we raise warnings for the functionality that has not been implemented
+ yet.
"""
if kwargs is None:
kwargs = {}
optimizer = optimizer_class(
parameter_space=configuration_space,
- optimization_targets=['score'],
- **kwargs
+ optimization_targets=["score"],
+ **kwargs,
)
suggestion, _metadata = optimizer.suggest()
- scores = pd.DataFrame({'score': [1]})
+ scores = pd.DataFrame({"score": [1]})
context = pd.DataFrame([["something"]])
with pytest.raises(UserWarning):
diff --git a/mlos_core/mlos_core/tests/optimizers/conftest.py b/mlos_core/mlos_core/tests/optimizers/conftest.py
index be1b658387..fe82ff92bb 100644
--- a/mlos_core/mlos_core/tests/optimizers/conftest.py
+++ b/mlos_core/mlos_core/tests/optimizers/conftest.py
@@ -2,26 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Test fixtures for mlos_bench optimizers.
-"""
-
-import pytest
+"""Test fixtures for mlos_bench optimizers."""
import ConfigSpace as CS
+import pytest
@pytest.fixture
def configuration_space() -> CS.ConfigurationSpace:
- """
- Test fixture to produce a config space with all types of hyperparameters.
- """
+ """Test fixture to produce a config space with all types of hyperparameters."""
# Start defining a ConfigurationSpace for the Optimizer to search.
space = CS.ConfigurationSpace(seed=1234)
# Add a continuous input dimension between 0 and 1.
- space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1))
+ space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1))
# Add a categorical hyperparameter with 3 possible values.
- space.add_hyperparameter(CS.CategoricalHyperparameter(name='y', choices=["a", "b", "c"]))
+ space.add_hyperparameter(CS.CategoricalHyperparameter(name="y", choices=["a", "b", "c"]))
# Add a discrete input dimension between 0 and 10.
- space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='z', lower=0, upper=10))
+ space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="z", lower=0, upper=10))
return space
diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py
index 8e10afa302..c910f60fc5 100644
--- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py
+++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py
@@ -2,16 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for one-hot encoding for certain optimizers.
-"""
+"""Tests for one-hot encoding for certain optimizers."""
-import pytest
-
-import pandas as pd
+import ConfigSpace as CS
import numpy as np
import numpy.typing as npt
-import ConfigSpace as CS
+import pandas as pd
+import pytest
from mlos_core.optimizers import BaseOptimizer, SmacOptimizer
@@ -22,104 +19,117 @@ from mlos_core.optimizers import BaseOptimizer, SmacOptimizer
def data_frame() -> pd.DataFrame:
"""
Toy data frame corresponding to the `configuration_space` hyperparameters.
+
The columns are deliberately *not* in alphabetic order.
"""
- return pd.DataFrame({
- 'y': ['a', 'b', 'c'],
- 'x': [0.1, 0.2, 0.3],
- 'z': [1, 5, 8],
- })
+ return pd.DataFrame(
+ {
+ "y": ["a", "b", "c"],
+ "x": [0.1, 0.2, 0.3],
+ "z": [1, 5, 8],
+ }
+ )
@pytest.fixture
def one_hot_data_frame() -> npt.NDArray:
"""
One-hot encoding of the `data_frame` above.
+
The columns follow the order of the hyperparameters in `configuration_space`.
"""
- return np.array([
- [0.1, 1.0, 0.0, 0.0, 1.0],
- [0.2, 0.0, 1.0, 0.0, 5.0],
- [0.3, 0.0, 0.0, 1.0, 8.0],
- ])
+ return np.array(
+ [
+ [0.1, 1.0, 0.0, 0.0, 1.0],
+ [0.2, 0.0, 1.0, 0.0, 5.0],
+ [0.3, 0.0, 0.0, 1.0, 8.0],
+ ]
+ )
@pytest.fixture
def series() -> pd.Series:
"""
Toy series corresponding to the `configuration_space` hyperparameters.
+
The columns are deliberately *not* in alphabetic order.
"""
- return pd.Series({
- 'y': 'b',
- 'x': 0.4,
- 'z': 3,
- })
+ return pd.Series(
+ {
+ "y": "b",
+ "x": 0.4,
+ "z": 3,
+ }
+ )
@pytest.fixture
def one_hot_series() -> npt.NDArray:
"""
One-hot encoding of the `series` above.
+
The columns follow the order of the hyperparameters in `configuration_space`.
"""
- return np.array([
- [0.4, 0.0, 1.0, 0.0, 3],
- ])
+ return np.array(
+ [
+ [0.4, 0.0, 1.0, 0.0, 3],
+ ]
+ )
@pytest.fixture
def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer:
"""
- Test fixture for the optimizer. Use it to test one-hot encoding/decoding.
+ Test fixture for the optimizer.
+
+ Use it to test one-hot encoding/decoding.
"""
return SmacOptimizer(
parameter_space=configuration_space,
- optimization_targets=['score'],
+ optimization_targets=["score"],
)
-def test_to_1hot_data_frame(optimizer: BaseOptimizer,
- data_frame: pd.DataFrame,
- one_hot_data_frame: npt.NDArray) -> None:
- """
- Toy problem to test one-hot encoding of dataframe.
- """
+def test_to_1hot_data_frame(
+ optimizer: BaseOptimizer,
+ data_frame: pd.DataFrame,
+ one_hot_data_frame: npt.NDArray,
+) -> None:
+ """Toy problem to test one-hot encoding of dataframe."""
assert optimizer._to_1hot(config=data_frame) == pytest.approx(one_hot_data_frame)
-def test_to_1hot_series(optimizer: BaseOptimizer,
- series: pd.Series, one_hot_series: npt.NDArray) -> None:
- """
- Toy problem to test one-hot encoding of series.
- """
+def test_to_1hot_series(
+ optimizer: BaseOptimizer,
+ series: pd.Series,
+ one_hot_series: npt.NDArray,
+) -> None:
+ """Toy problem to test one-hot encoding of series."""
assert optimizer._to_1hot(config=series) == pytest.approx(one_hot_series)
-def test_from_1hot_data_frame(optimizer: BaseOptimizer,
- data_frame: pd.DataFrame,
- one_hot_data_frame: npt.NDArray) -> None:
- """
- Toy problem to test one-hot decoding of dataframe.
- """
+def test_from_1hot_data_frame(
+ optimizer: BaseOptimizer,
+ data_frame: pd.DataFrame,
+ one_hot_data_frame: npt.NDArray,
+) -> None:
+ """Toy problem to test one-hot decoding of dataframe."""
assert optimizer._from_1hot(config=one_hot_data_frame).to_dict() == data_frame.to_dict()
-def test_from_1hot_series(optimizer: BaseOptimizer,
- series: pd.Series,
- one_hot_series: npt.NDArray) -> None:
- """
- Toy problem to test one-hot decoding of series.
- """
+def test_from_1hot_series(
+ optimizer: BaseOptimizer,
+ series: pd.Series,
+ one_hot_series: npt.NDArray,
+) -> None:
+ """Toy problem to test one-hot decoding of series."""
one_hot_df = optimizer._from_1hot(config=one_hot_series)
assert one_hot_df.shape[0] == 1, f"Unexpected number of rows ({one_hot_df.shape[0]} != 1)"
assert one_hot_df.iloc[0].to_dict() == series.to_dict()
def test_round_trip_data_frame(optimizer: BaseOptimizer, data_frame: pd.DataFrame) -> None:
- """
- Round-trip test for one-hot-encoding and then decoding a data frame.
- """
+ """Round-trip test for one-hot-encoding and then decoding a data frame."""
df_round_trip = optimizer._from_1hot(config=optimizer._to_1hot(config=data_frame))
assert df_round_trip.x.to_numpy() == pytest.approx(data_frame.x)
assert (df_round_trip.y == data_frame.y).all()
@@ -127,28 +137,23 @@ def test_round_trip_data_frame(optimizer: BaseOptimizer, data_frame: pd.DataFram
def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> None:
- """
- Round-trip test for one-hot-encoding and then decoding a series.
- """
+ """Round-trip test for one-hot-encoding and then decoding a series."""
series_round_trip = optimizer._from_1hot(config=optimizer._to_1hot(config=series))
assert series_round_trip.x.to_numpy() == pytest.approx(series.x)
assert (series_round_trip.y == series.y).all()
assert (series_round_trip.z == series.z).all()
-def test_round_trip_reverse_data_frame(optimizer: BaseOptimizer,
- one_hot_data_frame: npt.NDArray) -> None:
- """
- Round-trip test for one-hot-decoding and then encoding of a numpy array.
- """
+def test_round_trip_reverse_data_frame(
+ optimizer: BaseOptimizer,
+ one_hot_data_frame: npt.NDArray,
+) -> None:
+ """Round-trip test for one-hot-decoding and then encoding of a numpy array."""
round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_data_frame))
assert round_trip == pytest.approx(one_hot_data_frame)
-def test_round_trip_reverse_series(optimizer: BaseOptimizer,
- one_hot_series: npt.NDArray) -> None:
- """
- Round-trip test for one-hot-decoding and then encoding of a numpy array.
- """
+def test_round_trip_reverse_series(optimizer: BaseOptimizer, one_hot_series: npt.NDArray) -> None:
+ """Round-trip test for one-hot-decoding and then encoding of a numpy array."""
round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_series))
assert round_trip == pytest.approx(one_hot_series)
diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py
index 22263b4c1d..748fd1cc82 100644
--- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py
+++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py
@@ -2,77 +2,85 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Test multi-target optimization.
-"""
+"""Test multi-target optimization."""
import logging
from typing import List, Optional, Type
+import ConfigSpace as CS
+import numpy as np
+import pandas as pd
import pytest
-import pandas as pd
-import numpy as np
-import ConfigSpace as CS
-
-from mlos_core.optimizers import OptimizerType, BaseOptimizer
-
+from mlos_core.optimizers import BaseOptimizer, OptimizerType
from mlos_core.tests import SEED
_LOG = logging.getLogger(__name__)
-@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [
- *[(member.value, {}) for member in OptimizerType],
-])
-def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kwargs: dict) -> None:
- """
- Make sure that the optimizer raises an error if the number of objective weights
+@pytest.mark.parametrize(
+ ("optimizer_class", "kwargs"),
+ [
+ *[(member.value, {}) for member in OptimizerType],
+ ],
+)
+def test_multi_target_opt_wrong_weights(
+ optimizer_class: Type[BaseOptimizer],
+ kwargs: dict,
+) -> None:
+ """Make sure that the optimizer raises an error if the number of objective weights
does not match the number of optimization targets.
"""
with pytest.raises(ValueError):
optimizer_class(
parameter_space=CS.ConfigurationSpace(seed=SEED),
- optimization_targets=['main_score', 'other_score'],
+ optimization_targets=["main_score", "other_score"],
objective_weights=[1],
- **kwargs
+ **kwargs,
)
-@pytest.mark.parametrize(('objective_weights'), [
- [2, 1],
- [0.5, 0.5],
- None,
-])
-@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [
- *[(member.value, {}) for member in OptimizerType],
-])
-def test_multi_target_opt(objective_weights: Optional[List[float]],
- optimizer_class: Type[BaseOptimizer],
- kwargs: dict) -> None:
- """
- Toy multi-target optimization problem to test the optimizers with
- mixed numeric types to ensure that original dtypes are retained.
+@pytest.mark.parametrize(
+ ("objective_weights"),
+ [
+ [2, 1],
+ [0.5, 0.5],
+ None,
+ ],
+)
+@pytest.mark.parametrize(
+ ("optimizer_class", "kwargs"),
+ [
+ *[(member.value, {}) for member in OptimizerType],
+ ],
+)
+def test_multi_target_opt(
+ objective_weights: Optional[List[float]],
+ optimizer_class: Type[BaseOptimizer],
+ kwargs: dict,
+) -> None:
+ """Toy multi-target optimization problem to test the optimizers with mixed numeric
+ types to ensure that original dtypes are retained.
"""
max_iterations = 10
def objective(point: pd.DataFrame) -> pd.DataFrame:
# mix of hyperparameters, optimal is to select the highest possible
- return pd.DataFrame({
- "main_score": point.x + point.y,
- "other_score": point.x ** 2 + point.y ** 2,
- })
+ return pd.DataFrame(
+ {
+ "main_score": point.x + point.y,
+ "other_score": point.x**2 + point.y**2,
+ }
+ )
input_space = CS.ConfigurationSpace(seed=SEED)
# add a mix of numeric datatypes
- input_space.add_hyperparameter(
- CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5))
- input_space.add_hyperparameter(
- CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0))
+ input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5))
+ input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0))
optimizer = optimizer_class(
parameter_space=input_space,
- optimization_targets=['main_score', 'other_score'],
+ optimization_targets=["main_score", "other_score"],
objective_weights=objective_weights,
**kwargs,
)
@@ -87,27 +95,28 @@ def test_multi_target_opt(objective_weights: Optional[List[float]],
suggestion, metadata = optimizer.suggest()
assert isinstance(suggestion, pd.DataFrame)
assert metadata is None or isinstance(metadata, pd.DataFrame)
- assert set(suggestion.columns) == {'x', 'y'}
+ assert set(suggestion.columns) == {"x", "y"}
# Check suggestion values are the expected dtype
assert isinstance(suggestion.x.iloc[0], np.integer)
assert isinstance(suggestion.y.iloc[0], np.floating)
# Check that suggestion is in the space
test_configuration = CS.Configuration(
- optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict())
+ optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict()
+ )
# Raises an error if outside of configuration space
test_configuration.is_valid_configuration()
# Test registering the suggested configuration with a score.
observation = objective(suggestion)
assert isinstance(observation, pd.DataFrame)
- assert set(observation.columns) == {'main_score', 'other_score'}
+ assert set(observation.columns) == {"main_score", "other_score"}
optimizer.register(configs=suggestion, scores=observation)
(best_config, best_score, best_context) = optimizer.get_best_observations()
assert isinstance(best_config, pd.DataFrame)
assert isinstance(best_score, pd.DataFrame)
assert best_context is None
- assert set(best_config.columns) == {'x', 'y'}
- assert set(best_score.columns) == {'main_score', 'other_score'}
+ assert set(best_config.columns) == {"x", "y"}
+ assert set(best_score.columns) == {"main_score", "other_score"}
assert best_config.shape == (1, 2)
assert best_score.shape == (1, 2)
@@ -115,7 +124,7 @@ def test_multi_target_opt(objective_weights: Optional[List[float]],
assert isinstance(all_configs, pd.DataFrame)
assert isinstance(all_scores, pd.DataFrame)
assert all_contexts is None
- assert set(all_configs.columns) == {'x', 'y'}
- assert set(all_scores.columns) == {'main_score', 'other_score'}
+ assert set(all_configs.columns) == {"x", "y"}
+ assert set(all_scores.columns) == {"main_score", "other_score"}
assert all_configs.shape == (max_iterations, 2)
assert all_scores.shape == (max_iterations, 2)
diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py
index 8231e59feb..a6aa77087c 100644
--- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py
+++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py
@@ -2,47 +2,52 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for Bayesian Optimizers.
-"""
+"""Tests for Bayesian Optimizers."""
+import logging
from copy import deepcopy
from typing import List, Optional, Type
-import logging
+import ConfigSpace as CS
+import numpy as np
+import pandas as pd
import pytest
-import pandas as pd
-import numpy as np
-import ConfigSpace as CS
-
from mlos_core.optimizers import (
- OptimizerType, ConcreteOptimizer, OptimizerFactory, BaseOptimizer)
-
-from mlos_core.optimizers.bayesian_optimizers import BaseBayesianOptimizer, SmacOptimizer
+ BaseOptimizer,
+ ConcreteOptimizer,
+ OptimizerFactory,
+ OptimizerType,
+)
+from mlos_core.optimizers.bayesian_optimizers import (
+ BaseBayesianOptimizer,
+ SmacOptimizer,
+)
from mlos_core.spaces.adapters import SpaceAdapterType
-
-from mlos_core.tests import get_all_concrete_subclasses, SEED
-
+from mlos_core.tests import SEED, get_all_concrete_subclasses
_LOG = logging.getLogger(__name__)
_LOG.setLevel(logging.DEBUG)
-@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [
- *[(member.value, {}) for member in OptimizerType],
-])
-def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace,
- optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None:
- """
- Test that we can create an optimizer and get a suggestion from it.
- """
+@pytest.mark.parametrize(
+ ("optimizer_class", "kwargs"),
+ [
+ *[(member.value, {}) for member in OptimizerType],
+ ],
+)
+def test_create_optimizer_and_suggest(
+ configuration_space: CS.ConfigurationSpace,
+ optimizer_class: Type[BaseOptimizer],
+ kwargs: Optional[dict],
+) -> None:
+ """Test that we can create an optimizer and get a suggestion from it."""
if kwargs is None:
kwargs = {}
optimizer = optimizer_class(
parameter_space=configuration_space,
- optimization_targets=['score'],
- **kwargs
+ optimization_targets=["score"],
+ **kwargs,
)
assert optimizer is not None
@@ -59,32 +64,38 @@ def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace
optimizer.register_pending(configs=suggestion, metadata=metadata)
-@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [
- *[(member.value, {}) for member in OptimizerType],
-])
-def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace,
- optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None:
- """
- Toy problem to test the optimizers.
- """
+@pytest.mark.parametrize(
+ ("optimizer_class", "kwargs"),
+ [
+ *[(member.value, {}) for member in OptimizerType],
+ ],
+)
+def test_basic_interface_toy_problem(
+ configuration_space: CS.ConfigurationSpace,
+ optimizer_class: Type[BaseOptimizer],
+ kwargs: Optional[dict],
+) -> None:
+ """Toy problem to test the optimizers."""
# pylint: disable=too-many-locals
max_iterations = 20
if kwargs is None:
kwargs = {}
if optimizer_class == OptimizerType.SMAC.value:
- # SMAC sets the initial random samples as a percentage of the max iterations, which defaults to 100.
- # To avoid having to train more than 25 model iterations, we set a lower number of max iterations.
- kwargs['max_trials'] = max_iterations * 2
+ # SMAC sets the initial random samples as a percentage of the max
+ # iterations, which defaults to 100.
+ # To avoid having to train more than 25 model iterations, we set a lower
+ # number of max iterations.
+ kwargs["max_trials"] = max_iterations * 2
def objective(x: pd.Series) -> pd.DataFrame:
- return pd.DataFrame({"score": (6 * x - 2)**2 * np.sin(12 * x - 4)})
+ return pd.DataFrame({"score": (6 * x - 2) ** 2 * np.sin(12 * x - 4)})
# Emukit doesn't allow specifying a random state, so we set the global seed.
np.random.seed(SEED)
optimizer = optimizer_class(
parameter_space=configuration_space,
- optimization_targets=['score'],
- **kwargs
+ optimization_targets=["score"],
+ **kwargs,
)
with pytest.raises(ValueError, match="No observations"):
@@ -97,12 +108,12 @@ def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace,
suggestion, metadata = optimizer.suggest()
assert isinstance(suggestion, pd.DataFrame)
assert metadata is None or isinstance(metadata, pd.DataFrame)
- assert set(suggestion.columns) == {'x', 'y', 'z'}
+ assert set(suggestion.columns) == {"x", "y", "z"}
# check that suggestion is in the space
configuration = CS.Configuration(optimizer.parameter_space, suggestion.iloc[0].to_dict())
# Raises an error if outside of configuration space
configuration.is_valid_configuration()
- observation = objective(suggestion['x'])
+ observation = objective(suggestion["x"])
assert isinstance(observation, pd.DataFrame)
optimizer.register(configs=suggestion, scores=observation, metadata=metadata)
@@ -110,8 +121,8 @@ def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace,
assert isinstance(best_config, pd.DataFrame)
assert isinstance(best_score, pd.DataFrame)
assert best_context is None
- assert set(best_config.columns) == {'x', 'y', 'z'}
- assert set(best_score.columns) == {'score'}
+ assert set(best_config.columns) == {"x", "y", "z"}
+ assert set(best_score.columns) == {"score"}
assert best_config.shape == (1, 3)
assert best_score.shape == (1, 1)
assert best_score.score.iloc[0] < -5
@@ -120,12 +131,13 @@ def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace,
assert isinstance(all_configs, pd.DataFrame)
assert isinstance(all_scores, pd.DataFrame)
assert all_contexts is None
- assert set(all_configs.columns) == {'x', 'y', 'z'}
- assert set(all_scores.columns) == {'score'}
+ assert set(all_configs.columns) == {"x", "y", "z"}
+ assert set(all_scores.columns) == {"score"}
assert all_configs.shape == (20, 3)
assert all_scores.shape == (20, 1)
- # It would be better to put this into bayesian_optimizer_test but then we'd have to refit the model
+ # It would be better to put this into bayesian_optimizer_test but then we'd have
+ # to refit the model
if isinstance(optimizer, BaseBayesianOptimizer):
pred_best = optimizer.surrogate_predict(configs=best_config)
assert pred_best.shape == (1,)
@@ -134,42 +146,48 @@ def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace,
assert pred_all.shape == (20,)
-@pytest.mark.parametrize(('optimizer_type'), [
- # Enumerate all supported Optimizers
- # *[member for member in OptimizerType],
- *list(OptimizerType),
-])
+@pytest.mark.parametrize(
+ ("optimizer_type"),
+ [
+ # Enumerate all supported Optimizers
+ # *[member for member in OptimizerType],
+ *list(OptimizerType),
+ ],
+)
def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None:
- """
- Test that all optimizer types are listed in the ConcreteOptimizer constraints.
- """
- assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member
+ """Test that all optimizer types are listed in the ConcreteOptimizer constraints."""
+ # pylint: disable=no-member
+ assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined]
-@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [
- # Default optimizer
- (None, {}),
- # Enumerate all supported Optimizers
- *[(member, {}) for member in OptimizerType],
- # Optimizer with non-empty kwargs argument
-])
-def test_create_optimizer_with_factory_method(configuration_space: CS.ConfigurationSpace,
- optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None:
- """
- Test that we can create an optimizer via a factory.
- """
+@pytest.mark.parametrize(
+ ("optimizer_type", "kwargs"),
+ [
+ # Default optimizer
+ (None, {}),
+ # Enumerate all supported Optimizers
+ *[(member, {}) for member in OptimizerType],
+ # Optimizer with non-empty kwargs argument
+ ],
+)
+def test_create_optimizer_with_factory_method(
+ configuration_space: CS.ConfigurationSpace,
+ optimizer_type: Optional[OptimizerType],
+ kwargs: Optional[dict],
+) -> None:
+ """Test that we can create an optimizer via a factory."""
if kwargs is None:
kwargs = {}
if optimizer_type is None:
optimizer = OptimizerFactory.create(
parameter_space=configuration_space,
- optimization_targets=['score'],
+ optimization_targets=["score"],
optimizer_kwargs=kwargs,
)
else:
optimizer = OptimizerFactory.create(
parameter_space=configuration_space,
- optimization_targets=['score'],
+ optimization_targets=["score"],
optimizer_type=optimizer_type,
optimizer_kwargs=kwargs,
)
@@ -185,20 +203,24 @@ def test_create_optimizer_with_factory_method(configuration_space: CS.Configurat
assert myrepr.startswith(optimizer_type.value.__name__)
-@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [
- # Enumerate all supported Optimizers
- *[(member, {}) for member in OptimizerType],
- # Optimizer with non-empty kwargs argument
- (OptimizerType.SMAC, {
- # Test with default config.
- 'use_default_config': True,
- # 'n_random_init': 10,
- }),
-])
+@pytest.mark.parametrize(
+ ("optimizer_type", "kwargs"),
+ [
+ # Enumerate all supported Optimizers
+ *[(member, {}) for member in OptimizerType],
+ # Optimizer with non-empty kwargs argument
+ (
+ OptimizerType.SMAC,
+ {
+ # Test with default config.
+ "use_default_config": True,
+ # 'n_random_init': 10,
+ },
+ ),
+ ],
+)
def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optional[dict]) -> None:
- """
- Toy problem to test the optimizers with llamatune space adapter.
- """
+ """Toy problem to test the optimizers with llamatune space adapter."""
# pylint: disable=too-complex,disable=too-many-statements,disable=too-many-locals
num_iters = 50
if kwargs is None:
@@ -212,8 +234,8 @@ def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optiona
input_space = CS.ConfigurationSpace(seed=1234)
# Add two continuous inputs
- input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=3))
- input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=3))
+ input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=3))
+ input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0, upper=3))
# Initialize an optimizer that uses LlamaTune space adapter
space_adapter_kwargs = {
@@ -236,7 +258,7 @@ def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optiona
llamatune_optimizer: BaseOptimizer = OptimizerFactory.create(
parameter_space=input_space,
- optimization_targets=['score'],
+ optimization_targets=["score"],
optimizer_type=optimizer_type,
optimizer_kwargs=llamatune_optimizer_kwargs,
space_adapter_type=SpaceAdapterType.LLAMATUNE,
@@ -245,7 +267,7 @@ def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optiona
# Initialize an optimizer that uses the original space
optimizer: BaseOptimizer = OptimizerFactory.create(
parameter_space=input_space,
- optimization_targets=['score'],
+ optimization_targets=["score"],
optimizer_type=optimizer_type,
optimizer_kwargs=optimizer_kwargs,
)
@@ -254,7 +276,7 @@ def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optiona
assert optimizer.optimizer_parameter_space != llamatune_optimizer.optimizer_parameter_space
llamatune_n_random_init = 0
- opt_n_random_init = int(kwargs.get('n_random_init', 0))
+ opt_n_random_init = int(kwargs.get("n_random_init", 0))
if optimizer_type == OptimizerType.SMAC:
assert isinstance(optimizer, SmacOptimizer)
assert isinstance(llamatune_optimizer, SmacOptimizer)
@@ -275,8 +297,9 @@ def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optiona
# loop for llamatune-optimizer
suggestion, metadata = llamatune_optimizer.suggest()
- _x, _y = suggestion['x'].iloc[0], suggestion['y'].iloc[0]
- assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx(3., rel=1e-3) # optimizer explores 1-dimensional space
+ _x, _y = suggestion["x"].iloc[0], suggestion["y"].iloc[0]
+ # optimizer explores 1-dimensional space
+ assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx(3.0, rel=1e-3)
observation = objective(suggestion)
llamatune_optimizer.register(configs=suggestion, scores=observation, metadata=metadata)
@@ -284,28 +307,33 @@ def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optiona
best_observation = optimizer.get_best_observations()
llamatune_best_observation = llamatune_optimizer.get_best_observations()
- for (best_config, best_score, best_context) in (best_observation, llamatune_best_observation):
+ for best_config, best_score, best_context in (best_observation, llamatune_best_observation):
assert isinstance(best_config, pd.DataFrame)
assert isinstance(best_score, pd.DataFrame)
assert best_context is None
- assert set(best_config.columns) == {'x', 'y'}
- assert set(best_score.columns) == {'score'}
+ assert set(best_config.columns) == {"x", "y"}
+ assert set(best_score.columns) == {"score"}
(best_config, best_score, _context) = best_observation
(llamatune_best_config, llamatune_best_score, _context) = llamatune_best_observation
- # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's one, or close to that
- assert best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] or \
- best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0]
+ # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's
+ # one, or close to that
+ assert (
+ best_score.score.iloc[0] > llamatune_best_score.score.iloc[0]
+ or best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0]
+ )
# Retrieve and check all observations
- for (all_configs, all_scores, all_contexts) in (
- optimizer.get_observations(), llamatune_optimizer.get_observations()):
+ for all_configs, all_scores, all_contexts in (
+ optimizer.get_observations(),
+ llamatune_optimizer.get_observations(),
+ ):
assert isinstance(all_configs, pd.DataFrame)
assert isinstance(all_scores, pd.DataFrame)
assert all_contexts is None
- assert set(all_configs.columns) == {'x', 'y'}
- assert set(all_scores.columns) == {'score'}
+ assert set(all_configs.columns) == {"x", "y"}
+ assert set(all_scores.columns) == {"score"}
assert len(all_configs) == num_iters
assert len(all_scores) == num_iters
@@ -317,30 +345,36 @@ def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optiona
# Dynamically determine all of the optimizers we have implemented.
# Note: these must be sorted.
-optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses(BaseOptimizer, # type: ignore[type-abstract]
- pkg_name='mlos_core')
+optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses(
+ BaseOptimizer, # type: ignore[type-abstract]
+ pkg_name="mlos_core",
+)
assert optimizer_subclasses
-@pytest.mark.parametrize(('optimizer_class'), optimizer_subclasses)
+@pytest.mark.parametrize(("optimizer_class"), optimizer_subclasses)
def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None:
- """
- Test that all optimizer classes are listed in the OptimizerType enum.
- """
+ """Test that all optimizer classes are listed in the OptimizerType enum."""
optimizer_type_classes = {member.value for member in OptimizerType}
assert optimizer_class in optimizer_type_classes
-@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [
- # Default optimizer
- (None, {}),
- # Enumerate all supported Optimizers
- *[(member, {}) for member in OptimizerType],
- # Optimizer with non-empty kwargs argument
-])
-def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None:
- """
- Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained.
+@pytest.mark.parametrize(
+ ("optimizer_type", "kwargs"),
+ [
+ # Default optimizer
+ (None, {}),
+ # Enumerate all supported Optimizers
+ *[(member, {}) for member in OptimizerType],
+ # Optimizer with non-empty kwargs argument
+ ],
+)
+def test_mixed_numerics_type_input_space_types(
+ optimizer_type: Optional[OptimizerType],
+ kwargs: Optional[dict],
+) -> None:
+ """Toy problem to test the optimizers with mixed numeric types to ensure that
+ original dtypes are retained.
"""
max_iterations = 10
if kwargs is None:
@@ -352,19 +386,19 @@ def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[Optimize
input_space = CS.ConfigurationSpace(seed=SEED)
# add a mix of numeric datatypes
- input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5))
- input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0))
+ input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5))
+ input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0))
if optimizer_type is None:
optimizer = OptimizerFactory.create(
parameter_space=input_space,
- optimization_targets=['score'],
+ optimization_targets=["score"],
optimizer_kwargs=kwargs,
)
else:
optimizer = OptimizerFactory.create(
parameter_space=input_space,
- optimization_targets=['score'],
+ optimization_targets=["score"],
optimizer_type=optimizer_type,
optimizer_kwargs=kwargs,
)
@@ -378,12 +412,14 @@ def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[Optimize
for _ in range(max_iterations):
suggestion, metadata = optimizer.suggest()
assert isinstance(suggestion, pd.DataFrame)
- assert (suggestion.columns == ['x', 'y']).all()
+ assert (suggestion.columns == ["x", "y"]).all()
# Check suggestion values are the expected dtype
- assert isinstance(suggestion['x'].iloc[0], np.integer)
- assert isinstance(suggestion['y'].iloc[0], np.floating)
+ assert isinstance(suggestion["x"].iloc[0], np.integer)
+ assert isinstance(suggestion["y"].iloc[0], np.floating)
# Check that suggestion is in the space
- test_configuration = CS.Configuration(optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict())
+ test_configuration = CS.Configuration(
+ optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict()
+ )
# Raises an error if outside of configuration space
test_configuration.is_valid_configuration()
# Test registering the suggested configuration with a score.
diff --git a/mlos_core/mlos_core/tests/optimizers/random_optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/random_optimizer_test.py
index e7f3fb9d3e..b3b79ffadb 100644
--- a/mlos_core/mlos_core/tests/optimizers/random_optimizer_test.py
+++ b/mlos_core/mlos_core/tests/optimizers/random_optimizer_test.py
@@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for random optimizer.
-"""
+"""Tests for random optimizer."""
diff --git a/mlos_core/mlos_core/tests/spaces/__init__.py b/mlos_core/mlos_core/tests/spaces/__init__.py
index 489802cb5a..a4112b6081 100644
--- a/mlos_core/mlos_core/tests/spaces/__init__.py
+++ b/mlos_core/mlos_core/tests/spaces/__init__.py
@@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Basic initializer module for the mlos_core.tests.spaces package.
-"""
+"""Basic initializer module for the mlos_core.tests.spaces package."""
diff --git a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py
index 37b8aa3a69..5d394cf4e2 100644
--- a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py
+++ b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for Identity space adapter.
-"""
+"""Tests for Identity space adapter."""
# pylint: disable=missing-function-docstring
@@ -15,27 +13,36 @@ from mlos_core.spaces.adapters import IdentityAdapter
def test_identity_adapter() -> None:
- """
- Tests identity adapter
- """
+ """Tests identity adapter."""
input_space = CS.ConfigurationSpace(seed=1234)
input_space.add_hyperparameter(
- CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100))
+ CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100)
+ )
input_space.add_hyperparameter(
- CS.UniformFloatHyperparameter(name='float_1', lower=0, upper=100))
+ CS.UniformFloatHyperparameter(name="float_1", lower=0, upper=100)
+ )
input_space.add_hyperparameter(
- CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off']))
+ CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"])
+ )
adapter = IdentityAdapter(orig_parameter_space=input_space)
num_configs = 10
- for sampled_config in input_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable # (false positive)
- sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys()))
+ for sampled_config in input_space.sample_configuration(
+ size=num_configs
+ ): # pylint: disable=not-an-iterable # (false positive)
+ sampled_config_df = pd.DataFrame(
+ [sampled_config.values()], columns=list(sampled_config.keys())
+ )
target_config_df = adapter.inverse_transform(sampled_config_df)
assert target_config_df.equals(sampled_config_df)
- target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict())
+ target_config = CS.Configuration(
+ adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()
+ )
assert target_config == sampled_config
orig_config_df = adapter.transform(target_config_df)
assert orig_config_df.equals(sampled_config_df)
- orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict())
+ orig_config = CS.Configuration(
+ adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()
+ )
assert orig_config == sampled_config
diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py
index 661decc288..9d73b6ba8c 100644
--- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py
+++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py
@@ -2,18 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for LlamaTune space adapter.
-"""
+"""Tests for LlamaTune space adapter."""
# pylint: disable=missing-function-docstring
from typing import Any, Dict, Iterator, List, Set
-import pytest
-
import ConfigSpace as CS
import pandas as pd
+import pytest
from mlos_core.spaces.adapters import LlamaTuneAdapter
@@ -24,51 +21,59 @@ def construct_parameter_space(
n_categorical_params: int = 0,
seed: int = 1234,
) -> CS.ConfigurationSpace:
- """
- Helper function for construct an instance of `ConfigSpace.ConfigurationSpace`.
- """
+ """Helper function for construct an instance of `ConfigSpace.ConfigurationSpace`."""
input_space = CS.ConfigurationSpace(seed=seed)
for idx in range(n_continuous_params):
input_space.add_hyperparameter(
- CS.UniformFloatHyperparameter(name=f'cont_{idx}', lower=0, upper=64))
+ CS.UniformFloatHyperparameter(name=f"cont_{idx}", lower=0, upper=64)
+ )
for idx in range(n_integer_params):
input_space.add_hyperparameter(
- CS.UniformIntegerHyperparameter(name=f'int_{idx}', lower=-1, upper=256))
+ CS.UniformIntegerHyperparameter(name=f"int_{idx}", lower=-1, upper=256)
+ )
for idx in range(n_categorical_params):
input_space.add_hyperparameter(
- CS.CategoricalHyperparameter(name=f'str_{idx}', choices=[f'option_{idx}' for idx in range(5)]))
+ CS.CategoricalHyperparameter(
+ name=f"str_{idx}", choices=[f"option_{idx}" for idx in range(5)]
+ )
+ )
return input_space
-@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([
- (num_target_space_dims, param_space_kwargs)
- for num_target_space_dims in (2, 4)
- for num_orig_space_factor in (1.5, 4)
- for param_space_kwargs in (
- {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)},
- {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)},
- {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)},
- # Mix of all three types
- {
- 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3),
- 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3),
- 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3),
- },
- )
-]))
-def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals
- """
- Tests LlamaTune's low-to-high space projection method.
- """
+@pytest.mark.parametrize(
+ ("num_target_space_dims", "param_space_kwargs"),
+ (
+ [
+ (num_target_space_dims, param_space_kwargs)
+ for num_target_space_dims in (2, 4)
+ for num_orig_space_factor in (1.5, 4)
+ for param_space_kwargs in (
+ {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)},
+ {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)},
+ {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)},
+ # Mix of all three types
+ {
+ "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ },
+ )
+ ]
+ ),
+)
+def test_num_low_dims(
+ num_target_space_dims: int,
+ param_space_kwargs: dict,
+) -> None: # pylint: disable=too-many-locals
+ """Tests LlamaTune's low-to-high space projection method."""
input_space = construct_parameter_space(**param_space_kwargs)
# Number of target parameter space dimensions should be fewer than those of the original space
with pytest.raises(ValueError):
LlamaTuneAdapter(
- orig_parameter_space=input_space,
- num_low_dims=len(list(input_space.keys()))
+ orig_parameter_space=input_space, num_low_dims=len(list(input_space.keys()))
)
# Enable only low-dimensional space projections
@@ -76,13 +81,15 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N
orig_parameter_space=input_space,
num_low_dims=num_target_space_dims,
special_param_values=None,
- max_unique_values_per_param=None
+ max_unique_values_per_param=None,
)
sampled_configs = adapter.target_parameter_space.sample_configuration(size=100)
for sampled_config in sampled_configs: # pylint: disable=not-an-iterable # (false positive)
# Transform low-dim config to high-dim point/config
- sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys()))
+ sampled_config_df = pd.DataFrame(
+ [sampled_config.values()], columns=list(sampled_config.keys())
+ )
orig_config_df = adapter.transform(sampled_config_df)
# High-dim (i.e., original) config should be valid
@@ -93,35 +100,45 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N
target_config_df = adapter.inverse_transform(orig_config_df)
# Sampled config and this should be the same
- target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict())
+ target_config = CS.Configuration(
+ adapter.target_parameter_space,
+ values=target_config_df.iloc[0].to_dict(),
+ )
assert target_config == sampled_config
# Try inverse projection (i.e., high-to-low) for previously unseen configs
unseen_sampled_configs = adapter.target_parameter_space.sample_configuration(size=25)
- for unseen_sampled_config in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive)
- if unseen_sampled_config in sampled_configs: # pylint: disable=unsupported-membership-test # (false positive)
+ for (
+ unseen_sampled_config
+ ) in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive)
+ if (
+ unseen_sampled_config in sampled_configs
+ ): # pylint: disable=unsupported-membership-test # (false positive)
continue
- unseen_sampled_config_df = pd.DataFrame([unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys()))
+ unseen_sampled_config_df = pd.DataFrame(
+ [unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys())
+ )
with pytest.raises(ValueError):
- _ = adapter.inverse_transform(unseen_sampled_config_df) # pylint: disable=redefined-variable-type
+ _ = adapter.inverse_transform(
+ unseen_sampled_config_df
+ ) # pylint: disable=redefined-variable-type
def test_special_parameter_values_validation() -> None:
- """
- Tests LlamaTune's validation process of user-provided special parameter values dictionary.
+ """Tests LlamaTune's validation process of user-provided special parameter values
+ dictionary.
"""
input_space = CS.ConfigurationSpace(seed=1234)
input_space.add_hyperparameter(
- CS.CategoricalHyperparameter(name='str', choices=[f'choice_{idx}' for idx in range(5)]))
- input_space.add_hyperparameter(
- CS.UniformFloatHyperparameter(name='cont', lower=-1, upper=100))
- input_space.add_hyperparameter(
- CS.UniformIntegerHyperparameter(name='int', lower=0, upper=100))
+ CS.CategoricalHyperparameter(name="str", choices=[f"choice_{idx}" for idx in range(5)])
+ )
+ input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="cont", lower=-1, upper=100))
+ input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="int", lower=0, upper=100))
# Only UniformIntegerHyperparameters are currently supported
with pytest.raises(NotImplementedError):
- special_param_values_dict_1 = {'str': 'choice_1'}
+ special_param_values_dict_1 = {"str": "choice_1"}
LlamaTuneAdapter(
orig_parameter_space=input_space,
num_low_dims=2,
@@ -130,7 +147,7 @@ def test_special_parameter_values_validation() -> None:
)
with pytest.raises(NotImplementedError):
- special_param_values_dict_2 = {'cont': -1}
+ special_param_values_dict_2 = {"cont": -1}
LlamaTuneAdapter(
orig_parameter_space=input_space,
num_low_dims=2,
@@ -139,8 +156,8 @@ def test_special_parameter_values_validation() -> None:
)
# Special value should belong to parameter value domain
- with pytest.raises(ValueError, match='value domain'):
- special_param_values_dict = {'int': -1}
+ with pytest.raises(ValueError, match="value domain"):
+ special_param_values_dict = {"int": -1}
LlamaTuneAdapter(
orig_parameter_space=input_space,
num_low_dims=2,
@@ -150,15 +167,15 @@ def test_special_parameter_values_validation() -> None:
# Invalid dicts; ValueError should be thrown
invalid_special_param_values_dicts: List[Dict[str, Any]] = [
- {'int-Q': 0}, # parameter does not exist
- {'int': {0: 0.2}}, # invalid definition
- {'int': 0.2}, # invalid parameter value
- {'int': (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %)
- {'int': [0, 0]}, # duplicate special values
- {'int': []}, # empty list
- {'int': [{0: 0.2}]},
- {'int': [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct
- {'int': [(0, 0.1), (0, 0.2)]}, # duplicate special values
+ {"int-Q": 0}, # parameter does not exist
+ {"int": {0: 0.2}}, # invalid definition
+ {"int": 0.2}, # invalid parameter value
+ {"int": (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %)
+ {"int": [0, 0]}, # duplicate special values
+ {"int": []}, # empty list
+ {"int": [{0: 0.2}]},
+ {"int": [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct
+ {"int": [(0, 0.1), (0, 0.2)]}, # duplicate special values
]
for spv_dict in invalid_special_param_values_dicts:
with pytest.raises(ValueError):
@@ -171,13 +188,13 @@ def test_special_parameter_values_validation() -> None:
# Biasing percentage of special value(s) are invalid
invalid_special_param_values_dicts = [
- {'int': (0, 1.1)}, # >1 probability
- {'int': (0, 0)}, # Zero probability
- {'int': (0, -0.1)}, # Negative probability
- {'int': (0, 20)}, # 2,000% instead of 20%
- {'int': [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100%
- {'int': [(0, 0.4), (1, 0.7)]}, # combined probability >100%
- {'int': [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid.
+ {"int": (0, 1.1)}, # >1 probability
+ {"int": (0, 0)}, # Zero probability
+ {"int": (0, -0.1)}, # Negative probability
+ {"int": (0, 20)}, # 2,000% instead of 20%
+ {"int": [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100%
+ {"int": [(0, 0.4), (1, 0.7)]}, # combined probability >100%
+ {"int": [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid.
]
for spv_dict in invalid_special_param_values_dicts:
@@ -193,21 +210,26 @@ def test_special_parameter_values_validation() -> None:
def gen_random_configs(adapter: LlamaTuneAdapter, num_configs: int) -> Iterator[CS.Configuration]:
for sampled_config in adapter.target_parameter_space.sample_configuration(size=num_configs):
# Transform low-dim config to high-dim config
- sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys()))
+ sampled_config_df = pd.DataFrame(
+ [sampled_config.values()], columns=list(sampled_config.keys())
+ )
orig_config_df = adapter.transform(sampled_config_df)
- orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict())
+ orig_config = CS.Configuration(
+ adapter.orig_parameter_space,
+ values=orig_config_df.iloc[0].to_dict(),
+ )
yield orig_config
-def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex
- """
- Tests LlamaTune's special parameter values biasing methodology
- """
+def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex
+ """Tests LlamaTune's special parameter values biasing methodology."""
input_space = CS.ConfigurationSpace(seed=1234)
input_space.add_hyperparameter(
- CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100))
+ CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100)
+ )
input_space.add_hyperparameter(
- CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=100))
+ CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=100)
+ )
num_configs = 400
bias_percentage = LlamaTuneAdapter.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE
@@ -215,10 +237,10 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co
# Single parameter; single special value
special_param_value_dicts: List[Dict[str, Any]] = [
- {'int_1': 0},
- {'int_1': (0, bias_percentage)},
- {'int_1': [0]},
- {'int_1': [(0, bias_percentage)]}
+ {"int_1": 0},
+ {"int_1": (0, bias_percentage)},
+ {"int_1": [0]},
+ {"int_1": [(0, bias_percentage)]},
]
for spv_dict in special_param_value_dicts:
@@ -230,13 +252,14 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co
)
special_value_occurrences = sum(
- 1 for config in gen_random_configs(adapter, num_configs) if config['int_1'] == 0)
+ 1 for config in gen_random_configs(adapter, num_configs) if config["int_1"] == 0
+ )
assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences
# Single parameter; multiple special values
special_param_value_dicts = [
- {'int_1': [0, 1]},
- {'int_1': [(0, bias_percentage), (1, bias_percentage)]}
+ {"int_1": [0, 1]},
+ {"int_1": [(0, bias_percentage), (1, bias_percentage)]},
]
for spv_dict in special_param_value_dicts:
@@ -249,9 +272,9 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co
special_values_occurrences = {0: 0, 1: 0}
for config in gen_random_configs(adapter, num_configs):
- if config['int_1'] == 0:
+ if config["int_1"] == 0:
special_values_occurrences[0] += 1
- elif config['int_1'] == 1:
+ elif config["int_1"] == 1:
special_values_occurrences[1] += 1
assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_occurrences[0]
@@ -259,8 +282,8 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co
# Multiple parameters; multiple special values; different biasing percentage
spv_dict = {
- 'int_1': [(0, bias_percentage), (1, bias_percentage / 2)],
- 'int_2': [(2, bias_percentage / 2), (100, bias_percentage * 1.5)]
+ "int_1": [(0, bias_percentage), (1, bias_percentage / 2)],
+ "int_2": [(2, bias_percentage / 2), (100, bias_percentage * 1.5)],
}
adapter = LlamaTuneAdapter(
orig_parameter_space=input_space,
@@ -270,44 +293,54 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co
)
special_values_instances: Dict[str, Dict[int, int]] = {
- 'int_1': {0: 0, 1: 0},
- 'int_2': {2: 0, 100: 0},
+ "int_1": {0: 0, 1: 0},
+ "int_2": {2: 0, 100: 0},
}
for config in gen_random_configs(adapter, num_configs):
- if config['int_1'] == 0:
- special_values_instances['int_1'][0] += 1
- elif config['int_1'] == 1:
- special_values_instances['int_1'][1] += 1
+ if config["int_1"] == 0:
+ special_values_instances["int_1"][0] += 1
+ elif config["int_1"] == 1:
+ special_values_instances["int_1"][1] += 1
- if config['int_2'] == 2:
- special_values_instances['int_2'][2] += 1
- elif config['int_2'] == 100:
- special_values_instances['int_2'][100] += 1
+ if config["int_2"] == 2:
+ special_values_instances["int_2"][2] += 1
+ elif config["int_2"] == 100:
+ special_values_instances["int_2"][100] += 1
- assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances['int_1'][0]
- assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_1'][1]
- assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_2'][2]
- assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances['int_2'][100]
+ assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances["int_1"][0]
+ assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances["int_1"][
+ 1
+ ]
+ assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances["int_2"][
+ 2
+ ]
+ assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances[
+ "int_2"
+ ][100]
def test_max_unique_values_per_param() -> None:
- """
- Tests LlamaTune's parameter values discretization implementation.
- """
+ """Tests LlamaTune's parameter values discretization implementation."""
# Define config space with a mix of different parameter types
input_space = CS.ConfigurationSpace(seed=1234)
input_space.add_hyperparameter(
- CS.UniformFloatHyperparameter(name='cont_1', lower=0, upper=5))
+ CS.UniformFloatHyperparameter(name="cont_1", lower=0, upper=5),
+ )
input_space.add_hyperparameter(
- CS.UniformFloatHyperparameter(name='cont_2', lower=1, upper=100))
+ CS.UniformFloatHyperparameter(name="cont_2", lower=1, upper=100)
+ )
input_space.add_hyperparameter(
- CS.UniformIntegerHyperparameter(name='int_1', lower=1, upper=10))
+ CS.UniformIntegerHyperparameter(name="int_1", lower=1, upper=10)
+ )
input_space.add_hyperparameter(
- CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=2048))
+ CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=2048)
+ )
input_space.add_hyperparameter(
- CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off']))
+ CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"])
+ )
input_space.add_hyperparameter(
- CS.CategoricalHyperparameter(name='str_2', choices=[f'choice_{idx}' for idx in range(10)]))
+ CS.CategoricalHyperparameter(name="str_2", choices=[f"choice_{idx}" for idx in range(10)])
+ )
# Restrict the number of unique parameter values
num_configs = 200
@@ -330,25 +363,33 @@ def test_max_unique_values_per_param() -> None:
assert len(unique_values) <= max_unique_values_per_param
-@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([
- (num_target_space_dims, param_space_kwargs)
- for num_target_space_dims in (2, 4)
- for num_orig_space_factor in (1.5, 4)
- for param_space_kwargs in (
- {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)},
- {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)},
- {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)},
- # Mix of all three types
- {
- 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3),
- 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3),
- 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3),
- },
- )
-]))
-def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals
- """
- Tests LlamaTune's approximate high-to-low space projection method, using pseudo-inverse.
+@pytest.mark.parametrize(
+ ("num_target_space_dims", "param_space_kwargs"),
+ (
+ [
+ (num_target_space_dims, param_space_kwargs)
+ for num_target_space_dims in (2, 4)
+ for num_orig_space_factor in (1.5, 4)
+ for param_space_kwargs in (
+ {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)},
+ {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)},
+ {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)},
+ # Mix of all three types
+ {
+ "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ },
+ )
+ ]
+ ),
+)
+def test_approx_inverse_mapping(
+ num_target_space_dims: int,
+ param_space_kwargs: dict,
+) -> None: # pylint: disable=too-many-locals
+ """Tests LlamaTune's approximate high-to-low space projection method, using pseudo-
+ inverse.
"""
input_space = construct_parameter_space(**param_space_kwargs)
@@ -361,9 +402,11 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs:
use_approximate_reverse_mapping=False,
)
- sampled_config = input_space.sample_configuration() # size=1)
+ sampled_config = input_space.sample_configuration() # size=1)
with pytest.raises(ValueError):
- sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys()))
+ sampled_config_df = pd.DataFrame(
+ [sampled_config.values()], columns=list(sampled_config.keys())
+ )
_ = adapter.inverse_transform(sampled_config_df)
# Enable low-dimensional space projection *and* reverse mapping
@@ -376,41 +419,67 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs:
)
# Warning should be printed the first time
- sampled_config = input_space.sample_configuration() # size=1)
+ sampled_config = input_space.sample_configuration() # size=1)
with pytest.warns(UserWarning):
- sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys()))
+ sampled_config_df = pd.DataFrame(
+ [sampled_config.values()], columns=list(sampled_config.keys())
+ )
target_config_df = adapter.inverse_transform(sampled_config_df)
# Low-dim (i.e., target) config should be valid
- target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict())
+ target_config = CS.Configuration(
+ adapter.target_parameter_space,
+ values=target_config_df.iloc[0].to_dict(),
+ )
adapter.target_parameter_space.check_configuration(target_config)
# Test inverse transform with 100 random configs
for _ in range(100):
- sampled_config = input_space.sample_configuration() # size=1)
- sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys()))
+ sampled_config = input_space.sample_configuration() # size=1)
+ sampled_config_df = pd.DataFrame(
+ [sampled_config.values()], columns=list(sampled_config.keys())
+ )
target_config_df = adapter.inverse_transform(sampled_config_df)
# Low-dim (i.e., target) config should be valid
- target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict())
+ target_config = CS.Configuration(
+ adapter.target_parameter_space,
+ values=target_config_df.iloc[0].to_dict(),
+ )
adapter.target_parameter_space.check_configuration(target_config)
-@pytest.mark.parametrize(('num_low_dims', 'special_param_values', 'max_unique_values_per_param'), ([
- (num_low_dims, special_param_values, max_unique_values_per_param)
- for num_low_dims in (8, 16)
- for special_param_values in (
- {'int_1': -1, 'int_2': -1, 'int_3': -1, 'int_4': [-1, 0]},
- {'int_1': (-1, 0.1), 'int_2': -1, 'int_3': (-1, 0.3), 'int_4': [(-1, 0.1), (0, 0.2)]},
- )
- for max_unique_values_per_param in (50, 250)
-]))
-def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int) -> None:
- """
- Tests LlamaTune space adapter when all components are active.
- """
+@pytest.mark.parametrize(
+ ("num_low_dims", "special_param_values", "max_unique_values_per_param"),
+ (
+ [
+ (num_low_dims, special_param_values, max_unique_values_per_param)
+ for num_low_dims in (8, 16)
+ for special_param_values in (
+ {"int_1": -1, "int_2": -1, "int_3": -1, "int_4": [-1, 0]},
+ {
+ "int_1": (-1, 0.1),
+ "int_2": -1,
+ "int_3": (-1, 0.3),
+ "int_4": [(-1, 0.1), (0, 0.2)],
+ },
+ )
+ for max_unique_values_per_param in (50, 250)
+ ]
+ ),
+)
+def test_llamatune_pipeline(
+ num_low_dims: int,
+ special_param_values: dict,
+ max_unique_values_per_param: int,
+) -> None:
+ """Tests LlamaTune space adapter when all components are active."""
# pylint: disable=too-many-locals
# Define config space with a mix of different parameter types
- input_space = construct_parameter_space(n_continuous_params=10, n_integer_params=10, n_categorical_params=5)
+ input_space = construct_parameter_space(
+ n_continuous_params=10,
+ n_integer_params=10,
+ n_categorical_params=5,
+ )
adapter = LlamaTuneAdapter(
orig_parameter_space=input_space,
num_low_dims=num_low_dims,
@@ -419,13 +488,16 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u
)
special_value_occurrences = {
+ # pylint: disable=protected-access
param: {special_value: 0 for special_value, _ in tuples_list}
- for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access
+ for param, tuples_list in adapter._special_param_values_dict.items()
}
unique_values_dict: Dict[str, Set] = {param: set() for param in input_space.keys()}
num_configs = 1000
- for config in adapter.target_parameter_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable
+ for config in adapter.target_parameter_space.sample_configuration(
+ size=num_configs
+ ): # pylint: disable=not-an-iterable
# Transform low-dim config to high-dim point/config
sampled_config_df = pd.DataFrame([config.values()], columns=list(config.keys()))
orig_config_df = adapter.transform(sampled_config_df)
@@ -436,7 +508,10 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u
# Transform high-dim config back to low-dim
target_config_df = adapter.inverse_transform(orig_config_df)
# Sampled config and this should be the same
- target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict())
+ target_config = CS.Configuration(
+ adapter.target_parameter_space,
+ values=target_config_df.iloc[0].to_dict(),
+ )
assert target_config == config
for param, value in orig_config.items():
@@ -450,35 +525,49 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u
# Ensure that occurrences of special values do not significantly deviate from expected
eps = 0.2
- for param, tuples_list in adapter._special_param_values_dict.items(): # pylint: disable=protected-access
+ for (
+ param,
+ tuples_list,
+ ) in adapter._special_param_values_dict.items(): # pylint: disable=protected-access
for value, bias_percentage in tuples_list:
- assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[param][value]
+ assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[
+ param
+ ][value]
# Ensure that number of unique values is less than the maximum number allowed
for _, unique_values in unique_values_dict.items():
assert len(unique_values) <= max_unique_values_per_param
-@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([
- (num_target_space_dims, param_space_kwargs)
- for num_target_space_dims in (2, 4)
- for num_orig_space_factor in (1.5, 4)
- for param_space_kwargs in (
- {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)},
- {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)},
- {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)},
- # Mix of all three types
- {
- 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3),
- 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3),
- 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3),
- },
- )
-]))
-def test_deterministic_behavior_for_same_seed(num_target_space_dims: int, param_space_kwargs: dict) -> None:
- """
- Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space.
+@pytest.mark.parametrize(
+ ("num_target_space_dims", "param_space_kwargs"),
+ (
+ [
+ (num_target_space_dims, param_space_kwargs)
+ for num_target_space_dims in (2, 4)
+ for num_orig_space_factor in (1.5, 4)
+ for param_space_kwargs in (
+ {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)},
+ {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)},
+ {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)},
+ # Mix of all three types
+ {
+ "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3),
+ },
+ )
+ ]
+ ),
+)
+def test_deterministic_behavior_for_same_seed(
+ num_target_space_dims: int,
+ param_space_kwargs: dict,
+) -> None:
+ """Tests LlamaTune's space adapter deterministic behavior when given same seed in
+ the input parameter space.
"""
+
def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]:
input_space = construct_parameter_space(**param_space_kwargs, seed=seed)
@@ -491,7 +580,9 @@ def test_deterministic_behavior_for_same_seed(num_target_space_dims: int, param_
use_approximate_reverse_mapping=False,
)
- sample_configs: List[CS.Configuration] = adapter.target_parameter_space.sample_configuration(size=100)
+ sample_configs: List[CS.Configuration] = (
+ adapter.target_parameter_space.sample_configuration(size=100)
+ )
return sample_configs
assert generate_target_param_space_configs(42) == generate_target_param_space_configs(42)
diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py
index 4f0c31538f..6dc35441dc 100644
--- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py
+++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py
@@ -2,58 +2,68 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for space adapter factory.
-"""
+"""Tests for space adapter factory."""
# pylint: disable=missing-function-docstring
from typing import List, Optional, Type
+import ConfigSpace as CS
import pytest
-import ConfigSpace as CS
-
-from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType, ConcreteSpaceAdapter
+from mlos_core.spaces.adapters import (
+ ConcreteSpaceAdapter,
+ SpaceAdapterFactory,
+ SpaceAdapterType,
+)
from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter
from mlos_core.spaces.adapters.identity_adapter import IdentityAdapter
-
from mlos_core.tests import get_all_concrete_subclasses
-@pytest.mark.parametrize(('space_adapter_type'), [
- # Enumerate all supported SpaceAdapters
- # *[member for member in SpaceAdapterType],
- *list(SpaceAdapterType),
-])
+@pytest.mark.parametrize(
+ ("space_adapter_type"),
+ [
+ # Enumerate all supported SpaceAdapters
+ # *[member for member in SpaceAdapterType],
+ *list(SpaceAdapterType),
+ ],
+)
def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None:
- """
- Test that all optimizer types are listed in the ConcreteOptimizer constraints.
- """
+ """Test that all optimizer types are listed in the ConcreteOptimizer constraints."""
# pylint: disable=no-member
- assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined]
+ assert (
+ space_adapter_type.value
+ in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined]
+ )
-@pytest.mark.parametrize(('space_adapter_type', 'kwargs'), [
- # Default space adapter
- (None, {}),
- # Enumerate all supported Optimizers
- *[(member, {}) for member in SpaceAdapterType],
-])
-def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict]) -> None:
+@pytest.mark.parametrize(
+ ("space_adapter_type", "kwargs"),
+ [
+ # Default space adapter
+ (None, {}),
+ # Enumerate all supported Optimizers
+ *[(member, {}) for member in SpaceAdapterType],
+ ],
+)
+def test_create_space_adapter_with_factory_method(
+ space_adapter_type: Optional[SpaceAdapterType],
+ kwargs: Optional[dict],
+) -> None:
# Start defining a ConfigurationSpace for the Optimizer to search.
input_space = CS.ConfigurationSpace(seed=1234)
# Add a single continuous input dimension between 0 and 1.
- input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1))
+ input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1))
# Add a single continuous input dimension between 0 and 1.
- input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=1))
+ input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0, upper=1))
# Adjust some kwargs for specific space adapters
if space_adapter_type is SpaceAdapterType.LLAMATUNE:
if kwargs is None:
kwargs = {}
- kwargs.setdefault('num_low_dims', 1)
+ kwargs.setdefault("num_low_dims", 1)
space_adapter: BaseSpaceAdapter
if space_adapter_type is None:
@@ -71,21 +81,24 @@ def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[S
assert space_adapter is not None
assert space_adapter.orig_parameter_space is not None
myrepr = repr(space_adapter)
- assert myrepr.startswith(space_adapter_type.value.__name__), \
- f"Expected {space_adapter_type.value.__name__} but got {myrepr}"
+ assert myrepr.startswith(
+ space_adapter_type.value.__name__
+ ), f"Expected {space_adapter_type.value.__name__} but got {myrepr}"
# Dynamically determine all of the optimizers we have implemented.
# Note: these must be sorted.
-space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = \
- get_all_concrete_subclasses(BaseSpaceAdapter, pkg_name='mlos_core') # type: ignore[type-abstract]
+space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = get_all_concrete_subclasses(
+ BaseSpaceAdapter, # type: ignore[type-abstract]
+ pkg_name="mlos_core",
+)
assert space_adapter_subclasses
-@pytest.mark.parametrize(('space_adapter_class'), space_adapter_subclasses)
+@pytest.mark.parametrize(("space_adapter_class"), space_adapter_subclasses)
def test_space_adapter_type_defs(space_adapter_class: Type[BaseSpaceAdapter]) -> None:
- """
- Test that all space adapter classes are listed in the SpaceAdapterType enum.
- """
- space_adapter_type_classes = {space_adapter_type.value for space_adapter_type in SpaceAdapterType}
+ """Test that all space adapter classes are listed in the SpaceAdapterType enum."""
+ space_adapter_type_classes = {
+ space_adapter_type.value for space_adapter_type in SpaceAdapterType
+ }
assert space_adapter_class in space_adapter_type_classes
diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py
index f77852594e..9d4c17f160 100644
--- a/mlos_core/mlos_core/tests/spaces/spaces_test.py
+++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py
@@ -2,28 +2,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Tests for mlos_core.spaces
-"""
+"""Tests for mlos_core.spaces."""
# pylint: disable=missing-function-docstring
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, List, NoReturn, Union
+import ConfigSpace as CS
+import flaml.tune.sample
import numpy as np
import numpy.typing as npt
import pytest
-
import scipy
-
-import ConfigSpace as CS
from ConfigSpace.hyperparameters import Hyperparameter, NormalIntegerHyperparameter
-import flaml.tune.sample
-
-from mlos_core.spaces.converters.flaml import configspace_to_flaml_space, FlamlDomain, FlamlSpace
-
+from mlos_core.spaces.converters.flaml import (
+ FlamlDomain,
+ FlamlSpace,
+ configspace_to_flaml_space,
+)
OptimizerSpace = Union[FlamlSpace, CS.ConfigurationSpace]
OptimizerParam = Union[FlamlDomain, Hyperparameter]
@@ -41,9 +39,9 @@ def assert_is_uniform(arr: npt.NDArray) -> None:
assert np.isclose(frequencies.sum(), 1)
_f_chi_sq, f_p_value = scipy.stats.chisquare(frequencies)
- assert np.isclose(kurtosis, -1.2, atol=.1)
- assert p_value > .3
- assert f_p_value > .5
+ assert np.isclose(kurtosis, -1.2, atol=0.1)
+ assert p_value > 0.3
+ assert f_p_value > 0.5
def assert_is_log_uniform(arr: npt.NDArray, base: float = np.e) -> None:
@@ -67,16 +65,13 @@ def test_is_log_uniform() -> None:
def invalid_conversion_function(*args: Any) -> NoReturn:
- """
- A quick dummy function for the base class to make pylint happy.
- """
- raise NotImplementedError('subclass must override conversion_function')
+ """A quick dummy function for the base class to make pylint happy."""
+ raise NotImplementedError("subclass must override conversion_function")
class BaseConversion(metaclass=ABCMeta):
- """
- Base class for testing optimizer space conversions.
- """
+ """Base class for testing optimizer space conversions."""
+
conversion_function: Callable[..., OptimizerSpace] = invalid_conversion_function
@abstractmethod
@@ -116,9 +111,7 @@ class BaseConversion(metaclass=ABCMeta):
@abstractmethod
def test_dimensionality(self) -> None:
- """
- Check that the dimensionality of the converted space is correct.
- """
+ """Check that the dimensionality of the converted space is correct."""
def test_unsupported_hyperparameter(self) -> None:
input_space = CS.ConfigurationSpace()
@@ -150,8 +143,8 @@ class BaseConversion(metaclass=ABCMeta):
assert_is_uniform(uniform)
# Check that we get both ends of the sampled range returned to us.
- assert input_space['c'].lower in integer_uniform
- assert input_space['c'].upper in integer_uniform
+ assert input_space["c"].lower in integer_uniform
+ assert input_space["c"].upper in integer_uniform
# integer uniform
assert_is_uniform(integer_uniform)
@@ -165,29 +158,33 @@ class BaseConversion(metaclass=ABCMeta):
assert 35 < counts[1] < 65
def test_weighted_categorical(self) -> None:
- raise NotImplementedError('subclass must override')
+ raise NotImplementedError("subclass must override")
def test_log_int_spaces(self) -> None:
- raise NotImplementedError('subclass must override')
+ raise NotImplementedError("subclass must override")
def test_log_float_spaces(self) -> None:
- raise NotImplementedError('subclass must override')
+ raise NotImplementedError("subclass must override")
class TestFlamlConversion(BaseConversion):
- """
- Tests for ConfigSpace to Flaml parameter conversions.
- """
+ """Tests for ConfigSpace to Flaml parameter conversions."""
conversion_function = staticmethod(configspace_to_flaml_space)
- def sample(self, config_space: FlamlSpace, n_samples: int = 1) -> npt.NDArray: # type: ignore[override]
+ def sample(
+ self,
+ config_space: FlamlSpace, # type: ignore[override]
+ n_samples: int = 1,
+ ) -> npt.NDArray:
assert isinstance(config_space, dict)
assert isinstance(next(iter(config_space.values())), flaml.tune.sample.Domain)
- ret: npt.NDArray = np.array([domain.sample(size=n_samples) for domain in config_space.values()]).T
+ ret: npt.NDArray = np.array(
+ [domain.sample(size=n_samples) for domain in config_space.values()]
+ ).T
return ret
- def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override]
+ def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override]
assert isinstance(config_space, dict)
ret: List[str] = list(config_space.keys())
return ret
@@ -208,7 +205,9 @@ class TestFlamlConversion(BaseConversion):
def test_weighted_categorical(self) -> None:
np.random.seed(42)
input_space = CS.ConfigurationSpace()
- input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1]))
+ input_space.add_hyperparameter(
+ CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1])
+ )
with pytest.raises(ValueError, match="non-uniform"):
configspace_to_flaml_space(input_space)
@@ -217,7 +216,9 @@ class TestFlamlConversion(BaseConversion):
np.random.seed(42)
# integer is supported
input_space = CS.ConfigurationSpace()
- input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True))
+ input_space.add_hyperparameter(
+ CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True)
+ )
converted_space = configspace_to_flaml_space(input_space)
# test log integer sampling
@@ -235,7 +236,9 @@ class TestFlamlConversion(BaseConversion):
# continuous is supported
input_space = CS.ConfigurationSpace()
- input_space.add_hyperparameter(CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True))
+ input_space.add_hyperparameter(
+ CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True)
+ )
converted_space = configspace_to_flaml_space(input_space)
# test log integer sampling
@@ -245,6 +248,6 @@ class TestFlamlConversion(BaseConversion):
assert_is_log_uniform(float_log_uniform)
-if __name__ == '__main__':
+if __name__ == "__main__":
# For attaching debugger debugging:
pytest.main(["-vv", "-k", "test_log_int_spaces", __file__])
diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py
index 8acb654adf..0a66c5a837 100644
--- a/mlos_core/mlos_core/util.py
+++ b/mlos_core/mlos_core/util.py
@@ -2,18 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Internal helper functions for mlos_core package.
-"""
+"""Internal helper functions for mlos_core package."""
from typing import Union
-from ConfigSpace import Configuration, ConfigurationSpace
import pandas as pd
+from ConfigSpace import Configuration, ConfigurationSpace
def config_to_dataframe(config: Configuration) -> pd.DataFrame:
- """Converts a ConfigSpace config to a DataFrame
+ """
+ Converts a ConfigSpace config to a DataFrame.
Parameters
----------
@@ -28,7 +27,10 @@ def config_to_dataframe(config: Configuration) -> pd.DataFrame:
return pd.DataFrame([dict(config)])
-def normalize_config(config_space: ConfigurationSpace, config: Union[Configuration, dict]) -> Configuration:
+def normalize_config(
+ config_space: ConfigurationSpace,
+ config: Union[Configuration, dict],
+) -> Configuration:
"""
Convert a dictionary to a valid ConfigSpace configuration.
@@ -49,8 +51,6 @@ def normalize_config(config_space: ConfigurationSpace, config: Union[Configurati
"""
cs_config = Configuration(config_space, values=config, allow_inactive_with_values=True)
return Configuration(
- config_space, values={
- key: cs_config[key]
- for key in config_space.get_active_hyperparameters(cs_config)
- }
+ config_space,
+ values={key: cs_config[key] for key in config_space.get_active_hyperparameters(cs_config)},
)
diff --git a/mlos_core/mlos_core/version.py b/mlos_core/mlos_core/version.py
index 2362de7083..f8bc82063c 100644
--- a/mlos_core/mlos_core/version.py
+++ b/mlos_core/mlos_core/version.py
@@ -2,12 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Version number for the mlos_core package.
-"""
+"""Version number for the mlos_core package."""
# NOTE: This should be managed by bumpversion.
-VERSION = '0.5.1'
+VERSION = "0.5.1"
if __name__ == "__main__":
print(VERSION)
diff --git a/mlos_core/setup.py b/mlos_core/setup.py
index a828a5637f..e4c792e270 100644
--- a/mlos_core/setup.py
+++ b/mlos_core/setup.py
@@ -2,19 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Setup instructions for the mlos_core package.
-"""
+"""Setup instructions for the mlos_core package."""
# pylint: disable=duplicate-code
+import os
+import re
from itertools import chain
from logging import warning
from typing import Dict, List
-import os
-import re
-
from setuptools import setup
PKG_NAME = "mlos_core"
@@ -22,15 +19,16 @@ PKG_NAME = "mlos_core"
try:
ns: Dict[str, str] = {}
with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file:
- exec(version_file.read(), ns) # pylint: disable=exec-used
- VERSION = ns['VERSION']
+ exec(version_file.read(), ns) # pylint: disable=exec-used
+ VERSION = ns["VERSION"]
except OSError:
VERSION = "0.0.1-dev"
warning(f"version.py not found, using dummy VERSION={VERSION}")
try:
from setuptools_scm import get_version
- version = get_version(root='..', relative_to=__file__, fallback_version=VERSION)
+
+ version = get_version(root="..", relative_to=__file__, fallback_version=VERSION)
if version is not None:
VERSION = version
except ImportError:
@@ -50,52 +48,56 @@ except LookupError as e:
# we return nothing when the file is not available.
def _get_long_desc_from_readme(base_url: str) -> dict:
pkg_dir = os.path.dirname(__file__)
- readme_path = os.path.join(pkg_dir, 'README.md')
+ readme_path = os.path.join(pkg_dir, "README.md")
if not os.path.isfile(readme_path):
return {
- 'long_description': 'missing',
+ "long_description": "missing",
}
- jsonc_re = re.compile(r'```jsonc')
- link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)')
- with open(readme_path, mode='r', encoding='utf-8') as readme_fh:
+ jsonc_re = re.compile(r"```jsonc")
+ link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)")
+ with open(readme_path, mode="r", encoding="utf-8") as readme_fh:
lines = readme_fh.readlines()
# Tweak the lexers for local expansion by pygments instead of github's.
- lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines]
+ lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines]
# Tweak source source code links.
- lines = [jsonc_re.sub(r'```json', line) for line in lines]
+ lines = [jsonc_re.sub(r"```json", line) for line in lines]
return {
- 'long_description': ''.join(lines),
- 'long_description_content_type': 'text/markdown',
+ "long_description": "".join(lines),
+ "long_description_content_type": "text/markdown",
}
extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass
- 'flaml': ['flaml[blendsearch]'],
- 'smac': ['smac>=2.0.0'], # NOTE: Major refactoring on SMAC starting from v2.0.0
+ "flaml": ["flaml[blendsearch]"],
+ "smac": ["smac>=2.0.0"], # NOTE: Major refactoring on SMAC starting from v2.0.0
}
# construct special 'full' extra that adds requirements for all built-in
# backend integrations and additional extra features.
-extra_requires['full'] = list(set(chain(*extra_requires.values())))
+extra_requires["full"] = list(set(chain(*extra_requires.values())))
-extra_requires['full-tests'] = extra_requires['full'] + [
- 'pytest',
- 'pytest-forked',
- 'pytest-xdist',
- 'pytest-cov',
- 'pytest-local-badge',
+extra_requires["full-tests"] = extra_requires["full"] + [
+ "pytest",
+ "pytest-forked",
+ "pytest-xdist",
+ "pytest-cov",
+ "pytest-local-badge",
]
setup(
version=VERSION,
install_requires=[
- 'scikit-learn>=1.2',
- 'joblib>=1.1.1', # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released
- 'scipy>=1.3.2',
- 'numpy>=1.24', 'numpy<2.0.0', # FIXME: https://github.com/numpy/numpy/issues/26710
- 'pandas >= 2.2.0;python_version>="3.9"', 'Bottleneck > 1.3.5;python_version>="3.9"',
+ "scikit-learn>=1.2",
+ # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which
+ # isn't currently released
+ "joblib>=1.1.1",
+ "scipy>=1.3.2",
+ "numpy>=1.24",
+ "numpy<2.0.0", # FIXME: https://github.com/numpy/numpy/issues/26710
+ 'pandas >= 2.2.0;python_version>="3.9"',
+ 'Bottleneck > 1.3.5;python_version>="3.9"',
'pandas >= 1.0.3;python_version<"3.9"',
- 'ConfigSpace==0.7.1', # Temporarily restrict ConfigSpace version.
+ "ConfigSpace==0.7.1", # Temporarily restrict ConfigSpace version.
],
extras_require=extra_requires,
**_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_core"),
diff --git a/mlos_viz/mlos_viz/__init__.py b/mlos_viz/mlos_viz/__init__.py
index a7ba74b1d7..ddf7727ec8 100644
--- a/mlos_viz/mlos_viz/__init__.py
+++ b/mlos_viz/mlos_viz/__init__.py
@@ -2,8 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-mlos_viz is a framework to help visualizing, explain, and gain insights from results
+"""mlos_viz is a framework to help visualizing, explain, and gain insights from results
from the mlos_bench framework for benchmarking and optimization automation.
"""
@@ -18,12 +17,10 @@ from mlos_viz.util import expand_results_data_args
class MlosVizMethod(Enum):
- """
- What method to use for visualizing the experiment results.
- """
+ """What method to use for visualizing the experiment results."""
DABL = "dabl"
- AUTO = DABL # use dabl as the current default
+ AUTO = DABL # use dabl as the current default
def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None:
@@ -38,18 +35,22 @@ def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO)
"""
base.ignore_plotter_warnings()
if plotter_method == MlosVizMethod.DABL:
- import mlos_viz.dabl # pylint: disable=import-outside-toplevel
+ import mlos_viz.dabl # pylint: disable=import-outside-toplevel
+
mlos_viz.dabl.ignore_plotter_warnings()
else:
raise NotImplementedError(f"Unhandled method: {plotter_method}")
-def plot(exp_data: Optional[ExperimentData] = None, *,
- results_df: Optional[pandas.DataFrame] = None,
- objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
- plotter_method: MlosVizMethod = MlosVizMethod.AUTO,
- filter_warnings: bool = True,
- **kwargs: Any) -> None:
+def plot(
+ exp_data: Optional[ExperimentData] = None,
+ *,
+ results_df: Optional[pandas.DataFrame] = None,
+ objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
+ plotter_method: MlosVizMethod = MlosVizMethod.AUTO,
+ filter_warnings: bool = True,
+ **kwargs: Any,
+) -> None:
"""
Plots the results of the experiment.
@@ -80,7 +81,8 @@ def plot(exp_data: Optional[ExperimentData] = None, *,
base.plot_top_n_configs(exp_data, results_df=results_df, objectives=objectives, **kwargs)
if MlosVizMethod.DABL:
- import mlos_viz.dabl # pylint: disable=import-outside-toplevel
+ import mlos_viz.dabl # pylint: disable=import-outside-toplevel
+
mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives)
else:
raise NotImplementedError(f"Unhandled method: {plotter_method}")
diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py
index 787315313a..0c6d58cd7f 100644
--- a/mlos_viz/mlos_viz/base.py
+++ b/mlos_viz/mlos_viz/base.py
@@ -2,28 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Base functions for visualizing, explain, and gain insights from results.
-"""
-
-from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
+"""Base functions for visualizing, explain, and gain insights from results."""
import re
import warnings
-
from importlib.metadata import version
+from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
-from matplotlib import pyplot as plt
import pandas
+import seaborn as sns
+from matplotlib import pyplot as plt
from pandas.api.types import is_numeric_dtype
from pandas.core.groupby.generic import SeriesGroupBy
-import seaborn as sns
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_viz.util import expand_results_data_args
-
-_SEABORN_VERS = version('seaborn')
+_SEABORN_VERS = version("seaborn")
def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]:
@@ -33,26 +28,30 @@ def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]:
Note: this only works with non-positional kwargs (e.g., those after a * arg).
"""
target_kwargs = {}
- for kword in target.__kwdefaults__: # or {} # intentionally omitted for now
+ for kword in target.__kwdefaults__: # or {} # intentionally omitted for now
if kword in kwargs:
target_kwargs[kword] = kwargs[kword]
return target_kwargs
def ignore_plotter_warnings() -> None:
- """
- Suppress some annoying warnings from third-party data visualization packages by
+ """Suppress some annoying warnings from third-party data visualization packages by
adding them to the warnings filter.
"""
warnings.filterwarnings("ignore", category=FutureWarning)
- if _SEABORN_VERS <= '0.13.1':
- warnings.filterwarnings("ignore", category=DeprecationWarning, module="seaborn", # but actually comes from pandas
- message="is_categorical_dtype is deprecated and will be removed in a future version.")
+ if _SEABORN_VERS <= "0.13.1":
+ warnings.filterwarnings(
+ "ignore",
+ category=DeprecationWarning,
+ module="seaborn", # but actually comes from pandas
+ message="is_categorical_dtype is deprecated and will be removed in a future version.",
+ )
-def _add_groupby_desc_column(results_df: pandas.DataFrame,
- groupby_columns: Optional[List[str]] = None,
- ) -> Tuple[pandas.DataFrame, List[str], str]:
+def _add_groupby_desc_column(
+ results_df: pandas.DataFrame,
+ groupby_columns: Optional[List[str]] = None,
+) -> Tuple[pandas.DataFrame, List[str], str]:
"""
Adds a group descriptor column to the results_df.
@@ -70,17 +69,19 @@ def _add_groupby_desc_column(results_df: pandas.DataFrame,
if groupby_columns is None:
groupby_columns = ["tunable_config_trial_group_id", "tunable_config_id"]
groupby_column = ",".join(groupby_columns)
- results_df[groupby_column] = results_df[groupby_columns].astype(str).apply(
- lambda x: ",".join(x), axis=1) # pylint: disable=unnecessary-lambda
+ results_df[groupby_column] = (
+ results_df[groupby_columns].astype(str).apply(lambda x: ",".join(x), axis=1)
+ ) # pylint: disable=unnecessary-lambda
groupby_columns.append(groupby_column)
return (results_df, groupby_columns, groupby_column)
-def augment_results_df_with_config_trial_group_stats(exp_data: Optional[ExperimentData] = None,
- *,
- results_df: Optional[pandas.DataFrame] = None,
- requested_result_cols: Optional[Iterable[str]] = None,
- ) -> pandas.DataFrame:
+def augment_results_df_with_config_trial_group_stats(
+ exp_data: Optional[ExperimentData] = None,
+ *,
+ results_df: Optional[pandas.DataFrame] = None,
+ requested_result_cols: Optional[Iterable[str]] = None,
+) -> pandas.DataFrame:
# pylint: disable=too-complex
"""
Add a number of useful statistical measure columns to the results dataframe.
@@ -137,30 +138,47 @@ def augment_results_df_with_config_trial_group_stats(exp_data: Optional[Experime
raise ValueError(f"Not enough data: {len(results_groups)}")
if requested_result_cols is None:
- result_cols = set(col for col in results_df.columns if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX))
+ result_cols = set(
+ col
+ for col in results_df.columns
+ if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX)
+ )
else:
- result_cols = set(col for col in requested_result_cols
- if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns)
- result_cols.update(set(ExperimentData.RESULT_COLUMN_PREFIX + col for col in requested_result_cols
- if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns))
+ result_cols = set(
+ col
+ for col in requested_result_cols
+ if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns
+ )
+ result_cols.update(
+ set(
+ ExperimentData.RESULT_COLUMN_PREFIX + col
+ for col in requested_result_cols
+ if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns
+ )
+ )
def compute_zscore_for_group_agg(
- results_groups_perf: "SeriesGroupBy",
- stats_df: pandas.DataFrame,
- result_col: str,
- agg: Union[Literal["mean"], Literal["var"], Literal["std"]]
+ results_groups_perf: "SeriesGroupBy",
+ stats_df: pandas.DataFrame,
+ result_col: str,
+ agg: Union[Literal["mean"], Literal["var"], Literal["std"]],
) -> None:
- results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating?
- # Compute the zscore of the chosen aggregate performance of each group into each row in the dataframe.
+ results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating?
+ # Compute the zscore of the chosen aggregate performance of each group into
+ # each row in the dataframe.
stats_df[result_col + f".{agg}_mean"] = results_groups_perf_aggs.mean()
stats_df[result_col + f".{agg}_stddev"] = results_groups_perf_aggs.std()
- stats_df[result_col + f".{agg}_zscore"] = \
- (stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"]) \
- / stats_df[result_col + f".{agg}_stddev"]
- stats_df.drop(columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True)
+ stats_df[result_col + f".{agg}_zscore"] = (
+ stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"]
+ ) / stats_df[result_col + f".{agg}_stddev"]
+ stats_df.drop(
+ columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True
+ )
augmented_results_df = results_df
- augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform("count")
+ augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform(
+ "count"
+ )
for result_col in result_cols:
if not result_col.startswith(ExperimentData.RESULT_COLUMN_PREFIX):
continue
@@ -179,24 +197,25 @@ def augment_results_df_with_config_trial_group_stats(exp_data: Optional[Experime
compute_zscore_for_group_agg(results_groups_perf, stats_df, result_col, "var")
quantiles = [0.50, 0.75, 0.90, 0.95, 0.99]
- for quantile in quantiles: # TODO: can we do this in one pass?
+ for quantile in quantiles: # TODO: can we do this in one pass?
quantile_col = f"{result_col}.p{int(quantile * 100)}"
stats_df[quantile_col] = results_groups_perf.transform("quantile", quantile)
augmented_results_df = pandas.concat([augmented_results_df, stats_df], axis=1)
return augmented_results_df
-def limit_top_n_configs(exp_data: Optional[ExperimentData] = None,
- *,
- results_df: Optional[pandas.DataFrame] = None,
- objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
- top_n_configs: int = 10,
- method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean",
- ) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]:
+def limit_top_n_configs(
+ exp_data: Optional[ExperimentData] = None,
+ *,
+ results_df: Optional[pandas.DataFrame] = None,
+ objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
+ top_n_configs: int = 10,
+ method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean",
+) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]:
# pylint: disable=too-many-locals
"""
- Utility function to process the results and determine the best performing
- configs including potential repeats to help assess variability.
+ Utility function to process the results and determine the best performing configs
+ including potential repeats to help assess variability.
Parameters
----------
@@ -205,24 +224,32 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None,
results_df : Optional[pandas.DataFrame]
The results dataframe to augment, by default None to use the results_df property.
objectives : Iterable[str], optional
- Which result column(s) to use for sorting the configs, and in which direction ("min" or "max").
+ Which result column(s) to use for sorting the configs, and in which
+ direction ("min" or "max").
By default None to automatically select the experiment objectives.
top_n_configs : int, optional
How many configs to return, including the default, by default 20.
method: Literal["mean", "median", "p50", "p75", "p90", "p95", "p99"] = "mean",
- Which statistical method to use when sorting the config groups before determining the cutoff, by default "mean".
+ Which statistical method to use when sorting the config groups before
+ determining the cutoff, by default "mean".
Returns
-------
- (top_n_config_results_df, top_n_config_ids, orderby_cols) : Tuple[pandas.DataFrame, List[int], Dict[str, bool]]
- The filtered results dataframe, the config ids, and the columns used to order the configs.
+ (top_n_config_results_df, top_n_config_ids, orderby_cols) :
+ Tuple[pandas.DataFrame, List[int], Dict[str, bool]]
+ The filtered results dataframe, the config ids, and the columns used to
+ order the configs.
"""
# Do some input checking first.
if method not in ["mean", "median", "p50", "p75", "p90", "p95", "p99"]:
raise ValueError(f"Invalid method: {method}")
# Prepare the orderby columns.
- (results_df, objs_cols) = expand_results_data_args(exp_data, results_df=results_df, objectives=objectives)
+ (results_df, objs_cols) = expand_results_data_args(
+ exp_data,
+ results_df=results_df,
+ objectives=objectives,
+ )
assert isinstance(results_df, pandas.DataFrame)
# Augment the results dataframe with some useful stats.
@@ -235,13 +262,17 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None,
# results_df is not None and is in fact a DataFrame, so we periodically assert
# it in this func for now.
assert results_df is not None
- orderby_cols: Dict[str, bool] = {obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items()}
+ orderby_cols: Dict[str, bool] = {
+ obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items()
+ }
config_id_col = "tunable_config_id"
- group_id_col = "tunable_config_trial_group_id" # first trial_id per config group
+ group_id_col = "tunable_config_trial_group_id" # first trial_id per config group
trial_id_col = "trial_id"
- default_config_id = results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id
+ default_config_id = (
+ results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id
+ )
assert default_config_id is not None, "Failed to determine default config id."
# Filter out configs whose variance is too large.
@@ -253,16 +284,18 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None,
singletons_mask = results_df["tunable_config_trial_group_size"] == 1
else:
singletons_mask = results_df["tunable_config_trial_group_size"] > 1
- results_df = results_df.loc[(
- (results_df[f"{obj_col}.var_zscore"].abs() < 2)
- | (singletons_mask)
- | (results_df[config_id_col] == default_config_id)
- )]
+ results_df = results_df.loc[
+ (
+ (results_df[f"{obj_col}.var_zscore"].abs() < 2)
+ | (singletons_mask)
+ | (results_df[config_id_col] == default_config_id)
+ )
+ ]
assert results_df is not None
# Also, filter results that are worse than the default.
default_config_results_df = results_df.loc[results_df[config_id_col] == default_config_id]
- for (orderby_col, ascending) in orderby_cols.items():
+ for orderby_col, ascending in orderby_cols.items():
default_vals = default_config_results_df[orderby_col].unique()
assert len(default_vals) == 1
default_val = default_vals[0]
@@ -274,29 +307,39 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None,
# Now regroup and filter to the top-N configs by their group performance dimensions.
assert results_df is not None
- group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[orderby_cols.keys()]
- top_n_config_ids: List[int] = group_results_df.sort_values(
- by=list(orderby_cols.keys()), ascending=list(orderby_cols.values())).head(top_n_configs).index.tolist()
+ group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[
+ orderby_cols.keys()
+ ]
+ top_n_config_ids: List[int] = (
+ group_results_df.sort_values(
+ by=list(orderby_cols.keys()), ascending=list(orderby_cols.values())
+ )
+ .head(top_n_configs)
+ .index.tolist()
+ )
# Remove the default config if it's included. We'll add it back later.
if default_config_id in top_n_config_ids:
top_n_config_ids.remove(default_config_id)
# Get just the top-n config results.
# Sort by the group ids.
- top_n_config_results_df = results_df.loc[(
- results_df[config_id_col].isin(top_n_config_ids)
- )].sort_values([group_id_col, config_id_col, trial_id_col])
+ top_n_config_results_df = results_df.loc[
+ (results_df[config_id_col].isin(top_n_config_ids))
+ ].sort_values([group_id_col, config_id_col, trial_id_col])
# Place the default config at the top of the list.
top_n_config_ids.insert(0, default_config_id)
- top_n_config_results_df = pandas.concat([default_config_results_df, top_n_config_results_df], axis=0)
+ top_n_config_results_df = pandas.concat(
+ [default_config_results_df, top_n_config_results_df],
+ axis=0,
+ )
return (top_n_config_results_df, top_n_config_ids, orderby_cols)
def plot_optimizer_trends(
- exp_data: Optional[ExperimentData] = None,
- *,
- results_df: Optional[pandas.DataFrame] = None,
- objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
+ exp_data: Optional[ExperimentData] = None,
+ *,
+ results_df: Optional[pandas.DataFrame] = None,
+ objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
) -> None:
"""
Plots the optimizer trends for the Experiment.
@@ -315,12 +358,16 @@ def plot_optimizer_trends(
(results_df, obj_cols) = expand_results_data_args(exp_data, results_df, objectives)
(results_df, groupby_columns, groupby_column) = _add_groupby_desc_column(results_df)
- for (objective_column, ascending) in obj_cols.items():
+ for objective_column, ascending in obj_cols.items():
incumbent_column = objective_column + ".incumbent"
# Determine the mean of each config trial group to match the box plots.
- group_results_df = results_df.groupby(groupby_columns)[objective_column].mean()\
- .reset_index().sort_values(groupby_columns)
+ group_results_df = (
+ results_df.groupby(groupby_columns)[objective_column]
+ .mean()
+ .reset_index()
+ .sort_values(groupby_columns)
+ )
#
# Note: technically the optimizer (usually) uses the *first* result for a
# given config trial group before moving on to a new config (x-axis), so
@@ -358,24 +405,29 @@ def plot_optimizer_trends(
ax=axis,
)
- plt.yscale('log')
+ plt.yscale("log")
plt.ylabel(objective_column.replace(ExperimentData.RESULT_COLUMN_PREFIX, ""))
plt.xlabel("Config Trial Group ID, Config ID")
plt.xticks(rotation=90, fontsize=8)
- plt.title("Optimizer Trends for Experiment: " + exp_data.experiment_id if exp_data is not None else "")
+ plt.title(
+ "Optimizer Trends for Experiment: " + exp_data.experiment_id
+ if exp_data is not None
+ else ""
+ )
plt.grid()
plt.show() # type: ignore[no-untyped-call]
-def plot_top_n_configs(exp_data: Optional[ExperimentData] = None,
- *,
- results_df: Optional[pandas.DataFrame] = None,
- objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
- with_scatter_plot: bool = False,
- **kwargs: Any,
- ) -> None:
+def plot_top_n_configs(
+ exp_data: Optional[ExperimentData] = None,
+ *,
+ results_df: Optional[pandas.DataFrame] = None,
+ objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
+ with_scatter_plot: bool = False,
+ **kwargs: Any,
+) -> None:
# pylint: disable=too-many-locals
"""
Plots the top-N configs along with the default config for the given ExperimentData.
@@ -403,12 +455,17 @@ def plot_top_n_configs(exp_data: Optional[ExperimentData] = None,
top_n_config_args["results_df"] = results_df
if "objectives" not in top_n_config_args:
top_n_config_args["objectives"] = objectives
- (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs(exp_data=exp_data, **top_n_config_args)
+ (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs(
+ exp_data=exp_data,
+ **top_n_config_args,
+ )
- (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column(top_n_config_results_df)
+ (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column(
+ top_n_config_results_df,
+ )
top_n = len(top_n_config_results_df[groupby_column].unique()) - 1
- for (orderby_col, ascending) in orderby_cols.items():
+ for orderby_col, ascending in orderby_cols.items():
opt_tgt = orderby_col.replace(ExperimentData.RESULT_COLUMN_PREFIX, "")
(_fig, axis) = plt.subplots()
sns.violinplot(
@@ -428,12 +485,12 @@ def plot_top_n_configs(exp_data: Optional[ExperimentData] = None,
plt.grid()
(xticks, xlabels) = plt.xticks()
# default should be in the first position based on top_n_configs() return
- xlabels[0] = "default" # type: ignore[call-overload]
- plt.xticks(xticks, xlabels) # type: ignore[arg-type]
+ xlabels[0] = "default" # type: ignore[call-overload]
+ plt.xticks(xticks, xlabels) # type: ignore[arg-type]
plt.xlabel("Config Trial Group, Config ID")
plt.xticks(rotation=90)
plt.ylabel(opt_tgt)
- plt.yscale('log')
+ plt.yscale("log")
extra_title = "(lower is better)" if ascending else "(lower is better)"
plt.title(f"Top {top_n} configs {opt_tgt} {extra_title}")
plt.show() # type: ignore[no-untyped-call]
diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py
index 112bf70470..3f8ac640ad 100644
--- a/mlos_viz/mlos_viz/dabl.py
+++ b/mlos_viz/mlos_viz/dabl.py
@@ -2,25 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Small wrapper functions for dabl plotting functions via mlos_bench data.
-"""
-from typing import Dict, Optional, Literal
-
+"""Small wrapper functions for dabl plotting functions via mlos_bench data."""
import warnings
+from typing import Dict, Literal, Optional
import dabl
import pandas
from mlos_bench.storage.base_experiment_data import ExperimentData
-
from mlos_viz.util import expand_results_data_args
-def plot(exp_data: Optional[ExperimentData] = None, *,
- results_df: Optional[pandas.DataFrame] = None,
- objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
- ) -> None:
+def plot(
+ exp_data: Optional[ExperimentData] = None,
+ *,
+ results_df: Optional[pandas.DataFrame] = None,
+ objectives: Optional[Dict[str, Literal["min", "max"]]] = None,
+) -> None:
"""
Plots the Experiment results data using dabl.
@@ -41,22 +39,57 @@ def plot(exp_data: Optional[ExperimentData] = None, *,
def ignore_plotter_warnings() -> None:
- """
- Add some filters to ignore warnings from the plotter.
- """
+ """Add some filters to ignore warnings from the plotter."""
# pylint: disable=import-outside-toplevel
warnings.filterwarnings("ignore", category=FutureWarning)
- warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Could not infer format")
- warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers")
- warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated")
- warnings.filterwarnings("ignore", module="dabl", category=UserWarning,
- message="Missing values in target_col have been removed for regression")
+ warnings.filterwarnings(
+ "ignore",
+ module="dabl",
+ category=UserWarning,
+ message="Could not infer format",
+ )
+ warnings.filterwarnings(
+ "ignore",
+ module="dabl",
+ category=UserWarning,
+ message="(Dropped|Discarding) .* outliers",
+ )
+ warnings.filterwarnings(
+ "ignore",
+ module="dabl",
+ category=UserWarning,
+ message="Not plotting highly correlated",
+ )
+ warnings.filterwarnings(
+ "ignore",
+ module="dabl",
+ category=UserWarning,
+ message="Missing values in target_col have been removed for regression",
+ )
from sklearn.exceptions import UndefinedMetricWarning
- warnings.filterwarnings("ignore", module="sklearn", category=UndefinedMetricWarning, message="Recall is ill-defined")
- warnings.filterwarnings("ignore", category=DeprecationWarning,
- message="is_categorical_dtype is deprecated and will be removed in a future version.")
- warnings.filterwarnings("ignore", category=DeprecationWarning, module="sklearn",
- message="is_sparse is deprecated and will be removed in a future version.")
+
+ warnings.filterwarnings(
+ "ignore",
+ module="sklearn",
+ category=UndefinedMetricWarning,
+ message="Recall is ill-defined",
+ )
+ warnings.filterwarnings(
+ "ignore",
+ category=DeprecationWarning,
+ message="is_categorical_dtype is deprecated and will be removed in a future version.",
+ )
+ warnings.filterwarnings(
+ "ignore",
+ category=DeprecationWarning,
+ module="sklearn",
+ message="is_sparse is deprecated and will be removed in a future version.",
+ )
from matplotlib._api.deprecation import MatplotlibDeprecationWarning
- warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning, module="dabl",
- message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed")
+
+ warnings.filterwarnings(
+ "ignore",
+ category=MatplotlibDeprecationWarning,
+ module="dabl",
+ message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed",
+ )
diff --git a/mlos_viz/mlos_viz/tests/__init__.py b/mlos_viz/mlos_viz/tests/__init__.py
index d496cbe2b3..df64e0a313 100644
--- a/mlos_viz/mlos_viz/tests/__init__.py
+++ b/mlos_viz/mlos_viz/tests/__init__.py
@@ -2,15 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mlos_viz.
-"""
+"""Unit tests for mlos_viz."""
import sys
import seaborn # pylint: disable=unused-import # (used by patch) # noqa: unused
-
BASE_MATPLOTLIB_SHOW_PATCH = "mlos_viz.base.plt.show"
if sys.version_info >= (3, 11):
diff --git a/mlos_viz/mlos_viz/tests/conftest.py b/mlos_viz/mlos_viz/tests/conftest.py
index ad29489e2c..228609ba09 100644
--- a/mlos_viz/mlos_viz/tests/conftest.py
+++ b/mlos_viz/mlos_viz/tests/conftest.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Export test fixtures for mlos_viz.
-"""
+"""Export test fixtures for mlos_viz."""
from mlos_bench.tests import tunable_groups_fixtures
from mlos_bench.tests.storage.sql import fixtures as sql_storage_fixtures
diff --git a/mlos_viz/mlos_viz/tests/test_base_plot.py b/mlos_viz/mlos_viz/tests/test_base_plot.py
index 9fb33471e6..1dc283c891 100644
--- a/mlos_viz/mlos_viz/tests/test_base_plot.py
+++ b/mlos_viz/mlos_viz/tests/test_base_plot.py
@@ -2,18 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mlos_viz.
-"""
+"""Unit tests for mlos_viz."""
import warnings
-
-from unittest.mock import patch, Mock
+from unittest.mock import Mock, patch
from mlos_bench.storage.base_experiment_data import ExperimentData
-
-from mlos_viz.base import ignore_plotter_warnings, plot_optimizer_trends, plot_top_n_configs
-
+from mlos_viz.base import (
+ ignore_plotter_warnings,
+ plot_optimizer_trends,
+ plot_top_n_configs,
+)
from mlos_viz.tests import BASE_MATPLOTLIB_SHOW_PATCH
diff --git a/mlos_viz/mlos_viz/tests/test_dabl_plot.py b/mlos_viz/mlos_viz/tests/test_dabl_plot.py
index 36c83b12f2..7fcee4dfe9 100644
--- a/mlos_viz/mlos_viz/tests/test_dabl_plot.py
+++ b/mlos_viz/mlos_viz/tests/test_dabl_plot.py
@@ -2,18 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mlos_viz.dabl.plot.
-"""
+"""Unit tests for mlos_viz.dabl.plot."""
import warnings
-
-from unittest.mock import patch, Mock
+from unittest.mock import Mock, patch
from mlos_bench.storage.base_experiment_data import ExperimentData
-
from mlos_viz import dabl
-
from mlos_viz.tests import SEABORN_BOXPLOT_PATCH
diff --git a/mlos_viz/mlos_viz/tests/test_mlos_viz.py b/mlos_viz/mlos_viz/tests/test_mlos_viz.py
index 0be7220f47..6d393dca6a 100644
--- a/mlos_viz/mlos_viz/tests/test_mlos_viz.py
+++ b/mlos_viz/mlos_viz/tests/test_mlos_viz.py
@@ -2,19 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Unit tests for mlos_viz.
-"""
+"""Unit tests for mlos_viz."""
import random
import warnings
-
-from unittest.mock import patch, Mock
+from unittest.mock import Mock, patch
from mlos_bench.storage.base_experiment_data import ExperimentData
-
from mlos_viz import MlosVizMethod, plot
-
from mlos_viz.tests import BASE_MATPLOTLIB_SHOW_PATCH, SEABORN_BOXPLOT_PATCH
@@ -33,5 +28,5 @@ def test_plot(mock_show: Mock, mock_boxplot: Mock, exp_data: ExperimentData) ->
warnings.simplefilter("error")
random.seed(42)
plot(exp_data, filter_warnings=True)
- assert mock_show.call_count >= 2 # from the two base plots and anything dabl did
- assert mock_boxplot.call_count >= 1 # from anything dabl did
+ assert mock_show.call_count >= 2 # from the two base plots and anything dabl did
+ assert mock_boxplot.call_count >= 1 # from anything dabl did
diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py
index 744fe28648..cefc3080d9 100644
--- a/mlos_viz/mlos_viz/util.py
+++ b/mlos_viz/mlos_viz/util.py
@@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Utility functions for manipulating experiment results data.
-"""
+"""Utility functions for manipulating experiment results data."""
from typing import Dict, Literal, Optional, Tuple
import pandas
@@ -36,7 +34,8 @@ def expand_results_data_args(
Returns
-------
Tuple[pandas.DataFrame, Dict[str, bool]]
- The results dataframe and the objectives columns in the dataframe, plus whether or not they are in ascending order.
+ The results dataframe and the objectives columns in the dataframe, plus
+ whether or not they are in ascending order.
"""
# Prepare the orderby columns.
if results_df is None:
@@ -49,11 +48,14 @@ def expand_results_data_args(
raise ValueError("Must provide either exp_data or both results_df and objectives.")
objectives = exp_data.objectives
objs_cols: Dict[str, bool] = {}
- for (opt_tgt, opt_dir) in objectives.items():
+ for opt_tgt, opt_dir in objectives.items():
if opt_dir not in ["min", "max"]:
raise ValueError(f"Unexpected optimization direction for target {opt_tgt}: {opt_dir}")
ascending = opt_dir == "min"
- if opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and opt_tgt in results_df.columns:
+ if (
+ opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX)
+ and opt_tgt in results_df.columns
+ ):
objs_cols[opt_tgt] = ascending
elif ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt in results_df.columns:
objs_cols[ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt] = ascending
diff --git a/mlos_viz/mlos_viz/version.py b/mlos_viz/mlos_viz/version.py
index 607c7cc014..6d75e70bb5 100644
--- a/mlos_viz/mlos_viz/version.py
+++ b/mlos_viz/mlos_viz/version.py
@@ -2,12 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Version number for the mlos_viz package.
-"""
+"""Version number for the mlos_viz package."""
# NOTE: This should be managed by bumpversion.
-VERSION = '0.5.1'
+VERSION = "0.5.1"
if __name__ == "__main__":
print(VERSION)
diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py
index 7c8c4deb07..4f5e8677d1 100644
--- a/mlos_viz/setup.py
+++ b/mlos_viz/setup.py
@@ -2,36 +2,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
-"""
-Setup instructions for the mlos_viz package.
-"""
+"""Setup instructions for the mlos_viz package."""
# pylint: disable=duplicate-code
-from logging import warning
-from itertools import chain
-from typing import Dict, List
-
import os
import re
+from itertools import chain
+from logging import warning
+from typing import Dict, List
from setuptools import setup
-
PKG_NAME = "mlos_viz"
try:
ns: Dict[str, str] = {}
with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file:
- exec(version_file.read(), ns) # pylint: disable=exec-used
- VERSION = ns['VERSION']
+ exec(version_file.read(), ns) # pylint: disable=exec-used
+ VERSION = ns["VERSION"]
except OSError:
VERSION = "0.0.1-dev"
warning(f"version.py not found, using dummy VERSION={VERSION}")
try:
from setuptools_scm import get_version
- version = get_version(root='..', relative_to=__file__, fallback_version=VERSION)
+
+ version = get_version(root="..", relative_to=__file__, fallback_version=VERSION)
if version is not None:
VERSION = version
except ImportError:
@@ -49,22 +46,22 @@ except LookupError as e:
# be duplicated for now.
def _get_long_desc_from_readme(base_url: str) -> dict:
pkg_dir = os.path.dirname(__file__)
- readme_path = os.path.join(pkg_dir, 'README.md')
+ readme_path = os.path.join(pkg_dir, "README.md")
if not os.path.isfile(readme_path):
return {
- 'long_description': 'missing',
+ "long_description": "missing",
}
- jsonc_re = re.compile(r'```jsonc')
- link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)')
- with open(readme_path, mode='r', encoding='utf-8') as readme_fh:
+ jsonc_re = re.compile(r"```jsonc")
+ link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)")
+ with open(readme_path, mode="r", encoding="utf-8") as readme_fh:
lines = readme_fh.readlines()
# Tweak the lexers for local expansion by pygments instead of github's.
- lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines]
+ lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines]
# Tweak source source code links.
- lines = [jsonc_re.sub(r'```json', line) for line in lines]
+ lines = [jsonc_re.sub(r"```json", line) for line in lines]
return {
- 'long_description': ''.join(lines),
- 'long_description_content_type': 'text/markdown',
+ "long_description": "".join(lines),
+ "long_description_content_type": "text/markdown",
}
@@ -72,23 +69,23 @@ extra_requires: Dict[str, List[str]] = {}
# construct special 'full' extra that adds requirements for all built-in
# backend integrations and additional extra features.
-extra_requires['full'] = list(set(chain(*extra_requires.values())))
+extra_requires["full"] = list(set(chain(*extra_requires.values())))
-extra_requires['full-tests'] = extra_requires['full'] + [
- 'pytest',
- 'pytest-forked',
- 'pytest-xdist',
- 'pytest-cov',
- 'pytest-local-badge',
+extra_requires["full-tests"] = extra_requires["full"] + [
+ "pytest",
+ "pytest-forked",
+ "pytest-xdist",
+ "pytest-cov",
+ "pytest-local-badge",
]
setup(
version=VERSION,
install_requires=[
- 'mlos-bench==' + VERSION,
- 'dabl>=0.2.6',
- 'matplotlib<3.9', # FIXME: https://github.com/dabl/dabl/pull/341
+ "mlos-bench==" + VERSION,
+ "dabl>=0.2.6",
+ "matplotlib<3.9", # FIXME: https://github.com/dabl/dabl/pull/341
],
extras_require=extra_requires,
- **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_viz'),
+ **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_viz"),
)
diff --git a/pyproject.toml b/pyproject.toml
index 65f1e5a02c..f70030a576 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,5 +1,5 @@
[tool.black]
-line-length = 88
+line-length = 99
target-version = ["py38", "py39", "py310", "py311", "py312"]
include = '\.pyi?$'
@@ -7,3 +7,12 @@ include = '\.pyi?$'
profile = "black"
py_version = 311
src_paths = ["mlos_core", "mlos_bench", "mlos_viz"]
+
+[tool.docformatter]
+recursive = true
+black = true
+style = "numpy"
+pre-summary-newline = true
+close-quotes-on-newline = true
+
+# TODO: move pylintrc and some setup.cfg configs here
diff --git a/setup.cfg b/setup.cfg
index 492d2de7f2..6f948f523a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,30 +2,32 @@
[pycodestyle]
count = True
+# E203: Whitespace before : (black incompatibility)
# W503: Line break occurred before a binary operator
# W504: Line break occurred after a binary operator
-ignore = W503,W504
+ignore = E203,W503,W504
format = pylint
# See Also: .editorconfig, .pylintrc
-max-line-length = 132
+max-line-length = 99
show-source = True
statistics = True
[pydocstyle]
-# D102: Missing docstring in public method (Avoids inheritence bug. Force checked in .pylintrc instead.)
+# D102: Missing docstring in public method (Avoids inheritence bug. Force checked in pylint instead.)
# D105: Missing docstring in magic method
# D107: Missing docstring in __init__
-# D200: One-line docstring should fit on one line with quotes
# D401: First line should be in imperative mood
# We have many docstrings that are too long to fit on one line, so we ignore both of these two rules:
# D205: 1 blank line required between summary line and description
# D400: First line should end with a period
-add_ignore = D102,D105,D107,D200,D401,D205,D400
+add_ignore = D102,D105,D107,D401,D205,D400
match = .+(?