зеркало из https://github.com/microsoft/MLOS.git
Enable black, isort, and doc formatters and checks (#774)
Follow on work to #766. This enabled both formatters and applies their changes to the repo. Additionally, since `black` does not make changes to comments nor docstrings, we also enable `docformatter` to reformat docstrings which better aligns with `pydocstyle` rules as well. Without this additional change (and some manual fixups), `pycodestyle` and `pylint` would still complain about line lengths, for instance. Finally, we make a minor adjustment to the max line length setting it to 99 (which is also accepted and mentioned in pep8) instead of 88 to avoid some comment (especially linter overrides) wrapping.
This commit is contained in:
Родитель
fd9c8f9935
Коммит
e40ac28317
|
@ -45,7 +45,6 @@
|
|||
// Adjust the python interpreter path to point to the conda environment
|
||||
"python.defaultInterpreterPath": "/opt/conda/envs/mlos/bin/python",
|
||||
"python.testing.pytestPath": "/opt/conda/envs/mlos/bin/pytest",
|
||||
"python.formatting.autopep8Path": "/opt/conda/envs/mlos/bin/autopep8",
|
||||
"python.linting.pylintPath": "/opt/conda/envs/mlos/bin/pylint",
|
||||
"pylint.path": [
|
||||
"/opt/conda/envs/mlos/bin/pylint"
|
||||
|
@ -71,9 +70,8 @@
|
|||
"lextudio.restructuredtext",
|
||||
"matangover.mypy",
|
||||
"ms-azuretools.vscode-docker",
|
||||
// TODO: Enable additional formatter extensions:
|
||||
//"ms-python.black-formatter",
|
||||
//"ms-python.isort",
|
||||
"ms-python.black-formatter",
|
||||
"ms-python.isort",
|
||||
"ms-python.pylint",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
|
|
|
@ -12,6 +12,9 @@ charset = utf-8
|
|||
# Note: this is not currently supported by all editors or their editorconfig plugins.
|
||||
max_line_length = 132
|
||||
|
||||
[*.py]
|
||||
max_line_length = 99
|
||||
|
||||
# Makefiles need tab indentation
|
||||
[{Makefile,*.mk}]
|
||||
indent_style = tab
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Ignore git directory (ripgrep)
|
||||
.git/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
|
|
@ -35,7 +35,7 @@ load-plugins=
|
|||
|
||||
[FORMAT]
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=132
|
||||
max-line-length=99
|
||||
|
||||
[MESSAGE CONTROL]
|
||||
disable=
|
||||
|
@ -48,5 +48,5 @@ disable=
|
|||
missing-raises-doc
|
||||
|
||||
[STRING]
|
||||
#check-quote-consistency=yes
|
||||
check-quote-consistency=yes
|
||||
check-str-concat-over-line-jumps=yes
|
||||
|
|
|
@ -14,9 +14,8 @@
|
|||
"lextudio.restructuredtext",
|
||||
"matangover.mypy",
|
||||
"ms-azuretools.vscode-docker",
|
||||
// TODO: Enable additional formatter extensions:
|
||||
//"ms-python.black-formatter",
|
||||
//"ms-python.isort",
|
||||
"ms-python.black-formatter",
|
||||
"ms-python.isort",
|
||||
"ms-python.pylint",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
|
|
|
@ -125,14 +125,10 @@
|
|||
],
|
||||
"esbonio.sphinx.confDir": "${workspaceFolder}/doc/source",
|
||||
"esbonio.sphinx.buildDir": "${workspaceFolder}/doc/build/",
|
||||
"autopep8.args": [
|
||||
"--experimental"
|
||||
],
|
||||
"[python]": {
|
||||
// TODO: Enable black formatter
|
||||
//"editor.defaultFormatter": "ms-python.black-formatter",
|
||||
//"editor.formatOnSave": true,
|
||||
//"editor.formatOnSaveMode": "modifications"
|
||||
"editor.defaultFormatter": "ms-python.black-formatter",
|
||||
"editor.formatOnSave": true,
|
||||
"editor.formatOnSaveMode": "modifications"
|
||||
},
|
||||
// See Also .vscode/launch.json for environment variable args to pytest during debug sessions.
|
||||
// For the rest, see setup.cfg
|
||||
|
|
103
Makefile
103
Makefile
|
@ -34,7 +34,7 @@ conda-env: build/conda-env.${CONDA_ENV_NAME}.build-stamp
|
|||
MLOS_CORE_CONF_FILES := mlos_core/pyproject.toml mlos_core/setup.py mlos_core/MANIFEST.in
|
||||
MLOS_BENCH_CONF_FILES := mlos_bench/pyproject.toml mlos_bench/setup.py mlos_bench/MANIFEST.in
|
||||
MLOS_VIZ_CONF_FILES := mlos_viz/pyproject.toml mlos_viz/setup.py mlos_viz/MANIFEST.in
|
||||
MLOS_GLOBAL_CONF_FILES := setup.cfg # pyproject.toml
|
||||
MLOS_GLOBAL_CONF_FILES := setup.cfg pyproject.toml
|
||||
|
||||
MLOS_PKGS := mlos_core mlos_bench mlos_viz
|
||||
MLOS_PKG_CONF_FILES := $(MLOS_CORE_CONF_FILES) $(MLOS_BENCH_CONF_FILES) $(MLOS_VIZ_CONF_FILES) $(MLOS_GLOBAL_CONF_FILES)
|
||||
|
@ -69,9 +69,9 @@ ifneq (,$(filter format,$(MAKECMDGOALS)))
|
|||
endif
|
||||
|
||||
build/format.${CONDA_ENV_NAME}.build-stamp: build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
|
||||
# TODO: enable isort and black formatters
|
||||
#build/format.${CONDA_ENV_NAME}.build-stamp: build/isort.${CONDA_ENV_NAME}.build-stamp
|
||||
#build/format.${CONDA_ENV_NAME}.build-stamp: build/black.${CONDA_ENV_NAME}.build-stamp
|
||||
build/format.${CONDA_ENV_NAME}.build-stamp: build/isort.${CONDA_ENV_NAME}.build-stamp
|
||||
build/format.${CONDA_ENV_NAME}.build-stamp: build/black.${CONDA_ENV_NAME}.build-stamp
|
||||
build/format.${CONDA_ENV_NAME}.build-stamp: build/docformatter.${CONDA_ENV_NAME}.build-stamp
|
||||
build/format.${CONDA_ENV_NAME}.build-stamp:
|
||||
touch $@
|
||||
|
||||
|
@ -111,8 +111,8 @@ build/isort.${CONDA_ENV_NAME}.build-stamp:
|
|||
# NOTE: when using pattern rules (involving %) we can only add one line of
|
||||
# prerequisities, so we use this pattern to compose the list as variables.
|
||||
|
||||
# Both isort and licenseheaders alter files, so only run one at a time, by
|
||||
# making licenseheaders an order-only prerequisite.
|
||||
# black, licenseheaders, isort, and docformatter all alter files, so only run
|
||||
# one at a time, by adding prerequisites, but only as necessary.
|
||||
ISORT_COMMON_PREREQS :=
|
||||
ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
|
||||
ISORT_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
|
||||
|
@ -126,7 +126,7 @@ build/isort.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
|
|||
|
||||
build/isort.%.${CONDA_ENV_NAME}.build-stamp: $(ISORT_COMMON_PREREQS)
|
||||
# Reformat python file imports with isort.
|
||||
conda run -n ${CONDA_ENV_NAME} isort --verbose --only-modified --atomic -j0 $(filter %.py,$?)
|
||||
conda run -n ${CONDA_ENV_NAME} isort --verbose --only-modified --atomic -j0 $(filter %.py,$+)
|
||||
touch $@
|
||||
|
||||
.PHONY: black
|
||||
|
@ -142,8 +142,8 @@ build/black.${CONDA_ENV_NAME}.build-stamp: build/black.mlos_viz.${CONDA_ENV_NAME
|
|||
build/black.${CONDA_ENV_NAME}.build-stamp:
|
||||
touch $@
|
||||
|
||||
# Both black, licenseheaders, and isort all alter files, so only run one at a time, by
|
||||
# making licenseheaders and isort an order-only prerequisite.
|
||||
# black, licenseheaders, isort, and docformatter all alter files, so only run
|
||||
# one at a time, by adding prerequisites, but only as necessary.
|
||||
BLACK_COMMON_PREREQS :=
|
||||
ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
|
||||
BLACK_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
|
||||
|
@ -160,13 +160,52 @@ build/black.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
|
|||
|
||||
build/black.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_COMMON_PREREQS)
|
||||
# Reformat python files with black.
|
||||
conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$?)
|
||||
conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$+)
|
||||
touch $@
|
||||
|
||||
.PHONY: docformatter
|
||||
docformatter: build/docformatter.${CONDA_ENV_NAME}.build-stamp
|
||||
|
||||
ifneq (,$(filter docformatter,$(MAKECMDGOALS)))
|
||||
FORMAT_PREREQS += build/docformatter.${CONDA_ENV_NAME}.build-stamp
|
||||
endif
|
||||
|
||||
build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp
|
||||
build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp
|
||||
build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp
|
||||
build/docformatter.${CONDA_ENV_NAME}.build-stamp:
|
||||
touch $@
|
||||
|
||||
# black, licenseheaders, isort, and docformatter all alter files, so only run
|
||||
# one at a time, by adding prerequisites, but only as necessary.
|
||||
DOCFORMATTER_COMMON_PREREQS :=
|
||||
ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
|
||||
DOCFORMATTER_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
|
||||
endif
|
||||
ifneq (,$(filter format isort,$(MAKECMDGOALS)))
|
||||
DOCFORMATTER_COMMON_PREREQS += build/isort.${CONDA_ENV_NAME}.build-stamp
|
||||
endif
|
||||
ifneq (,$(filter format black,$(MAKECMDGOALS)))
|
||||
DOCFORMATTER_COMMON_PREREQS += build/black.${CONDA_ENV_NAME}.build-stamp
|
||||
endif
|
||||
DOCFORMATTER_COMMON_PREREQS += build/conda-env.${CONDA_ENV_NAME}.build-stamp
|
||||
DOCFORMATTER_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
|
||||
|
||||
build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES)
|
||||
build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES)
|
||||
build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
|
||||
|
||||
# docformatter returns non-zero when it changes anything so instead we ignore that
|
||||
# return code and just have it recheck itself immediately
|
||||
build/docformatter.%.${CONDA_ENV_NAME}.build-stamp: $(DOCFORMATTER_COMMON_PREREQS)
|
||||
# Reformat python file docstrings with docformatter.
|
||||
conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$+) || true
|
||||
conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$+)
|
||||
touch $@
|
||||
|
||||
|
||||
.PHONY: check
|
||||
check: pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
|
||||
# TODO: Enable isort and black checks
|
||||
#check: isort-check black-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
|
||||
check: isort-check black-check docformatter-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
|
||||
|
||||
.PHONY: black-check
|
||||
black-check: build/black-check.mlos_core.${CONDA_ENV_NAME}.build-stamp
|
||||
|
@ -185,7 +224,27 @@ BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
|
|||
build/black-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS)
|
||||
# Check for import sort order.
|
||||
# Note: if this fails use "make format" or "make black" to fix it.
|
||||
conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$?)
|
||||
conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$+)
|
||||
touch $@
|
||||
|
||||
.PHONY: docformatter-check
|
||||
docformatter-check: build/docformatter-check.mlos_core.${CONDA_ENV_NAME}.build-stamp
|
||||
docformatter-check: build/docformatter-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp
|
||||
docformatter-check: build/docformatter-check.mlos_viz.${CONDA_ENV_NAME}.build-stamp
|
||||
|
||||
# Make sure docformatter format rules run before docformatter-check rules.
|
||||
build/docformatter-check.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES)
|
||||
build/docformatter-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES)
|
||||
build/docformatter-check.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
|
||||
|
||||
BLACK_CHECK_COMMON_PREREQS := build/conda-env.${CONDA_ENV_NAME}.build-stamp
|
||||
BLACK_CHECK_COMMON_PREREQS += $(FORMAT_PREREQS)
|
||||
BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
|
||||
|
||||
build/docformatter-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS)
|
||||
# Check for import sort order.
|
||||
# Note: if this fails use "make format" or "make docformatter" to fix it.
|
||||
conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$+)
|
||||
touch $@
|
||||
|
||||
.PHONY: isort-check
|
||||
|
@ -204,7 +263,7 @@ ISORT_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
|
|||
|
||||
build/isort-check.%.${CONDA_ENV_NAME}.build-stamp: $(ISORT_CHECK_COMMON_PREREQS)
|
||||
# Note: if this fails use "make format" or "make isort" to fix it.
|
||||
conda run -n ${CONDA_ENV_NAME} isort --only-modified --check --diff -j0 $(filter %.py,$?)
|
||||
conda run -n ${CONDA_ENV_NAME} isort --only-modified --check --diff -j0 $(filter %.py,$+)
|
||||
touch $@
|
||||
|
||||
.PHONY: pycodestyle
|
||||
|
@ -723,7 +782,12 @@ clean-doc:
|
|||
|
||||
.PHONY: clean-format
|
||||
clean-format:
|
||||
# TODO: add black and isort rules
|
||||
rm -f build/black.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/black.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/docformatter.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/docformatter.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/isort.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/isort.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/licenseheaders-prereqs.${CONDA_ENV_NAME}.build-stamp
|
||||
|
||||
|
@ -733,6 +797,13 @@ clean-check:
|
|||
rm -f build/pylint.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/pylint.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/mypy.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/black-check.build-stamp
|
||||
rm -f build/black-check.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/black-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/docformatter-check.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/docformatter-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/isort-check.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/isort-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/pycodestyle.build-stamp
|
||||
rm -f build/pycodestyle.${CONDA_ENV_NAME}.build-stamp
|
||||
rm -f build/pycodestyle.mlos_*.${CONDA_ENV_NAME}.build-stamp
|
||||
|
|
|
@ -25,10 +25,10 @@ dependencies:
|
|||
# See comments in mlos.yml.
|
||||
#- gcc_linux-64
|
||||
- pip:
|
||||
- autopep8>=1.7.0
|
||||
- bump2version
|
||||
- check-jsonschema
|
||||
- isort
|
||||
- docformatter
|
||||
- licenseheaders
|
||||
- mypy
|
||||
- pandas-stubs
|
||||
|
|
|
@ -25,10 +25,10 @@ dependencies:
|
|||
# See comments in mlos.yml.
|
||||
#- gcc_linux-64
|
||||
- pip:
|
||||
- autopep8>=1.7.0
|
||||
- bump2version
|
||||
- check-jsonschema
|
||||
- isort
|
||||
- docformatter
|
||||
- licenseheaders
|
||||
- mypy
|
||||
- pandas-stubs
|
||||
|
|
|
@ -25,10 +25,10 @@ dependencies:
|
|||
# See comments in mlos.yml.
|
||||
#- gcc_linux-64
|
||||
- pip:
|
||||
- autopep8>=1.7.0
|
||||
- bump2version
|
||||
- check-jsonschema
|
||||
- isort
|
||||
- docformatter
|
||||
- licenseheaders
|
||||
- mypy
|
||||
- pandas-stubs
|
||||
|
|
|
@ -25,10 +25,10 @@ dependencies:
|
|||
# See comments in mlos.yml.
|
||||
#- gcc_linux-64
|
||||
- pip:
|
||||
- autopep8>=1.7.0
|
||||
- bump2version
|
||||
- check-jsonschema
|
||||
- isort
|
||||
- docformatter
|
||||
- licenseheaders
|
||||
- mypy
|
||||
- pandas-stubs
|
||||
|
|
|
@ -28,10 +28,10 @@ dependencies:
|
|||
# This also requires a more recent vs2015_runtime from conda-forge.
|
||||
- pyrfr>=0.9.0
|
||||
- pip:
|
||||
- autopep8>=1.7.0
|
||||
- bump2version
|
||||
- check-jsonschema
|
||||
- isort
|
||||
- docformatter
|
||||
- licenseheaders
|
||||
- mypy
|
||||
- pandas-stubs
|
||||
|
|
|
@ -24,10 +24,10 @@ dependencies:
|
|||
# FIXME: https://github.com/microsoft/MLOS/issues/727
|
||||
- python<3.12
|
||||
- pip:
|
||||
- autopep8>=1.7.0
|
||||
- bump2version
|
||||
- check-jsonschema
|
||||
- isort
|
||||
- docformatter
|
||||
- licenseheaders
|
||||
- mypy
|
||||
- pandas-stubs
|
||||
|
|
21
conftest.py
21
conftest.py
|
@ -32,14 +32,23 @@ def pytest_configure(config: pytest.Config) -> None:
|
|||
Add some additional (global) configuration steps for pytest.
|
||||
"""
|
||||
# Workaround some issues loading emukit in certain environments.
|
||||
if os.environ.get('DISPLAY', None):
|
||||
if os.environ.get("DISPLAY", None):
|
||||
try:
|
||||
import matplotlib # pylint: disable=import-outside-toplevel
|
||||
matplotlib.rcParams['backend'] = 'agg'
|
||||
if is_master(config) or dict(getattr(config, 'workerinput', {}))['workerid'] == 'gw0':
|
||||
import matplotlib # pylint: disable=import-outside-toplevel
|
||||
|
||||
matplotlib.rcParams["backend"] = "agg"
|
||||
if is_master(config) or dict(getattr(config, "workerinput", {}))["workerid"] == "gw0":
|
||||
# Only warn once.
|
||||
warn(UserWarning('DISPLAY environment variable is set, which can cause problems in some setups (e.g. WSL). '
|
||||
+ f'Adjusting matplotlib backend to "{matplotlib.rcParams["backend"]}" to compensate.'))
|
||||
warn(
|
||||
UserWarning(
|
||||
(
|
||||
"DISPLAY environment variable is set, "
|
||||
"which can cause problems in some setups (e.g. WSL). "
|
||||
f"Adjusting matplotlib backend to '{matplotlib.rcParams['backend']}' "
|
||||
"to compensate."
|
||||
)
|
||||
)
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -87,10 +87,11 @@ autosummary_generate = True
|
|||
numpydoc_class_members_toctree = False
|
||||
|
||||
autodoc_default_options = {
|
||||
'members': True,
|
||||
'undoc-members': True,
|
||||
# Don't generate documentation for some (non-private) functions that are more for internal implementation use.
|
||||
'exclude-members': 'mlos_bench.util.check_required_params'
|
||||
"members": True,
|
||||
"undoc-members": True,
|
||||
# Don't generate documentation for some (non-private) functions that are more
|
||||
# for internal implementation use.
|
||||
"exclude-members": "mlos_bench.util.check_required_params",
|
||||
}
|
||||
|
||||
# Generate the plots for the gallery
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
mlos_bench is a framework to help automate benchmarking and and
|
||||
OS/application parameter autotuning.
|
||||
"""mlos_bench is a framework to help automate benchmarking and and OS/application
|
||||
parameter autotuning.
|
||||
"""
|
||||
|
|
|
@ -2,6 +2,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
mlos_bench.config
|
||||
"""
|
||||
"""mlos_bench.config."""
|
||||
|
|
|
@ -3,41 +3,36 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Script for post-processing FIO results for mlos_bench.
|
||||
"""
|
||||
"""Script for post-processing FIO results for mlos_bench."""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
|
||||
from typing import Any, Iterator, Tuple
|
||||
|
||||
import pandas
|
||||
|
||||
|
||||
def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]:
|
||||
"""
|
||||
Flatten every dict in the hierarchy and rename the keys with the dict path.
|
||||
"""
|
||||
"""Flatten every dict in the hierarchy and rename the keys with the dict path."""
|
||||
if isinstance(data, dict):
|
||||
for (key, val) in data.items():
|
||||
for key, val in data.items():
|
||||
yield from _flat_dict(val, f"{path}.{key}")
|
||||
else:
|
||||
yield (path, data)
|
||||
|
||||
|
||||
def _main(input_file: str, output_file: str, prefix: str) -> None:
|
||||
"""
|
||||
Convert FIO read data from JSON to tall CSV.
|
||||
"""
|
||||
with open(input_file, mode='r', encoding='utf-8') as fh_input:
|
||||
"""Convert FIO read data from JSON to tall CSV."""
|
||||
with open(input_file, mode="r", encoding="utf-8") as fh_input:
|
||||
json_data = json.load(fh_input)
|
||||
|
||||
data = list(itertools.chain(
|
||||
_flat_dict(json_data["jobs"][0], prefix),
|
||||
_flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util")
|
||||
))
|
||||
data = list(
|
||||
itertools.chain(
|
||||
_flat_dict(json_data["jobs"][0], prefix),
|
||||
_flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util"),
|
||||
)
|
||||
)
|
||||
|
||||
tall_df = pandas.DataFrame(data, columns=["metric", "value"])
|
||||
tall_df.to_csv(output_file, index=False)
|
||||
|
@ -50,12 +45,18 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(description="Post-process FIO benchmark results.")
|
||||
|
||||
parser.add_argument(
|
||||
"input", help="FIO benchmark results in JSON format (downloaded from a remote VM).")
|
||||
"input",
|
||||
help="FIO benchmark results in JSON format (downloaded from a remote VM).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).")
|
||||
"output",
|
||||
help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix", default="fio",
|
||||
help="Prefix of the metric IDs (default 'fio')")
|
||||
"--prefix",
|
||||
default="fio",
|
||||
help="Prefix of the metric IDs (default 'fio')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
_main(args.input, args.output, args.prefix)
|
||||
|
|
|
@ -9,22 +9,24 @@ Helper script to generate Redis config from tunable parameters JSON.
|
|||
Run: `./generate_redis_config.py ./input-params.json ./output-redis.cfg`
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def _main(fname_input: str, fname_output: str) -> None:
|
||||
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
|
||||
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
|
||||
for (key, val) in json.load(fh_tunables).items():
|
||||
line = f'{key} {val}'
|
||||
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open(
|
||||
fname_output, "wt", encoding="utf-8", newline=""
|
||||
) as fh_config:
|
||||
for key, val in json.load(fh_tunables).items():
|
||||
line = f"{key} {val}"
|
||||
fh_config.write(line + "\n")
|
||||
print(line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="generate Redis config from tunable parameters JSON.")
|
||||
description="generate Redis config from tunable parameters JSON."
|
||||
)
|
||||
parser.add_argument("input", help="JSON file with tunable parameters.")
|
||||
parser.add_argument("output", help="Output Redis config file.")
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -3,9 +3,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Script for post-processing redis-benchmark results.
|
||||
"""
|
||||
"""Script for post-processing redis-benchmark results."""
|
||||
|
||||
import argparse
|
||||
|
||||
|
@ -13,26 +11,25 @@ import pandas as pd
|
|||
|
||||
|
||||
def _main(input_file: str, output_file: str) -> None:
|
||||
"""
|
||||
Re-shape Redis benchmark CSV results from wide to long.
|
||||
"""
|
||||
"""Re-shape Redis benchmark CSV results from wide to long."""
|
||||
df_wide = pd.read_csv(input_file)
|
||||
|
||||
# Format the results from wide to long
|
||||
# The target is columns of metric and value to act as key-value pairs.
|
||||
df_long = (
|
||||
df_wide
|
||||
.melt(id_vars=["test"])
|
||||
df_wide.melt(id_vars=["test"])
|
||||
.assign(metric=lambda df: df["test"] + "_" + df["variable"])
|
||||
.drop(columns=["test", "variable"])
|
||||
.loc[:, ["metric", "value"]]
|
||||
)
|
||||
|
||||
# Add a default `score` metric to the end of the dataframe.
|
||||
df_long = pd.concat([
|
||||
df_long,
|
||||
pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]})
|
||||
])
|
||||
df_long = pd.concat(
|
||||
[
|
||||
df_long,
|
||||
pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}),
|
||||
]
|
||||
)
|
||||
|
||||
df_long.to_csv(output_file, index=False)
|
||||
print(f"Converted: {input_file} -> {output_file}")
|
||||
|
@ -41,8 +38,13 @@ def _main(input_file: str, output_file: str) -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.")
|
||||
parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).")
|
||||
parser.add_argument("output", help="Converted Redis benchmark data" +
|
||||
" (to be consumed by OS Autotune framework).")
|
||||
parser.add_argument(
|
||||
"input",
|
||||
help="Redis benchmark results (downloaded from a remote VM).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output",
|
||||
help="Converted Redis benchmark data (to be consumed by OS Autotune framework).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_main(args.input, args.output)
|
||||
|
|
|
@ -6,16 +6,18 @@
|
|||
"""
|
||||
Python script to parse through JSON and create new config file.
|
||||
|
||||
This script will be run in the SCHEDULER.
|
||||
NEW_CFG will need to be copied over to the VM (/etc/default/grub.d).
|
||||
This script will be run in the SCHEDULER. NEW_CFG will need to be copied over to the VM
|
||||
(/etc/default/grub.d).
|
||||
"""
|
||||
import json
|
||||
|
||||
JSON_CONFIG_FILE = "config-boot-time.json"
|
||||
NEW_CFG = "zz-mlos-boot-params.cfg"
|
||||
|
||||
with open(JSON_CONFIG_FILE, 'r', encoding='UTF-8') as fh_json, \
|
||||
open(NEW_CFG, 'w', encoding='UTF-8') as fh_config:
|
||||
with open(JSON_CONFIG_FILE, "r", encoding="UTF-8") as fh_json, open(
|
||||
NEW_CFG, "w", encoding="UTF-8"
|
||||
) as fh_config:
|
||||
for key, val in json.load(fh_json).items():
|
||||
fh_config.write('GRUB_CMDLINE_LINUX_DEFAULT="$'
|
||||
f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n')
|
||||
fh_config.write(
|
||||
'GRUB_CMDLINE_LINUX_DEFAULT="$' f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n'
|
||||
)
|
||||
|
|
|
@ -9,14 +9,15 @@ Helper script to generate GRUB config from tunable parameters JSON.
|
|||
Run: `./generate_grub_config.py ./input-boot-params.json ./output-grub.cfg`
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def _main(fname_input: str, fname_output: str) -> None:
|
||||
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
|
||||
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
|
||||
for (key, val) in json.load(fh_tunables).items():
|
||||
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open(
|
||||
fname_output, "wt", encoding="utf-8", newline=""
|
||||
) as fh_config:
|
||||
for key, val in json.load(fh_tunables).items():
|
||||
line = f'GRUB_CMDLINE_LINUX_DEFAULT="${{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"'
|
||||
fh_config.write(line + "\n")
|
||||
print(line)
|
||||
|
@ -24,7 +25,8 @@ def _main(fname_input: str, fname_output: str) -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate GRUB config from tunable parameters JSON.")
|
||||
description="Generate GRUB config from tunable parameters JSON."
|
||||
)
|
||||
parser.add_argument("input", help="JSON file with tunable parameters.")
|
||||
parser.add_argument("output", help="Output shell script to configure GRUB.")
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -6,11 +6,14 @@
|
|||
"""
|
||||
Helper script to generate a script to update kernel parameters from tunables JSON.
|
||||
|
||||
Run: `./generate_kernel_config_script.py ./kernel-params.json ./kernel-params-meta.json ./config-kernel.sh`
|
||||
Run:
|
||||
./generate_kernel_config_script.py \
|
||||
./kernel-params.json ./kernel-params-meta.json \
|
||||
./config-kernel.sh
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
|
||||
|
@ -22,7 +25,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
|
|||
tunables_meta = json.load(fh_meta)
|
||||
|
||||
with open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
|
||||
for (key, val) in tunables_data.items():
|
||||
for key, val in tunables_data.items():
|
||||
meta = tunables_meta.get(key, {})
|
||||
name_prefix = meta.get("name_prefix", "")
|
||||
line = f'echo "{val}" > {name_prefix}{key}'
|
||||
|
@ -33,7 +36,8 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
|
|||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="generate a script to update kernel parameters from tunables JSON.")
|
||||
description="generate a script to update kernel parameters from tunables JSON."
|
||||
)
|
||||
|
||||
parser.add_argument("input", help="JSON file with tunable parameters.")
|
||||
parser.add_argument("meta", help="JSON file with tunable parameters metadata.")
|
||||
|
|
|
@ -2,14 +2,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A module for managing config schemas and their validation.
|
||||
"""
|
||||
|
||||
from mlos_bench.config.schemas.config_schemas import ConfigSchema, CONFIG_SCHEMA_DIR
|
||||
"""A module for managing config schemas and their validation."""
|
||||
|
||||
from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema
|
||||
|
||||
__all__ = [
|
||||
'ConfigSchema',
|
||||
'CONFIG_SCHEMA_DIR',
|
||||
"ConfigSchema",
|
||||
"CONFIG_SCHEMA_DIR",
|
||||
]
|
||||
|
|
|
@ -2,18 +2,17 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A simple class for describing where to find different config schemas and validating configs against them.
|
||||
"""A simple class for describing where to find different config schemas and validating
|
||||
configs against them.
|
||||
"""
|
||||
|
||||
import json # schema files are pure json - no comments
|
||||
import logging
|
||||
from enum import Enum
|
||||
from os import path, walk, environ
|
||||
from os import environ, path, walk
|
||||
from typing import Dict, Iterator, Mapping
|
||||
|
||||
import json # schema files are pure json - no comments
|
||||
import jsonschema
|
||||
|
||||
from referencing import Registry, Resource
|
||||
from referencing.jsonschema import DRAFT202012
|
||||
|
||||
|
@ -28,16 +27,21 @@ CONFIG_SCHEMA_DIR = path_join(path.dirname(__file__), abs_path=True)
|
|||
# It is used in `ConfigSchema.validate()` method below.
|
||||
# NOTE: this may cause pytest to fail if it's expecting exceptions
|
||||
# to be raised for invalid configs.
|
||||
_VALIDATION_ENV_FLAG = 'MLOS_BENCH_SKIP_SCHEMA_VALIDATION'
|
||||
_SKIP_VALIDATION = (environ.get(_VALIDATION_ENV_FLAG, 'false').lower()
|
||||
in {'true', 'y', 'yes', 'on', '1'})
|
||||
_VALIDATION_ENV_FLAG = "MLOS_BENCH_SKIP_SCHEMA_VALIDATION"
|
||||
_SKIP_VALIDATION = environ.get(_VALIDATION_ENV_FLAG, "false").lower() in {
|
||||
"true",
|
||||
"y",
|
||||
"yes",
|
||||
"on",
|
||||
"1",
|
||||
}
|
||||
|
||||
|
||||
# Note: we separate out the SchemaStore from a class method on ConfigSchema
|
||||
# because of issues with mypy/pylint and non-Enum-member class members.
|
||||
class SchemaStore(Mapping):
|
||||
"""
|
||||
A simple class for storing schemas and subschemas for the validator to reference.
|
||||
"""A simple class for storing schemas and subschemas for the validator to
|
||||
reference.
|
||||
"""
|
||||
|
||||
# A class member mapping of schema id to schema object.
|
||||
|
@ -58,7 +62,9 @@ class SchemaStore(Mapping):
|
|||
|
||||
@classmethod
|
||||
def _load_schemas(cls) -> None:
|
||||
"""Loads all schemas and subschemas into the schema store for the validator to reference."""
|
||||
"""Loads all schemas and subschemas into the schema store for the validator to
|
||||
reference.
|
||||
"""
|
||||
if cls._SCHEMA_STORE:
|
||||
return
|
||||
for root, _, files in walk(CONFIG_SCHEMA_DIR):
|
||||
|
@ -78,13 +84,17 @@ class SchemaStore(Mapping):
|
|||
|
||||
@classmethod
|
||||
def _load_registry(cls) -> None:
|
||||
"""Also store them in a Registry object for referencing by recent versions of jsonschema."""
|
||||
"""Also store them in a Registry object for referencing by recent versions of
|
||||
jsonschema.
|
||||
"""
|
||||
if not cls._SCHEMA_STORE:
|
||||
cls._load_schemas()
|
||||
cls._REGISTRY = Registry().with_resources([
|
||||
(url, Resource.from_contents(schema, default_specification=DRAFT202012))
|
||||
for url, schema in cls._SCHEMA_STORE.items()
|
||||
])
|
||||
cls._REGISTRY = Registry().with_resources(
|
||||
[
|
||||
(url, Resource.from_contents(schema, default_specification=DRAFT202012))
|
||||
for url, schema in cls._SCHEMA_STORE.items()
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def registry(self) -> Registry:
|
||||
|
@ -98,9 +108,7 @@ SCHEMA_STORE = SchemaStore()
|
|||
|
||||
|
||||
class ConfigSchema(Enum):
|
||||
"""
|
||||
An enum to help describe schema types and help validate configs against them.
|
||||
"""
|
||||
"""An enum to help describe schema types and help validate configs against them."""
|
||||
|
||||
CLI = path_join(CONFIG_SCHEMA_DIR, "cli/cli-schema.json")
|
||||
GLOBALS = path_join(CONFIG_SCHEMA_DIR, "cli/globals-schema.json")
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Simple class to help with nested dictionary $var templating.
|
||||
"""
|
||||
"""Simple class to help with nested dictionary $var templating."""
|
||||
|
||||
from copy import deepcopy
|
||||
from string import Template
|
||||
|
@ -13,10 +11,8 @@ from typing import Any, Dict, Optional
|
|||
from mlos_bench.os_environ import environ
|
||||
|
||||
|
||||
class DictTemplater: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Simple class to help with nested dictionary $var templating.
|
||||
"""
|
||||
class DictTemplater: # pylint: disable=too-few-public-methods
|
||||
"""Simple class to help with nested dictionary $var templating."""
|
||||
|
||||
def __init__(self, source_dict: Dict[str, Any]):
|
||||
"""
|
||||
|
@ -32,9 +28,12 @@ class DictTemplater: # pylint: disable=too-few-public-methods
|
|||
# The source/target dictionary to expand.
|
||||
self._dict: Dict[str, Any] = {}
|
||||
|
||||
def expand_vars(self, *,
|
||||
extra_source_dict: Optional[Dict[str, Any]] = None,
|
||||
use_os_env: bool = False) -> Dict[str, Any]:
|
||||
def expand_vars(
|
||||
self,
|
||||
*,
|
||||
extra_source_dict: Optional[Dict[str, Any]] = None,
|
||||
use_os_env: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Expand the template variables in the destination dictionary.
|
||||
|
||||
|
@ -55,10 +54,13 @@ class DictTemplater: # pylint: disable=too-few-public-methods
|
|||
assert isinstance(self._dict, dict)
|
||||
return self._dict
|
||||
|
||||
def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any:
|
||||
"""
|
||||
Recursively expand $var strings in the currently operating dictionary.
|
||||
"""
|
||||
def _expand_vars(
|
||||
self,
|
||||
value: Any,
|
||||
extra_source_dict: Optional[Dict[str, Any]],
|
||||
use_os_env: bool,
|
||||
) -> Any:
|
||||
"""Recursively expand $var strings in the currently operating dictionary."""
|
||||
if isinstance(value, str):
|
||||
# First try to expand all $vars internally.
|
||||
value = Template(value).safe_substitute(self._dict)
|
||||
|
@ -71,7 +73,7 @@ class DictTemplater: # pylint: disable=too-few-public-methods
|
|||
elif isinstance(value, dict):
|
||||
# Note: we use a loop instead of dict comprehension in order to
|
||||
# allow secondary expansion of subsequent values immediately.
|
||||
for (key, val) in value.items():
|
||||
for key, val in value.items():
|
||||
value[key] = self._expand_vars(val, extra_source_dict, use_os_env)
|
||||
elif isinstance(value, list):
|
||||
value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value]
|
||||
|
|
|
@ -2,26 +2,22 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Tunable Environments for mlos_bench.
|
||||
"""
|
||||
"""Tunable Environments for mlos_bench."""
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
|
||||
from mlos_bench.environments.mock_env import MockEnv
|
||||
from mlos_bench.environments.remote.remote_env import RemoteEnv
|
||||
from mlos_bench.environments.composite_env import CompositeEnv
|
||||
from mlos_bench.environments.local.local_env import LocalEnv
|
||||
from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv
|
||||
from mlos_bench.environments.composite_env import CompositeEnv
|
||||
from mlos_bench.environments.mock_env import MockEnv
|
||||
from mlos_bench.environments.remote.remote_env import RemoteEnv
|
||||
from mlos_bench.environments.status import Status
|
||||
|
||||
__all__ = [
|
||||
'Status',
|
||||
|
||||
'Environment',
|
||||
'MockEnv',
|
||||
'RemoteEnv',
|
||||
'LocalEnv',
|
||||
'LocalFileShareEnv',
|
||||
'CompositeEnv',
|
||||
"Status",
|
||||
"Environment",
|
||||
"MockEnv",
|
||||
"RemoteEnv",
|
||||
"LocalEnv",
|
||||
"LocalFileShareEnv",
|
||||
"CompositeEnv",
|
||||
]
|
||||
|
|
|
@ -2,19 +2,28 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A hierarchy of benchmark environments.
|
||||
"""
|
||||
"""A hierarchy of benchmark environments."""
|
||||
|
||||
import abc
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union
|
||||
from typing_extensions import Literal
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pytz import UTC
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mlos_bench.config.schemas import ConfigSchema
|
||||
from mlos_bench.dict_templater import DictTemplater
|
||||
|
@ -32,20 +41,19 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
class Environment(metaclass=abc.ABCMeta):
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
An abstract base of all benchmark environments.
|
||||
"""
|
||||
"""An abstract base of all benchmark environments."""
|
||||
|
||||
@classmethod
|
||||
def new(cls,
|
||||
*,
|
||||
env_name: str,
|
||||
class_name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
) -> "Environment":
|
||||
def new(
|
||||
cls,
|
||||
*,
|
||||
env_name: str,
|
||||
class_name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
) -> "Environment":
|
||||
"""
|
||||
Factory method for a new environment with a given config.
|
||||
|
||||
|
@ -83,16 +91,18 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service
|
||||
service=service,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment with a given config.
|
||||
|
||||
|
@ -123,24 +133,33 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
self._const_args: Dict[str, TunableValue] = config.get("const_args", {})
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Environment: '%s' Service: %s", name,
|
||||
self._service.pprint() if self._service else None)
|
||||
_LOG.debug(
|
||||
"Environment: '%s' Service: %s",
|
||||
name,
|
||||
self._service.pprint() if self._service else None,
|
||||
)
|
||||
|
||||
if tunables is None:
|
||||
_LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name)
|
||||
_LOG.warning(
|
||||
(
|
||||
"No tunables provided for %s. "
|
||||
"Tunable inheritance across composite environments may be broken."
|
||||
),
|
||||
name,
|
||||
)
|
||||
tunables = TunableGroups()
|
||||
|
||||
groups = self._expand_groups(
|
||||
config.get("tunable_params", []),
|
||||
(global_config or {}).get("tunable_params_map", {}))
|
||||
(global_config or {}).get("tunable_params_map", {}),
|
||||
)
|
||||
_LOG.debug("Tunable groups for: '%s' :: %s", name, groups)
|
||||
|
||||
self._tunable_params = tunables.subgroup(groups)
|
||||
|
||||
# If a parameter comes from the tunables, do not require it in the const_args or globals
|
||||
req_args = (
|
||||
set(config.get("required_args", [])) -
|
||||
set(self._tunable_params.get_param_values().keys())
|
||||
req_args = set(config.get("required_args", [])) - set(
|
||||
self._tunable_params.get_param_values().keys()
|
||||
)
|
||||
merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args)
|
||||
self._const_args = self._expand_vars(self._const_args, global_config or {})
|
||||
|
@ -149,14 +168,12 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
_LOG.debug("Parameters for '%s' :: %s", name, self._params)
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Config for: '%s'\n%s",
|
||||
name, json.dumps(self.config, indent=2))
|
||||
_LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2))
|
||||
|
||||
def _validate_json_config(self, config: dict, name: str) -> None:
|
||||
"""
|
||||
Reconstructs a basic json config that this class might have been
|
||||
instantiated from in order to validate configs provided outside the
|
||||
file loading mechanism.
|
||||
"""Reconstructs a basic json config that this class might have been instantiated
|
||||
from in order to validate configs provided outside the file loading
|
||||
mechanism.
|
||||
"""
|
||||
json_config: dict = {
|
||||
"class": self.__class__.__module__ + "." + self.__class__.__name__,
|
||||
|
@ -168,8 +185,10 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
ConfigSchema.ENVIRONMENT.validate(json_config)
|
||||
|
||||
@staticmethod
|
||||
def _expand_groups(groups: Iterable[str],
|
||||
groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]:
|
||||
def _expand_groups(
|
||||
groups: Iterable[str],
|
||||
groups_exp: Dict[str, Union[str, Sequence[str]]],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Expand `$tunable_group` into actual names of the tunable groups.
|
||||
|
||||
|
@ -191,7 +210,12 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
if grp[:1] == "$":
|
||||
tunable_group_name = grp[1:]
|
||||
if tunable_group_name not in groups_exp:
|
||||
raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}")
|
||||
raise KeyError(
|
||||
(
|
||||
f"Expected tunable group name ${tunable_group_name} "
|
||||
"undefined in {groups_exp}"
|
||||
)
|
||||
)
|
||||
add_groups = groups_exp[tunable_group_name]
|
||||
res += [add_groups] if isinstance(add_groups, str) else add_groups
|
||||
else:
|
||||
|
@ -199,10 +223,11 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
return res
|
||||
|
||||
@staticmethod
|
||||
def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict:
|
||||
"""
|
||||
Expand `$var` into actual values of the variables.
|
||||
"""
|
||||
def _expand_vars(
|
||||
params: Dict[str, TunableValue],
|
||||
global_config: Dict[str, TunableValue],
|
||||
) -> dict:
|
||||
"""Expand `$var` into actual values of the variables."""
|
||||
return DictTemplater(params).expand_vars(extra_source_dict=global_config)
|
||||
|
||||
@property
|
||||
|
@ -210,10 +235,8 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
assert self._service is not None
|
||||
return self._service.config_loader_service
|
||||
|
||||
def __enter__(self) -> 'Environment':
|
||||
"""
|
||||
Enter the environment's benchmarking context.
|
||||
"""
|
||||
def __enter__(self) -> "Environment":
|
||||
"""Enter the environment's benchmarking context."""
|
||||
_LOG.debug("Environment START :: %s", self)
|
||||
assert not self._in_context
|
||||
if self._service:
|
||||
|
@ -221,12 +244,13 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
self._in_context = True
|
||||
return self
|
||||
|
||||
def __exit__(self, ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
"""
|
||||
Exit the context of the benchmarking environment.
|
||||
"""
|
||||
def __exit__(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""Exit the context of the benchmarking environment."""
|
||||
ex_throw = None
|
||||
if ex_val is None:
|
||||
_LOG.debug("Environment END :: %s", self)
|
||||
|
@ -256,8 +280,8 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
|
||||
def pprint(self, indent: int = 4, level: int = 0) -> str:
|
||||
"""
|
||||
Pretty-print the environment configuration.
|
||||
For composite environments, print all children environments as well.
|
||||
Pretty-print the environment configuration. For composite environments, print
|
||||
all children environments as well.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -277,8 +301,8 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]:
|
||||
"""
|
||||
Plug tunable values into the base config. If the tunable group is unknown,
|
||||
ignore it (it might belong to another environment). This method should
|
||||
never mutate the original config or the tunables.
|
||||
ignore it (it might belong to another environment). This method should never
|
||||
mutate the original config or the tunables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -293,7 +317,8 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
"""
|
||||
return tunables.get_param_values(
|
||||
group_names=list(self._tunable_params.get_covariant_group_names()),
|
||||
into_params=self._const_args.copy())
|
||||
into_params=self._const_args.copy(),
|
||||
)
|
||||
|
||||
@property
|
||||
def tunable_params(self) -> TunableGroups:
|
||||
|
@ -310,21 +335,23 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
@property
|
||||
def parameters(self) -> Dict[str, TunableValue]:
|
||||
"""
|
||||
Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`).
|
||||
Note that before `.setup()` is called, all tunables will be set to None.
|
||||
Key/value pairs of all environment parameters (i.e., `const_args` and
|
||||
`tunable_params`). Note that before `.setup()` is called, all tunables will be
|
||||
set to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
parameters : Dict[str, TunableValue]
|
||||
Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`).
|
||||
Key/value pairs of all environment parameters
|
||||
(i.e., `const_args` and `tunable_params`).
|
||||
"""
|
||||
return self._params
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Set up a new benchmark environment, if necessary. This method must be
|
||||
idempotent, i.e., calling it several times in a row should be
|
||||
equivalent to a single call.
|
||||
idempotent, i.e., calling it several times in a row should be equivalent to a
|
||||
single call.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -353,10 +380,15 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
# (Derived classes still have to check `self._tunable_params.is_updated()`).
|
||||
is_updated = self._tunable_params.is_updated()
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, {
|
||||
name: self._tunable_params.is_updated([name])
|
||||
for name in self._tunable_params.get_covariant_group_names()
|
||||
})
|
||||
_LOG.debug(
|
||||
"Env '%s': Tunable groups reset = %s :: %s",
|
||||
self,
|
||||
is_updated,
|
||||
{
|
||||
name: self._tunable_params.is_updated([name])
|
||||
for name in self._tunable_params.get_covariant_group_names()
|
||||
},
|
||||
)
|
||||
else:
|
||||
_LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated)
|
||||
|
||||
|
@ -371,9 +403,10 @@ class Environment(metaclass=abc.ABCMeta):
|
|||
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Tear down the benchmark environment. This method must be idempotent,
|
||||
i.e., calling it several times in a row should be equivalent to a
|
||||
single call.
|
||||
Tear down the benchmark environment.
|
||||
|
||||
This method must be idempotent, i.e., calling it several times in a row should
|
||||
be equivalent to a single call.
|
||||
"""
|
||||
_LOG.info("Teardown %s", self)
|
||||
# Make sure we create a context before invoking setup/run/status/teardown
|
||||
|
|
|
@ -2,20 +2,18 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Composite benchmark environment.
|
||||
"""
|
||||
"""Composite benchmark environment."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
|
@ -23,17 +21,17 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CompositeEnv(Environment):
|
||||
"""
|
||||
Composite benchmark environment.
|
||||
"""
|
||||
"""Composite benchmark environment."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment with a given config.
|
||||
|
||||
|
@ -53,8 +51,13 @@ class CompositeEnv(Environment):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config,
|
||||
tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
# By default, the Environment includes only the tunables explicitly specified
|
||||
# in the "tunable_params" section of the config. `CompositeEnv`, however, must
|
||||
|
@ -70,17 +73,27 @@ class CompositeEnv(Environment):
|
|||
# each CompositeEnv gets a copy of the original global config and adjusts it with
|
||||
# the `const_args` specific to it.
|
||||
global_config = (global_config or {}).copy()
|
||||
for (key, val) in self._const_args.items():
|
||||
for key, val in self._const_args.items():
|
||||
global_config.setdefault(key, val)
|
||||
|
||||
for child_config_file in config.get("include_children", []):
|
||||
for env in self._config_loader_service.load_environment_list(
|
||||
child_config_file, tunables, global_config, self._const_args, self._service):
|
||||
child_config_file,
|
||||
tunables,
|
||||
global_config,
|
||||
self._const_args,
|
||||
self._service,
|
||||
):
|
||||
self._add_child(env, tunables)
|
||||
|
||||
for child_config in config.get("children", []):
|
||||
env = self._config_loader_service.build_environment(
|
||||
child_config, tunables, global_config, self._const_args, self._service)
|
||||
child_config,
|
||||
tunables,
|
||||
global_config,
|
||||
self._const_args,
|
||||
self._service,
|
||||
)
|
||||
self._add_child(env, tunables)
|
||||
|
||||
_LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params)
|
||||
|
@ -92,9 +105,12 @@ class CompositeEnv(Environment):
|
|||
self._child_contexts = [env.__enter__() for env in self._children]
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
def __exit__(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
ex_throw = None
|
||||
for env in reversed(self._children):
|
||||
try:
|
||||
|
@ -111,9 +127,7 @@ class CompositeEnv(Environment):
|
|||
|
||||
@property
|
||||
def children(self) -> List[Environment]:
|
||||
"""
|
||||
Return the list of child environments.
|
||||
"""
|
||||
"""Return the list of child environments."""
|
||||
return self._children
|
||||
|
||||
def pprint(self, indent: int = 4, level: int = 0) -> str:
|
||||
|
@ -132,12 +146,16 @@ class CompositeEnv(Environment):
|
|||
pretty : str
|
||||
Pretty-printed environment configuration.
|
||||
"""
|
||||
return super().pprint(indent, level) + '\n' + '\n'.join(
|
||||
child.pprint(indent, level + 1) for child in self._children)
|
||||
return (
|
||||
super().pprint(indent, level)
|
||||
+ "\n"
|
||||
+ "\n".join(child.pprint(indent, level + 1) for child in self._children)
|
||||
)
|
||||
|
||||
def _add_child(self, env: Environment, tunables: TunableGroups) -> None:
|
||||
"""
|
||||
Add a new child environment to the composite environment.
|
||||
|
||||
This method is called from the constructor only.
|
||||
"""
|
||||
_LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params)
|
||||
|
@ -165,14 +183,16 @@ class CompositeEnv(Environment):
|
|||
"""
|
||||
assert self._in_context
|
||||
self._is_ready = super().setup(tunables, global_config) and all(
|
||||
env_context.setup(tunables, global_config) for env_context in self._child_contexts)
|
||||
env_context.setup(tunables, global_config) for env_context in self._child_contexts
|
||||
)
|
||||
return self._is_ready
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Tear down the children environments. This method is idempotent,
|
||||
i.e., calling it several times is equivalent to a single call.
|
||||
The environments are being torn down in the reverse order.
|
||||
Tear down the children environments.
|
||||
|
||||
This method is idempotent, i.e., calling it several times is equivalent to a
|
||||
single call. The environments are being torn down in the reverse order.
|
||||
"""
|
||||
assert self._in_context
|
||||
for env_context in reversed(self._child_contexts):
|
||||
|
@ -181,9 +201,9 @@ class CompositeEnv(Environment):
|
|||
|
||||
def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
|
||||
"""
|
||||
Submit a new experiment to the environment.
|
||||
Return the result of the *last* child environment if successful,
|
||||
or the status of the last failed environment otherwise.
|
||||
Submit a new experiment to the environment. Return the result of the *last*
|
||||
child environment if successful, or the status of the last failed environment
|
||||
otherwise.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -238,5 +258,6 @@ class CompositeEnv(Environment):
|
|||
|
||||
final_status = final_status or status
|
||||
_LOG.info("Final status: %s :: %s", self, final_status)
|
||||
# Return the status and the timestamp of the last child environment or the first failed child environment.
|
||||
# Return the status and the timestamp of the last child environment or the
|
||||
# first failed child environment.
|
||||
return (final_status, timestamp, joint_telemetry)
|
||||
|
|
|
@ -2,14 +2,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Local Environments for mlos_bench.
|
||||
"""
|
||||
"""Local Environments for mlos_bench."""
|
||||
|
||||
from mlos_bench.environments.local.local_env import LocalEnv
|
||||
from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv
|
||||
|
||||
__all__ = [
|
||||
'LocalEnv',
|
||||
'LocalFileShareEnv',
|
||||
"LocalEnv",
|
||||
"LocalFileShareEnv",
|
||||
]
|
||||
|
|
|
@ -2,27 +2,23 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Scheduler-side benchmark environment to run scripts locally.
|
||||
"""
|
||||
"""Scheduler-side benchmark environment to run scripts locally."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from tempfile import TemporaryDirectory
|
||||
from contextlib import nullcontext
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union
|
||||
from typing_extensions import Literal
|
||||
|
||||
import pandas
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.environments.script_env import ScriptEnv
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
|
@ -34,17 +30,17 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
class LocalEnv(ScriptEnv):
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
Scheduler-side Environment that runs scripts locally.
|
||||
"""
|
||||
"""Scheduler-side Environment that runs scripts locally."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment for local execution.
|
||||
|
||||
|
@ -67,11 +63,17 @@ class LocalEnv(ScriptEnv):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config,
|
||||
tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
|
||||
"LocalEnv requires a service that supports local execution"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsLocalExec
|
||||
), "LocalEnv requires a service that supports local execution"
|
||||
self._local_exec_service: SupportsLocalExec = self._service
|
||||
|
||||
self._temp_dir: Optional[str] = None
|
||||
|
@ -85,16 +87,19 @@ class LocalEnv(ScriptEnv):
|
|||
|
||||
def __enter__(self) -> Environment:
|
||||
assert self._temp_dir is None and self._temp_dir_context is None
|
||||
self._temp_dir_context = self._local_exec_service.temp_dir_context(self.config.get("temp_dir"))
|
||||
self._temp_dir_context = self._local_exec_service.temp_dir_context(
|
||||
self.config.get("temp_dir"),
|
||||
)
|
||||
self._temp_dir = self._temp_dir_context.__enter__()
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
"""
|
||||
Exit the context of the benchmarking environment.
|
||||
"""
|
||||
def __exit__(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""Exit the context of the benchmarking environment."""
|
||||
assert not (self._temp_dir is None or self._temp_dir_context is None)
|
||||
self._temp_dir_context.__exit__(ex_type, ex_val, ex_tb)
|
||||
self._temp_dir = None
|
||||
|
@ -103,8 +108,8 @@ class LocalEnv(ScriptEnv):
|
|||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Check if the environment is ready and set up the application
|
||||
and benchmarks, if necessary.
|
||||
Check if the environment is ready and set up the application and benchmarks, if
|
||||
necessary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -139,10 +144,14 @@ class LocalEnv(ScriptEnv):
|
|||
fname = path_join(self._temp_dir, self._dump_meta_file)
|
||||
_LOG.debug("Dump tunables metadata to file: %s", fname)
|
||||
with open(fname, "wt", encoding="utf-8") as fh_meta:
|
||||
json.dump({
|
||||
tunable.name: tunable.meta
|
||||
for (tunable, _group) in self._tunable_params if tunable.meta
|
||||
}, fh_meta)
|
||||
json.dump(
|
||||
{
|
||||
tunable.name: tunable.meta
|
||||
for (tunable, _group) in self._tunable_params
|
||||
if tunable.meta
|
||||
},
|
||||
fh_meta,
|
||||
)
|
||||
|
||||
if self._script_setup:
|
||||
(return_code, _output) = self._local_exec(self._script_setup, self._temp_dir)
|
||||
|
@ -182,18 +191,28 @@ class LocalEnv(ScriptEnv):
|
|||
_LOG.debug("Not reading the data at: %s", self)
|
||||
return (Status.SUCCEEDED, timestamp, stdout_data)
|
||||
|
||||
data = self._normalize_columns(pandas.read_csv(
|
||||
self._config_loader_service.resolve_path(
|
||||
self._read_results_file, extra_paths=[self._temp_dir]),
|
||||
index_col=False,
|
||||
))
|
||||
data = self._normalize_columns(
|
||||
pandas.read_csv(
|
||||
self._config_loader_service.resolve_path(
|
||||
self._read_results_file,
|
||||
extra_paths=[self._temp_dir],
|
||||
),
|
||||
index_col=False,
|
||||
)
|
||||
)
|
||||
|
||||
_LOG.debug("Read data:\n%s", data)
|
||||
if list(data.columns) == ["metric", "value"]:
|
||||
_LOG.info("Local results have (metric,value) header and %d rows: assume long format", len(data))
|
||||
_LOG.info(
|
||||
"Local results have (metric,value) header and %d rows: assume long format",
|
||||
len(data),
|
||||
)
|
||||
data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list())
|
||||
# Try to convert string metrics to numbers.
|
||||
data = data.apply(pandas.to_numeric, errors='coerce').fillna(data) # type: ignore[assignment] # (false positive)
|
||||
data = data.apply( # type: ignore[assignment] # (false positive)
|
||||
pandas.to_numeric,
|
||||
errors="coerce",
|
||||
).fillna(data)
|
||||
elif len(data) == 1:
|
||||
_LOG.info("Local results have 1 row: assume wide format")
|
||||
else:
|
||||
|
@ -205,14 +224,12 @@ class LocalEnv(ScriptEnv):
|
|||
|
||||
@staticmethod
|
||||
def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame:
|
||||
"""
|
||||
Strip trailing spaces from column names (Windows only).
|
||||
"""
|
||||
"""Strip trailing spaces from column names (Windows only)."""
|
||||
# Windows cmd interpretation of > redirect symbols can leave trailing spaces in
|
||||
# the final column, which leads to misnamed columns.
|
||||
# For now, we simply strip trailing spaces from column names to account for that.
|
||||
if sys.platform == 'win32':
|
||||
data.rename(str.rstrip, axis='columns', inplace=True)
|
||||
if sys.platform == "win32":
|
||||
data.rename(str.rstrip, axis="columns", inplace=True)
|
||||
return data
|
||||
|
||||
def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
|
||||
|
@ -224,24 +241,24 @@ class LocalEnv(ScriptEnv):
|
|||
assert self._temp_dir is not None
|
||||
try:
|
||||
fname = self._config_loader_service.resolve_path(
|
||||
self._read_telemetry_file, extra_paths=[self._temp_dir])
|
||||
self._read_telemetry_file,
|
||||
extra_paths=[self._temp_dir],
|
||||
)
|
||||
|
||||
# TODO: Use the timestamp of the CSV file as our status timestamp?
|
||||
|
||||
# FIXME: We should not be assuming that the only output file type is a CSV.
|
||||
|
||||
data = self._normalize_columns(
|
||||
pandas.read_csv(fname, index_col=False))
|
||||
data = self._normalize_columns(pandas.read_csv(fname, index_col=False))
|
||||
data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
|
||||
|
||||
expected_col_names = ["timestamp", "metric", "value"]
|
||||
if len(data.columns) != len(expected_col_names):
|
||||
raise ValueError(f'Telemetry data must have columns {expected_col_names}')
|
||||
raise ValueError(f"Telemetry data must have columns {expected_col_names}")
|
||||
|
||||
if list(data.columns) != expected_col_names:
|
||||
# Assume no header - this is ok for telemetry data.
|
||||
data = pandas.read_csv(
|
||||
fname, index_col=False, names=expected_col_names)
|
||||
data = pandas.read_csv(fname, index_col=False, names=expected_col_names)
|
||||
data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
|
||||
|
||||
except FileNotFoundError as ex:
|
||||
|
@ -250,15 +267,17 @@ class LocalEnv(ScriptEnv):
|
|||
|
||||
_LOG.debug("Read telemetry data:\n%s", data)
|
||||
col_dtypes: Mapping[int, Type] = {0: datetime}
|
||||
return (status, timestamp, [
|
||||
(pandas.Timestamp(ts).to_pydatetime(), metric, value)
|
||||
for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
|
||||
])
|
||||
return (
|
||||
status,
|
||||
timestamp,
|
||||
[
|
||||
(pandas.Timestamp(ts).to_pydatetime(), metric, value)
|
||||
for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
|
||||
],
|
||||
)
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Clean up the local environment.
|
||||
"""
|
||||
"""Clean up the local environment."""
|
||||
if self._script_teardown:
|
||||
_LOG.info("Local teardown: %s", self)
|
||||
(return_code, _output) = self._local_exec(self._script_teardown)
|
||||
|
@ -285,7 +304,10 @@ class LocalEnv(ScriptEnv):
|
|||
env_params = self._get_env_params()
|
||||
_LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params)
|
||||
(return_code, stdout, stderr) = self._local_exec_service.local_exec(
|
||||
script, env=env_params, cwd=cwd)
|
||||
script,
|
||||
env=env_params,
|
||||
cwd=cwd,
|
||||
)
|
||||
if return_code != 0:
|
||||
_LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr)
|
||||
return (return_code, {"stdout": stdout, "stderr": stderr})
|
||||
|
|
|
@ -2,22 +2,20 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Scheduler-side Environment to run scripts locally
|
||||
and upload/download data to the shared storage.
|
||||
"""Scheduler-side Environment to run scripts locally and upload/download data to the
|
||||
shared storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from datetime import datetime
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Generator, Iterable, Mapping, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
|
||||
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.environments.local.local_env import LocalEnv
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
|
||||
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
|
@ -25,18 +23,19 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class LocalFileShareEnv(LocalEnv):
|
||||
"""
|
||||
Scheduler-side Environment that runs scripts locally
|
||||
and uploads/downloads data to the shared file storage.
|
||||
"""Scheduler-side Environment that runs scripts locally and uploads/downloads data
|
||||
to the shared file storage.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new application environment with a given config.
|
||||
|
||||
|
@ -60,34 +59,41 @@ class LocalFileShareEnv(LocalEnv):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
|
||||
"LocalEnv requires a service that supports local execution"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsLocalExec
|
||||
), "LocalEnv requires a service that supports local execution"
|
||||
self._local_exec_service: SupportsLocalExec = self._service
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \
|
||||
"LocalEnv requires a service that supports file upload/download operations"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsFileShareOps
|
||||
), "LocalEnv requires a service that supports file upload/download operations"
|
||||
self._file_share_service: SupportsFileShareOps = self._service
|
||||
|
||||
self._upload = self._template_from_to("upload")
|
||||
self._download = self._template_from_to("download")
|
||||
|
||||
def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]:
|
||||
"""Convert a list of {"from": "...", "to": "..."} to a list of pairs of
|
||||
string.Template objects so that we can plug in self._params into it later.
|
||||
"""
|
||||
Convert a list of {"from": "...", "to": "..."} to a list of pairs
|
||||
of string.Template objects so that we can plug in self._params into it later.
|
||||
"""
|
||||
return [
|
||||
(Template(d['from']), Template(d['to']))
|
||||
for d in self.config.get(config_key, [])
|
||||
]
|
||||
return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])]
|
||||
|
||||
@staticmethod
|
||||
def _expand(from_to: Iterable[Tuple[Template, Template]],
|
||||
params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]:
|
||||
def _expand(
|
||||
from_to: Iterable[Tuple[Template, Template]],
|
||||
params: Mapping[str, TunableValue],
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
"""
|
||||
Substitute $var parameters in from/to path templates.
|
||||
|
||||
Return a generator of (str, str) pairs of paths.
|
||||
"""
|
||||
return (
|
||||
|
@ -120,9 +126,15 @@ class LocalFileShareEnv(LocalEnv):
|
|||
assert self._temp_dir is not None
|
||||
params = self._get_env_params(restrict=False)
|
||||
params["PWD"] = self._temp_dir
|
||||
for (path_from, path_to) in self._expand(self._upload, params):
|
||||
self._file_share_service.upload(self._params, self._config_loader_service.resolve_path(
|
||||
path_from, extra_paths=[self._temp_dir]), path_to)
|
||||
for path_from, path_to in self._expand(self._upload, params):
|
||||
self._file_share_service.upload(
|
||||
self._params,
|
||||
self._config_loader_service.resolve_path(
|
||||
path_from,
|
||||
extra_paths=[self._temp_dir],
|
||||
),
|
||||
path_to,
|
||||
)
|
||||
return self._is_ready
|
||||
|
||||
def _download_files(self, ignore_missing: bool = False) -> None:
|
||||
|
@ -138,11 +150,16 @@ class LocalFileShareEnv(LocalEnv):
|
|||
assert self._temp_dir is not None
|
||||
params = self._get_env_params(restrict=False)
|
||||
params["PWD"] = self._temp_dir
|
||||
for (path_from, path_to) in self._expand(self._download, params):
|
||||
for path_from, path_to in self._expand(self._download, params):
|
||||
try:
|
||||
self._file_share_service.download(self._params,
|
||||
path_from, self._config_loader_service.resolve_path(
|
||||
path_to, extra_paths=[self._temp_dir]))
|
||||
self._file_share_service.download(
|
||||
self._params,
|
||||
path_from,
|
||||
self._config_loader_service.resolve_path(
|
||||
path_to,
|
||||
extra_paths=[self._temp_dir],
|
||||
),
|
||||
)
|
||||
except FileNotFoundError as ex:
|
||||
_LOG.warning("Cannot download: %s", path_from)
|
||||
if not ignore_missing:
|
||||
|
@ -153,8 +170,8 @@ class LocalFileShareEnv(LocalEnv):
|
|||
|
||||
def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
|
||||
"""
|
||||
Download benchmark results from the shared storage
|
||||
and run post-processing scripts locally.
|
||||
Download benchmark results from the shared storage and run post-processing
|
||||
scripts locally.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
|
@ -2,40 +2,38 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Scheduler-side environment to mock the benchmark results.
|
||||
"""
|
||||
"""Scheduler-side environment to mock the benchmark results."""
|
||||
|
||||
import random
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables import Tunable, TunableGroups, TunableValue
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockEnv(Environment):
|
||||
"""
|
||||
Scheduler-side environment to mock the benchmark results.
|
||||
"""
|
||||
"""Scheduler-side environment to mock the benchmark results."""
|
||||
|
||||
_NOISE_VAR = 0.2
|
||||
"""Variance of the Gaussian noise added to the benchmark value."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment that produces mock benchmark data.
|
||||
|
||||
|
@ -55,8 +53,13 @@ class MockEnv(Environment):
|
|||
service: Service
|
||||
An optional service object. Not used by this class.
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config,
|
||||
tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
seed = int(self.config.get("mock_env_seed", -1))
|
||||
self._random = random.Random(seed or None) if seed >= 0 else None
|
||||
self._range = self.config.get("mock_env_range")
|
||||
|
@ -81,9 +84,9 @@ class MockEnv(Environment):
|
|||
return result
|
||||
|
||||
# Simple convex function of all tunable parameters.
|
||||
score = numpy.mean(numpy.square([
|
||||
self._normalized(tunable) for (tunable, _group) in self._tunable_params
|
||||
]))
|
||||
score = numpy.mean(
|
||||
numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params])
|
||||
)
|
||||
|
||||
# Add noise and shift the benchmark value from [0, 1] to a given range.
|
||||
noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0
|
||||
|
@ -97,15 +100,16 @@ class MockEnv(Environment):
|
|||
def _normalized(tunable: Tunable) -> float:
|
||||
"""
|
||||
Get the NORMALIZED value of a tunable.
|
||||
|
||||
That is, map current value to the [0, 1] range.
|
||||
"""
|
||||
val = None
|
||||
if tunable.is_categorical:
|
||||
val = (tunable.categories.index(tunable.category) /
|
||||
float(len(tunable.categories) - 1))
|
||||
val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1)
|
||||
elif tunable.is_numerical:
|
||||
val = ((tunable.numerical_value - tunable.range[0]) /
|
||||
float(tunable.range[1] - tunable.range[0]))
|
||||
val = (tunable.numerical_value - tunable.range[0]) / float(
|
||||
tunable.range[1] - tunable.range[0]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid parameter type: " + tunable.type)
|
||||
# Explicitly clip the value in case of numerical errors.
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Remote Tunable Environments for mlos_bench.
|
||||
"""
|
||||
"""Remote Tunable Environments for mlos_bench."""
|
||||
|
||||
from mlos_bench.environments.remote.host_env import HostEnv
|
||||
from mlos_bench.environments.remote.network_env import NetworkEnv
|
||||
|
@ -14,10 +12,10 @@ from mlos_bench.environments.remote.saas_env import SaaSEnv
|
|||
from mlos_bench.environments.remote.vm_env import VMEnv
|
||||
|
||||
__all__ = [
|
||||
'HostEnv',
|
||||
'NetworkEnv',
|
||||
'OSEnv',
|
||||
'RemoteEnv',
|
||||
'SaaSEnv',
|
||||
'VMEnv',
|
||||
"HostEnv",
|
||||
"NetworkEnv",
|
||||
"OSEnv",
|
||||
"RemoteEnv",
|
||||
"SaaSEnv",
|
||||
"VMEnv",
|
||||
]
|
||||
|
|
|
@ -2,13 +2,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Remote host Environment.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
"""Remote host Environment."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.services.base_service import Service
|
||||
|
@ -19,17 +16,17 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class HostEnv(Environment):
|
||||
"""
|
||||
Remote host environment.
|
||||
"""
|
||||
"""Remote host environment."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment for host operations.
|
||||
|
||||
|
@ -50,10 +47,17 @@ class HostEnv(Environment):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM/host, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsHostProvisioning), \
|
||||
"HostEnv requires a service that supports host provisioning operations"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsHostProvisioning
|
||||
), "HostEnv requires a service that supports host provisioning operations"
|
||||
self._host_service: SupportsHostProvisioning = self._service
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
|
@ -88,9 +92,7 @@ class HostEnv(Environment):
|
|||
return self._is_ready
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Shut down the Host and release it.
|
||||
"""
|
||||
"""Shut down the Host and release it."""
|
||||
_LOG.info("Host tear down: %s", self)
|
||||
(status, params) = self._host_service.deprovision_host(self._params)
|
||||
if status.is_pending():
|
||||
|
|
|
@ -2,17 +2,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Network Environment.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
"""Network Environment."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
|
||||
from mlos_bench.services.types.network_provisioner_type import (
|
||||
SupportsNetworkProvisioning,
|
||||
)
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -26,13 +25,15 @@ class NetworkEnv(Environment):
|
|||
but no real tuning is expected for it ... yet.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment for network operations.
|
||||
|
||||
|
@ -53,14 +54,21 @@ class NetworkEnv(Environment):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy a network, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
# Virtual networks can be used for more than one experiment, so by default
|
||||
# we don't attempt to deprovision them.
|
||||
self._deprovision_on_teardown = config.get("deprovision_on_teardown", False)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsNetworkProvisioning), \
|
||||
"NetworkEnv requires a service that supports network provisioning"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsNetworkProvisioning
|
||||
), "NetworkEnv requires a service that supports network provisioning"
|
||||
self._network_service: SupportsNetworkProvisioning = self._service
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
|
@ -96,15 +104,16 @@ class NetworkEnv(Environment):
|
|||
return self._is_ready
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Shut down the Network and releases it.
|
||||
"""
|
||||
"""Shut down the Network and releases it."""
|
||||
if not self._deprovision_on_teardown:
|
||||
_LOG.info("Skipping Network deprovision: %s", self)
|
||||
return
|
||||
# Else
|
||||
_LOG.info("Network tear down: %s", self)
|
||||
(status, params) = self._network_service.deprovision_network(self._params, ignore_errors=True)
|
||||
(status, params) = self._network_service.deprovision_network(
|
||||
self._params,
|
||||
ignore_errors=True,
|
||||
)
|
||||
if status.is_pending():
|
||||
(status, _) = self._network_service.wait_network_deployment(params, is_setup=False)
|
||||
|
||||
|
|
|
@ -2,13 +2,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
OS-level remote Environment on Azure.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
"""OS-level remote Environment on Azure."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.environments.status import Status
|
||||
|
@ -21,17 +18,17 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class OSEnv(Environment):
|
||||
"""
|
||||
OS Level Environment for a host.
|
||||
"""
|
||||
"""OS Level Environment for a host."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment for remote execution.
|
||||
|
||||
|
@ -54,14 +51,22 @@ class OSEnv(Environment):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsHostOps), \
|
||||
"RemoteEnv requires a service that supports host operations"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsHostOps
|
||||
), "RemoteEnv requires a service that supports host operations"
|
||||
self._host_service: SupportsHostOps = self._service
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsOSOps), \
|
||||
"RemoteEnv requires a service that supports host operations"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsOSOps
|
||||
), "RemoteEnv requires a service that supports host operations"
|
||||
self._os_service: SupportsOSOps = self._service
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
|
@ -98,9 +103,7 @@ class OSEnv(Environment):
|
|||
return self._is_ready
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Clean up and shut down the host without deprovisioning it.
|
||||
"""
|
||||
"""Clean up and shut down the host without deprovisioning it."""
|
||||
_LOG.info("OS tear down: %s", self)
|
||||
(status, params) = self._os_service.shutdown(self._params)
|
||||
if status.is_pending():
|
||||
|
|
|
@ -14,11 +14,11 @@ from typing import Dict, Iterable, Optional, Tuple
|
|||
|
||||
from pytz import UTC
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.environments.script_env import ScriptEnv
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
|
||||
from mlos_bench.services.types.host_ops_type import SupportsHostOps
|
||||
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
|
@ -32,13 +32,15 @@ class RemoteEnv(ScriptEnv):
|
|||
e.g. Application Environment
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment for remote execution.
|
||||
|
||||
|
@ -61,24 +63,31 @@ class RemoteEnv(ScriptEnv):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a Host, VM, OS, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config,
|
||||
tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
self._wait_boot = self.config.get("wait_boot", False)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \
|
||||
"RemoteEnv requires a service that supports remote execution operations"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsRemoteExec
|
||||
), "RemoteEnv requires a service that supports remote execution operations"
|
||||
self._remote_exec_service: SupportsRemoteExec = self._service
|
||||
|
||||
if self._wait_boot:
|
||||
assert self._service is not None and isinstance(self._service, SupportsHostOps), \
|
||||
"RemoteEnv requires a service that supports host operations"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsHostOps
|
||||
), "RemoteEnv requires a service that supports host operations"
|
||||
self._host_service: SupportsHostOps = self._service
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
Check if the environment is ready and set up the application
|
||||
and benchmarks on a remote host.
|
||||
Check if the environment is ready and set up the application and benchmarks on a
|
||||
remote host.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -143,9 +152,7 @@ class RemoteEnv(ScriptEnv):
|
|||
return (status, timestamp, output)
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
Clean up and shut down the remote environment.
|
||||
"""
|
||||
"""Clean up and shut down the remote environment."""
|
||||
if self._script_teardown:
|
||||
_LOG.info("Remote teardown: %s", self)
|
||||
(status, _timestamp, _output) = self._remote_exec(self._script_teardown)
|
||||
|
@ -170,7 +177,10 @@ class RemoteEnv(ScriptEnv):
|
|||
env_params = self._get_env_params()
|
||||
_LOG.debug("Submit script: %s with %s", self, env_params)
|
||||
(status, output) = self._remote_exec_service.remote_exec(
|
||||
script, config=self._params, env_params=env_params)
|
||||
script,
|
||||
config=self._params,
|
||||
env_params=env_params,
|
||||
)
|
||||
_LOG.debug("Script submitted: %s %s :: %s", self, status, output)
|
||||
if status in {Status.PENDING, Status.SUCCEEDED}:
|
||||
(status, output) = self._remote_exec_service.get_remote_exec_results(output)
|
||||
|
|
|
@ -2,13 +2,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Cloud-based (configurable) SaaS environment.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
"""Cloud-based (configurable) SaaS environment."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.services.base_service import Service
|
||||
|
@ -20,17 +17,17 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SaaSEnv(Environment):
|
||||
"""
|
||||
Cloud-based (configurable) SaaS environment.
|
||||
"""
|
||||
"""Cloud-based (configurable) SaaS environment."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment for (configurable) cloud-based SaaS instance.
|
||||
|
||||
|
@ -51,15 +48,22 @@ class SaaSEnv(Environment):
|
|||
An optional service object
|
||||
(e.g., providing methods to configure the remote service).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config,
|
||||
tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsHostOps), \
|
||||
"RemoteEnv requires a service that supports host operations"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsHostOps
|
||||
), "RemoteEnv requires a service that supports host operations"
|
||||
self._host_service: SupportsHostOps = self._service
|
||||
|
||||
assert self._service is not None and isinstance(self._service, SupportsRemoteConfig), \
|
||||
"SaaSEnv requires a service that supports remote host configuration API"
|
||||
assert self._service is not None and isinstance(
|
||||
self._service, SupportsRemoteConfig
|
||||
), "SaaSEnv requires a service that supports remote host configuration API"
|
||||
self._config_service: SupportsRemoteConfig = self._service
|
||||
|
||||
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
|
||||
|
@ -85,7 +89,9 @@ class SaaSEnv(Environment):
|
|||
return False
|
||||
|
||||
(status, _) = self._config_service.configure(
|
||||
self._params, self._tunable_params.get_param_values())
|
||||
self._params,
|
||||
self._tunable_params.get_param_values(),
|
||||
)
|
||||
if not status.is_succeeded():
|
||||
return False
|
||||
|
||||
|
@ -94,7 +100,7 @@ class SaaSEnv(Environment):
|
|||
return False
|
||||
|
||||
# Azure Flex DB instances currently require a VM reboot after reconfiguration.
|
||||
if res.get('isConfigPendingRestart') or res.get('isConfigPendingReboot'):
|
||||
if res.get("isConfigPendingRestart") or res.get("isConfigPendingReboot"):
|
||||
_LOG.info("Restarting: %s", self)
|
||||
(status, params) = self._host_service.restart_host(self._params)
|
||||
if status.is_pending():
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
"Remote" VM (Host) Environment.
|
||||
"""
|
||||
"""Remote VM (Host) Environment."""
|
||||
|
||||
import logging
|
||||
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base scriptable benchmark environment.
|
||||
"""
|
||||
"""Base scriptable benchmark environment."""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
|
@ -15,26 +13,25 @@ from mlos_bench.environments.base_environment import Environment
|
|||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
from mlos_bench.util import try_parse_val
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScriptEnv(Environment, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Base Environment that runs scripts for setup/run/teardown.
|
||||
"""
|
||||
"""Base Environment that runs scripts for setup/run/teardown."""
|
||||
|
||||
_RE_INVALID = re.compile(r"[^a-zA-Z0-9_]")
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
tunables: Optional[TunableGroups] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new environment for script execution.
|
||||
|
||||
|
@ -64,19 +61,29 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta):
|
|||
An optional service object (e.g., providing methods to
|
||||
deploy or reboot a VM, etc.).
|
||||
"""
|
||||
super().__init__(name=name, config=config, global_config=global_config,
|
||||
tunables=tunables, service=service)
|
||||
super().__init__(
|
||||
name=name,
|
||||
config=config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
self._script_setup = self.config.get("setup")
|
||||
self._script_run = self.config.get("run")
|
||||
self._script_teardown = self.config.get("teardown")
|
||||
|
||||
self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", [])
|
||||
self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {})
|
||||
self._shell_env_params_rename: Dict[str, str] = self.config.get(
|
||||
"shell_env_params_rename", {}
|
||||
)
|
||||
|
||||
results_stdout_pattern = self.config.get("results_stdout_pattern")
|
||||
self._results_stdout_pattern: Optional[re.Pattern[str]] = \
|
||||
re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None
|
||||
self._results_stdout_pattern: Optional[re.Pattern[str]] = (
|
||||
re.compile(results_stdout_pattern, flags=re.MULTILINE)
|
||||
if results_stdout_pattern
|
||||
else None
|
||||
)
|
||||
|
||||
def _get_env_params(self, restrict: bool = True) -> Dict[str, str]:
|
||||
"""
|
||||
|
@ -117,4 +124,6 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta):
|
|||
if not self._results_stdout_pattern:
|
||||
return {}
|
||||
_LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout)
|
||||
return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)}
|
||||
return {
|
||||
key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)
|
||||
}
|
||||
|
|
|
@ -2,17 +2,13 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Enum for the status of the benchmark/environment.
|
||||
"""
|
||||
"""Enum for the status of the benchmark/environment."""
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class Status(enum.Enum):
|
||||
"""
|
||||
Enum for the status of the benchmark/environment.
|
||||
"""
|
||||
"""Enum for the status of the benchmark/environment."""
|
||||
|
||||
UNKNOWN = 0
|
||||
PENDING = 1
|
||||
|
@ -24,9 +20,7 @@ class Status(enum.Enum):
|
|||
TIMED_OUT = 7
|
||||
|
||||
def is_good(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is good.
|
||||
"""
|
||||
"""Check if the status of the benchmark/environment is good."""
|
||||
return self in {
|
||||
Status.PENDING,
|
||||
Status.READY,
|
||||
|
@ -35,9 +29,8 @@ class Status(enum.Enum):
|
|||
}
|
||||
|
||||
def is_completed(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is
|
||||
one of {SUCCEEDED, CANCELED, FAILED, TIMED_OUT}.
|
||||
"""Check if the status of the benchmark/environment is one of {SUCCEEDED,
|
||||
CANCELED, FAILED, TIMED_OUT}.
|
||||
"""
|
||||
return self in {
|
||||
Status.SUCCEEDED,
|
||||
|
@ -47,37 +40,25 @@ class Status(enum.Enum):
|
|||
}
|
||||
|
||||
def is_pending(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is PENDING.
|
||||
"""
|
||||
"""Check if the status of the benchmark/environment is PENDING."""
|
||||
return self == Status.PENDING
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is READY.
|
||||
"""
|
||||
"""Check if the status of the benchmark/environment is READY."""
|
||||
return self == Status.READY
|
||||
|
||||
def is_succeeded(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is SUCCEEDED.
|
||||
"""
|
||||
"""Check if the status of the benchmark/environment is SUCCEEDED."""
|
||||
return self == Status.SUCCEEDED
|
||||
|
||||
def is_failed(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is FAILED.
|
||||
"""
|
||||
"""Check if the status of the benchmark/environment is FAILED."""
|
||||
return self == Status.FAILED
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is CANCELED.
|
||||
"""
|
||||
"""Check if the status of the benchmark/environment is CANCELED."""
|
||||
return self == Status.CANCELED
|
||||
|
||||
def is_timed_out(self) -> bool:
|
||||
"""
|
||||
Check if the status of the benchmark/environment is TIMED_OUT.
|
||||
"""
|
||||
"""Check if the status of the benchmark/environment is TIMED_OUT."""
|
||||
return self == Status.FAILED
|
||||
|
|
|
@ -2,25 +2,23 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
EventLoopContext class definition.
|
||||
"""
|
||||
|
||||
from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Coroutine, Optional, TypeVar
|
||||
from threading import Lock as ThreadLock, Thread
|
||||
"""EventLoopContext class definition."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import Future
|
||||
from threading import Lock as ThreadLock
|
||||
from threading import Thread
|
||||
from typing import Any, Coroutine, Optional, TypeVar
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name
|
||||
CoroReturnType = TypeVar("CoroReturnType") # pylint: disable=invalid-name
|
||||
if sys.version_info >= (3, 9):
|
||||
FutureReturnType: TypeAlias = Future[CoroReturnType]
|
||||
else:
|
||||
|
@ -31,15 +29,15 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
class EventLoopContext:
|
||||
"""
|
||||
EventLoopContext encapsulates a background thread for asyncio event
|
||||
loop processing as an aid for context managers.
|
||||
EventLoopContext encapsulates a background thread for asyncio event loop processing
|
||||
as an aid for context managers.
|
||||
|
||||
There is generally only expected to be one of these, either as a base
|
||||
class instance if it's specific to that functionality or for the full
|
||||
mlos_bench process to support parallel trial runners, for instance.
|
||||
There is generally only expected to be one of these, either as a base class instance
|
||||
if it's specific to that functionality or for the full mlos_bench process to support
|
||||
parallel trial runners, for instance.
|
||||
|
||||
It's enter() and exit() routines are expected to be called from the
|
||||
caller's context manager routines (e.g., __enter__ and __exit__).
|
||||
It's enter() and exit() routines are expected to be called from the caller's context
|
||||
manager routines (e.g., __enter__ and __exit__).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -49,17 +47,13 @@ class EventLoopContext:
|
|||
self._event_loop_thread_refcnt: int = 0
|
||||
|
||||
def _run_event_loop(self) -> None:
|
||||
"""
|
||||
Runs the asyncio event loop in a background thread.
|
||||
"""
|
||||
"""Runs the asyncio event loop in a background thread."""
|
||||
assert self._event_loop is not None
|
||||
asyncio.set_event_loop(self._event_loop)
|
||||
self._event_loop.run_forever()
|
||||
|
||||
def enter(self) -> None:
|
||||
"""
|
||||
Manages starting the background thread for event loop processing.
|
||||
"""
|
||||
"""Manages starting the background thread for event loop processing."""
|
||||
# Start the background thread if it's not already running.
|
||||
with self._event_loop_thread_lock:
|
||||
if not self._event_loop_thread:
|
||||
|
@ -74,9 +68,7 @@ class EventLoopContext:
|
|||
self._event_loop_thread_refcnt += 1
|
||||
|
||||
def exit(self) -> None:
|
||||
"""
|
||||
Manages cleaning up the background thread for event loop processing.
|
||||
"""
|
||||
"""Manages cleaning up the background thread for event loop processing."""
|
||||
with self._event_loop_thread_lock:
|
||||
self._event_loop_thread_refcnt -= 1
|
||||
assert self._event_loop_thread_refcnt >= 0
|
||||
|
@ -92,8 +84,8 @@ class EventLoopContext:
|
|||
|
||||
def run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType:
|
||||
"""
|
||||
Runs the given coroutine in the background event loop thread and
|
||||
returns a Future that can be used to wait for the result.
|
||||
Runs the given coroutine in the background event loop thread and returns a
|
||||
Future that can be used to wait for the result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A helper class to load the configuration files, parse the command line parameters,
|
||||
and instantiate the main components of mlos_bench system.
|
||||
A helper class to load the configuration files, parse the command line parameters, and
|
||||
instantiate the main components of mlos_bench system.
|
||||
|
||||
It is used in `mlos_bench.run` module to run the benchmark/optimizer from the
|
||||
command line.
|
||||
|
@ -13,34 +13,26 @@ command line.
|
|||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from mlos_bench.config.schemas import ConfigSchema
|
||||
from mlos_bench.dict_templater import DictTemplater
|
||||
from mlos_bench.util import try_parse_val
|
||||
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
|
||||
from mlos_bench.optimizers.base_optimizer import Optimizer
|
||||
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
|
||||
from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
|
||||
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.local.local_exec import LocalExecService
|
||||
from mlos_bench.services.config_persistence import ConfigPersistenceService
|
||||
|
||||
from mlos_bench.schedulers.base_scheduler import Scheduler
|
||||
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.config_persistence import ConfigPersistenceService
|
||||
from mlos_bench.services.local.local_exec import LocalExecService
|
||||
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
|
||||
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.util import try_parse_val
|
||||
|
||||
_LOG_LEVEL = logging.INFO
|
||||
_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s'
|
||||
_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s"
|
||||
logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT)
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -48,22 +40,22 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
class Launcher:
|
||||
# pylint: disable=too-few-public-methods,too-many-instance-attributes
|
||||
"""
|
||||
Command line launcher for mlos_bench and mlos_core.
|
||||
"""
|
||||
"""Command line launcher for mlos_bench and mlos_core."""
|
||||
|
||||
def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None):
|
||||
# pylint: disable=too-many-statements
|
||||
_LOG.info("Launch: %s", description)
|
||||
epilog = """
|
||||
Additional --key=value pairs can be specified to augment or override values listed in --globals.
|
||||
Other required_args values can also be pulled from shell environment variables.
|
||||
Additional --key=value pairs can be specified to augment or override
|
||||
values listed in --globals.
|
||||
Other required_args values can also be pulled from shell environment
|
||||
variables.
|
||||
|
||||
For additional details, please see the website or the README.md files in the source tree:
|
||||
For additional details, please see the website or the README.md files in
|
||||
the source tree:
|
||||
<https://github.com/microsoft/MLOS/tree/main/mlos_bench/>
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=f"{description} : {long_text}",
|
||||
epilog=epilog)
|
||||
parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog)
|
||||
(args, args_rest) = self._parse_args(parser, argv)
|
||||
|
||||
# Bootstrap config loader: command line takes priority.
|
||||
|
@ -101,16 +93,18 @@ class Launcher:
|
|||
args_rest,
|
||||
{key: val for (key, val) in config.items() if key not in vars(args)},
|
||||
)
|
||||
# experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI.
|
||||
# experiment_id is generally taken from --globals files, but we also allow
|
||||
# overriding it on the CLI.
|
||||
# It's useful to keep it there explicitly mostly for the --help output.
|
||||
if args.experiment_id:
|
||||
self.global_config['experiment_id'] = args.experiment_id
|
||||
# trial_config_repeat_count is a scheduler property but it's convenient to set it via command line
|
||||
self.global_config["experiment_id"] = args.experiment_id
|
||||
# trial_config_repeat_count is a scheduler property but it's convenient to
|
||||
# set it via command line
|
||||
if args.trial_config_repeat_count:
|
||||
self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count
|
||||
# Ensure that the trial_id is present since it gets used by some other
|
||||
# configs but is typically controlled by the run optimize loop.
|
||||
self.global_config.setdefault('trial_id', 1)
|
||||
self.global_config.setdefault("trial_id", 1)
|
||||
|
||||
self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True)
|
||||
assert isinstance(self.global_config, dict)
|
||||
|
@ -118,24 +112,31 @@ class Launcher:
|
|||
# --service cli args should override the config file values.
|
||||
service_files: List[str] = config.get("services", []) + (args.service or [])
|
||||
assert isinstance(self._parent_service, SupportsConfigLoading)
|
||||
self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service)
|
||||
self._parent_service = self._parent_service.load_services(
|
||||
service_files,
|
||||
self.global_config,
|
||||
self._parent_service,
|
||||
)
|
||||
|
||||
env_path = args.environment or config.get("environment")
|
||||
if not env_path:
|
||||
_LOG.error("No environment config specified.")
|
||||
parser.error("At least the Environment config must be specified." +
|
||||
" Run `mlos_bench --help` and consult `README.md` for more info.")
|
||||
parser.error(
|
||||
"At least the Environment config must be specified."
|
||||
+ " Run `mlos_bench --help` and consult `README.md` for more info."
|
||||
)
|
||||
self.root_env_config = self._config_loader.resolve_path(env_path)
|
||||
|
||||
self.environment: Environment = self._config_loader.load_environment(
|
||||
self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service)
|
||||
self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service
|
||||
)
|
||||
_LOG.info("Init environment: %s", self.environment)
|
||||
|
||||
# NOTE: Init tunable values *after* the Environment, but *before* the Optimizer
|
||||
self.tunables = self._init_tunable_values(
|
||||
args.random_init or config.get("random_init", False),
|
||||
config.get("random_seed") if args.random_seed is None else args.random_seed,
|
||||
config.get("tunable_values", []) + (args.tunable_values or [])
|
||||
config.get("tunable_values", []) + (args.tunable_values or []),
|
||||
)
|
||||
_LOG.info("Init tunables: %s", self.tunables)
|
||||
|
||||
|
@ -145,106 +146,170 @@ class Launcher:
|
|||
self.storage = self._load_storage(args.storage or config.get("storage"))
|
||||
_LOG.info("Init storage: %s", self.storage)
|
||||
|
||||
self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True))
|
||||
self.teardown: bool = (
|
||||
bool(args.teardown)
|
||||
if args.teardown is not None
|
||||
else bool(config.get("teardown", True))
|
||||
)
|
||||
self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler"))
|
||||
_LOG.info("Init scheduler: %s", self.scheduler)
|
||||
|
||||
@property
|
||||
def config_loader(self) -> ConfigPersistenceService:
|
||||
"""
|
||||
Get the config loader service.
|
||||
"""
|
||||
"""Get the config loader service."""
|
||||
return self._config_loader
|
||||
|
||||
@property
|
||||
def service(self) -> Service:
|
||||
"""
|
||||
Get the parent service.
|
||||
"""
|
||||
"""Get the parent service."""
|
||||
return self._parent_service
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]:
|
||||
"""
|
||||
Parse the command line arguments.
|
||||
"""
|
||||
def _parse_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
argv: Optional[List[str]],
|
||||
) -> Tuple[argparse.Namespace, List[str]]:
|
||||
"""Parse the command line arguments."""
|
||||
parser.add_argument(
|
||||
'--config', required=False,
|
||||
help='Main JSON5 configuration file. Its keys are the same as the' +
|
||||
' command line options and can be overridden by the latter.\n' +
|
||||
'\n' +
|
||||
' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' +
|
||||
' for additional config examples for this and other arguments.')
|
||||
"--config",
|
||||
required=False,
|
||||
help="Main JSON5 configuration file. Its keys are the same as the"
|
||||
+ " command line options and can be overridden by the latter.\n"
|
||||
+ "\n"
|
||||
+ " See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ "
|
||||
+ " for additional config examples for this and other arguments.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--log_file', '--log-file', required=False,
|
||||
help='Path to the log file. Use stdout if omitted.')
|
||||
"--log_file",
|
||||
"--log-file",
|
||||
required=False,
|
||||
help="Path to the log file. Use stdout if omitted.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--log_level', '--log-level', required=False, type=str,
|
||||
help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' +
|
||||
' Set to DEBUG for debug, WARNING for warnings only.')
|
||||
"--log_level",
|
||||
"--log-level",
|
||||
required=False,
|
||||
type=str,
|
||||
help=f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}."
|
||||
+ " Set to DEBUG for debug, WARNING for warnings only.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config_path', '--config-path', '--config-paths', '--config_paths',
|
||||
nargs="+", action='extend', required=False,
|
||||
help='One or more locations of JSON config files.')
|
||||
"--config_path",
|
||||
"--config-path",
|
||||
"--config-paths",
|
||||
"--config_paths",
|
||||
nargs="+",
|
||||
action="extend",
|
||||
required=False,
|
||||
help="One or more locations of JSON config files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--service', '--services',
|
||||
nargs='+', action='extend', required=False,
|
||||
help='Path to JSON file with the configuration of the service(s) for environment(s) to use.')
|
||||
"--service",
|
||||
"--services",
|
||||
nargs="+",
|
||||
action="extend",
|
||||
required=False,
|
||||
help=(
|
||||
"Path to JSON file with the configuration "
|
||||
"of the service(s) for environment(s) to use."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--environment', required=False,
|
||||
help='Path to JSON file with the configuration of the benchmarking environment(s).')
|
||||
"--environment",
|
||||
required=False,
|
||||
help="Path to JSON file with the configuration of the benchmarking environment(s).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--optimizer', required=False,
|
||||
help='Path to the optimizer configuration file. If omitted, run' +
|
||||
' a single trial with default (or specified in --tunable_values).')
|
||||
"--optimizer",
|
||||
required=False,
|
||||
help="Path to the optimizer configuration file. If omitted, run"
|
||||
+ " a single trial with default (or specified in --tunable_values).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int,
|
||||
help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.')
|
||||
"--trial_config_repeat_count",
|
||||
"--trial-config-repeat-count",
|
||||
required=False,
|
||||
type=int,
|
||||
help=(
|
||||
"Number of times to repeat each config. "
|
||||
"Default is 1 trial per config, though more may be advised."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--scheduler', required=False,
|
||||
help='Path to the scheduler configuration file. By default, use' +
|
||||
' a single worker synchronous scheduler.')
|
||||
"--scheduler",
|
||||
required=False,
|
||||
help="Path to the scheduler configuration file. By default, use"
|
||||
+ " a single worker synchronous scheduler.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--storage', required=False,
|
||||
help='Path to the storage configuration file.' +
|
||||
' If omitted, use the ephemeral in-memory SQL storage.')
|
||||
"--storage",
|
||||
required=False,
|
||||
help="Path to the storage configuration file."
|
||||
+ " If omitted, use the ephemeral in-memory SQL storage.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--random_init', '--random-init', required=False, default=False,
|
||||
dest='random_init', action='store_true',
|
||||
help='Initialize tunables with random values. (Before applying --tunable_values).')
|
||||
"--random_init",
|
||||
"--random-init",
|
||||
required=False,
|
||||
default=False,
|
||||
dest="random_init",
|
||||
action="store_true",
|
||||
help="Initialize tunables with random values. (Before applying --tunable_values).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--random_seed', '--random-seed', required=False, type=int,
|
||||
help='Seed to use with --random_init')
|
||||
"--random_seed",
|
||||
"--random-seed",
|
||||
required=False,
|
||||
type=int,
|
||||
help="Seed to use with --random_init",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--tunable_values', '--tunable-values', nargs="+", action='extend', required=False,
|
||||
help='Path to one or more JSON files that contain values of the tunable' +
|
||||
' parameters. This can be used for a single trial (when no --optimizer' +
|
||||
' is specified) or as default values for the first run in optimization.')
|
||||
"--tunable_values",
|
||||
"--tunable-values",
|
||||
nargs="+",
|
||||
action="extend",
|
||||
required=False,
|
||||
help="Path to one or more JSON files that contain values of the tunable"
|
||||
+ " parameters. This can be used for a single trial (when no --optimizer"
|
||||
+ " is specified) or as default values for the first run in optimization.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--globals', nargs="+", action='extend', required=False,
|
||||
help='Path to one or more JSON files that contain additional' +
|
||||
' [private] parameters of the benchmarking environment.')
|
||||
"--globals",
|
||||
nargs="+",
|
||||
action="extend",
|
||||
required=False,
|
||||
help="Path to one or more JSON files that contain additional"
|
||||
+ " [private] parameters of the benchmarking environment.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no_teardown', '--no-teardown', required=False, default=None,
|
||||
dest='teardown', action='store_false',
|
||||
help='Disable teardown of the environment after the benchmark.')
|
||||
"--no_teardown",
|
||||
"--no-teardown",
|
||||
required=False,
|
||||
default=None,
|
||||
dest="teardown",
|
||||
action="store_false",
|
||||
help="Disable teardown of the environment after the benchmark.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--experiment_id', '--experiment-id', required=False, default=None,
|
||||
"--experiment_id",
|
||||
"--experiment-id",
|
||||
required=False,
|
||||
default=None,
|
||||
help="""
|
||||
Experiment ID to use for the benchmark.
|
||||
If omitted, the value from the --cli config or --globals is used.
|
||||
|
@ -254,7 +319,7 @@ class Launcher:
|
|||
changes are made to config files, scripts, versions, etc.
|
||||
This is left as a manual operation as detection of what is
|
||||
"incompatible" is not easily automatable across systems.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
# By default we use the command line arguments, but allow the caller to
|
||||
|
@ -267,9 +332,7 @@ class Launcher:
|
|||
|
||||
@staticmethod
|
||||
def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]:
|
||||
"""
|
||||
Helper function to parse global key/value pairs from the command line.
|
||||
"""
|
||||
"""Helper function to parse global key/value pairs from the command line."""
|
||||
_LOG.debug("Extra args: %s", cmdline)
|
||||
|
||||
config: Dict[str, TunableValue] = {}
|
||||
|
@ -296,16 +359,17 @@ class Launcher:
|
|||
_LOG.debug("Parsed config: %s", config)
|
||||
return config
|
||||
|
||||
def _load_config(self,
|
||||
args_globals: Iterable[str],
|
||||
config_path: Iterable[str],
|
||||
args_rest: Iterable[str],
|
||||
global_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _load_config(
|
||||
self,
|
||||
args_globals: Iterable[str],
|
||||
config_path: Iterable[str],
|
||||
args_rest: Iterable[str],
|
||||
global_config: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Get key/value pairs of the global configuration parameters from the specified
|
||||
config files (if any) and command line arguments.
|
||||
"""
|
||||
Get key/value pairs of the global configuration parameters
|
||||
from the specified config files (if any) and command line arguments.
|
||||
"""
|
||||
for config_file in (args_globals or []):
|
||||
for config_file in args_globals or []:
|
||||
conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS)
|
||||
assert isinstance(conf, dict)
|
||||
global_config.update(conf)
|
||||
|
@ -314,19 +378,24 @@ class Launcher:
|
|||
global_config["config_path"] = config_path
|
||||
return global_config
|
||||
|
||||
def _init_tunable_values(self, random_init: bool, seed: Optional[int],
|
||||
args_tunables: Optional[str]) -> TunableGroups:
|
||||
"""
|
||||
Initialize the tunables and load key/value pairs of the tunable values
|
||||
from given JSON files, if specified.
|
||||
def _init_tunable_values(
|
||||
self,
|
||||
random_init: bool,
|
||||
seed: Optional[int],
|
||||
args_tunables: Optional[str],
|
||||
) -> TunableGroups:
|
||||
"""Initialize the tunables and load key/value pairs of the tunable values from
|
||||
given JSON files, if specified.
|
||||
"""
|
||||
tunables = self.environment.tunable_params
|
||||
_LOG.debug("Init tunables: default = %s", tunables)
|
||||
|
||||
if random_init:
|
||||
tunables = MockOptimizer(
|
||||
tunables=tunables, service=None,
|
||||
config={"start_with_defaults": False, "seed": seed}).suggest()
|
||||
tunables=tunables,
|
||||
service=None,
|
||||
config={"start_with_defaults": False, "seed": seed},
|
||||
).suggest()
|
||||
_LOG.debug("Init tunables: random = %s", tunables)
|
||||
|
||||
if args_tunables is not None:
|
||||
|
@ -340,50 +409,64 @@ class Launcher:
|
|||
|
||||
def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer:
|
||||
"""
|
||||
Instantiate the Optimizer object from JSON config file, if specified
|
||||
in the --optimizer command line option. If config file not specified,
|
||||
create a one-shot optimizer to run a single benchmark trial.
|
||||
Instantiate the Optimizer object from JSON config file, if specified in the
|
||||
--optimizer command line option.
|
||||
|
||||
If config file not specified, create a one-shot optimizer to run a single
|
||||
benchmark trial.
|
||||
"""
|
||||
if args_optimizer is None:
|
||||
# global_config may contain additional properties, so we need to
|
||||
# strip those out before instantiating the basic oneshot optimizer.
|
||||
config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS}
|
||||
return OneShotOptimizer(
|
||||
self.tunables, config=config, service=self._parent_service)
|
||||
config = {
|
||||
key: val
|
||||
for key, val in self.global_config.items()
|
||||
if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS
|
||||
}
|
||||
return OneShotOptimizer(self.tunables, config=config, service=self._parent_service)
|
||||
class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER)
|
||||
assert isinstance(class_config, Dict)
|
||||
optimizer = self._config_loader.build_optimizer(tunables=self.tunables,
|
||||
service=self._parent_service,
|
||||
config=class_config,
|
||||
global_config=self.global_config)
|
||||
optimizer = self._config_loader.build_optimizer(
|
||||
tunables=self.tunables,
|
||||
service=self._parent_service,
|
||||
config=class_config,
|
||||
global_config=self.global_config,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
def _load_storage(self, args_storage: Optional[str]) -> Storage:
|
||||
"""
|
||||
Instantiate the Storage object from JSON file provided in the --storage
|
||||
command line parameter. If omitted, create an ephemeral in-memory SQL
|
||||
storage instead.
|
||||
Instantiate the Storage object from JSON file provided in the --storage command
|
||||
line parameter.
|
||||
|
||||
If omitted, create an ephemeral in-memory SQL storage instead.
|
||||
"""
|
||||
if args_storage is None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from mlos_bench.storage.sql.storage import SqlStorage
|
||||
return SqlStorage(service=self._parent_service,
|
||||
config={
|
||||
"drivername": "sqlite",
|
||||
"database": ":memory:",
|
||||
"lazy_schema_create": True,
|
||||
})
|
||||
|
||||
return SqlStorage(
|
||||
service=self._parent_service,
|
||||
config={
|
||||
"drivername": "sqlite",
|
||||
"database": ":memory:",
|
||||
"lazy_schema_create": True,
|
||||
},
|
||||
)
|
||||
class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE)
|
||||
assert isinstance(class_config, Dict)
|
||||
storage = self._config_loader.build_storage(service=self._parent_service,
|
||||
config=class_config,
|
||||
global_config=self.global_config)
|
||||
storage = self._config_loader.build_storage(
|
||||
service=self._parent_service,
|
||||
config=class_config,
|
||||
global_config=self.global_config,
|
||||
)
|
||||
return storage
|
||||
|
||||
def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler:
|
||||
"""
|
||||
Instantiate the Scheduler object from JSON file provided in the --scheduler
|
||||
command line parameter.
|
||||
|
||||
Create a simple synchronous single-threaded scheduler if omitted.
|
||||
"""
|
||||
# Set `teardown` for scheduler only to prevent conflicts with other configs.
|
||||
|
@ -392,6 +475,7 @@ class Launcher:
|
|||
if args_scheduler is None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from mlos_bench.schedulers.sync_scheduler import SyncScheduler
|
||||
|
||||
return SyncScheduler(
|
||||
# All config values can be overridden from global config
|
||||
config={
|
||||
|
|
|
@ -2,18 +2,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Interfaces and wrapper classes for optimizers to be used in Autotune.
|
||||
"""
|
||||
"""Interfaces and wrapper classes for optimizers to be used in Autotune."""
|
||||
|
||||
from mlos_bench.optimizers.base_optimizer import Optimizer
|
||||
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
|
||||
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
|
||||
from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
|
||||
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
|
||||
|
||||
__all__ = [
|
||||
'Optimizer',
|
||||
'MockOptimizer',
|
||||
'OneShotOptimizer',
|
||||
'MlosCoreOptimizer',
|
||||
"Optimizer",
|
||||
"MockOptimizer",
|
||||
"OneShotOptimizer",
|
||||
"MlosCoreOptimizer",
|
||||
]
|
||||
|
|
|
@ -2,34 +2,32 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base class for an interface between the benchmarking framework
|
||||
and mlos_core optimizers.
|
||||
"""Base class for an interface between the benchmarking framework and mlos_core
|
||||
optimizers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from distutils.util import strtobool # pylint: disable=deprecated-module
|
||||
|
||||
from distutils.util import strtobool # pylint: disable=deprecated-module
|
||||
from types import TracebackType
|
||||
from typing import Dict, Optional, Sequence, Tuple, Type, Union
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ConfigSpace import ConfigurationSpace
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mlos_bench.config.schemas import ConfigSchema
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.optimizers.convert_configspace import tunable_groups_to_configspace
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.optimizers.convert_configspace import tunable_groups_to_configspace
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
An abstract interface between the benchmarking framework and mlos_core optimizers.
|
||||
class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
|
||||
"""An abstract interface between the benchmarking framework and mlos_core
|
||||
optimizers.
|
||||
"""
|
||||
|
||||
# See Also: mlos_bench/mlos_bench/config/schemas/optimizers/optimizer-schema.json
|
||||
|
@ -40,13 +38,16 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
"start_with_defaults",
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new optimizer for the given configuration space defined by the tunables.
|
||||
Create a new optimizer for the given configuration space defined by the
|
||||
tunables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -68,19 +69,20 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
self._seed = int(config.get("seed", 42))
|
||||
self._in_context = False
|
||||
|
||||
experiment_id = self._global_config.get('experiment_id')
|
||||
experiment_id = self._global_config.get("experiment_id")
|
||||
self.experiment_id = str(experiment_id).strip() if experiment_id else None
|
||||
|
||||
self._iter = 0
|
||||
# If False, use the optimizer to suggest the initial configuration;
|
||||
# if True (default), use the already initialized values for the first iteration.
|
||||
self._start_with_defaults: bool = bool(
|
||||
strtobool(str(self._config.pop('start_with_defaults', True))))
|
||||
self._max_iter = int(self._config.pop('max_suggestions', 100))
|
||||
strtobool(str(self._config.pop("start_with_defaults", True)))
|
||||
)
|
||||
self._max_iter = int(self._config.pop("max_suggestions", 100))
|
||||
|
||||
opt_targets: Dict[str, str] = self._config.pop('optimization_targets', {'score': 'min'})
|
||||
opt_targets: Dict[str, str] = self._config.pop("optimization_targets", {"score": "min"})
|
||||
self._opt_targets: Dict[str, Literal[1, -1]] = {}
|
||||
for (opt_target, opt_dir) in opt_targets.items():
|
||||
for opt_target, opt_dir in opt_targets.items():
|
||||
if opt_dir == "min":
|
||||
self._opt_targets[opt_target] = 1
|
||||
elif opt_dir == "max":
|
||||
|
@ -89,10 +91,9 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
raise ValueError(f"Invalid optimization direction: {opt_dir} for {opt_target}")
|
||||
|
||||
def _validate_json_config(self, config: dict) -> None:
|
||||
"""
|
||||
Reconstructs a basic json config that this class might have been
|
||||
instantiated from in order to validate configs provided outside the
|
||||
file loading mechanism.
|
||||
"""Reconstructs a basic json config that this class might have been instantiated
|
||||
from in order to validate configs provided outside the file loading
|
||||
mechanism.
|
||||
"""
|
||||
json_config: dict = {
|
||||
"class": self.__class__.__module__ + "." + self.__class__.__name__,
|
||||
|
@ -108,21 +109,20 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
)
|
||||
return f"{self.name}({opt_targets},config={self._config})"
|
||||
|
||||
def __enter__(self) -> 'Optimizer':
|
||||
"""
|
||||
Enter the optimizer's context.
|
||||
"""
|
||||
def __enter__(self) -> "Optimizer":
|
||||
"""Enter the optimizer's context."""
|
||||
_LOG.debug("Optimizer START :: %s", self)
|
||||
assert not self._in_context
|
||||
self._in_context = True
|
||||
return self
|
||||
|
||||
def __exit__(self, ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
"""
|
||||
Exit the context of the optimizer.
|
||||
"""
|
||||
def __exit__(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""Exit the context of the optimizer."""
|
||||
if ex_val is None:
|
||||
_LOG.debug("Optimizer END :: %s", self)
|
||||
else:
|
||||
|
@ -154,15 +154,14 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
|
||||
@property
|
||||
def seed(self) -> int:
|
||||
"""
|
||||
The random seed for the optimizer.
|
||||
"""
|
||||
"""The random seed for the optimizer."""
|
||||
return self._seed
|
||||
|
||||
@property
|
||||
def start_with_defaults(self) -> bool:
|
||||
"""
|
||||
Return True if the optimizer should start with the default values.
|
||||
|
||||
Note: This parameter is mutable and will be reset to False after the
|
||||
defaults are first suggested.
|
||||
"""
|
||||
|
@ -198,16 +197,16 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
The name of the optimizer. We save this information in
|
||||
mlos_bench storage to track the source of each configuration.
|
||||
The name of the optimizer.
|
||||
|
||||
We save this information in mlos_bench storage to track the source of each
|
||||
configuration.
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def targets(self) -> Dict[str, Literal['min', 'max']]:
|
||||
"""
|
||||
A dictionary of {target: direction} of optimization targets.
|
||||
"""
|
||||
def targets(self) -> Dict[str, Literal["min", "max"]]:
|
||||
"""A dictionary of {target: direction} of optimization targets."""
|
||||
return {
|
||||
opt_target: "min" if opt_dir == 1 else "max"
|
||||
for (opt_target, opt_dir) in self._opt_targets.items()
|
||||
|
@ -215,16 +214,18 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
|
||||
@property
|
||||
def supports_preload(self) -> bool:
|
||||
"""
|
||||
Return True if the optimizer supports pre-loading the data from previous experiments.
|
||||
"""Return True if the optimizer supports pre-loading the data from previous
|
||||
experiments.
|
||||
"""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def bulk_register(self,
|
||||
configs: Sequence[dict],
|
||||
scores: Sequence[Optional[Dict[str, TunableValue]]],
|
||||
status: Optional[Sequence[Status]] = None) -> bool:
|
||||
def bulk_register(
|
||||
self,
|
||||
configs: Sequence[dict],
|
||||
scores: Sequence[Optional[Dict[str, TunableValue]]],
|
||||
status: Optional[Sequence[Status]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Pre-load the optimizer with the bulk data from previous experiments.
|
||||
|
||||
|
@ -242,8 +243,12 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
is_not_empty : bool
|
||||
True if there is data to register, false otherwise.
|
||||
"""
|
||||
_LOG.info("Update the optimizer with: %d configs, %d scores, %d status values",
|
||||
len(configs or []), len(scores or []), len(status or []))
|
||||
_LOG.info(
|
||||
"Update the optimizer with: %d configs, %d scores, %d status values",
|
||||
len(configs or []),
|
||||
len(scores or []),
|
||||
len(status or []),
|
||||
)
|
||||
if len(configs or []) != len(scores or []):
|
||||
raise ValueError("Numbers of configs and scores do not match.")
|
||||
if status is not None and len(configs or []) != len(status or []):
|
||||
|
@ -256,9 +261,8 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
|
||||
def suggest(self) -> TunableGroups:
|
||||
"""
|
||||
Generate the next suggestion.
|
||||
Base class' implementation increments the iteration count
|
||||
and returns the current values of the tunables.
|
||||
Generate the next suggestion. Base class' implementation increments the
|
||||
iteration count and returns the current values of the tunables.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -272,8 +276,12 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
return self._tunables.copy()
|
||||
|
||||
@abstractmethod
|
||||
def register(self, tunables: TunableGroups, status: Status,
|
||||
score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
|
||||
def register(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
status: Status,
|
||||
score: Optional[Dict[str, TunableValue]] = None,
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""
|
||||
Register the observation for the given configuration.
|
||||
|
||||
|
@ -294,18 +302,25 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
Benchmark scores extracted (and possibly transformed)
|
||||
from the dataframe that's being MINIMIZED.
|
||||
"""
|
||||
_LOG.info("Iteration %d :: Register: %s = %s score: %s",
|
||||
self._iter, tunables, status, score)
|
||||
_LOG.info(
|
||||
"Iteration %d :: Register: %s = %s score: %s",
|
||||
self._iter,
|
||||
tunables,
|
||||
status,
|
||||
score,
|
||||
)
|
||||
if status.is_succeeded() == (score is None): # XOR
|
||||
raise ValueError("Status and score must be consistent.")
|
||||
return self._get_scores(status, score)
|
||||
|
||||
def _get_scores(self, status: Status,
|
||||
scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]]
|
||||
) -> Optional[Dict[str, float]]:
|
||||
def _get_scores(
|
||||
self,
|
||||
status: Status,
|
||||
scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]],
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""
|
||||
Extract a scalar benchmark score from the dataframe.
|
||||
Change the sign if we are maximizing.
|
||||
Extract a scalar benchmark score from the dataframe. Change the sign if we are
|
||||
maximizing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -331,7 +346,7 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
|
||||
assert scores is not None
|
||||
target_metrics: Dict[str, float] = {}
|
||||
for (opt_target, opt_dir) in self._opt_targets.items():
|
||||
for opt_target, opt_dir in self._opt_targets.items():
|
||||
val = scores[opt_target]
|
||||
assert val is not None
|
||||
target_metrics[opt_target] = float(val) * opt_dir
|
||||
|
@ -341,12 +356,15 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
|
|||
def not_converged(self) -> bool:
|
||||
"""
|
||||
Return True if not converged, False otherwise.
|
||||
|
||||
Base implementation just checks the iteration count.
|
||||
"""
|
||||
return self._iter < self._max_iter
|
||||
|
||||
@abstractmethod
|
||||
def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
|
||||
def get_best_observation(
|
||||
self,
|
||||
) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
|
||||
"""
|
||||
Get the best observation so far.
|
||||
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Functions to convert TunableGroups to ConfigSpace for use with the mlos_core optimizers.
|
||||
"""Functions to convert TunableGroups to ConfigSpace for use with the mlos_core
|
||||
optimizers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ConfigSpace import (
|
||||
|
@ -21,9 +20,10 @@ from ConfigSpace import (
|
|||
Normal,
|
||||
Uniform,
|
||||
)
|
||||
|
||||
from mlos_bench.tunables.tunable import Tunable, TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.util import try_parse_val, nullable
|
||||
from mlos_bench.util import nullable, try_parse_val
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
@ -31,6 +31,7 @@ _LOG = logging.getLogger(__name__)
|
|||
class TunableValueKind:
|
||||
"""
|
||||
Enum for the kind of the tunable value (special or not).
|
||||
|
||||
It is not a true enum because ConfigSpace wants string values.
|
||||
"""
|
||||
|
||||
|
@ -40,15 +41,16 @@ class TunableValueKind:
|
|||
|
||||
|
||||
def _normalize_weights(weights: List[float]) -> List[float]:
|
||||
"""
|
||||
Helper function for normalizing weights to probabilities.
|
||||
"""
|
||||
"""Helper function for normalizing weights to probabilities."""
|
||||
total = sum(weights)
|
||||
return [w / total for w in weights]
|
||||
|
||||
|
||||
def _tunable_to_configspace(
|
||||
tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> ConfigurationSpace:
|
||||
tunable: Tunable,
|
||||
group_name: Optional[str] = None,
|
||||
cost: int = 0,
|
||||
) -> ConfigurationSpace:
|
||||
"""
|
||||
Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects,
|
||||
wrapped in a ConfigurationSpace for composability.
|
||||
|
@ -71,14 +73,17 @@ def _tunable_to_configspace(
|
|||
meta = {"group": group_name, "cost": cost} # {"scaling": ""}
|
||||
|
||||
if tunable.type == "categorical":
|
||||
return ConfigurationSpace({
|
||||
tunable.name: CategoricalHyperparameter(
|
||||
name=tunable.name,
|
||||
choices=tunable.categories,
|
||||
weights=_normalize_weights(tunable.weights) if tunable.weights else None,
|
||||
default_value=tunable.default,
|
||||
meta=meta)
|
||||
})
|
||||
return ConfigurationSpace(
|
||||
{
|
||||
tunable.name: CategoricalHyperparameter(
|
||||
name=tunable.name,
|
||||
choices=tunable.categories,
|
||||
weights=_normalize_weights(tunable.weights) if tunable.weights else None,
|
||||
default_value=tunable.default,
|
||||
meta=meta,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
distribution: Union[Uniform, Normal, Beta, None] = None
|
||||
if tunable.distribution == "uniform":
|
||||
|
@ -86,12 +91,12 @@ def _tunable_to_configspace(
|
|||
elif tunable.distribution == "normal":
|
||||
distribution = Normal(
|
||||
mu=tunable.distribution_params["mu"],
|
||||
sigma=tunable.distribution_params["sigma"]
|
||||
sigma=tunable.distribution_params["sigma"],
|
||||
)
|
||||
elif tunable.distribution == "beta":
|
||||
distribution = Beta(
|
||||
alpha=tunable.distribution_params["alpha"],
|
||||
beta=tunable.distribution_params["beta"]
|
||||
beta=tunable.distribution_params["beta"],
|
||||
)
|
||||
elif tunable.distribution is not None:
|
||||
raise TypeError(f"Invalid Distribution Type: {tunable.distribution}")
|
||||
|
@ -103,22 +108,26 @@ def _tunable_to_configspace(
|
|||
log=bool(tunable.is_log),
|
||||
q=nullable(int, tunable.quantization),
|
||||
distribution=distribution,
|
||||
default=(int(tunable.default)
|
||||
if tunable.in_range(tunable.default) and tunable.default is not None
|
||||
else None),
|
||||
meta=meta
|
||||
default=(
|
||||
int(tunable.default)
|
||||
if tunable.in_range(tunable.default) and tunable.default is not None
|
||||
else None
|
||||
),
|
||||
meta=meta,
|
||||
)
|
||||
elif tunable.type == "float":
|
||||
range_hp = Float(
|
||||
name=tunable.name,
|
||||
bounds=tunable.range,
|
||||
log=bool(tunable.is_log),
|
||||
q=tunable.quantization, # type: ignore[arg-type]
|
||||
q=tunable.quantization, # type: ignore[arg-type]
|
||||
distribution=distribution, # type: ignore[arg-type]
|
||||
default=(float(tunable.default)
|
||||
if tunable.in_range(tunable.default) and tunable.default is not None
|
||||
else None),
|
||||
meta=meta
|
||||
default=(
|
||||
float(tunable.default)
|
||||
if tunable.in_range(tunable.default) and tunable.default is not None
|
||||
else None
|
||||
),
|
||||
meta=meta,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Invalid Parameter Type: {tunable.type}")
|
||||
|
@ -136,31 +145,38 @@ def _tunable_to_configspace(
|
|||
# Create three hyperparameters: one for regular values,
|
||||
# one for special values, and one to choose between the two.
|
||||
(special_name, type_name) = special_param_names(tunable.name)
|
||||
conf_space = ConfigurationSpace({
|
||||
tunable.name: range_hp,
|
||||
special_name: CategoricalHyperparameter(
|
||||
name=special_name,
|
||||
choices=tunable.special,
|
||||
weights=special_weights,
|
||||
default_value=tunable.default if tunable.default in tunable.special else None,
|
||||
meta=meta
|
||||
),
|
||||
type_name: CategoricalHyperparameter(
|
||||
name=type_name,
|
||||
choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE],
|
||||
weights=switch_weights,
|
||||
default_value=TunableValueKind.SPECIAL,
|
||||
),
|
||||
})
|
||||
conf_space.add_condition(EqualsCondition(
|
||||
conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL))
|
||||
conf_space.add_condition(EqualsCondition(
|
||||
conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE))
|
||||
conf_space = ConfigurationSpace(
|
||||
{
|
||||
tunable.name: range_hp,
|
||||
special_name: CategoricalHyperparameter(
|
||||
name=special_name,
|
||||
choices=tunable.special,
|
||||
weights=special_weights,
|
||||
default_value=tunable.default if tunable.default in tunable.special else None,
|
||||
meta=meta,
|
||||
),
|
||||
type_name: CategoricalHyperparameter(
|
||||
name=type_name,
|
||||
choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE],
|
||||
weights=switch_weights,
|
||||
default_value=TunableValueKind.SPECIAL,
|
||||
),
|
||||
}
|
||||
)
|
||||
conf_space.add_condition(
|
||||
EqualsCondition(conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL)
|
||||
)
|
||||
conf_space.add_condition(
|
||||
EqualsCondition(conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE)
|
||||
)
|
||||
|
||||
return conf_space
|
||||
|
||||
|
||||
def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = None) -> ConfigurationSpace:
|
||||
def tunable_groups_to_configspace(
|
||||
tunables: TunableGroups,
|
||||
seed: Optional[int] = None,
|
||||
) -> ConfigurationSpace:
|
||||
"""
|
||||
Convert TunableGroups to hyperparameters in ConfigurationSpace.
|
||||
|
||||
|
@ -178,11 +194,16 @@ def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] =
|
|||
A new ConfigurationSpace instance that corresponds to the input TunableGroups.
|
||||
"""
|
||||
space = ConfigurationSpace(seed=seed)
|
||||
for (tunable, group) in tunables:
|
||||
for tunable, group in tunables:
|
||||
space.add_configuration_space(
|
||||
prefix="", delimiter="",
|
||||
prefix="",
|
||||
delimiter="",
|
||||
configuration_space=_tunable_to_configspace(
|
||||
tunable, group.name, group.get_current_cost()))
|
||||
tunable,
|
||||
group.name,
|
||||
group.get_current_cost(),
|
||||
),
|
||||
)
|
||||
return space
|
||||
|
||||
|
||||
|
@ -201,7 +222,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration:
|
|||
A ConfigSpace Configuration.
|
||||
"""
|
||||
values: Dict[str, TunableValue] = {}
|
||||
for (tunable, _group) in tunables:
|
||||
for tunable, _group in tunables:
|
||||
if tunable.special:
|
||||
(special_name, type_name) = special_param_names(tunable.name)
|
||||
if tunable.value in tunable.special:
|
||||
|
@ -219,13 +240,11 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration:
|
|||
def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]:
|
||||
"""
|
||||
Remove the fields that correspond to special values in ConfigSpace.
|
||||
|
||||
In particular, remove and keys suffixes added by `special_param_names`.
|
||||
"""
|
||||
data = data.copy()
|
||||
specials = [
|
||||
special_param_name_strip(k)
|
||||
for k in data.keys() if special_param_name_is_temp(k)
|
||||
]
|
||||
specials = [special_param_name_strip(k) for k in data.keys() if special_param_name_is_temp(k)]
|
||||
for k in specials:
|
||||
(special_name, type_name) = special_param_names(k)
|
||||
if data[type_name] == TunableValueKind.SPECIAL:
|
||||
|
@ -240,8 +259,8 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]:
|
|||
|
||||
def special_param_names(name: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate the names of the auxiliary hyperparameters that correspond
|
||||
to a tunable that can have special values.
|
||||
Generate the names of the auxiliary hyperparameters that correspond to a tunable
|
||||
that can have special values.
|
||||
|
||||
NOTE: `!` characters are currently disallowed in Tunable names in order handle this logic.
|
||||
|
||||
|
|
|
@ -2,38 +2,35 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Grid search optimizer for mlos_bench.
|
||||
"""
|
||||
"""Grid search optimizer for mlos_bench."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Iterable, Optional, Sequence, Set, Tuple
|
||||
|
||||
from typing import Dict, Iterable, Set, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import ConfigSpace
|
||||
import numpy as np
|
||||
from ConfigSpace.util import generate_grid
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
|
||||
from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
|
||||
from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
|
||||
from mlos_bench.services.base_service import Service
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GridSearchOptimizer(TrackBestOptimizer):
|
||||
"""
|
||||
Grid search optimizer.
|
||||
"""
|
||||
"""Grid search optimizer."""
|
||||
|
||||
def __init__(self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
super().__init__(tunables, config, global_config, service)
|
||||
|
||||
# Track the grid as a set of tuples of tunable values and reconstruct the
|
||||
|
@ -52,11 +49,21 @@ class GridSearchOptimizer(TrackBestOptimizer):
|
|||
def _sanity_check(self) -> None:
|
||||
size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables])
|
||||
if size == np.inf:
|
||||
raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}")
|
||||
raise ValueError(
|
||||
f"Unquantized tunables are not supported for grid search: {self._tunables}"
|
||||
)
|
||||
if size > 10000:
|
||||
_LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables)
|
||||
_LOG.warning(
|
||||
"Large number %d of config points requested for grid search: %s",
|
||||
size,
|
||||
self._tunables,
|
||||
)
|
||||
if size > self._max_iter:
|
||||
_LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter)
|
||||
_LOG.warning(
|
||||
"Grid search size %d, is greater than max iterations %d",
|
||||
size,
|
||||
self._max_iter,
|
||||
)
|
||||
|
||||
def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]:
|
||||
"""
|
||||
|
@ -69,12 +76,14 @@ class GridSearchOptimizer(TrackBestOptimizer):
|
|||
# names instead of the order given by TunableGroups.
|
||||
configs = [
|
||||
configspace_data_to_tunable_values(dict(config))
|
||||
for config in
|
||||
generate_grid(self.config_space, {
|
||||
tunable.name: int(tunable.cardinality)
|
||||
for (tunable, _group) in self._tunables
|
||||
if tunable.quantization or tunable.type == "int"
|
||||
})
|
||||
for config in generate_grid(
|
||||
self.config_space,
|
||||
{
|
||||
tunable.name: int(tunable.cardinality)
|
||||
for (tunable, _group) in self._tunables
|
||||
if tunable.quantization or tunable.type == "int"
|
||||
},
|
||||
)
|
||||
]
|
||||
names = set(tuple(configs.keys()) for configs in configs)
|
||||
assert len(names) == 1
|
||||
|
@ -104,15 +113,17 @@ class GridSearchOptimizer(TrackBestOptimizer):
|
|||
# See NOTEs above.
|
||||
return (dict(zip(self._config_keys, config)) for config in self._suggested_configs)
|
||||
|
||||
def bulk_register(self,
|
||||
configs: Sequence[dict],
|
||||
scores: Sequence[Optional[Dict[str, TunableValue]]],
|
||||
status: Optional[Sequence[Status]] = None) -> bool:
|
||||
def bulk_register(
|
||||
self,
|
||||
configs: Sequence[dict],
|
||||
scores: Sequence[Optional[Dict[str, TunableValue]]],
|
||||
status: Optional[Sequence[Status]] = None,
|
||||
) -> bool:
|
||||
if not super().bulk_register(configs, scores, status):
|
||||
return False
|
||||
if status is None:
|
||||
status = [Status.SUCCEEDED] * len(configs)
|
||||
for (params, score, trial_status) in zip(configs, scores, status):
|
||||
for params, score, trial_status in zip(configs, scores, status):
|
||||
tunables = self._tunables.copy().assign(params)
|
||||
self.register(tunables, trial_status, score)
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
|
@ -121,9 +132,7 @@ class GridSearchOptimizer(TrackBestOptimizer):
|
|||
return True
|
||||
|
||||
def suggest(self) -> TunableGroups:
|
||||
"""
|
||||
Generate the next grid search suggestion.
|
||||
"""
|
||||
"""Generate the next grid search suggestion."""
|
||||
tunables = super().suggest()
|
||||
if self._start_with_defaults:
|
||||
_LOG.info("Use default values for the first trial")
|
||||
|
@ -153,20 +162,35 @@ class GridSearchOptimizer(TrackBestOptimizer):
|
|||
_LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
|
||||
return tunables
|
||||
|
||||
def register(self, tunables: TunableGroups, status: Status,
|
||||
score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
|
||||
def register(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
status: Status,
|
||||
score: Optional[Dict[str, TunableValue]] = None,
|
||||
) -> Optional[Dict[str, float]]:
|
||||
registered_score = super().register(tunables, status, score)
|
||||
try:
|
||||
config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values()))
|
||||
config = dict(
|
||||
ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())
|
||||
)
|
||||
self._suggested_configs.remove(tuple(config.values()))
|
||||
except KeyError:
|
||||
_LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables)
|
||||
_LOG.warning(
|
||||
(
|
||||
"Attempted to remove missing config "
|
||||
"(previously registered?) from suggested set: %s"
|
||||
),
|
||||
tunables,
|
||||
)
|
||||
return registered_score
|
||||
|
||||
def not_converged(self) -> bool:
|
||||
if self._iter > self._max_iter:
|
||||
if bool(self._pending_configs):
|
||||
_LOG.warning("Exceeded max iterations, but still have %d pending configs: %s",
|
||||
len(self._pending_configs), list(self._pending_configs.keys()))
|
||||
_LOG.warning(
|
||||
"Exceeded max iterations, but still have %d pending configs: %s",
|
||||
len(self._pending_configs),
|
||||
list(self._pending_configs.keys()),
|
||||
)
|
||||
return False
|
||||
return bool(self._pending_configs)
|
||||
|
|
|
@ -2,72 +2,78 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A wrapper for mlos_core optimizers for mlos_bench.
|
||||
"""
|
||||
"""A wrapper for mlos_core optimizers for mlos_bench."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Dict, Optional, Sequence, Tuple, Type, Union
|
||||
from typing_extensions import Literal
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from mlos_core.optimizers import (
|
||||
BaseOptimizer, OptimizerType, OptimizerFactory, SpaceAdapterType, DEFAULT_OPTIMIZER_TYPE
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.optimizers.base_optimizer import Optimizer
|
||||
|
||||
from mlos_bench.optimizers.convert_configspace import (
|
||||
TunableValueKind,
|
||||
configspace_data_to_tunable_values,
|
||||
special_param_names,
|
||||
)
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_core.optimizers import (
|
||||
DEFAULT_OPTIMIZER_TYPE,
|
||||
BaseOptimizer,
|
||||
OptimizerFactory,
|
||||
OptimizerType,
|
||||
SpaceAdapterType,
|
||||
)
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MlosCoreOptimizer(Optimizer):
|
||||
"""
|
||||
A wrapper class for the mlos_core optimizers.
|
||||
"""
|
||||
"""A wrapper class for the mlos_core optimizers."""
|
||||
|
||||
def __init__(self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
super().__init__(tunables, config, global_config, service)
|
||||
|
||||
opt_type = getattr(OptimizerType, self._config.pop(
|
||||
'optimizer_type', DEFAULT_OPTIMIZER_TYPE.name))
|
||||
opt_type = getattr(
|
||||
OptimizerType, self._config.pop("optimizer_type", DEFAULT_OPTIMIZER_TYPE.name)
|
||||
)
|
||||
|
||||
if opt_type == OptimizerType.SMAC:
|
||||
output_directory = self._config.get('output_directory')
|
||||
output_directory = self._config.get("output_directory")
|
||||
if output_directory is not None:
|
||||
# If output_directory is specified, turn it into an absolute path.
|
||||
self._config['output_directory'] = os.path.abspath(output_directory)
|
||||
self._config["output_directory"] = os.path.abspath(output_directory)
|
||||
else:
|
||||
_LOG.warning("SMAC optimizer output_directory was null. SMAC will use a temporary directory.")
|
||||
_LOG.warning(
|
||||
(
|
||||
"SMAC optimizer output_directory was null. "
|
||||
"SMAC will use a temporary directory."
|
||||
)
|
||||
)
|
||||
|
||||
# Make sure max_trials >= max_iterations.
|
||||
if 'max_trials' not in self._config:
|
||||
self._config['max_trials'] = self._max_iter
|
||||
assert int(self._config['max_trials']) >= self._max_iter, \
|
||||
f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}"
|
||||
if "max_trials" not in self._config:
|
||||
self._config["max_trials"] = self._max_iter
|
||||
assert (
|
||||
int(self._config["max_trials"]) >= self._max_iter
|
||||
), f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}"
|
||||
|
||||
if 'run_name' not in self._config and self.experiment_id:
|
||||
self._config['run_name'] = self.experiment_id
|
||||
if "run_name" not in self._config and self.experiment_id:
|
||||
self._config["run_name"] = self.experiment_id
|
||||
|
||||
space_adapter_type = self._config.pop('space_adapter_type', None)
|
||||
space_adapter_config = self._config.pop('space_adapter_config', {})
|
||||
space_adapter_type = self._config.pop("space_adapter_type", None)
|
||||
space_adapter_config = self._config.pop("space_adapter_config", {})
|
||||
|
||||
if space_adapter_type is not None:
|
||||
space_adapter_type = getattr(SpaceAdapterType, space_adapter_type)
|
||||
|
@ -81,9 +87,12 @@ class MlosCoreOptimizer(Optimizer):
|
|||
space_adapter_kwargs=space_adapter_config,
|
||||
)
|
||||
|
||||
def __exit__(self, ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
def __exit__(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
self._opt.cleanup()
|
||||
return super().__exit__(ex_type, ex_val, ex_tb)
|
||||
|
||||
|
@ -91,10 +100,12 @@ class MlosCoreOptimizer(Optimizer):
|
|||
def name(self) -> str:
|
||||
return f"{self.__class__.__name__}:{self._opt.__class__.__name__}"
|
||||
|
||||
def bulk_register(self,
|
||||
configs: Sequence[dict],
|
||||
scores: Sequence[Optional[Dict[str, TunableValue]]],
|
||||
status: Optional[Sequence[Status]] = None) -> bool:
|
||||
def bulk_register(
|
||||
self,
|
||||
configs: Sequence[dict],
|
||||
scores: Sequence[Optional[Dict[str, TunableValue]]],
|
||||
status: Optional[Sequence[Status]] = None,
|
||||
) -> bool:
|
||||
|
||||
if not super().bulk_register(configs, scores, status):
|
||||
return False
|
||||
|
@ -102,7 +113,8 @@ class MlosCoreOptimizer(Optimizer):
|
|||
df_configs = self._to_df(configs) # Impute missing values, if necessary
|
||||
|
||||
df_scores = self._adjust_signs_df(
|
||||
pd.DataFrame([{} if score is None else score for score in scores]))
|
||||
pd.DataFrame([{} if score is None else score for score in scores])
|
||||
)
|
||||
|
||||
opt_targets = list(self._opt_targets)
|
||||
if status is not None:
|
||||
|
@ -126,17 +138,15 @@ class MlosCoreOptimizer(Optimizer):
|
|||
return True
|
||||
|
||||
def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
In-place adjust the signs of the scores for MINIMIZATION problem.
|
||||
"""
|
||||
for (opt_target, opt_dir) in self._opt_targets.items():
|
||||
"""In-place adjust the signs of the scores for MINIMIZATION problem."""
|
||||
for opt_target, opt_dir in self._opt_targets.items():
|
||||
df_scores[opt_target] *= opt_dir
|
||||
return df_scores
|
||||
|
||||
def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame:
|
||||
"""
|
||||
Select from past trials only the columns required in this experiment and
|
||||
impute default values for the tunables that are missing in the dataframe.
|
||||
Select from past trials only the columns required in this experiment and impute
|
||||
default values for the tunables that are missing in the dataframe.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -151,7 +161,7 @@ class MlosCoreOptimizer(Optimizer):
|
|||
df_configs = pd.DataFrame(configs)
|
||||
tunables_names = list(self._tunables.get_param_values().keys())
|
||||
missing_cols = set(tunables_names).difference(df_configs.columns)
|
||||
for (tunable, _group) in self._tunables:
|
||||
for tunable, _group in self._tunables:
|
||||
if tunable.name in missing_cols:
|
||||
df_configs[tunable.name] = tunable.default
|
||||
else:
|
||||
|
@ -183,22 +193,34 @@ class MlosCoreOptimizer(Optimizer):
|
|||
df_config, _metadata = self._opt.suggest(defaults=self._start_with_defaults)
|
||||
self._start_with_defaults = False
|
||||
_LOG.info("Iteration %d :: Suggest:\n%s", self._iter, df_config)
|
||||
return tunables.assign(
|
||||
configspace_data_to_tunable_values(df_config.loc[0].to_dict()))
|
||||
return tunables.assign(configspace_data_to_tunable_values(df_config.loc[0].to_dict()))
|
||||
|
||||
def register(self, tunables: TunableGroups, status: Status,
|
||||
score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
|
||||
registered_score = super().register(tunables, status, score) # Sign-adjusted for MINIMIZATION
|
||||
def register(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
status: Status,
|
||||
score: Optional[Dict[str, TunableValue]] = None,
|
||||
) -> Optional[Dict[str, float]]:
|
||||
registered_score = super().register(
|
||||
tunables,
|
||||
status,
|
||||
score,
|
||||
) # Sign-adjusted for MINIMIZATION
|
||||
if status.is_completed():
|
||||
assert registered_score is not None
|
||||
df_config = self._to_df([tunables.get_param_values()])
|
||||
_LOG.debug("Score: %s Dataframe:\n%s", registered_score, df_config)
|
||||
# TODO: Specify (in the config) which metrics to pass to the optimizer.
|
||||
# Issue: https://github.com/microsoft/MLOS/issues/745
|
||||
self._opt.register(configs=df_config, scores=pd.DataFrame([registered_score], dtype=float))
|
||||
self._opt.register(
|
||||
configs=df_config,
|
||||
scores=pd.DataFrame([registered_score], dtype=float),
|
||||
)
|
||||
return registered_score
|
||||
|
||||
def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
|
||||
def get_best_observation(
|
||||
self,
|
||||
) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
|
||||
(df_config, df_score, _df_context) = self._opt.get_best_observations()
|
||||
if len(df_config) == 0:
|
||||
return (None, None)
|
||||
|
|
|
@ -2,35 +2,31 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Mock optimizer for mlos_bench.
|
||||
"""
|
||||
"""Mock optimizer for mlos_bench."""
|
||||
|
||||
import random
|
||||
import logging
|
||||
|
||||
import random
|
||||
from typing import Callable, Dict, Optional, Sequence
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.tunables.tunable import Tunable, TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable import Tunable, TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockOptimizer(TrackBestOptimizer):
|
||||
"""
|
||||
Mock optimizer to test the Environment API.
|
||||
"""
|
||||
"""Mock optimizer to test the Environment API."""
|
||||
|
||||
def __init__(self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
super().__init__(tunables, config, global_config, service)
|
||||
rnd = random.Random(self.seed)
|
||||
self._random: Dict[str, Callable[[Tunable], TunableValue]] = {
|
||||
|
@ -39,15 +35,17 @@ class MockOptimizer(TrackBestOptimizer):
|
|||
"int": lambda tunable: rnd.randint(*tunable.range),
|
||||
}
|
||||
|
||||
def bulk_register(self,
|
||||
configs: Sequence[dict],
|
||||
scores: Sequence[Optional[Dict[str, TunableValue]]],
|
||||
status: Optional[Sequence[Status]] = None) -> bool:
|
||||
def bulk_register(
|
||||
self,
|
||||
configs: Sequence[dict],
|
||||
scores: Sequence[Optional[Dict[str, TunableValue]]],
|
||||
status: Optional[Sequence[Status]] = None,
|
||||
) -> bool:
|
||||
if not super().bulk_register(configs, scores, status):
|
||||
return False
|
||||
if status is None:
|
||||
status = [Status.SUCCEEDED] * len(configs)
|
||||
for (params, score, trial_status) in zip(configs, scores, status):
|
||||
for params, score, trial_status in zip(configs, scores, status):
|
||||
tunables = self._tunables.copy().assign(params)
|
||||
self.register(tunables, trial_status, score)
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
|
@ -56,15 +54,13 @@ class MockOptimizer(TrackBestOptimizer):
|
|||
return True
|
||||
|
||||
def suggest(self) -> TunableGroups:
|
||||
"""
|
||||
Generate the next (random) suggestion.
|
||||
"""
|
||||
"""Generate the next (random) suggestion."""
|
||||
tunables = super().suggest()
|
||||
if self._start_with_defaults:
|
||||
_LOG.info("Use default tunable values")
|
||||
self._start_with_defaults = False
|
||||
else:
|
||||
for (tunable, _group) in tunables:
|
||||
for tunable, _group in tunables:
|
||||
tunable.value = self._random[tunable.type](tunable)
|
||||
_LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
|
||||
return tunables
|
||||
|
|
|
@ -2,16 +2,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
No-op optimizer for mlos_bench that proposes a single configuration.
|
||||
"""
|
||||
"""No-op optimizer for mlos_bench that proposes a single configuration."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
@ -19,24 +17,25 @@ _LOG = logging.getLogger(__name__)
|
|||
class OneShotOptimizer(MockOptimizer):
|
||||
"""
|
||||
No-op optimizer that proposes a single configuration and returns.
|
||||
|
||||
Explicit configs (partial or full) are possible using configuration files.
|
||||
"""
|
||||
|
||||
# TODO: Add support for multiple explicit configs (i.e., FewShot or Manual Optimizer) - #344
|
||||
|
||||
def __init__(self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
super().__init__(tunables, config, global_config, service)
|
||||
_LOG.info("Run a single iteration for: %s", self._tunables)
|
||||
self._max_iter = 1 # Always run for just one iteration.
|
||||
|
||||
def suggest(self) -> TunableGroups:
|
||||
"""
|
||||
Always produce the same (initial) suggestion.
|
||||
"""
|
||||
"""Always produce the same (initial) suggestion."""
|
||||
tunables = super().suggest()
|
||||
self._start_with_defaults = True
|
||||
return tunables
|
||||
|
|
|
@ -2,40 +2,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Mock optimizer for mlos_bench.
|
||||
"""
|
||||
"""Mock optimizer for mlos_bench."""
|
||||
|
||||
import logging
|
||||
from abc import ABCMeta
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
from mlos_bench.optimizers.base_optimizer import Optimizer
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
|
||||
"""
|
||||
Base Optimizer class that keeps track of the best score and configuration.
|
||||
"""
|
||||
"""Base Optimizer class that keeps track of the best score and configuration."""
|
||||
|
||||
def __init__(self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
super().__init__(tunables, config, global_config, service)
|
||||
self._best_config: Optional[TunableGroups] = None
|
||||
self._best_score: Optional[Dict[str, float]] = None
|
||||
|
||||
def register(self, tunables: TunableGroups, status: Status,
|
||||
score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
|
||||
def register(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
status: Status,
|
||||
score: Optional[Dict[str, TunableValue]] = None,
|
||||
) -> Optional[Dict[str, float]]:
|
||||
registered_score = super().register(tunables, status, score)
|
||||
if status.is_succeeded() and self._is_better(registered_score):
|
||||
self._best_score = registered_score
|
||||
|
@ -43,13 +44,11 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
|
|||
return registered_score
|
||||
|
||||
def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool:
|
||||
"""
|
||||
Compare the optimization scores to the best ones so far lexicographically.
|
||||
"""
|
||||
"""Compare the optimization scores to the best ones so far lexicographically."""
|
||||
if self._best_score is None:
|
||||
return True
|
||||
assert registered_score is not None
|
||||
for (opt_target, best_score) in self._best_score.items():
|
||||
for opt_target, best_score in self._best_score.items():
|
||||
score = registered_score[opt_target]
|
||||
if score < best_score:
|
||||
return True
|
||||
|
@ -57,7 +56,9 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
|
|||
return False
|
||||
return False
|
||||
|
||||
def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
|
||||
def get_best_observation(
|
||||
self,
|
||||
) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
|
||||
if self._best_score is None:
|
||||
return (None, None)
|
||||
score = self._get_scores(Status.SUCCEEDED, self._best_score)
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Simple platform agnostic abstraction for the OS environment variables.
|
||||
Meant as a replacement for os.environ vs nt.environ.
|
||||
Simple platform agnostic abstraction for the OS environment variables. Meant as a
|
||||
replacement for os.environ vs nt.environ.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
@ -22,16 +22,18 @@ else:
|
|||
from typing_extensions import TypeAlias
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
EnvironType: TypeAlias = os._Environ[str] # pylint: disable=protected-access,disable=unsubscriptable-object
|
||||
# pylint: disable=protected-access,disable=unsubscriptable-object
|
||||
EnvironType: TypeAlias = os._Environ[str]
|
||||
else:
|
||||
EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access
|
||||
EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access
|
||||
|
||||
# Handle case sensitivity differences between platforms.
|
||||
# https://stackoverflow.com/a/19023293
|
||||
if sys.platform == 'win32':
|
||||
import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8)
|
||||
if sys.platform == "win32":
|
||||
import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8)
|
||||
|
||||
environ: EnvironType = nt.environ
|
||||
else:
|
||||
environ: EnvironType = os.environ
|
||||
|
||||
__all__ = ['environ']
|
||||
__all__ = ["environ"]
|
||||
|
|
|
@ -20,8 +20,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
|
|||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _main(argv: Optional[List[str]] = None
|
||||
) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
|
||||
def _main(
|
||||
argv: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
|
||||
|
||||
launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv)
|
||||
|
||||
|
|
|
@ -2,14 +2,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Interfaces and implementations of the optimization loop scheduling policies.
|
||||
"""
|
||||
"""Interfaces and implementations of the optimization loop scheduling policies."""
|
||||
|
||||
from mlos_bench.schedulers.base_scheduler import Scheduler
|
||||
from mlos_bench.schedulers.sync_scheduler import SyncScheduler
|
||||
|
||||
__all__ = [
|
||||
'Scheduler',
|
||||
'SyncScheduler',
|
||||
"Scheduler",
|
||||
"SyncScheduler",
|
||||
]
|
||||
|
|
|
@ -2,20 +2,17 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base class for the optimization loop scheduling policies.
|
||||
"""
|
||||
"""Base class for the optimization loop scheduling policies."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
from typing_extensions import Literal
|
||||
|
||||
from pytz import UTC
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.optimizers.base_optimizer import Optimizer
|
||||
|
@ -28,22 +25,23 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
class Scheduler(metaclass=ABCMeta):
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
Base class for the optimization loop scheduling policies.
|
||||
"""
|
||||
"""Base class for the optimization loop scheduling policies."""
|
||||
|
||||
def __init__(self, *,
|
||||
config: Dict[str, Any],
|
||||
global_config: Dict[str, Any],
|
||||
environment: Environment,
|
||||
optimizer: Optimizer,
|
||||
storage: Storage,
|
||||
root_env_config: str):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: Dict[str, Any],
|
||||
global_config: Dict[str, Any],
|
||||
environment: Environment,
|
||||
optimizer: Optimizer,
|
||||
storage: Storage,
|
||||
root_env_config: str,
|
||||
):
|
||||
"""
|
||||
Create a new instance of the scheduler. The constructor of this
|
||||
and the derived classes is called by the persistence service
|
||||
after reading the class JSON configuration. Other objects like
|
||||
the Environment and Optimizer are provided by the Launcher.
|
||||
Create a new instance of the scheduler. The constructor of this and the derived
|
||||
classes is called by the persistence service after reading the class JSON
|
||||
configuration. Other objects like the Environment and Optimizer are provided by
|
||||
the Launcher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -61,8 +59,11 @@ class Scheduler(metaclass=ABCMeta):
|
|||
Path to the root environment configuration.
|
||||
"""
|
||||
self.global_config = global_config
|
||||
config = merge_parameters(dest=config.copy(), source=global_config,
|
||||
required_keys=["experiment_id", "trial_id"])
|
||||
config = merge_parameters(
|
||||
dest=config.copy(),
|
||||
source=global_config,
|
||||
required_keys=["experiment_id", "trial_id"],
|
||||
)
|
||||
|
||||
self._experiment_id = config["experiment_id"].strip()
|
||||
self._trial_id = int(config["trial_id"])
|
||||
|
@ -72,7 +73,9 @@ class Scheduler(metaclass=ABCMeta):
|
|||
|
||||
self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1))
|
||||
if self._trial_config_repeat_count <= 0:
|
||||
raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}")
|
||||
raise ValueError(
|
||||
f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}"
|
||||
)
|
||||
|
||||
self._do_teardown = bool(config.get("teardown", True))
|
||||
|
||||
|
@ -96,10 +99,8 @@ class Scheduler(metaclass=ABCMeta):
|
|||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __enter__(self) -> 'Scheduler':
|
||||
"""
|
||||
Enter the scheduler's context.
|
||||
"""
|
||||
def __enter__(self) -> "Scheduler":
|
||||
"""Enter the scheduler's context."""
|
||||
_LOG.debug("Scheduler START :: %s", self)
|
||||
assert self.experiment is None
|
||||
self.environment.__enter__()
|
||||
|
@ -118,13 +119,13 @@ class Scheduler(metaclass=ABCMeta):
|
|||
).__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
"""
|
||||
Exit the context of the scheduler.
|
||||
"""
|
||||
def __exit__(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""Exit the context of the scheduler."""
|
||||
if ex_val is None:
|
||||
_LOG.debug("Scheduler END :: %s", self)
|
||||
else:
|
||||
|
@ -139,12 +140,14 @@ class Scheduler(metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the optimization loop.
|
||||
"""
|
||||
"""Start the optimization loop."""
|
||||
assert self.experiment is not None
|
||||
_LOG.info("START: Experiment: %s Env: %s Optimizer: %s",
|
||||
self.experiment, self.environment, self.optimizer)
|
||||
_LOG.info(
|
||||
"START: Experiment: %s Env: %s Optimizer: %s",
|
||||
self.experiment,
|
||||
self.environment,
|
||||
self.optimizer,
|
||||
)
|
||||
if _LOG.isEnabledFor(logging.INFO):
|
||||
_LOG.info("Root Environment:\n%s", self.environment.pprint())
|
||||
|
||||
|
@ -155,6 +158,7 @@ class Scheduler(metaclass=ABCMeta):
|
|||
def teardown(self) -> None:
|
||||
"""
|
||||
Tear down the environment.
|
||||
|
||||
Call it after the completion of the `.start()` in the scheduler context.
|
||||
"""
|
||||
assert self.experiment is not None
|
||||
|
@ -162,17 +166,13 @@ class Scheduler(metaclass=ABCMeta):
|
|||
self.environment.teardown()
|
||||
|
||||
def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
|
||||
"""
|
||||
Get the best observation from the optimizer.
|
||||
"""
|
||||
"""Get the best observation from the optimizer."""
|
||||
(best_score, best_config) = self.optimizer.get_best_observation()
|
||||
_LOG.info("Env: %s best score: %s", self.environment, best_score)
|
||||
return (best_score, best_config)
|
||||
|
||||
def load_config(self, config_id: int) -> TunableGroups:
|
||||
"""
|
||||
Load the existing tunable configuration from the storage.
|
||||
"""
|
||||
"""Load the existing tunable configuration from the storage."""
|
||||
assert self.experiment is not None
|
||||
tunable_values = self.experiment.load_tunable_config(config_id)
|
||||
tunables = self.environment.tunable_params.assign(tunable_values)
|
||||
|
@ -183,9 +183,11 @@ class Scheduler(metaclass=ABCMeta):
|
|||
|
||||
def _schedule_new_optimizer_suggestions(self) -> bool:
|
||||
"""
|
||||
Optimizer part of the loop. Load the results of the executed trials
|
||||
into the optimizer, suggest new configurations, and add them to the queue.
|
||||
Return True if optimization is not over, False otherwise.
|
||||
Optimizer part of the loop.
|
||||
|
||||
Load the results of the executed trials into the optimizer, suggest new
|
||||
configurations, and add them to the queue. Return True if optimization is not
|
||||
over, False otherwise.
|
||||
"""
|
||||
assert self.experiment is not None
|
||||
(trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id)
|
||||
|
@ -201,33 +203,38 @@ class Scheduler(metaclass=ABCMeta):
|
|||
return not_done
|
||||
|
||||
def schedule_trial(self, tunables: TunableGroups) -> None:
|
||||
"""
|
||||
Add a configuration to the queue of trials.
|
||||
"""
|
||||
"""Add a configuration to the queue of trials."""
|
||||
for repeat_i in range(1, self._trial_config_repeat_count + 1):
|
||||
self._add_trial_to_queue(tunables, config={
|
||||
# Add some additional metadata to track for the trial such as the
|
||||
# optimizer config used.
|
||||
# Note: these values are unfortunately mutable at the moment.
|
||||
# Consider them as hints of what the config was the trial *started*.
|
||||
# It is possible that the experiment configs were changed
|
||||
# between resuming the experiment (since that is not currently
|
||||
# prevented).
|
||||
"optimizer": self.optimizer.name,
|
||||
"repeat_i": repeat_i,
|
||||
"is_defaults": tunables.is_defaults(),
|
||||
**{
|
||||
f"opt_{key}_{i}": val
|
||||
for (i, opt_target) in enumerate(self.optimizer.targets.items())
|
||||
for (key, val) in zip(["target", "direction"], opt_target)
|
||||
}
|
||||
})
|
||||
self._add_trial_to_queue(
|
||||
tunables,
|
||||
config={
|
||||
# Add some additional metadata to track for the trial such as the
|
||||
# optimizer config used.
|
||||
# Note: these values are unfortunately mutable at the moment.
|
||||
# Consider them as hints of what the config was the trial *started*.
|
||||
# It is possible that the experiment configs were changed
|
||||
# between resuming the experiment (since that is not currently
|
||||
# prevented).
|
||||
"optimizer": self.optimizer.name,
|
||||
"repeat_i": repeat_i,
|
||||
"is_defaults": tunables.is_defaults(),
|
||||
**{
|
||||
f"opt_{key}_{i}": val
|
||||
for (i, opt_target) in enumerate(self.optimizer.targets.items())
|
||||
for (key, val) in zip(["target", "direction"], opt_target)
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _add_trial_to_queue(self, tunables: TunableGroups,
|
||||
ts_start: Optional[datetime] = None,
|
||||
config: Optional[Dict[str, Any]] = None) -> None:
|
||||
def _add_trial_to_queue(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
ts_start: Optional[datetime] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a configuration to the queue of trials.
|
||||
|
||||
A wrapper for the `Experiment.new_trial` method.
|
||||
"""
|
||||
assert self.experiment is not None
|
||||
|
@ -236,7 +243,9 @@ class Scheduler(metaclass=ABCMeta):
|
|||
|
||||
def _run_schedule(self, running: bool = False) -> None:
|
||||
"""
|
||||
Scheduler part of the loop. Check for pending trials in the queue and run them.
|
||||
Scheduler part of the loop.
|
||||
|
||||
Check for pending trials in the queue and run them.
|
||||
"""
|
||||
assert self.experiment is not None
|
||||
for trial in self.experiment.pending_trials(datetime.now(UTC), running=running):
|
||||
|
@ -245,6 +254,7 @@ class Scheduler(metaclass=ABCMeta):
|
|||
def not_done(self) -> bool:
|
||||
"""
|
||||
Check the stopping conditions.
|
||||
|
||||
By default, stop when the optimizer converges or max limit of trials reached.
|
||||
"""
|
||||
return self.optimizer.not_converged() and (
|
||||
|
@ -254,7 +264,9 @@ class Scheduler(metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def run_trial(self, trial: Storage.Trial) -> None:
|
||||
"""
|
||||
Set up and run a single trial. Save the results in the storage.
|
||||
Set up and run a single trial.
|
||||
|
||||
Save the results in the storage.
|
||||
"""
|
||||
assert self.experiment is not None
|
||||
self._trial_count += 1
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A simple single-threaded synchronous optimization loop implementation.
|
||||
"""
|
||||
"""A simple single-threaded synchronous optimization loop implementation."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
@ -19,14 +17,10 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SyncScheduler(Scheduler):
|
||||
"""
|
||||
A simple single-threaded synchronous optimization loop implementation.
|
||||
"""
|
||||
"""A simple single-threaded synchronous optimization loop implementation."""
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the optimization loop.
|
||||
"""
|
||||
"""Start the optimization loop."""
|
||||
super().start()
|
||||
|
||||
is_warm_up = self.optimizer.supports_preload
|
||||
|
@ -42,7 +36,9 @@ class SyncScheduler(Scheduler):
|
|||
|
||||
def run_trial(self, trial: Storage.Trial) -> None:
|
||||
"""
|
||||
Set up and run a single trial. Save the results in the storage.
|
||||
Set up and run a single trial.
|
||||
|
||||
Save the results in the storage.
|
||||
"""
|
||||
super().run_trial(trial)
|
||||
|
||||
|
@ -53,7 +49,8 @@ class SyncScheduler(Scheduler):
|
|||
trial.update(Status.FAILED, datetime.now(UTC))
|
||||
return
|
||||
|
||||
(status, timestamp, results) = self.environment.run() # Block and wait for the final result.
|
||||
# Block and wait for the final result.
|
||||
(status, timestamp, results) = self.environment.run()
|
||||
_LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results)
|
||||
|
||||
# In async mode (TODO), poll the environment for status and telemetry
|
||||
|
|
|
@ -2,17 +2,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Services for implementing Environments for mlos_bench.
|
||||
"""
|
||||
"""Services for implementing Environments for mlos_bench."""
|
||||
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.base_fileshare import FileShareService
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.local.local_exec import LocalExecService
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Service',
|
||||
'FileShareService',
|
||||
'LocalExecService',
|
||||
"Service",
|
||||
"FileShareService",
|
||||
"LocalExecService",
|
||||
]
|
||||
|
|
|
@ -2,12 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base class for remote file shares.
|
||||
"""
|
||||
"""Base class for remote file shares."""
|
||||
|
||||
import logging
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
|
@ -18,14 +15,15 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
|
||||
"""
|
||||
An abstract base of all file shares.
|
||||
"""
|
||||
"""An abstract base of all file shares."""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new file share with a given config.
|
||||
|
||||
|
@ -43,12 +41,20 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [self.upload, self.download])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(methods, [self.upload, self.download]),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
|
||||
def download(
|
||||
self,
|
||||
params: dict,
|
||||
remote_path: str,
|
||||
local_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Downloads contents from a remote share path to a local path.
|
||||
|
||||
|
@ -66,11 +72,22 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
|
|||
if True (the default), download the entire directory tree.
|
||||
"""
|
||||
params = params or {}
|
||||
_LOG.info("Download from File Share %s recursively: %s -> %s (%s)",
|
||||
"" if recursive else "non", remote_path, local_path, params)
|
||||
_LOG.info(
|
||||
"Download from File Share %s recursively: %s -> %s (%s)",
|
||||
"" if recursive else "non",
|
||||
remote_path,
|
||||
local_path,
|
||||
params,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
|
||||
def upload(
|
||||
self,
|
||||
params: dict,
|
||||
local_path: str,
|
||||
remote_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Uploads contents from a local path to remote share path.
|
||||
|
||||
|
@ -87,5 +104,10 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
|
|||
if True (the default), upload the entire directory tree.
|
||||
"""
|
||||
params = params or {}
|
||||
_LOG.info("Upload to File Share %s recursively: %s -> %s (%s)",
|
||||
"" if recursive else "non", local_path, remote_path, params)
|
||||
_LOG.info(
|
||||
"Upload to File Share %s recursively: %s -> %s (%s)",
|
||||
"" if recursive else "non",
|
||||
local_path,
|
||||
remote_path,
|
||||
params,
|
||||
)
|
||||
|
|
|
@ -2,15 +2,13 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base class for the service mix-ins.
|
||||
"""
|
||||
"""Base class for the service mix-ins."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mlos_bench.config.schemas import ConfigSchema
|
||||
|
@ -21,16 +19,16 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Service:
|
||||
"""
|
||||
An abstract base of all Environment Services and used to build up mix-ins.
|
||||
"""
|
||||
"""An abstract base of all Environment Services and used to build up mix-ins."""
|
||||
|
||||
@classmethod
|
||||
def new(cls,
|
||||
class_name: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional["Service"] = None) -> "Service":
|
||||
def new(
|
||||
cls,
|
||||
class_name: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional["Service"] = None,
|
||||
) -> "Service":
|
||||
"""
|
||||
Factory method for a new service with a given config.
|
||||
|
||||
|
@ -57,11 +55,13 @@ class Service:
|
|||
assert issubclass(cls, Service)
|
||||
return instantiate_from_config(cls, class_name, config, global_config, parent)
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional["Service"] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional["Service"] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new service with a given config.
|
||||
|
||||
|
@ -101,12 +101,15 @@ class Service:
|
|||
_LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
|
||||
|
||||
@staticmethod
|
||||
def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None],
|
||||
local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]:
|
||||
def merge_methods(
|
||||
ext_methods: Union[Dict[str, Callable], List[Callable], None],
|
||||
local_methods: Union[Dict[str, Callable], List[Callable]],
|
||||
) -> Dict[str, Callable]:
|
||||
"""
|
||||
Merge methods from the external caller with the local ones.
|
||||
This function is usually called by the derived class constructor
|
||||
just before invoking the constructor of the base class.
|
||||
|
||||
This function is usually called by the derived class constructor just before
|
||||
invoking the constructor of the base class.
|
||||
"""
|
||||
if isinstance(local_methods, dict):
|
||||
local_methods = local_methods.copy()
|
||||
|
@ -138,9 +141,12 @@ class Service:
|
|||
self._in_context = True
|
||||
return self
|
||||
|
||||
def __exit__(self, ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
def __exit__(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""
|
||||
Exit the Service mix-in context.
|
||||
|
||||
|
@ -170,21 +176,24 @@ class Service:
|
|||
"""
|
||||
Enters the context for this particular Service instance.
|
||||
|
||||
Called by the base __enter__ method of the Service class so it can be
|
||||
used with mix-ins and overridden by subclasses.
|
||||
Called by the base __enter__ method of the Service class so it can be used with
|
||||
mix-ins and overridden by subclasses.
|
||||
"""
|
||||
assert not self._in_context
|
||||
self._in_context = True
|
||||
return self
|
||||
|
||||
def _exit_context(self, ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
def _exit_context(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""
|
||||
Exits the context for this particular Service instance.
|
||||
|
||||
Called by the base __enter__ method of the Service class so it can be
|
||||
used with mix-ins and overridden by subclasses.
|
||||
Called by the base __enter__ method of the Service class so it can be used with
|
||||
mix-ins and overridden by subclasses.
|
||||
"""
|
||||
# pylint: disable=unused-argument
|
||||
assert self._in_context
|
||||
|
@ -192,13 +201,13 @@ class Service:
|
|||
return False
|
||||
|
||||
def _validate_json_config(self, config: dict) -> None:
|
||||
"""
|
||||
Reconstructs a basic json config that this class might have been
|
||||
instantiated from in order to validate configs provided outside the
|
||||
file loading mechanism.
|
||||
"""Reconstructs a basic json config that this class might have been instantiated
|
||||
from in order to validate configs provided outside the file loading
|
||||
mechanism.
|
||||
"""
|
||||
if self.__class__ == Service:
|
||||
# Skip over the case where instantiate a bare base Service class in order to build up a mix-in.
|
||||
# Skip over the case where instantiate a bare base Service class in
|
||||
# order to build up a mix-in.
|
||||
assert config == {}
|
||||
return
|
||||
json_config: dict = {
|
||||
|
@ -212,9 +221,7 @@ class Service:
|
|||
return f"{self.__class__.__name__}@{hex(id(self))}"
|
||||
|
||||
def pprint(self) -> str:
|
||||
"""
|
||||
Produce a human-readable string listing all public methods of the service.
|
||||
"""
|
||||
"""Produce a human-readable string listing all public methods of the service."""
|
||||
return f"{self} ::\n" + "\n".join(
|
||||
f' "{key}": {getattr(val, "__self__", "stand-alone")}'
|
||||
for (key, val) in self._service_methods.items()
|
||||
|
@ -265,10 +272,11 @@ class Service:
|
|||
# Unfortunately, by creating a set, we may destroy the ability to
|
||||
# preserve the context enter/exit order, but hopefully it doesn't
|
||||
# matter.
|
||||
svc_method.__self__ for _, svc_method in self._service_methods.items()
|
||||
svc_method.__self__
|
||||
for _, svc_method in self._service_methods.items()
|
||||
# Note: some methods are actually stand alone functions, so we need
|
||||
# to filter them out.
|
||||
if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service)
|
||||
if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service)
|
||||
}
|
||||
|
||||
def export(self) -> Dict[str, Callable]:
|
||||
|
|
|
@ -2,22 +2,28 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Helper functions to load, instantiate, and serialize Python objects
|
||||
that encapsulate benchmark environments, tunable parameters, and
|
||||
service functions.
|
||||
"""Helper functions to load, instantiate, and serialize Python objects that encapsulate
|
||||
benchmark environments, tunable parameters, and service functions.
|
||||
"""
|
||||
|
||||
import json # For logging only
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import json # For logging only
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
import json5 # To read configs with comments and other JSON5 syntax features
|
||||
from jsonschema import ValidationError, SchemaError
|
||||
import json5 # To read configs with comments and other JSON5 syntax features
|
||||
from jsonschema import SchemaError, ValidationError
|
||||
|
||||
from mlos_bench.config.schemas import ConfigSchema
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
|
@ -26,7 +32,12 @@ from mlos_bench.services.base_service import Service
|
|||
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.util import instantiate_from_config, merge_parameters, path_join, preprocess_dynamic_configs
|
||||
from mlos_bench.util import (
|
||||
instantiate_from_config,
|
||||
merge_parameters,
|
||||
path_join,
|
||||
preprocess_dynamic_configs,
|
||||
)
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_resources import files
|
||||
|
@ -34,25 +45,27 @@ else:
|
|||
from importlib.resources import files
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
from mlos_bench.schedulers.base_scheduler import Scheduler
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigPersistenceService(Service, SupportsConfigLoading):
|
||||
"""
|
||||
Collection of methods to deserialize the Environment, Service, and TunableGroups objects.
|
||||
"""Collection of methods to deserialize the Environment, Service, and TunableGroups
|
||||
objects.
|
||||
"""
|
||||
|
||||
BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/")
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of config persistence service.
|
||||
|
||||
|
@ -69,17 +82,22 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [
|
||||
self.resolve_path,
|
||||
self.load_config,
|
||||
self.prepare_class_load,
|
||||
self.build_service,
|
||||
self.build_environment,
|
||||
self.load_services,
|
||||
self.load_environment,
|
||||
self.load_environment_list,
|
||||
])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(
|
||||
methods,
|
||||
[
|
||||
self.resolve_path,
|
||||
self.load_config,
|
||||
self.prepare_class_load,
|
||||
self.build_service,
|
||||
self.build_environment,
|
||||
self.load_services,
|
||||
self.load_environment,
|
||||
self.load_environment_list,
|
||||
],
|
||||
),
|
||||
)
|
||||
self._config_loader_service = self
|
||||
|
||||
|
@ -107,11 +125,10 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
"""
|
||||
return list(self._config_path) # make a copy to avoid modifications
|
||||
|
||||
def resolve_path(self, file_path: str,
|
||||
extra_paths: Optional[Iterable[str]] = None) -> str:
|
||||
def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str:
|
||||
"""
|
||||
Prepend the suitable `_config_path` to `path` if the latter is not absolute.
|
||||
If `_config_path` is `None` or `path` is absolute, return `path` as is.
|
||||
Prepend the suitable `_config_path` to `path` if the latter is not absolute. If
|
||||
`_config_path` is `None` or `path` is absolute, return `path` as is.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -138,14 +155,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
_LOG.debug("Path not resolved: %s", file_path)
|
||||
return file_path
|
||||
|
||||
def load_config(self,
|
||||
json_file_name: str,
|
||||
schema_type: Optional[ConfigSchema],
|
||||
) -> Dict[str, Any]:
|
||||
def load_config(
|
||||
self,
|
||||
json_file_name: str,
|
||||
schema_type: Optional[ConfigSchema],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load JSON config file. Search for a file relative to `_config_path`
|
||||
if the input path is not absolute.
|
||||
This method is exported to be used as a service.
|
||||
Load JSON config file. Search for a file relative to `_config_path` if the input
|
||||
path is not absolute. This method is exported to be used as a service.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -161,16 +178,22 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
"""
|
||||
json_file_name = self.resolve_path(json_file_name)
|
||||
_LOG.info("Load config: %s", json_file_name)
|
||||
with open(json_file_name, mode='r', encoding='utf-8') as fh_json:
|
||||
with open(json_file_name, mode="r", encoding="utf-8") as fh_json:
|
||||
config = json5.load(fh_json)
|
||||
if schema_type is not None:
|
||||
try:
|
||||
schema_type.validate(config)
|
||||
except (ValidationError, SchemaError) as ex:
|
||||
_LOG.error("Failed to validate config %s against schema type %s at %s",
|
||||
json_file_name, schema_type.name, schema_type.value)
|
||||
raise ValueError(f"Failed to validate config {json_file_name} against " +
|
||||
f"schema type {schema_type.name} at {schema_type.value}") from ex
|
||||
_LOG.error(
|
||||
"Failed to validate config %s against schema type %s at %s",
|
||||
json_file_name,
|
||||
schema_type.name,
|
||||
schema_type.value,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to validate config {json_file_name} against "
|
||||
f"schema type {schema_type.name} at {schema_type.value}"
|
||||
) from ex
|
||||
if isinstance(config, dict) and config.get("$schema"):
|
||||
# Remove $schema attributes from the config after we've validated
|
||||
# them to avoid passing them on to other objects
|
||||
|
@ -181,15 +204,17 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
del config["$schema"]
|
||||
else:
|
||||
_LOG.warning("Config %s is not validated against a schema.", json_file_name)
|
||||
return config # type: ignore[no-any-return]
|
||||
return config # type: ignore[no-any-return]
|
||||
|
||||
def prepare_class_load(self, config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]:
|
||||
def prepare_class_load(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Extract the class instantiation parameters from the configuration.
|
||||
Mix-in the global parameters and resolve the local file system paths,
|
||||
where it is required.
|
||||
Extract the class instantiation parameters from the configuration. Mix-in the
|
||||
global parameters and resolve the local file system paths, where it is required.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -228,19 +253,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
raise ValueError(f"Parameter {key} must be a string or a list")
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Instantiating: %s with config:\n%s",
|
||||
class_name, json.dumps(class_config, indent=2))
|
||||
_LOG.debug(
|
||||
"Instantiating: %s with config:\n%s",
|
||||
class_name,
|
||||
json.dumps(class_config, indent=2),
|
||||
)
|
||||
|
||||
return (class_name, class_config)
|
||||
|
||||
def build_optimizer(self, *,
|
||||
tunables: TunableGroups,
|
||||
service: Service,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None) -> Optimizer:
|
||||
def build_optimizer(
|
||||
self,
|
||||
*,
|
||||
tunables: TunableGroups,
|
||||
service: Service,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
) -> Optimizer:
|
||||
"""
|
||||
Instantiation of mlos_bench Optimizer
|
||||
that depend on Service and TunableGroups.
|
||||
Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups.
|
||||
|
||||
A class *MUST* have a constructor that takes four named arguments:
|
||||
(tunables, config, global_config, service)
|
||||
|
@ -266,18 +296,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
if tunables_path is not None:
|
||||
tunables = self._load_tunables(tunables_path, tunables)
|
||||
(class_name, class_config) = self.prepare_class_load(config, global_config)
|
||||
inst = instantiate_from_config(Optimizer, class_name, # type: ignore[type-abstract]
|
||||
tunables=tunables,
|
||||
config=class_config,
|
||||
global_config=global_config,
|
||||
service=service)
|
||||
inst = instantiate_from_config(
|
||||
Optimizer, # type: ignore[type-abstract]
|
||||
class_name,
|
||||
tunables=tunables,
|
||||
config=class_config,
|
||||
global_config=global_config,
|
||||
service=service,
|
||||
)
|
||||
_LOG.info("Created: Optimizer %s", inst)
|
||||
return inst
|
||||
|
||||
def build_storage(self, *,
|
||||
service: Service,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None) -> "Storage":
|
||||
def build_storage(
|
||||
self,
|
||||
*,
|
||||
service: Service,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
) -> "Storage":
|
||||
"""
|
||||
Instantiation of mlos_bench Storage objects.
|
||||
|
||||
|
@ -296,21 +332,29 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
A new instance of the Storage class.
|
||||
"""
|
||||
(class_name, class_config) = self.prepare_class_load(config, global_config)
|
||||
from mlos_bench.storage.base_storage import Storage # pylint: disable=import-outside-toplevel
|
||||
inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract]
|
||||
config=class_config,
|
||||
global_config=global_config,
|
||||
service=service)
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
|
||||
inst = instantiate_from_config(
|
||||
Storage, # type: ignore[type-abstract]
|
||||
class_name,
|
||||
config=class_config,
|
||||
global_config=global_config,
|
||||
service=service,
|
||||
)
|
||||
_LOG.info("Created: Storage %s", inst)
|
||||
return inst
|
||||
|
||||
def build_scheduler(self, *,
|
||||
config: Dict[str, Any],
|
||||
global_config: Dict[str, Any],
|
||||
environment: Environment,
|
||||
optimizer: Optimizer,
|
||||
storage: "Storage",
|
||||
root_env_config: str) -> "Scheduler":
|
||||
def build_scheduler(
|
||||
self,
|
||||
*,
|
||||
config: Dict[str, Any],
|
||||
global_config: Dict[str, Any],
|
||||
environment: Environment,
|
||||
optimizer: Optimizer,
|
||||
storage: "Storage",
|
||||
root_env_config: str,
|
||||
) -> "Scheduler":
|
||||
"""
|
||||
Instantiation of mlos_bench Scheduler.
|
||||
|
||||
|
@ -335,23 +379,30 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
A new instance of the Scheduler.
|
||||
"""
|
||||
(class_name, class_config) = self.prepare_class_load(config, global_config)
|
||||
from mlos_bench.schedulers.base_scheduler import Scheduler # pylint: disable=import-outside-toplevel
|
||||
inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract]
|
||||
config=class_config,
|
||||
global_config=global_config,
|
||||
environment=environment,
|
||||
optimizer=optimizer,
|
||||
storage=storage,
|
||||
root_env_config=root_env_config)
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from mlos_bench.schedulers.base_scheduler import Scheduler
|
||||
|
||||
inst = instantiate_from_config(
|
||||
Scheduler, # type: ignore[type-abstract]
|
||||
class_name,
|
||||
config=class_config,
|
||||
global_config=global_config,
|
||||
environment=environment,
|
||||
optimizer=optimizer,
|
||||
storage=storage,
|
||||
root_env_config=root_env_config,
|
||||
)
|
||||
_LOG.info("Created: Scheduler %s", inst)
|
||||
return inst
|
||||
|
||||
def build_environment(self, # pylint: disable=too-many-arguments
|
||||
config: Dict[str, Any],
|
||||
tunables: TunableGroups,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional[Service] = None) -> Environment:
|
||||
def build_environment(
|
||||
self, # pylint: disable=too-many-arguments
|
||||
config: Dict[str, Any],
|
||||
tunables: TunableGroups,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional[Service] = None,
|
||||
) -> Environment:
|
||||
"""
|
||||
Factory method for a new environment with a given config.
|
||||
|
||||
|
@ -391,16 +442,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
tunables = self._load_tunables(env_tunables_path, tunables)
|
||||
|
||||
_LOG.debug("Creating env: %s :: %s", env_name, env_class)
|
||||
env = Environment.new(env_name=env_name, class_name=env_class,
|
||||
config=env_config, global_config=global_config,
|
||||
tunables=tunables, service=service)
|
||||
env = Environment.new(
|
||||
env_name=env_name,
|
||||
class_name=env_class,
|
||||
config=env_config,
|
||||
global_config=global_config,
|
||||
tunables=tunables,
|
||||
service=service,
|
||||
)
|
||||
|
||||
_LOG.info("Created env: %s :: %s", env_name, env)
|
||||
return env
|
||||
|
||||
def _build_standalone_service(self, config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None) -> Service:
|
||||
def _build_standalone_service(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
) -> Service:
|
||||
"""
|
||||
Factory method for a new service with a given config.
|
||||
|
||||
|
@ -425,9 +484,12 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
_LOG.info("Created service: %s", service)
|
||||
return service
|
||||
|
||||
def _build_composite_service(self, config_list: Iterable[Dict[str, Any]],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None) -> Service:
|
||||
def _build_composite_service(
|
||||
self,
|
||||
config_list: Iterable[Dict[str, Any]],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
) -> Service:
|
||||
"""
|
||||
Factory method for a new service with a given config.
|
||||
|
||||
|
@ -453,18 +515,21 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
service.register(parent.export())
|
||||
|
||||
for config in config_list:
|
||||
service.register(self._build_standalone_service(
|
||||
config, global_config, service).export())
|
||||
service.register(
|
||||
self._build_standalone_service(config, global_config, service).export()
|
||||
)
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Created mix-in service: %s", service)
|
||||
|
||||
return service
|
||||
|
||||
def build_service(self,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None) -> Service:
|
||||
def build_service(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
) -> Service:
|
||||
"""
|
||||
Factory method for a new service with a given config.
|
||||
|
||||
|
@ -486,8 +551,7 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
services from the list plus the parent mix-in.
|
||||
"""
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Build service from config:\n%s",
|
||||
json.dumps(config, indent=2))
|
||||
_LOG.debug("Build service from config:\n%s", json.dumps(config, indent=2))
|
||||
|
||||
assert isinstance(config, dict)
|
||||
config_list: List[Dict[str, Any]]
|
||||
|
@ -502,12 +566,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
|
||||
return self._build_composite_service(config_list, global_config, parent)
|
||||
|
||||
def load_environment(self, # pylint: disable=too-many-arguments
|
||||
json_file_name: str,
|
||||
tunables: TunableGroups,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional[Service] = None) -> Environment:
|
||||
def load_environment(
|
||||
self, # pylint: disable=too-many-arguments
|
||||
json_file_name: str,
|
||||
tunables: TunableGroups,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional[Service] = None,
|
||||
) -> Environment:
|
||||
"""
|
||||
Load and build new environment from the config file.
|
||||
|
||||
|
@ -534,12 +600,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
assert isinstance(config, dict)
|
||||
return self.build_environment(config, tunables, global_config, parent_args, service)
|
||||
|
||||
def load_environment_list(self, # pylint: disable=too-many-arguments
|
||||
json_file_name: str,
|
||||
tunables: TunableGroups,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional[Service] = None) -> List[Environment]:
|
||||
def load_environment_list(
|
||||
self, # pylint: disable=too-many-arguments
|
||||
json_file_name: str,
|
||||
tunables: TunableGroups,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional[Service] = None,
|
||||
) -> List[Environment]:
|
||||
"""
|
||||
Load and build a list of environments from the config file.
|
||||
|
||||
|
@ -564,16 +632,17 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
A list of new benchmarking environments.
|
||||
"""
|
||||
config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT)
|
||||
return [
|
||||
self.build_environment(config, tunables, global_config, parent_args, service)
|
||||
]
|
||||
return [self.build_environment(config, tunables, global_config, parent_args, service)]
|
||||
|
||||
def load_services(self, json_file_names: Iterable[str],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None) -> Service:
|
||||
def load_services(
|
||||
self,
|
||||
json_file_names: Iterable[str],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
) -> Service:
|
||||
"""
|
||||
Read the configuration files and bundle all service methods
|
||||
from those configs into a single Service object.
|
||||
Read the configuration files and bundle all service methods from those configs
|
||||
into a single Service object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -589,16 +658,18 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
|
|||
service : Service
|
||||
A collection of service methods.
|
||||
"""
|
||||
_LOG.info("Load services: %s parent: %s",
|
||||
json_file_names, parent.__class__.__name__)
|
||||
_LOG.info("Load services: %s parent: %s", json_file_names, parent.__class__.__name__)
|
||||
service = Service({}, global_config, parent)
|
||||
for fname in json_file_names:
|
||||
config = self.load_config(fname, ConfigSchema.SERVICE)
|
||||
service.register(self.build_service(config, global_config, service).export())
|
||||
return service
|
||||
|
||||
def _load_tunables(self, json_file_names: Iterable[str],
|
||||
parent: TunableGroups) -> TunableGroups:
|
||||
def _load_tunables(
|
||||
self,
|
||||
json_file_names: Iterable[str],
|
||||
parent: TunableGroups,
|
||||
) -> TunableGroups:
|
||||
"""
|
||||
Load a collection of tunable parameters from JSON files into the parent
|
||||
TunableGroup.
|
||||
|
|
|
@ -2,13 +2,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Local scheduler side Services for mlos_bench.
|
||||
"""
|
||||
"""Local scheduler side Services for mlos_bench."""
|
||||
|
||||
from mlos_bench.services.local.local_exec import LocalExecService
|
||||
|
||||
|
||||
__all__ = [
|
||||
'LocalExecService',
|
||||
"LocalExecService",
|
||||
]
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Helper functions to run scripts and commands locally on the scheduler side.
|
||||
"""
|
||||
"""Helper functions to run scripts and commands locally on the scheduler side."""
|
||||
|
||||
import errno
|
||||
import logging
|
||||
|
@ -12,10 +10,18 @@ import os
|
|||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from string import Template
|
||||
from typing import (
|
||||
Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from mlos_bench.os_environ import environ
|
||||
|
@ -31,9 +37,9 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
def split_cmdline(cmdline: str) -> Iterable[List[str]]:
|
||||
"""
|
||||
A single command line may contain multiple commands separated by
|
||||
special characters (e.g., &&, ||, etc.) so further split the
|
||||
commandline into an array of subcommand arrays.
|
||||
A single command line may contain multiple commands separated by special characters
|
||||
(e.g., &&, ||, etc.) so further split the commandline into an array of subcommand
|
||||
arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -66,16 +72,20 @@ def split_cmdline(cmdline: str) -> Iterable[List[str]]:
|
|||
|
||||
class LocalExecService(TempDirContextService, SupportsLocalExec):
|
||||
"""
|
||||
Collection of methods to run scripts and commands in an external process
|
||||
on the node acting as the scheduler. Can be useful for data processing
|
||||
due to reduced dependency management complications vs the target environment.
|
||||
Collection of methods to run scripts and commands in an external process on the node
|
||||
acting as the scheduler.
|
||||
|
||||
Can be useful for data processing due to reduced dependency management complications
|
||||
vs the target environment.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of a service to run scripts locally.
|
||||
|
||||
|
@ -92,14 +102,19 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [self.local_exec])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(methods, [self.local_exec]),
|
||||
)
|
||||
self.abort_on_error = self.config.get("abort_on_error", True)
|
||||
|
||||
def local_exec(self, script_lines: Iterable[str],
|
||||
env: Optional[Mapping[str, "TunableValue"]] = None,
|
||||
cwd: Optional[str] = None) -> Tuple[int, str, str]:
|
||||
def local_exec(
|
||||
self,
|
||||
script_lines: Iterable[str],
|
||||
env: Optional[Mapping[str, "TunableValue"]] = None,
|
||||
cwd: Optional[str] = None,
|
||||
) -> Tuple[int, str, str]:
|
||||
"""
|
||||
Execute the script lines from `script_lines` in a local process.
|
||||
|
||||
|
@ -141,8 +156,8 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
|
|||
|
||||
def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]:
|
||||
"""
|
||||
Resolves local script path (first token) in the (sub)command line
|
||||
tokens to its full path.
|
||||
Resolves local script path (first token) in the (sub)command line tokens to its
|
||||
full path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -167,9 +182,12 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
|
|||
subcmd_tokens.insert(0, sys.executable)
|
||||
return subcmd_tokens
|
||||
|
||||
def _local_exec_script(self, script_line: str,
|
||||
env_params: Optional[Mapping[str, "TunableValue"]],
|
||||
cwd: str) -> Tuple[int, str, str]:
|
||||
def _local_exec_script(
|
||||
self,
|
||||
script_line: str,
|
||||
env_params: Optional[Mapping[str, "TunableValue"]],
|
||||
cwd: str,
|
||||
) -> Tuple[int, str, str]:
|
||||
"""
|
||||
Execute the script from `script_path` in a local process.
|
||||
|
||||
|
@ -198,7 +216,7 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
|
|||
if env_params:
|
||||
env = {key: str(val) for (key, val) in env_params.items()}
|
||||
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
# A hack to run Python on Windows with env variables set:
|
||||
env_copy = environ.copy()
|
||||
env_copy["PYTHONPATH"] = ""
|
||||
|
@ -206,7 +224,7 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
|
|||
env = env_copy
|
||||
|
||||
try:
|
||||
if sys.platform != 'win32':
|
||||
if sys.platform != "win32":
|
||||
cmd = [" ".join(cmd)]
|
||||
|
||||
_LOG.info("Run: %s", cmd)
|
||||
|
@ -214,8 +232,15 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
|
|||
_LOG.debug("Expands to: %s", Template(" ".join(cmd)).safe_substitute(env))
|
||||
_LOG.debug("Current working dir: %s", cwd)
|
||||
|
||||
proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True,
|
||||
text=True, check=False, capture_output=True)
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
env=env or None,
|
||||
cwd=cwd,
|
||||
shell=True,
|
||||
text=True,
|
||||
check=False,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
_LOG.debug("Run: return code = %d", proc.returncode)
|
||||
return (proc.returncode, proc.stdout, proc.stderr)
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Helper functions to work with temp files locally on the scheduler side.
|
||||
"""
|
||||
"""Helper functions to work with temp files locally on the scheduler side."""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
|
@ -21,21 +19,23 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
class TempDirContextService(Service, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
A *base* service class that provides a method to create a temporary
|
||||
directory context for local scripts.
|
||||
A *base* service class that provides a method to create a temporary directory
|
||||
context for local scripts.
|
||||
|
||||
It is inherited by LocalExecService and MockLocalExecService.
|
||||
This class is not supposed to be used as a standalone service.
|
||||
It is inherited by LocalExecService and MockLocalExecService. This class is not
|
||||
supposed to be used as a standalone service.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of a service that provides temporary directory context
|
||||
for local exec service.
|
||||
Create a new instance of a service that provides temporary directory context for
|
||||
local exec service.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -50,8 +50,10 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta):
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [self.temp_dir_context])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(methods, [self.temp_dir_context]),
|
||||
)
|
||||
self._temp_dir = self.config.get("temp_dir")
|
||||
if self._temp_dir:
|
||||
|
@ -61,7 +63,10 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta):
|
|||
self._temp_dir = self._config_loader_service.resolve_path(self._temp_dir)
|
||||
_LOG.info("%s: temp dir: %s", self, self._temp_dir)
|
||||
|
||||
def temp_dir_context(self, path: Optional[str] = None) -> Union[TemporaryDirectory, nullcontext]:
|
||||
def temp_dir_context(
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
) -> Union[TemporaryDirectory, nullcontext]:
|
||||
"""
|
||||
Create a temp directory or use the provided path.
|
||||
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Azure-specific benchmark environments for mlos_bench.
|
||||
"""
|
||||
"""Azure-specific benchmark environments for mlos_bench."""
|
||||
|
||||
from mlos_bench.services.remote.azure.azure_auth import AzureAuthService
|
||||
from mlos_bench.services.remote.azure.azure_fileshare import AzureFileShareService
|
||||
|
@ -12,11 +10,10 @@ from mlos_bench.services.remote.azure.azure_network_services import AzureNetwork
|
|||
from mlos_bench.services.remote.azure.azure_saas import AzureSaaSConfigService
|
||||
from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AzureAuthService',
|
||||
'AzureFileShareService',
|
||||
'AzureNetworkService',
|
||||
'AzureSaaSConfigService',
|
||||
'AzureVMService',
|
||||
"AzureAuthService",
|
||||
"AzureFileShareService",
|
||||
"AzureNetworkService",
|
||||
"AzureSaaSConfigService",
|
||||
"AzureVMService",
|
||||
]
|
||||
|
|
|
@ -2,19 +2,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A collection Service functions for managing VMs on Azure.
|
||||
"""
|
||||
"""A collection Service functions for managing VMs on Azure."""
|
||||
|
||||
import logging
|
||||
from base64 import b64decode
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from pytz import UTC
|
||||
|
||||
import azure.identity as azure_id
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
from pytz import UTC
|
||||
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.types.authenticator_type import SupportsAuth
|
||||
|
@ -24,17 +21,17 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AzureAuthService(Service, SupportsAuth):
|
||||
"""
|
||||
Helper methods to get access to Azure services.
|
||||
"""
|
||||
"""Helper methods to get access to Azure services."""
|
||||
|
||||
_REQ_INTERVAL = 300 # = 5 min
|
||||
_REQ_INTERVAL = 300 # = 5 min
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of Azure authentication services proxy.
|
||||
|
||||
|
@ -51,11 +48,16 @@ class AzureAuthService(Service, SupportsAuth):
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [
|
||||
self.get_access_token,
|
||||
self.get_auth_headers,
|
||||
])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(
|
||||
methods,
|
||||
[
|
||||
self.get_access_token,
|
||||
self.get_auth_headers,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# This parameter can come from command line as strings, so conversion is needed.
|
||||
|
@ -71,12 +73,13 @@ class AzureAuthService(Service, SupportsAuth):
|
|||
# Verify info required for SP auth early
|
||||
if "spClientId" in self.config:
|
||||
check_required_params(
|
||||
self.config, {
|
||||
self.config,
|
||||
{
|
||||
"spClientId",
|
||||
"keyVaultName",
|
||||
"certName",
|
||||
"tenant",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def _init_sp(self) -> None:
|
||||
|
@ -105,12 +108,14 @@ class AzureAuthService(Service, SupportsAuth):
|
|||
cert_bytes = b64decode(secret.value)
|
||||
|
||||
# Reauthenticate as the service principal.
|
||||
self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes)
|
||||
self._cred = azure_id.CertificateCredential(
|
||||
tenant_id=tenant_id,
|
||||
client_id=sp_client_id,
|
||||
certificate_data=cert_bytes,
|
||||
)
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""
|
||||
Get the access token from Azure CLI, if expired.
|
||||
"""
|
||||
"""Get the access token from Azure CLI, if expired."""
|
||||
# Ensure we are logged as the Service Principal, if provided
|
||||
if "spClientId" in self.config:
|
||||
self._init_sp()
|
||||
|
@ -126,7 +131,5 @@ class AzureAuthService(Service, SupportsAuth):
|
|||
return self._access_token
|
||||
|
||||
def get_auth_headers(self) -> dict:
|
||||
"""
|
||||
Get the authorization part of HTTP headers for REST API calls.
|
||||
"""
|
||||
"""Get the authorization part of HTTP headers for REST API calls."""
|
||||
return {"Authorization": "Bearer " + self.get_access_token()}
|
||||
|
|
|
@ -2,15 +2,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base class for certain Azure Services classes that do deployments.
|
||||
"""
|
||||
"""Base class for certain Azure Services classes that do deployments."""
|
||||
|
||||
import abc
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
|
@ -26,33 +23,34 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Helper methods to manage and deploy Azure resources via REST APIs.
|
||||
"""
|
||||
"""Helper methods to manage and deploy Azure resources via REST APIs."""
|
||||
|
||||
_POLL_INTERVAL = 4 # seconds
|
||||
_POLL_TIMEOUT = 300 # seconds
|
||||
_REQUEST_TIMEOUT = 5 # seconds
|
||||
_POLL_INTERVAL = 4 # seconds
|
||||
_POLL_TIMEOUT = 300 # seconds
|
||||
_REQUEST_TIMEOUT = 5 # seconds
|
||||
_REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request
|
||||
_REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
|
||||
# Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
|
||||
_REQUEST_RETRY_BACKOFF_FACTOR = 0.3
|
||||
|
||||
# Azure Resources Deployment REST API as described in
|
||||
# https://docs.microsoft.com/en-us/rest/api/resources/deployments
|
||||
|
||||
_URL_DEPLOY = (
|
||||
"https://management.azure.com" +
|
||||
"/subscriptions/{subscription}" +
|
||||
"/resourceGroups/{resource_group}" +
|
||||
"/providers/Microsoft.Resources" +
|
||||
"/deployments/{deployment_name}" +
|
||||
"https://management.azure.com"
|
||||
"/subscriptions/{subscription}"
|
||||
"/resourceGroups/{resource_group}"
|
||||
"/providers/Microsoft.Resources"
|
||||
"/deployments/{deployment_name}"
|
||||
"?api-version=2022-05-01"
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of an Azure Services proxy.
|
||||
|
||||
|
@ -70,38 +68,49 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
"""
|
||||
super().__init__(config, global_config, parent, methods)
|
||||
|
||||
check_required_params(self.config, [
|
||||
"subscription",
|
||||
"resourceGroup",
|
||||
])
|
||||
check_required_params(
|
||||
self.config,
|
||||
[
|
||||
"subscription",
|
||||
"resourceGroup",
|
||||
],
|
||||
)
|
||||
|
||||
# These parameters can come from command line as strings, so conversion is needed.
|
||||
self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL))
|
||||
self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT))
|
||||
self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
|
||||
self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES))
|
||||
self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR))
|
||||
self._total_retries = int(
|
||||
self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)
|
||||
)
|
||||
self._backoff_factor = float(
|
||||
self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)
|
||||
)
|
||||
|
||||
self._deploy_template = {}
|
||||
self._deploy_params = {}
|
||||
if self.config.get("deploymentTemplatePath") is not None:
|
||||
# TODO: Provide external schema validation?
|
||||
template = self.config_loader_service.load_config(
|
||||
self.config['deploymentTemplatePath'], schema_type=None)
|
||||
self.config["deploymentTemplatePath"],
|
||||
schema_type=None,
|
||||
)
|
||||
assert template is not None and isinstance(template, dict)
|
||||
self._deploy_template = template
|
||||
|
||||
# Allow for recursive variable expansion as we do with global params and const_args.
|
||||
deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config)
|
||||
deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars(
|
||||
extra_source_dict=global_config
|
||||
)
|
||||
self._deploy_params = merge_parameters(dest=deploy_params, source=global_config)
|
||||
else:
|
||||
_LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.")
|
||||
_LOG.info(
|
||||
"No deploymentTemplatePath provided. Deployment services will be unavailable.",
|
||||
)
|
||||
|
||||
@property
|
||||
def deploy_params(self) -> dict:
|
||||
"""
|
||||
Get the deployment parameters.
|
||||
"""
|
||||
"""Get the deployment parameters."""
|
||||
return self._deploy_params
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -122,24 +131,24 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
raise NotImplementedError("Should be overridden by subclass.")
|
||||
|
||||
def _get_session(self, params: dict) -> requests.Session:
|
||||
"""
|
||||
Get a session object that includes automatic retries and headers for REST API calls.
|
||||
"""Get a session object that includes automatic retries and headers for REST API
|
||||
calls.
|
||||
"""
|
||||
total_retries = params.get("requestTotalRetries", self._total_retries)
|
||||
backoff_factor = params.get("requestBackoffFactor", self._backoff_factor)
|
||||
session = requests.Session()
|
||||
session.mount(
|
||||
"https://",
|
||||
HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)))
|
||||
HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)),
|
||||
)
|
||||
session.headers.update(self._get_headers())
|
||||
return session
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""
|
||||
Get the headers for the REST API calls.
|
||||
"""
|
||||
assert self._parent is not None and isinstance(self._parent, SupportsAuth), \
|
||||
"Authorization service not provided. Include service-auth.jsonc?"
|
||||
"""Get the headers for the REST API calls."""
|
||||
assert self._parent is not None and isinstance(
|
||||
self._parent, SupportsAuth
|
||||
), "Authorization service not provided. Include service-auth.jsonc?"
|
||||
return self._parent.get_auth_headers()
|
||||
|
||||
@staticmethod
|
||||
|
@ -235,9 +244,11 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
return (Status.FAILED, {})
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Response: %s\n%s", response,
|
||||
json.dumps(response.json(), indent=2)
|
||||
if response.content else "")
|
||||
_LOG.debug(
|
||||
"Response: %s\n%s",
|
||||
response,
|
||||
json.dumps(response.json(), indent=2) if response.content else "",
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
|
@ -252,15 +263,16 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
|
||||
def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or FAILED.
|
||||
Return TIMED_OUT when timing out.
|
||||
Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or
|
||||
FAILED. Return TIMED_OUT when timing out.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of tunable parameters.
|
||||
is_setup : bool
|
||||
If True, wait for resource being deployed; otherwise, wait for successful deprovisioning.
|
||||
If True, wait for resource being deployed; otherwise, wait for
|
||||
successful deprovisioning.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -270,15 +282,22 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
|
||||
"""
|
||||
params = self._set_default_params(params)
|
||||
_LOG.info("Wait for %s to %s", params.get("deploymentName"),
|
||||
"provision" if is_setup else "deprovision")
|
||||
_LOG.info(
|
||||
"Wait for %s to %s",
|
||||
params.get("deploymentName"),
|
||||
"provision" if is_setup else "deprovision",
|
||||
)
|
||||
return self._wait_while(self._check_deployment, Status.PENDING, params)
|
||||
|
||||
def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]],
|
||||
loop_status: Status, params: dict) -> Tuple[Status, dict]:
|
||||
def _wait_while(
|
||||
self,
|
||||
func: Callable[[dict], Tuple[Status, dict]],
|
||||
loop_status: Status,
|
||||
params: dict,
|
||||
) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Invoke `func` periodically while the status is equal to `loop_status`.
|
||||
Return TIMED_OUT when timing out.
|
||||
Invoke `func` periodically while the status is equal to `loop_status`. Return
|
||||
TIMED_OUT when timing out.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -297,12 +316,20 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
"""
|
||||
params = self._set_default_params(params)
|
||||
config = merge_parameters(
|
||||
dest=self.config.copy(), source=params, required_keys=["deploymentName"])
|
||||
dest=self.config.copy(),
|
||||
source=params,
|
||||
required_keys=["deploymentName"],
|
||||
)
|
||||
|
||||
poll_period = params.get("pollInterval", self._poll_interval)
|
||||
|
||||
_LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s",
|
||||
config["deploymentName"], loop_status, poll_period, self._poll_timeout)
|
||||
_LOG.debug(
|
||||
"Wait for %s status %s :: poll %.2f timeout %d s",
|
||||
config["deploymentName"],
|
||||
loop_status,
|
||||
poll_period,
|
||||
self._poll_timeout,
|
||||
)
|
||||
|
||||
ts_timeout = time.time() + self._poll_timeout
|
||||
poll_delay = poll_period
|
||||
|
@ -326,10 +353,10 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
_LOG.warning("Request timed out: %s", params)
|
||||
return (Status.TIMED_OUT, {})
|
||||
|
||||
def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements
|
||||
def _check_deployment(self, params: dict) -> Tuple[Status, dict]:
|
||||
# pylint: disable=too-many-return-statements
|
||||
"""
|
||||
Check if Azure deployment exists.
|
||||
Return SUCCEEDED if true, PENDING otherwise.
|
||||
Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -352,7 +379,7 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
"subscription",
|
||||
"resourceGroup",
|
||||
"deploymentName",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
_LOG.info("Check deployment: %s", config["deploymentName"])
|
||||
|
@ -413,13 +440,20 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
if not self._deploy_template:
|
||||
raise ValueError(f"Missing deployment template: {self}")
|
||||
params = self._set_default_params(params)
|
||||
config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"])
|
||||
config = merge_parameters(
|
||||
dest=self.config.copy(),
|
||||
source=params,
|
||||
required_keys=["deploymentName"],
|
||||
)
|
||||
_LOG.info("Deploy: %s :: %s", config["deploymentName"], params)
|
||||
|
||||
params = merge_parameters(dest=self._deploy_params.copy(), source=params)
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Deploy: %s merged params ::\n%s",
|
||||
config["deploymentName"], json.dumps(params, indent=2))
|
||||
_LOG.debug(
|
||||
"Deploy: %s merged params ::\n%s",
|
||||
config["deploymentName"],
|
||||
json.dumps(params, indent=2),
|
||||
)
|
||||
|
||||
url = self._URL_DEPLOY.format(
|
||||
subscription=config["subscription"],
|
||||
|
@ -432,22 +466,29 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
|
|||
"mode": "Incremental",
|
||||
"template": self._deploy_template,
|
||||
"parameters": {
|
||||
key: {"value": val} for (key, val) in params.items()
|
||||
key: {"value": val}
|
||||
for (key, val) in params.items()
|
||||
if key in self._deploy_template.get("parameters", {})
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2))
|
||||
|
||||
response = requests.put(url, json=json_req,
|
||||
headers=self._get_headers(), timeout=self._request_timeout)
|
||||
response = requests.put(
|
||||
url,
|
||||
json=json_req,
|
||||
headers=self._get_headers(),
|
||||
timeout=self._request_timeout,
|
||||
)
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Response: %s\n%s", response,
|
||||
json.dumps(response.json(), indent=2)
|
||||
if response.content else "")
|
||||
_LOG.debug(
|
||||
"Response: %s\n%s",
|
||||
response,
|
||||
json.dumps(response.json(), indent=2) if response.content else "",
|
||||
)
|
||||
else:
|
||||
_LOG.info("Response: %s", response)
|
||||
|
||||
|
|
|
@ -2,37 +2,34 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A collection FileShare functions for interacting with Azure File Shares.
|
||||
"""
|
||||
"""A collection FileShare functions for interacting with Azure File Shares."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from azure.storage.fileshare import ShareClient
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.storage.fileshare import ShareClient
|
||||
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.base_fileshare import FileShareService
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.util import check_required_params
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureFileShareService(FileShareService):
|
||||
"""
|
||||
Helper methods for interacting with Azure File Share
|
||||
"""
|
||||
"""Helper methods for interacting with Azure File Share."""
|
||||
|
||||
_SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}"
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new file share Service for Azure environments with a given config.
|
||||
|
||||
|
@ -50,16 +47,19 @@ class AzureFileShareService(FileShareService):
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [self.upload, self.download])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(methods, [self.upload, self.download]),
|
||||
)
|
||||
|
||||
check_required_params(
|
||||
self.config, {
|
||||
self.config,
|
||||
{
|
||||
"storageAccountName",
|
||||
"storageFileShareName",
|
||||
"storageAccountKey",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self._share_client = ShareClient.from_share_url(
|
||||
|
@ -70,7 +70,13 @@ class AzureFileShareService(FileShareService):
|
|||
credential=self.config["storageAccountKey"],
|
||||
)
|
||||
|
||||
def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
|
||||
def download(
|
||||
self,
|
||||
params: dict,
|
||||
remote_path: str,
|
||||
local_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
super().download(params, remote_path, local_path, recursive)
|
||||
dir_client = self._share_client.get_directory_client(remote_path)
|
||||
if dir_client.exists():
|
||||
|
@ -95,16 +101,21 @@ class AzureFileShareService(FileShareService):
|
|||
# Translate into non-Azure exception:
|
||||
raise FileNotFoundError(f"Cannot download: {remote_path}") from ex
|
||||
|
||||
def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
|
||||
def upload(
|
||||
self,
|
||||
params: dict,
|
||||
local_path: str,
|
||||
remote_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
super().upload(params, local_path, remote_path, recursive)
|
||||
self._upload(local_path, remote_path, recursive, set())
|
||||
|
||||
def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None:
|
||||
"""
|
||||
Upload contents from a local path to an Azure file share.
|
||||
This method is called from `.upload()` above. We need it to avoid exposing
|
||||
the `seen` parameter and to make `.upload()` match the base class' virtual
|
||||
method.
|
||||
Upload contents from a local path to an Azure file share. This method is called
|
||||
from `.upload()` above. We need it to avoid exposing the `seen` parameter and to
|
||||
make `.upload()` match the base class' virtual method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -143,8 +154,8 @@ class AzureFileShareService(FileShareService):
|
|||
|
||||
def _remote_makedirs(self, remote_path: str) -> None:
|
||||
"""
|
||||
Create remote directories for the entire path.
|
||||
Succeeds even some or all directories along the path already exist.
|
||||
Create remote directories for the entire path. Succeeds even some or all
|
||||
directories along the path already exist.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -2,47 +2,48 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A collection Service functions for managing virtual networks on Azure.
|
||||
"""
|
||||
"""A collection Service functions for managing virtual networks on Azure."""
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService
|
||||
from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
|
||||
from mlos_bench.services.remote.azure.azure_deployment_services import (
|
||||
AzureDeploymentService,
|
||||
)
|
||||
from mlos_bench.services.types.network_provisioner_type import (
|
||||
SupportsNetworkProvisioning,
|
||||
)
|
||||
from mlos_bench.util import merge_parameters
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
|
||||
"""
|
||||
Helper methods to manage Virtual Networks on Azure.
|
||||
"""
|
||||
"""Helper methods to manage Virtual Networks on Azure."""
|
||||
|
||||
# Azure Compute REST API calls as described in
|
||||
# https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01
|
||||
|
||||
# From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01
|
||||
# From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 # pylint: disable=line-too-long # noqa
|
||||
_URL_DEPROVISION = (
|
||||
"https://management.azure.com" +
|
||||
"/subscriptions/{subscription}" +
|
||||
"/resourceGroups/{resource_group}" +
|
||||
"/providers/Microsoft.Network" +
|
||||
"/virtualNetwork/{vnet_name}" +
|
||||
"/delete" +
|
||||
"https://management.azure.com"
|
||||
"/subscriptions/{subscription}"
|
||||
"/resourceGroups/{resource_group}"
|
||||
"/providers/Microsoft.Network"
|
||||
"/virtualNetwork/{vnet_name}"
|
||||
"/delete"
|
||||
"?api-version=2023-05-01"
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of Azure Network services proxy.
|
||||
|
||||
|
@ -59,25 +60,35 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [
|
||||
# SupportsNetworkProvisioning
|
||||
self.provision_network,
|
||||
self.deprovision_network,
|
||||
self.wait_network_deployment,
|
||||
])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(
|
||||
methods,
|
||||
[
|
||||
# SupportsNetworkProvisioning
|
||||
self.provision_network,
|
||||
self.deprovision_network,
|
||||
self.wait_network_deployment,
|
||||
],
|
||||
),
|
||||
)
|
||||
if not self._deploy_template:
|
||||
raise ValueError("AzureNetworkService requires a deployment template:\n"
|
||||
+ f"config={config}\nglobal_config={global_config}")
|
||||
raise ValueError(
|
||||
"AzureNetworkService requires a deployment template:\n"
|
||||
+ f"config={config}\nglobal_config={global_config}"
|
||||
)
|
||||
|
||||
def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
|
||||
def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
|
||||
# Try and provide a semi sane default for the deploymentName if not provided
|
||||
# since this is a common way to set the deploymentName and can same some
|
||||
# config work for the caller.
|
||||
if "vnetName" in params and "deploymentName" not in params:
|
||||
params["deploymentName"] = f"{params['vnetName']}-deployment"
|
||||
_LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"])
|
||||
_LOG.info(
|
||||
"deploymentName missing from params. Defaulting to '%s'.",
|
||||
params["deploymentName"],
|
||||
)
|
||||
return params
|
||||
|
||||
def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
|
||||
|
@ -148,15 +159,18 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
|
|||
"resourceGroup",
|
||||
"deploymentName",
|
||||
"vnetName",
|
||||
]
|
||||
],
|
||||
)
|
||||
_LOG.info("Deprovision Network: %s", config["vnetName"])
|
||||
_LOG.info("Deprovision deployment: %s", config["deploymentName"])
|
||||
(status, results) = self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vnet_name=config["vnetName"],
|
||||
))
|
||||
(status, results) = self._azure_rest_api_post_helper(
|
||||
config,
|
||||
self._URL_DEPROVISION.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vnet_name=config["vnetName"],
|
||||
),
|
||||
)
|
||||
if ignore_errors and status == Status.FAILED:
|
||||
_LOG.warning("Ignoring error: %s", results)
|
||||
status = Status.SUCCEEDED
|
||||
|
|
|
@ -2,11 +2,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A collection Service functions for configuring SaaS instances on Azure.
|
||||
"""
|
||||
"""A collection Service functions for configuring SaaS instances on Azure."""
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
|
@ -21,9 +18,7 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
||||
"""
|
||||
Helper methods to configure Azure Flex services.
|
||||
"""
|
||||
"""Helper methods to configure Azure Flex services."""
|
||||
|
||||
_REQUEST_TIMEOUT = 5 # seconds
|
||||
|
||||
|
@ -33,20 +28,22 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
|||
# https://learn.microsoft.com/en-us/rest/api/mariadb/configurations
|
||||
|
||||
_URL_CONFIGURE = (
|
||||
"https://management.azure.com" +
|
||||
"/subscriptions/{subscription}" +
|
||||
"/resourceGroups/{resource_group}" +
|
||||
"/providers/{provider}" +
|
||||
"/{server_type}/{vm_name}" +
|
||||
"/{update}" +
|
||||
"https://management.azure.com"
|
||||
"/subscriptions/{subscription}"
|
||||
"/resourceGroups/{resource_group}"
|
||||
"/providers/{provider}"
|
||||
"/{server_type}/{vm_name}"
|
||||
"/{update}"
|
||||
"?api-version={api_version}"
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of Azure services proxy.
|
||||
|
||||
|
@ -63,18 +60,20 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [
|
||||
self.configure,
|
||||
self.is_config_pending
|
||||
])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(methods, [self.configure, self.is_config_pending]),
|
||||
)
|
||||
|
||||
check_required_params(self.config, {
|
||||
"subscription",
|
||||
"resourceGroup",
|
||||
"provider",
|
||||
})
|
||||
check_required_params(
|
||||
self.config,
|
||||
{
|
||||
"subscription",
|
||||
"resourceGroup",
|
||||
"provider",
|
||||
},
|
||||
)
|
||||
|
||||
# Provide sane defaults for known DB providers.
|
||||
provider = self.config.get("provider")
|
||||
|
@ -118,8 +117,7 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
|||
# These parameters can come from command line as strings, so conversion is needed.
|
||||
self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
|
||||
|
||||
def configure(self, config: Dict[str, Any],
|
||||
params: Dict[str, Any]) -> Tuple[Status, dict]:
|
||||
def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Update the parameters of an Azure DB service.
|
||||
|
||||
|
@ -157,33 +155,39 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
|||
If "isConfigPendingReboot" is set to True, rebooting a VM is necessary.
|
||||
Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED}
|
||||
"""
|
||||
config = merge_parameters(
|
||||
dest=self.config.copy(), source=config, required_keys=["vmName"])
|
||||
config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
|
||||
url = self._url_config_get.format(vm_name=config["vmName"])
|
||||
_LOG.debug("Request: GET %s", url)
|
||||
response = requests.put(
|
||||
url, headers=self._get_headers(), timeout=self._request_timeout)
|
||||
response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout)
|
||||
_LOG.debug("Response: %s :: %s", response, response.text)
|
||||
if response.status_code == 504:
|
||||
return (Status.TIMED_OUT, {})
|
||||
if response.status_code != 200:
|
||||
return (Status.FAILED, {})
|
||||
# Currently, Azure Flex servers require a VM reboot.
|
||||
return (Status.SUCCEEDED, {"isConfigPendingReboot": any(
|
||||
{'False': False, 'True': True}[val['properties']['isConfigPendingRestart']]
|
||||
for val in response.json()['value']
|
||||
)})
|
||||
return (
|
||||
Status.SUCCEEDED,
|
||||
{
|
||||
"isConfigPendingReboot": any(
|
||||
{"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]]
|
||||
for val in response.json()["value"]
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""
|
||||
Get the headers for the REST API calls.
|
||||
"""
|
||||
assert self._parent is not None and isinstance(self._parent, SupportsAuth), \
|
||||
"Authorization service not provided. Include service-auth.jsonc?"
|
||||
"""Get the headers for the REST API calls."""
|
||||
assert self._parent is not None and isinstance(
|
||||
self._parent, SupportsAuth
|
||||
), "Authorization service not provided. Include service-auth.jsonc?"
|
||||
return self._parent.get_auth_headers()
|
||||
|
||||
def _config_one(self, config: Dict[str, Any],
|
||||
param_name: str, param_value: Any) -> Tuple[Status, dict]:
|
||||
def _config_one(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
param_name: str,
|
||||
param_value: Any,
|
||||
) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Update a single parameter of the Azure DB service.
|
||||
|
||||
|
@ -202,13 +206,15 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
|||
A pair of Status and result. The result is always {}.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
config = merge_parameters(
|
||||
dest=self.config.copy(), source=config, required_keys=["vmName"])
|
||||
config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
|
||||
url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name)
|
||||
_LOG.debug("Request: PUT %s", url)
|
||||
response = requests.put(url, headers=self._get_headers(),
|
||||
json={"properties": {"value": str(param_value)}},
|
||||
timeout=self._request_timeout)
|
||||
response = requests.put(
|
||||
url,
|
||||
headers=self._get_headers(),
|
||||
json={"properties": {"value": str(param_value)}},
|
||||
timeout=self._request_timeout,
|
||||
)
|
||||
_LOG.debug("Response: %s :: %s", response, response.text)
|
||||
if response.status_code == 504:
|
||||
return (Status.TIMED_OUT, {})
|
||||
|
@ -216,11 +222,10 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
|||
return (Status.SUCCEEDED, {})
|
||||
return (Status.FAILED, {})
|
||||
|
||||
def _config_many(self, config: Dict[str, Any],
|
||||
params: Dict[str, Any]) -> Tuple[Status, dict]:
|
||||
def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Update the parameters of an Azure DB service one-by-one.
|
||||
(If batch API is not available for it).
|
||||
Update the parameters of an Azure DB service one-by-one. (If batch API is not
|
||||
available for it).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -235,14 +240,13 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
|||
A pair of Status and result. The result is always {}.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
for (param_name, param_value) in params.items():
|
||||
for param_name, param_value in params.items():
|
||||
(status, result) = self._config_one(config, param_name, param_value)
|
||||
if not status.is_succeeded():
|
||||
return (status, result)
|
||||
return (Status.SUCCEEDED, {})
|
||||
|
||||
def _config_batch(self, config: Dict[str, Any],
|
||||
params: Dict[str, Any]) -> Tuple[Status, dict]:
|
||||
def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Batch update the parameters of an Azure DB service.
|
||||
|
||||
|
@ -259,19 +263,21 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
|
|||
A pair of Status and result. The result is always {}.
|
||||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
config = merge_parameters(
|
||||
dest=self.config.copy(), source=config, required_keys=["vmName"])
|
||||
config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
|
||||
url = self._url_config_set.format(vm_name=config["vmName"])
|
||||
json_req = {
|
||||
"value": [
|
||||
{"name": key, "properties": {"value": str(val)}}
|
||||
for (key, val) in params.items()
|
||||
{"name": key, "properties": {"value": str(val)}} for (key, val) in params.items()
|
||||
],
|
||||
# "resetAllToDefault": "True"
|
||||
}
|
||||
_LOG.debug("Request: POST %s", url)
|
||||
response = requests.post(url, headers=self._get_headers(),
|
||||
json=json_req, timeout=self._request_timeout)
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._get_headers(),
|
||||
json=json_req,
|
||||
timeout=self._request_timeout,
|
||||
)
|
||||
_LOG.debug("Response: %s :: %s", response, response.text)
|
||||
if response.status_code == 504:
|
||||
return (Status.TIMED_OUT, {})
|
||||
|
|
|
@ -2,33 +2,36 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A collection Service functions for managing VMs on Azure.
|
||||
"""
|
||||
"""A collection Service functions for managing VMs on Azure."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService
|
||||
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
|
||||
from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
|
||||
from mlos_bench.services.remote.azure.azure_deployment_services import (
|
||||
AzureDeploymentService,
|
||||
)
|
||||
from mlos_bench.services.types.host_ops_type import SupportsHostOps
|
||||
from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
|
||||
from mlos_bench.services.types.os_ops_type import SupportsOSOps
|
||||
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
|
||||
from mlos_bench.util import merge_parameters
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec):
|
||||
"""
|
||||
Helper methods to manage VMs on Azure.
|
||||
"""
|
||||
class AzureVMService(
|
||||
AzureDeploymentService,
|
||||
SupportsHostProvisioning,
|
||||
SupportsHostOps,
|
||||
SupportsOSOps,
|
||||
SupportsRemoteExec,
|
||||
):
|
||||
"""Helper methods to manage VMs on Azure."""
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
|
||||
|
@ -37,34 +40,34 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
|
||||
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start
|
||||
_URL_START = (
|
||||
"https://management.azure.com" +
|
||||
"/subscriptions/{subscription}" +
|
||||
"/resourceGroups/{resource_group}" +
|
||||
"/providers/Microsoft.Compute" +
|
||||
"/virtualMachines/{vm_name}" +
|
||||
"/start" +
|
||||
"https://management.azure.com"
|
||||
"/subscriptions/{subscription}"
|
||||
"/resourceGroups/{resource_group}"
|
||||
"/providers/Microsoft.Compute"
|
||||
"/virtualMachines/{vm_name}"
|
||||
"/start"
|
||||
"?api-version=2022-03-01"
|
||||
)
|
||||
|
||||
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off
|
||||
_URL_STOP = (
|
||||
"https://management.azure.com" +
|
||||
"/subscriptions/{subscription}" +
|
||||
"/resourceGroups/{resource_group}" +
|
||||
"/providers/Microsoft.Compute" +
|
||||
"/virtualMachines/{vm_name}" +
|
||||
"/powerOff" +
|
||||
"https://management.azure.com"
|
||||
"/subscriptions/{subscription}"
|
||||
"/resourceGroups/{resource_group}"
|
||||
"/providers/Microsoft.Compute"
|
||||
"/virtualMachines/{vm_name}"
|
||||
"/powerOff"
|
||||
"?api-version=2022-03-01"
|
||||
)
|
||||
|
||||
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate
|
||||
_URL_DEALLOCATE = (
|
||||
"https://management.azure.com" +
|
||||
"/subscriptions/{subscription}" +
|
||||
"/resourceGroups/{resource_group}" +
|
||||
"/providers/Microsoft.Compute" +
|
||||
"/virtualMachines/{vm_name}" +
|
||||
"/deallocate" +
|
||||
"https://management.azure.com"
|
||||
"/subscriptions/{subscription}"
|
||||
"/resourceGroups/{resource_group}"
|
||||
"/providers/Microsoft.Compute"
|
||||
"/virtualMachines/{vm_name}"
|
||||
"/deallocate"
|
||||
"?api-version=2022-03-01"
|
||||
)
|
||||
|
||||
|
@ -76,42 +79,44 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
|
||||
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/delete
|
||||
# _URL_DEPROVISION = (
|
||||
# "https://management.azure.com" +
|
||||
# "/subscriptions/{subscription}" +
|
||||
# "/resourceGroups/{resource_group}" +
|
||||
# "/providers/Microsoft.Compute" +
|
||||
# "/virtualMachines/{vm_name}" +
|
||||
# "/delete" +
|
||||
# "https://management.azure.com"
|
||||
# "/subscriptions/{subscription}"
|
||||
# "/resourceGroups/{resource_group}"
|
||||
# "/providers/Microsoft.Compute"
|
||||
# "/virtualMachines/{vm_name}"
|
||||
# "/delete"
|
||||
# "?api-version=2022-03-01"
|
||||
# )
|
||||
|
||||
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart
|
||||
_URL_REBOOT = (
|
||||
"https://management.azure.com" +
|
||||
"/subscriptions/{subscription}" +
|
||||
"/resourceGroups/{resource_group}" +
|
||||
"/providers/Microsoft.Compute" +
|
||||
"/virtualMachines/{vm_name}" +
|
||||
"/restart" +
|
||||
"https://management.azure.com"
|
||||
"/subscriptions/{subscription}"
|
||||
"/resourceGroups/{resource_group}"
|
||||
"/providers/Microsoft.Compute"
|
||||
"/virtualMachines/{vm_name}"
|
||||
"/restart"
|
||||
"?api-version=2022-03-01"
|
||||
)
|
||||
|
||||
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command
|
||||
_URL_REXEC_RUN = (
|
||||
"https://management.azure.com" +
|
||||
"/subscriptions/{subscription}" +
|
||||
"/resourceGroups/{resource_group}" +
|
||||
"/providers/Microsoft.Compute" +
|
||||
"/virtualMachines/{vm_name}" +
|
||||
"/runCommand" +
|
||||
"https://management.azure.com"
|
||||
"/subscriptions/{subscription}"
|
||||
"/resourceGroups/{resource_group}"
|
||||
"/providers/Microsoft.Compute"
|
||||
"/virtualMachines/{vm_name}"
|
||||
"/runCommand"
|
||||
"?api-version=2022-03-01"
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of Azure VM services proxy.
|
||||
|
||||
|
@ -128,26 +133,31 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
New methods to register with the service.
|
||||
"""
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [
|
||||
# SupportsHostProvisioning
|
||||
self.provision_host,
|
||||
self.deprovision_host,
|
||||
self.deallocate_host,
|
||||
self.wait_host_deployment,
|
||||
# SupportsHostOps
|
||||
self.start_host,
|
||||
self.stop_host,
|
||||
self.restart_host,
|
||||
self.wait_host_operation,
|
||||
# SupportsOSOps
|
||||
self.shutdown,
|
||||
self.reboot,
|
||||
self.wait_os_operation,
|
||||
# SupportsRemoteExec
|
||||
self.remote_exec,
|
||||
self.get_remote_exec_results,
|
||||
])
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(
|
||||
methods,
|
||||
[
|
||||
# SupportsHostProvisioning
|
||||
self.provision_host,
|
||||
self.deprovision_host,
|
||||
self.deallocate_host,
|
||||
self.wait_host_deployment,
|
||||
# SupportsHostOps
|
||||
self.start_host,
|
||||
self.stop_host,
|
||||
self.restart_host,
|
||||
self.wait_host_operation,
|
||||
# SupportsOSOps
|
||||
self.shutdown,
|
||||
self.reboot,
|
||||
self.wait_os_operation,
|
||||
# SupportsRemoteExec
|
||||
self.remote_exec,
|
||||
self.get_remote_exec_results,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# As a convenience, allow reading customData out of a file, rather than
|
||||
|
@ -156,19 +166,24 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
# can be done using the `base64()` string function inside the ARM template.
|
||||
self._custom_data_file = self.config.get("customDataFile", None)
|
||||
if self._custom_data_file:
|
||||
if self._deploy_params.get('customData', None):
|
||||
if self._deploy_params.get("customData", None):
|
||||
raise ValueError("Both customDataFile and customData are specified.")
|
||||
self._custom_data_file = self.config_loader_service.resolve_path(self._custom_data_file)
|
||||
with open(self._custom_data_file, 'r', encoding='utf-8') as custom_data_fh:
|
||||
self._custom_data_file = self.config_loader_service.resolve_path(
|
||||
self._custom_data_file
|
||||
)
|
||||
with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh:
|
||||
self._deploy_params["customData"] = custom_data_fh.read()
|
||||
|
||||
def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
|
||||
def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
|
||||
# Try and provide a semi sane default for the deploymentName if not provided
|
||||
# since this is a common way to set the deploymentName and can same some
|
||||
# config work for the caller.
|
||||
if "vmName" in params and "deploymentName" not in params:
|
||||
params["deploymentName"] = f"{params['vmName']}-deployment"
|
||||
_LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"])
|
||||
_LOG.info(
|
||||
"deploymentName missing from params. Defaulting to '%s'.",
|
||||
params["deploymentName"],
|
||||
)
|
||||
return params
|
||||
|
||||
def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
|
||||
|
@ -263,20 +278,24 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
"resourceGroup",
|
||||
"deploymentName",
|
||||
"vmName",
|
||||
]
|
||||
],
|
||||
)
|
||||
_LOG.info("Deprovision VM: %s", config["vmName"])
|
||||
_LOG.info("Deprovision deployment: %s", config["deploymentName"])
|
||||
# TODO: Properly deprovision *all* resources specified in the ARM template.
|
||||
return self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
))
|
||||
return self._azure_rest_api_post_helper(
|
||||
config,
|
||||
self._URL_DEPROVISION.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
),
|
||||
)
|
||||
|
||||
def deallocate_host(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Deallocates the VM on Azure by shutting it down then releasing the compute resources.
|
||||
Deallocates the VM on Azure by shutting it down then releasing the compute
|
||||
resources.
|
||||
|
||||
Note: This can cause the VM to arrive on a new host node when its
|
||||
restarted, which may have different performance characteristics.
|
||||
|
@ -300,14 +319,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
"subscription",
|
||||
"resourceGroup",
|
||||
"vmName",
|
||||
]
|
||||
],
|
||||
)
|
||||
_LOG.info("Deallocate VM: %s", config["vmName"])
|
||||
return self._azure_rest_api_post_helper(config, self._URL_DEALLOCATE.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
))
|
||||
return self._azure_rest_api_post_helper(
|
||||
config,
|
||||
self._URL_DEALLOCATE.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
),
|
||||
)
|
||||
|
||||
def start_host(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
|
@ -332,14 +354,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
"subscription",
|
||||
"resourceGroup",
|
||||
"vmName",
|
||||
]
|
||||
],
|
||||
)
|
||||
_LOG.info("Start VM: %s :: %s", config["vmName"], params)
|
||||
return self._azure_rest_api_post_helper(config, self._URL_START.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
))
|
||||
return self._azure_rest_api_post_helper(
|
||||
config,
|
||||
self._URL_START.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
),
|
||||
)
|
||||
|
||||
def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]:
|
||||
"""
|
||||
|
@ -366,14 +391,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
"subscription",
|
||||
"resourceGroup",
|
||||
"vmName",
|
||||
]
|
||||
],
|
||||
)
|
||||
_LOG.info("Stop VM: %s", config["vmName"])
|
||||
return self._azure_rest_api_post_helper(config, self._URL_STOP.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
))
|
||||
return self._azure_rest_api_post_helper(
|
||||
config,
|
||||
self._URL_STOP.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
),
|
||||
)
|
||||
|
||||
def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
|
||||
return self.stop_host(params, force)
|
||||
|
@ -403,20 +431,27 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
"subscription",
|
||||
"resourceGroup",
|
||||
"vmName",
|
||||
]
|
||||
],
|
||||
)
|
||||
_LOG.info("Reboot VM: %s", config["vmName"])
|
||||
return self._azure_rest_api_post_helper(config, self._URL_REBOOT.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
))
|
||||
return self._azure_rest_api_post_helper(
|
||||
config,
|
||||
self._URL_REBOOT.format(
|
||||
subscription=config["subscription"],
|
||||
resource_group=config["resourceGroup"],
|
||||
vm_name=config["vmName"],
|
||||
),
|
||||
)
|
||||
|
||||
def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
|
||||
return self.restart_host(params, force)
|
||||
|
||||
def remote_exec(self, script: Iterable[str], config: dict,
|
||||
env_params: dict) -> Tuple[Status, dict]:
|
||||
def remote_exec(
|
||||
self,
|
||||
script: Iterable[str],
|
||||
config: dict,
|
||||
env_params: dict,
|
||||
) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Run a command on Azure VM.
|
||||
|
||||
|
@ -446,7 +481,7 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
"subscription",
|
||||
"resourceGroup",
|
||||
"vmName",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
if _LOG.isEnabledFor(logging.INFO):
|
||||
|
@ -455,7 +490,7 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
json_req = {
|
||||
"commandId": "RunShellScript",
|
||||
"script": list(script),
|
||||
"parameters": [{"name": key, "value": val} for (key, val) in env_params.items()]
|
||||
"parameters": [{"name": key, "value": val} for (key, val) in env_params.items()],
|
||||
}
|
||||
|
||||
url = self._URL_REXEC_RUN.format(
|
||||
|
@ -468,12 +503,18 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
_LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2))
|
||||
|
||||
response = requests.post(
|
||||
url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout)
|
||||
url,
|
||||
json=json_req,
|
||||
headers=self._get_headers(),
|
||||
timeout=self._request_timeout,
|
||||
)
|
||||
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("Response: %s\n%s", response,
|
||||
json.dumps(response.json(), indent=2)
|
||||
if response.content else "")
|
||||
_LOG.debug(
|
||||
"Response: %s\n%s",
|
||||
response,
|
||||
json.dumps(response.json(), indent=2) if response.content else "",
|
||||
)
|
||||
else:
|
||||
_LOG.info("Response: %s", response)
|
||||
|
||||
|
@ -481,10 +522,10 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
|
|||
# TODO: extract the results from JSON response
|
||||
return (Status.SUCCEEDED, config)
|
||||
elif response.status_code == 202:
|
||||
return (Status.PENDING, {
|
||||
**config,
|
||||
"asyncResultsUrl": response.headers.get("Azure-AsyncOperation")
|
||||
})
|
||||
return (
|
||||
Status.PENDING,
|
||||
{**config, "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")},
|
||||
)
|
||||
else:
|
||||
_LOG.error("Response: %s :: %s", response, response.text)
|
||||
# _LOG.error("Bad Request:\n%s", response.request.body)
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
#
|
||||
"""SSH remote service."""
|
||||
|
||||
from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
|
||||
from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService
|
||||
from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
|
||||
|
||||
__all__ = [
|
||||
"SshHostService",
|
||||
|
|
|
@ -2,16 +2,13 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A collection functions for interacting with SSH servers as file shares.
|
||||
"""
|
||||
"""A collection functions for interacting with SSH servers as file shares."""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Tuple, Union
|
||||
|
||||
import logging
|
||||
|
||||
from asyncssh import scp, SFTPError, SFTPNoSuchFile, SFTPFailure, SSHClientConnection
|
||||
from asyncssh import SFTPError, SFTPFailure, SFTPNoSuchFile, SSHClientConnection, scp
|
||||
|
||||
from mlos_bench.services.base_fileshare import FileShareService
|
||||
from mlos_bench.services.remote.ssh.ssh_service import SshService
|
||||
|
@ -21,9 +18,7 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CopyMode(Enum):
|
||||
"""
|
||||
Copy mode enum.
|
||||
"""
|
||||
"""Copy mode enum."""
|
||||
|
||||
DOWNLOAD = 1
|
||||
UPLOAD = 2
|
||||
|
@ -32,17 +27,23 @@ class CopyMode(Enum):
|
|||
class SshFileShareService(FileShareService, SshService):
|
||||
"""A collection of functions for interacting with SSH servers as file shares."""
|
||||
|
||||
async def _start_file_copy(self, params: dict, mode: CopyMode,
|
||||
local_path: str, remote_path: str,
|
||||
recursive: bool = True) -> None:
|
||||
async def _start_file_copy(
|
||||
self,
|
||||
params: dict,
|
||||
mode: CopyMode,
|
||||
local_path: str,
|
||||
remote_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
# pylint: disable=too-many-arguments
|
||||
"""
|
||||
Starts a file copy operation
|
||||
Starts a file copy operation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of parameters (used for establishing the connection).
|
||||
Flat dictionary of (key, value) pairs of parameters (used for
|
||||
establishing the connection).
|
||||
mode : CopyMode
|
||||
Whether to download or upload the file.
|
||||
local_path : str
|
||||
|
@ -74,40 +75,70 @@ class SshFileShareService(FileShareService, SshService):
|
|||
raise ValueError(f"Unknown copy mode: {mode}")
|
||||
return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True)
|
||||
|
||||
def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
|
||||
def download(
|
||||
self,
|
||||
params: dict,
|
||||
remote_path: str,
|
||||
local_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
params = merge_parameters(
|
||||
dest=self.config.copy(),
|
||||
source=params,
|
||||
required_keys=[
|
||||
"ssh_hostname",
|
||||
]
|
||||
],
|
||||
)
|
||||
super().download(params, remote_path, local_path, recursive)
|
||||
file_copy_future = self._run_coroutine(
|
||||
self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive))
|
||||
self._start_file_copy(
|
||||
params,
|
||||
CopyMode.DOWNLOAD,
|
||||
local_path,
|
||||
remote_path,
|
||||
recursive,
|
||||
)
|
||||
)
|
||||
try:
|
||||
file_copy_future.result()
|
||||
except (OSError, SFTPError) as ex:
|
||||
_LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex)
|
||||
_LOG.error(
|
||||
"Failed to download %s to %s from %s: %s",
|
||||
remote_path,
|
||||
local_path,
|
||||
params,
|
||||
ex,
|
||||
)
|
||||
if isinstance(ex, SFTPNoSuchFile) or (
|
||||
isinstance(ex, SFTPFailure) and ex.code == 4
|
||||
and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory"))
|
||||
isinstance(ex, SFTPFailure)
|
||||
and ex.code == 4
|
||||
and any(
|
||||
msg.lower() in ex.reason.lower()
|
||||
for msg in ("File not found", "No such file or directory")
|
||||
)
|
||||
):
|
||||
_LOG.warning("File %s does not exist on %s", remote_path, params)
|
||||
raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex
|
||||
raise ex
|
||||
|
||||
def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
|
||||
def upload(
|
||||
self,
|
||||
params: dict,
|
||||
local_path: str,
|
||||
remote_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
params = merge_parameters(
|
||||
dest=self.config.copy(),
|
||||
source=params,
|
||||
required_keys=[
|
||||
"ssh_hostname",
|
||||
]
|
||||
],
|
||||
)
|
||||
super().upload(params, local_path, remote_path, recursive)
|
||||
file_copy_future = self._run_coroutine(
|
||||
self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive))
|
||||
self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)
|
||||
)
|
||||
try:
|
||||
file_copy_future.result()
|
||||
except (OSError, SFTPError) as ex:
|
||||
|
|
|
@ -2,39 +2,36 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A collection Service functions for managing hosts via SSH.
|
||||
"""
|
||||
"""A collection Service functions for managing hosts via SSH."""
|
||||
|
||||
import logging
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import logging
|
||||
|
||||
from asyncssh import SSHCompletedProcess, ConnectionLost, DisconnectError, ProcessError
|
||||
from asyncssh import ConnectionLost, DisconnectError, ProcessError, SSHCompletedProcess
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.services.remote.ssh.ssh_service import SshService
|
||||
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
|
||||
from mlos_bench.services.types.os_ops_type import SupportsOSOps
|
||||
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
|
||||
from mlos_bench.util import merge_parameters
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
||||
"""
|
||||
Helper methods to manage machines via SSH.
|
||||
"""
|
||||
"""Helper methods to manage machines via SSH."""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
"""
|
||||
Create a new instance of an SSH Service.
|
||||
|
||||
|
@ -53,24 +50,36 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
|||
# Same methods are also provided by the AzureVMService class
|
||||
# pylint: disable=duplicate-code
|
||||
super().__init__(
|
||||
config, global_config, parent,
|
||||
self.merge_methods(methods, [
|
||||
self.shutdown,
|
||||
self.reboot,
|
||||
self.wait_os_operation,
|
||||
self.remote_exec,
|
||||
self.get_remote_exec_results,
|
||||
]))
|
||||
config,
|
||||
global_config,
|
||||
parent,
|
||||
self.merge_methods(
|
||||
methods,
|
||||
[
|
||||
self.shutdown,
|
||||
self.reboot,
|
||||
self.wait_os_operation,
|
||||
self.remote_exec,
|
||||
self.get_remote_exec_results,
|
||||
],
|
||||
),
|
||||
)
|
||||
self._shell = self.config.get("ssh_shell", "/bin/bash")
|
||||
|
||||
async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess:
|
||||
async def _run_cmd(
|
||||
self,
|
||||
params: dict,
|
||||
script: Iterable[str],
|
||||
env_params: dict,
|
||||
) -> SSHCompletedProcess:
|
||||
"""
|
||||
Runs a command asynchronously on a host via SSH.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict
|
||||
Flat dictionary of (key, value) pairs of parameters (used for establishing the connection).
|
||||
Flat dictionary of (key, value) pairs of parameters (used for
|
||||
establishing the connection).
|
||||
cmd : str
|
||||
Command(s) to run via shell.
|
||||
|
||||
|
@ -83,19 +92,29 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
|||
# Script should be an iterable of lines, not an iterable string.
|
||||
script = [script]
|
||||
connection, _ = await self._get_client_connection(params)
|
||||
# Note: passing environment variables to SSH servers is typically restricted to just some LC_* values.
|
||||
# Note: passing environment variables to SSH servers is typically restricted
|
||||
# to just some LC_* values.
|
||||
# Handle transferring environment variables by making a script to set them.
|
||||
env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()]
|
||||
script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()]
|
||||
script_lines = env_script_lines + [
|
||||
line_split for line in script for line_split in line.splitlines()
|
||||
]
|
||||
# Note: connection.run() uses "exec" with a shell by default.
|
||||
script_str = '\n'.join(script_lines)
|
||||
script_str = "\n".join(script_lines)
|
||||
_LOG.debug("Running script on %s:\n%s", connection, script_str)
|
||||
return await connection.run(script_str,
|
||||
check=False,
|
||||
timeout=self._request_timeout,
|
||||
env=env_params)
|
||||
return await connection.run(
|
||||
script_str,
|
||||
check=False,
|
||||
timeout=self._request_timeout,
|
||||
env=env_params,
|
||||
)
|
||||
|
||||
def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]:
|
||||
def remote_exec(
|
||||
self,
|
||||
script: Iterable[str],
|
||||
config: dict,
|
||||
env_params: dict,
|
||||
) -> Tuple["Status", dict]:
|
||||
"""
|
||||
Start running a command on remote host OS.
|
||||
|
||||
|
@ -122,9 +141,15 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
|||
source=config,
|
||||
required_keys=[
|
||||
"ssh_hostname",
|
||||
]
|
||||
],
|
||||
)
|
||||
config["asyncRemoteExecResultsFuture"] = self._run_coroutine(
|
||||
self._run_cmd(
|
||||
config,
|
||||
script,
|
||||
env_params,
|
||||
)
|
||||
)
|
||||
config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params))
|
||||
return (Status.PENDING, config)
|
||||
|
||||
def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]:
|
||||
|
@ -155,7 +180,11 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
|||
stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout
|
||||
stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr
|
||||
return (
|
||||
Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED,
|
||||
(
|
||||
Status.SUCCEEDED
|
||||
if result.exit_status == 0 and result.returncode == 0
|
||||
else Status.FAILED
|
||||
),
|
||||
{
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
|
@ -167,7 +196,8 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
|||
return (Status.FAILED, {"result": result})
|
||||
|
||||
def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]:
|
||||
"""_summary_
|
||||
"""
|
||||
_summary_
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -187,9 +217,9 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
|||
source=params,
|
||||
required_keys=[
|
||||
"ssh_hostname",
|
||||
]
|
||||
],
|
||||
)
|
||||
cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list])
|
||||
cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list])
|
||||
script = rf"""
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
sudo=$(command -v sudo)
|
||||
|
@ -224,10 +254,10 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
|||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
cmd_opts_list = [
|
||||
'shutdown -h now',
|
||||
'poweroff',
|
||||
'halt -p',
|
||||
'systemctl poweroff',
|
||||
"shutdown -h now",
|
||||
"poweroff",
|
||||
"halt -p",
|
||||
"systemctl poweroff",
|
||||
]
|
||||
return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params)
|
||||
|
||||
|
@ -249,18 +279,18 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
|
|||
Status is one of {PENDING, SUCCEEDED, FAILED}
|
||||
"""
|
||||
cmd_opts_list = [
|
||||
'shutdown -r now',
|
||||
'reboot',
|
||||
'halt --reboot',
|
||||
'systemctl reboot',
|
||||
'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1',
|
||||
"shutdown -r now",
|
||||
"reboot",
|
||||
"halt --reboot",
|
||||
"systemctl reboot",
|
||||
"kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1",
|
||||
]
|
||||
return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params)
|
||||
|
||||
def wait_os_operation(self, params: dict) -> Tuple[Status, dict]:
|
||||
"""
|
||||
Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED.
|
||||
Return TIMED_OUT when timing out.
|
||||
Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return
|
||||
TIMED_OUT when timing out.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -2,25 +2,38 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
A collection functions for interacting with SSH servers as file shares.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta
|
||||
from asyncio import Event as CoroEvent, Lock as CoroLock
|
||||
from warnings import warn
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from threading import current_thread
|
||||
"""A collection functions for interacting with SSH servers as file shares."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABCMeta
|
||||
from asyncio import Event as CoroEvent
|
||||
from asyncio import Lock as CoroLock
|
||||
from threading import current_thread
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from warnings import warn
|
||||
|
||||
import asyncssh
|
||||
from asyncssh.connection import SSHClientConnection
|
||||
|
||||
from mlos_bench.event_loop_context import (
|
||||
CoroReturnType,
|
||||
EventLoopContext,
|
||||
FutureReturnType,
|
||||
)
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.event_loop_context import EventLoopContext, CoroReturnType, FutureReturnType
|
||||
from mlos_bench.util import nullable
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -30,13 +43,13 @@ class SshClient(asyncssh.SSHClient):
|
|||
"""
|
||||
Wrapper around SSHClient to help provide connection caching and reconnect logic.
|
||||
|
||||
Used by the SshService to try and maintain a single connection to hosts,
|
||||
handle reconnects if possible, and use that to run commands rather than
|
||||
reconnect for each command.
|
||||
Used by the SshService to try and maintain a single connection to hosts, handle
|
||||
reconnects if possible, and use that to run commands rather than reconnect for each
|
||||
command.
|
||||
"""
|
||||
|
||||
_CONNECTION_PENDING = 'INIT'
|
||||
_CONNECTION_LOST = 'LOST'
|
||||
_CONNECTION_PENDING = "INIT"
|
||||
_CONNECTION_LOST = "LOST"
|
||||
|
||||
def __init__(self, *args: tuple, **kwargs: dict):
|
||||
self._connection_id: str = SshClient._CONNECTION_PENDING
|
||||
|
@ -50,12 +63,16 @@ class SshClient(asyncssh.SSHClient):
|
|||
@staticmethod
|
||||
def id_from_connection(connection: SSHClientConnection) -> str:
|
||||
"""Gets a unique id repr for the connection."""
|
||||
return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
return f"{connection._username}@{connection._host}:{connection._port}"
|
||||
|
||||
@staticmethod
|
||||
def id_from_params(connect_params: dict) -> str:
|
||||
"""Gets a unique id repr for the connection."""
|
||||
return f"{connect_params.get('username')}@{connect_params['host']}:{connect_params.get('port')}"
|
||||
return (
|
||||
f"{connect_params.get('username')}@{connect_params['host']}"
|
||||
f":{connect_params.get('port')}"
|
||||
)
|
||||
|
||||
def connection_made(self, conn: SSHClientConnection) -> None:
|
||||
"""
|
||||
|
@ -64,8 +81,12 @@ class SshClient(asyncssh.SSHClient):
|
|||
Changes the connection_id from _CONNECTION_PENDING to a unique id repr.
|
||||
"""
|
||||
self._conn_event.clear()
|
||||
_LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \
|
||||
# pylint: disable=protected-access
|
||||
_LOG.debug(
|
||||
"%s: Connection made by %s: %s",
|
||||
current_thread().name,
|
||||
conn._options.env, # pylint: disable=protected-access
|
||||
conn,
|
||||
)
|
||||
self._connection_id = SshClient.id_from_connection(conn)
|
||||
self._connection = conn
|
||||
self._conn_event.set()
|
||||
|
@ -75,18 +96,26 @@ class SshClient(asyncssh.SSHClient):
|
|||
self._conn_event.clear()
|
||||
_LOG.debug("%s: %s", current_thread().name, "connection_lost")
|
||||
if exc is None:
|
||||
_LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc)
|
||||
_LOG.debug(
|
||||
"%s: gracefully disconnected ssh from %s: %s",
|
||||
current_thread().name,
|
||||
self._connection_id,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
_LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc)
|
||||
_LOG.debug(
|
||||
"%s: ssh connection lost on %s: %s",
|
||||
current_thread().name,
|
||||
self._connection_id,
|
||||
exc,
|
||||
)
|
||||
self._connection_id = SshClient._CONNECTION_LOST
|
||||
self._connection = None
|
||||
self._conn_event.set()
|
||||
return super().connection_lost(exc)
|
||||
|
||||
async def connection(self) -> Optional[SSHClientConnection]:
|
||||
"""
|
||||
Waits for and returns the SSHClientConnection to be established or lost.
|
||||
"""
|
||||
"""Waits for and returns the SSHClientConnection to be established or lost."""
|
||||
_LOG.debug("%s: Waiting for connection to be available.", current_thread().name)
|
||||
await self._conn_event.wait()
|
||||
_LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id)
|
||||
|
@ -96,6 +125,7 @@ class SshClient(asyncssh.SSHClient):
|
|||
class SshClientCache:
|
||||
"""
|
||||
Manages a cache of SshClient connections.
|
||||
|
||||
Note: Only one per event loop thread supported.
|
||||
See additional details in SshService comments.
|
||||
"""
|
||||
|
@ -114,6 +144,7 @@ class SshClientCache:
|
|||
def enter(self) -> None:
|
||||
"""
|
||||
Manages the cache lifecycle with reference counting.
|
||||
|
||||
To be used in the __enter__ method of a caller's context manager.
|
||||
"""
|
||||
self._refcnt += 1
|
||||
|
@ -121,6 +152,7 @@ class SshClientCache:
|
|||
def exit(self) -> None:
|
||||
"""
|
||||
Manages the cache lifecycle with reference counting.
|
||||
|
||||
To be used in the __exit__ method of a caller's context manager.
|
||||
"""
|
||||
self._refcnt -= 1
|
||||
|
@ -130,7 +162,10 @@ class SshClientCache:
|
|||
warn(RuntimeWarning("SshClientCache lock was still held on exit."))
|
||||
self._cache_lock.release()
|
||||
|
||||
async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]:
|
||||
async def get_client_connection(
|
||||
self,
|
||||
connect_params: dict,
|
||||
) -> Tuple[SSHClientConnection, SshClient]:
|
||||
"""
|
||||
Gets a (possibly cached) client connection.
|
||||
|
||||
|
@ -153,13 +188,21 @@ class SshClientCache:
|
|||
_LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id)
|
||||
connection = await client.connection()
|
||||
if not connection:
|
||||
_LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id)
|
||||
_LOG.debug(
|
||||
"%s: Removing stale client connection %s from cache.",
|
||||
current_thread().name,
|
||||
connection_id,
|
||||
)
|
||||
self._cache.pop(connection_id)
|
||||
# Try to reconnect next.
|
||||
else:
|
||||
_LOG.debug("%s: Using cached client %s", current_thread().name, connection_id)
|
||||
if connection_id not in self._cache:
|
||||
_LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id)
|
||||
_LOG.debug(
|
||||
"%s: Establishing client connection to %s",
|
||||
current_thread().name,
|
||||
connection_id,
|
||||
)
|
||||
connection, client = await asyncssh.create_connection(SshClient, **connect_params)
|
||||
assert isinstance(client, SshClient)
|
||||
self._cache[connection_id] = (connection, client)
|
||||
|
@ -167,18 +210,14 @@ class SshClientCache:
|
|||
return self._cache[connection_id]
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""
|
||||
Closes all cached connections.
|
||||
"""
|
||||
for (connection, _) in self._cache.values():
|
||||
"""Closes all cached connections."""
|
||||
for connection, _ in self._cache.values():
|
||||
connection.close()
|
||||
self._cache = {}
|
||||
|
||||
|
||||
class SshService(Service, metaclass=ABCMeta):
|
||||
"""
|
||||
Base class for SSH services.
|
||||
"""
|
||||
"""Base class for SSH services."""
|
||||
|
||||
# AsyncSSH requires an asyncio event loop to be running to work.
|
||||
# However, running that event loop blocks the main thread.
|
||||
|
@ -210,21 +249,23 @@ class SshService(Service, metaclass=ABCMeta):
|
|||
|
||||
_REQUEST_TIMEOUT: Optional[float] = None # seconds
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional[Service] = None,
|
||||
methods: Union[Dict[str, Callable], List[Callable], None] = None,
|
||||
):
|
||||
super().__init__(config, global_config, parent, methods)
|
||||
|
||||
# Make sure that the value we allow overriding on a per-connection
|
||||
# basis are present in the config so merge_parameters can do its thing.
|
||||
self.config.setdefault('ssh_port', None)
|
||||
assert isinstance(self.config['ssh_port'], (int, type(None)))
|
||||
self.config.setdefault('ssh_username', None)
|
||||
assert isinstance(self.config['ssh_username'], (str, type(None)))
|
||||
self.config.setdefault('ssh_priv_key_path', None)
|
||||
assert isinstance(self.config['ssh_priv_key_path'], (str, type(None)))
|
||||
self.config.setdefault("ssh_port", None)
|
||||
assert isinstance(self.config["ssh_port"], (int, type(None)))
|
||||
self.config.setdefault("ssh_username", None)
|
||||
assert isinstance(self.config["ssh_username"], (str, type(None)))
|
||||
self.config.setdefault("ssh_priv_key_path", None)
|
||||
assert isinstance(self.config["ssh_priv_key_path"], (str, type(None)))
|
||||
|
||||
# None can be used to disable the request timeout.
|
||||
self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT)
|
||||
|
@ -235,24 +276,25 @@ class SshService(Service, metaclass=ABCMeta):
|
|||
# In general scripted commands shouldn't need a pty and having one
|
||||
# available can confuse some commands, though we may need to make
|
||||
# this configurable in the future.
|
||||
'request_pty': False,
|
||||
# By default disable known_hosts checking (since most VMs expected to be dynamically created).
|
||||
'known_hosts': None,
|
||||
"request_pty": False,
|
||||
# By default disable known_hosts checking (since most VMs expected to be
|
||||
# dynamically created).
|
||||
"known_hosts": None,
|
||||
}
|
||||
|
||||
if 'ssh_known_hosts_file' in self.config:
|
||||
self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None)
|
||||
if isinstance(self._connect_params['known_hosts'], str):
|
||||
known_hosts_file = os.path.expanduser(self._connect_params['known_hosts'])
|
||||
if "ssh_known_hosts_file" in self.config:
|
||||
self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None)
|
||||
if isinstance(self._connect_params["known_hosts"], str):
|
||||
known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"])
|
||||
if not os.path.exists(known_hosts_file):
|
||||
raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist")
|
||||
self._connect_params['known_hosts'] = known_hosts_file
|
||||
if self._connect_params['known_hosts'] is None:
|
||||
self._connect_params["known_hosts"] = known_hosts_file
|
||||
if self._connect_params["known_hosts"] is None:
|
||||
_LOG.info("%s known_hosts checking is disabled per config.", self)
|
||||
|
||||
if 'ssh_keepalive_interval' in self.config:
|
||||
keepalive_internal = self.config.get('ssh_keepalive_interval')
|
||||
self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal)
|
||||
if "ssh_keepalive_interval" in self.config:
|
||||
keepalive_internal = self.config.get("ssh_keepalive_interval")
|
||||
self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal)
|
||||
|
||||
def _enter_context(self) -> "SshService":
|
||||
# Start the background thread if it's not already running.
|
||||
|
@ -262,9 +304,12 @@ class SshService(Service, metaclass=ABCMeta):
|
|||
super()._enter_context()
|
||||
return self
|
||||
|
||||
def _exit_context(self, ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
def _exit_context(
|
||||
self,
|
||||
ex_type: Optional[Type[BaseException]],
|
||||
ex_val: Optional[BaseException],
|
||||
ex_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
# Stop the background thread if it's not needed anymore and potentially
|
||||
# cleanup the cache as well.
|
||||
assert self._in_context
|
||||
|
@ -276,6 +321,7 @@ class SshService(Service, metaclass=ABCMeta):
|
|||
def clear_client_cache(cls) -> None:
|
||||
"""
|
||||
Clears the cache of client connections.
|
||||
|
||||
Note: This may cause in flight operations to fail.
|
||||
"""
|
||||
cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup()
|
||||
|
@ -319,24 +365,27 @@ class SshService(Service, metaclass=ABCMeta):
|
|||
# Start with the base config params.
|
||||
connect_params = self._connect_params.copy()
|
||||
|
||||
connect_params['host'] = params['ssh_hostname'] # required
|
||||
connect_params["host"] = params["ssh_hostname"] # required
|
||||
|
||||
if params.get('ssh_port'):
|
||||
connect_params['port'] = int(params.pop('ssh_port'))
|
||||
elif self.config['ssh_port']:
|
||||
connect_params['port'] = int(self.config['ssh_port'])
|
||||
if params.get("ssh_port"):
|
||||
connect_params["port"] = int(params.pop("ssh_port"))
|
||||
elif self.config["ssh_port"]:
|
||||
connect_params["port"] = int(self.config["ssh_port"])
|
||||
|
||||
if 'ssh_username' in params:
|
||||
connect_params['username'] = str(params.pop('ssh_username'))
|
||||
elif self.config['ssh_username']:
|
||||
connect_params['username'] = str(self.config['ssh_username'])
|
||||
if "ssh_username" in params:
|
||||
connect_params["username"] = str(params.pop("ssh_username"))
|
||||
elif self.config["ssh_username"]:
|
||||
connect_params["username"] = str(self.config["ssh_username"])
|
||||
|
||||
priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path'])
|
||||
priv_key_file: Optional[str] = params.get(
|
||||
"ssh_priv_key_path",
|
||||
self.config["ssh_priv_key_path"],
|
||||
)
|
||||
if priv_key_file:
|
||||
priv_key_file = os.path.expanduser(priv_key_file)
|
||||
if not os.path.exists(priv_key_file):
|
||||
raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist")
|
||||
connect_params['client_keys'] = [priv_key_file]
|
||||
connect_params["client_keys"] = [priv_key_file]
|
||||
|
||||
return connect_params
|
||||
|
||||
|
@ -355,4 +404,6 @@ class SshService(Service, metaclass=ABCMeta):
|
|||
The connection and client objects.
|
||||
"""
|
||||
assert self._in_context
|
||||
return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params))
|
||||
return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(
|
||||
self._get_connect_params(params)
|
||||
)
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Service types for implementing declaring Service behavior for Environments to use in mlos_bench.
|
||||
"""Service types for implementing declaring Service behavior for Environments to use in
|
||||
mlos_bench.
|
||||
"""
|
||||
|
||||
from mlos_bench.services.types.authenticator_type import SupportsAuth
|
||||
|
@ -11,18 +11,19 @@ from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
|
|||
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
|
||||
from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
|
||||
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
|
||||
from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
|
||||
from mlos_bench.services.types.network_provisioner_type import (
|
||||
SupportsNetworkProvisioning,
|
||||
)
|
||||
from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig
|
||||
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
|
||||
|
||||
|
||||
__all__ = [
|
||||
'SupportsAuth',
|
||||
'SupportsConfigLoading',
|
||||
'SupportsFileShareOps',
|
||||
'SupportsHostProvisioning',
|
||||
'SupportsLocalExec',
|
||||
'SupportsNetworkProvisioning',
|
||||
'SupportsRemoteConfig',
|
||||
'SupportsRemoteExec',
|
||||
"SupportsAuth",
|
||||
"SupportsConfigLoading",
|
||||
"SupportsFileShareOps",
|
||||
"SupportsHostProvisioning",
|
||||
"SupportsLocalExec",
|
||||
"SupportsNetworkProvisioning",
|
||||
"SupportsRemoteConfig",
|
||||
"SupportsRemoteExec",
|
||||
]
|
||||
|
|
|
@ -2,18 +2,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for authentication for the cloud services.
|
||||
"""
|
||||
"""Protocol interface for authentication for the cloud services."""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsAuth(Protocol):
|
||||
"""
|
||||
Protocol interface for authentication for the cloud services.
|
||||
"""
|
||||
"""Protocol interface for authentication for the cloud services."""
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""
|
||||
|
|
|
@ -2,34 +2,38 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for helper functions to lookup and load configs.
|
||||
"""
|
||||
"""Protocol interface for helper functions to lookup and load configs."""
|
||||
|
||||
from typing import Any, Dict, List, Iterable, Optional, Union, Protocol, runtime_checkable, TYPE_CHECKING
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from mlos_bench.config.schemas import ConfigSchema
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
|
||||
|
||||
# Avoid's circular import issues.
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.environments.base_environment import Environment
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsConfigLoading(Protocol):
|
||||
"""
|
||||
Protocol interface for helper functions to lookup and load configs.
|
||||
"""
|
||||
"""Protocol interface for helper functions to lookup and load configs."""
|
||||
|
||||
def resolve_path(self, file_path: str,
|
||||
extra_paths: Optional[Iterable[str]] = None) -> str:
|
||||
def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str:
|
||||
"""
|
||||
Prepend the suitable `_config_path` to `path` if the latter is not absolute.
|
||||
If `_config_path` is `None` or `path` is absolute, return `path` as is.
|
||||
Prepend the suitable `_config_path` to `path` if the latter is not absolute. If
|
||||
`_config_path` is `None` or `path` is absolute, return `path` as is.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -44,11 +48,14 @@ class SupportsConfigLoading(Protocol):
|
|||
An actual path to the config or script.
|
||||
"""
|
||||
|
||||
def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]:
|
||||
def load_config(
|
||||
self,
|
||||
json_file_name: str,
|
||||
schema_type: Optional[ConfigSchema],
|
||||
) -> Union[dict, List[dict]]:
|
||||
"""
|
||||
Load JSON config file. Search for a file relative to `_config_path`
|
||||
if the input path is not absolute.
|
||||
This method is exported to be used as a service.
|
||||
Load JSON config file. Search for a file relative to `_config_path` if the input
|
||||
path is not absolute. This method is exported to be used as a service.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -63,12 +70,14 @@ class SupportsConfigLoading(Protocol):
|
|||
Free-format dictionary that contains the configuration.
|
||||
"""
|
||||
|
||||
def build_environment(self, # pylint: disable=too-many-arguments
|
||||
config: dict,
|
||||
tunables: "TunableGroups",
|
||||
global_config: Optional[dict] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional["Service"] = None) -> "Environment":
|
||||
def build_environment(
|
||||
self, # pylint: disable=too-many-arguments
|
||||
config: dict,
|
||||
tunables: "TunableGroups",
|
||||
global_config: Optional[dict] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional["Service"] = None,
|
||||
) -> "Environment":
|
||||
"""
|
||||
Factory method for a new environment with a given config.
|
||||
|
||||
|
@ -98,12 +107,13 @@ class SupportsConfigLoading(Protocol):
|
|||
"""
|
||||
|
||||
def load_environment_list( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
json_file_name: str,
|
||||
tunables: "TunableGroups",
|
||||
global_config: Optional[dict] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional["Service"] = None) -> List["Environment"]:
|
||||
self,
|
||||
json_file_name: str,
|
||||
tunables: "TunableGroups",
|
||||
global_config: Optional[dict] = None,
|
||||
parent_args: Optional[Dict[str, TunableValue]] = None,
|
||||
service: Optional["Service"] = None,
|
||||
) -> List["Environment"]:
|
||||
"""
|
||||
Load and build a list of environments from the config file.
|
||||
|
||||
|
@ -128,12 +138,15 @@ class SupportsConfigLoading(Protocol):
|
|||
A list of new benchmarking environments.
|
||||
"""
|
||||
|
||||
def load_services(self, json_file_names: Iterable[str],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional["Service"] = None) -> "Service":
|
||||
def load_services(
|
||||
self,
|
||||
json_file_names: Iterable[str],
|
||||
global_config: Optional[Dict[str, Any]] = None,
|
||||
parent: Optional["Service"] = None,
|
||||
) -> "Service":
|
||||
"""
|
||||
Read the configuration files and bundle all service methods
|
||||
from those configs into a single Service object.
|
||||
Read the configuration files and bundle all service methods from those configs
|
||||
into a single Service object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -2,20 +2,22 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for file share operations.
|
||||
"""
|
||||
"""Protocol interface for file share operations."""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsFileShareOps(Protocol):
|
||||
"""
|
||||
Protocol interface for file share operations.
|
||||
"""
|
||||
"""Protocol interface for file share operations."""
|
||||
|
||||
def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
|
||||
def download(
|
||||
self,
|
||||
params: dict,
|
||||
remote_path: str,
|
||||
local_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Downloads contents from a remote share path to a local path.
|
||||
|
||||
|
@ -33,7 +35,13 @@ class SupportsFileShareOps(Protocol):
|
|||
if True (the default), download the entire directory tree.
|
||||
"""
|
||||
|
||||
def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
|
||||
def upload(
|
||||
self,
|
||||
params: dict,
|
||||
local_path: str,
|
||||
remote_path: str,
|
||||
recursive: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Uploads contents from a local path to remote share path.
|
||||
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for Host/VM boot operations.
|
||||
"""
|
||||
"""Protocol interface for Host/VM boot operations."""
|
||||
|
||||
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.environments.status import Status
|
||||
|
@ -14,9 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
@runtime_checkable
|
||||
class SupportsHostOps(Protocol):
|
||||
"""
|
||||
Protocol interface for Host/VM boot operations.
|
||||
"""
|
||||
"""Protocol interface for Host/VM boot operations."""
|
||||
|
||||
def start_host(self, params: dict) -> Tuple["Status", dict]:
|
||||
"""
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for Host/VM provisioning operations.
|
||||
"""
|
||||
"""Protocol interface for Host/VM provisioning operations."""
|
||||
|
||||
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.environments.status import Status
|
||||
|
@ -14,9 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
@runtime_checkable
|
||||
class SupportsHostProvisioning(Protocol):
|
||||
"""
|
||||
Protocol interface for Host/VM provisioning operations.
|
||||
"""
|
||||
"""Protocol interface for Host/VM provisioning operations."""
|
||||
|
||||
def provision_host(self, params: dict) -> Tuple["Status", dict]:
|
||||
"""
|
||||
|
@ -46,7 +42,8 @@ class SupportsHostProvisioning(Protocol):
|
|||
params : dict
|
||||
Flat dictionary of (key, value) pairs of tunable parameters.
|
||||
is_setup : bool
|
||||
If True, wait for Host/VM being deployed; otherwise, wait for successful deprovisioning.
|
||||
If True, wait for Host/VM being deployed; otherwise, wait for successful
|
||||
deprovisioning.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -74,7 +71,8 @@ class SupportsHostProvisioning(Protocol):
|
|||
|
||||
def deallocate_host(self, params: dict) -> Tuple["Status", dict]:
|
||||
"""
|
||||
Deallocates the Host/VM by shutting it down then releasing the compute resources.
|
||||
Deallocates the Host/VM by shutting it down then releasing the compute
|
||||
resources.
|
||||
|
||||
Note: This can cause the VM to arrive on a new host node when its
|
||||
restarted, which may have different performance characteristics.
|
||||
|
|
|
@ -2,15 +2,21 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for Service types that provide helper functions to run
|
||||
scripts and commands locally on the scheduler side.
|
||||
"""Protocol interface for Service types that provide helper functions to run scripts and
|
||||
commands locally on the scheduler side.
|
||||
"""
|
||||
|
||||
from typing import Iterable, Mapping, Optional, Tuple, Union, Protocol, runtime_checkable
|
||||
|
||||
import tempfile
|
||||
import contextlib
|
||||
import tempfile
|
||||
from typing import (
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
|
||||
|
@ -18,16 +24,19 @@ from mlos_bench.tunables.tunable import TunableValue
|
|||
@runtime_checkable
|
||||
class SupportsLocalExec(Protocol):
|
||||
"""
|
||||
Protocol interface for a collection of methods to run scripts and commands
|
||||
in an external process on the node acting as the scheduler. Can be useful
|
||||
for data processing due to reduced dependency management complications vs
|
||||
the target environment.
|
||||
Used in LocalEnv and provided by LocalExecService.
|
||||
Protocol interface for a collection of methods to run scripts and commands in an
|
||||
external process on the node acting as the scheduler.
|
||||
|
||||
Can be useful for data processing due to reduced dependency management complications
|
||||
vs the target environment. Used in LocalEnv and provided by LocalExecService.
|
||||
"""
|
||||
|
||||
def local_exec(self, script_lines: Iterable[str],
|
||||
env: Optional[Mapping[str, TunableValue]] = None,
|
||||
cwd: Optional[str] = None) -> Tuple[int, str, str]:
|
||||
def local_exec(
|
||||
self,
|
||||
script_lines: Iterable[str],
|
||||
env: Optional[Mapping[str, TunableValue]] = None,
|
||||
cwd: Optional[str] = None,
|
||||
) -> Tuple[int, str, str]:
|
||||
"""
|
||||
Execute the script lines from `script_lines` in a local process.
|
||||
|
||||
|
@ -48,7 +57,10 @@ class SupportsLocalExec(Protocol):
|
|||
A 3-tuple of return code, stdout, and stderr of the script process.
|
||||
"""
|
||||
|
||||
def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
|
||||
def temp_dir_context(
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
|
||||
"""
|
||||
Create a temp directory or use the provided path.
|
||||
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for Network provisioning operations.
|
||||
"""
|
||||
"""Protocol interface for Network provisioning operations."""
|
||||
|
||||
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.environments.status import Status
|
||||
|
@ -14,9 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
@runtime_checkable
|
||||
class SupportsNetworkProvisioning(Protocol):
|
||||
"""
|
||||
Protocol interface for Network provisioning operations.
|
||||
"""
|
||||
"""Protocol interface for Network provisioning operations."""
|
||||
|
||||
def provision_network(self, params: dict) -> Tuple["Status", dict]:
|
||||
"""
|
||||
|
@ -46,7 +42,8 @@ class SupportsNetworkProvisioning(Protocol):
|
|||
params : dict
|
||||
Flat dictionary of (key, value) pairs of tunable parameters.
|
||||
is_setup : bool
|
||||
If True, wait for Network being deployed; otherwise, wait for successful deprovisioning.
|
||||
If True, wait for Network being deployed; otherwise, wait for successful
|
||||
deprovisioning.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -56,7 +53,11 @@ class SupportsNetworkProvisioning(Protocol):
|
|||
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
|
||||
"""
|
||||
|
||||
def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple["Status", dict]:
|
||||
def deprovision_network(
|
||||
self,
|
||||
params: dict,
|
||||
ignore_errors: bool = True,
|
||||
) -> Tuple["Status", dict]:
|
||||
"""
|
||||
Deprovisions the Network by deleting it.
|
||||
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for Host/OS operations.
|
||||
"""
|
||||
"""Protocol interface for Host/OS operations."""
|
||||
|
||||
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.environments.status import Status
|
||||
|
@ -14,9 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
@runtime_checkable
|
||||
class SupportsOSOps(Protocol):
|
||||
"""
|
||||
Protocol interface for Host/OS operations.
|
||||
"""
|
||||
"""Protocol interface for Host/OS operations."""
|
||||
|
||||
def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
|
||||
"""
|
||||
|
@ -56,8 +52,8 @@ class SupportsOSOps(Protocol):
|
|||
|
||||
def wait_os_operation(self, params: dict) -> Tuple["Status", dict]:
|
||||
"""
|
||||
Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED.
|
||||
Return TIMED_OUT when timing out.
|
||||
Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return
|
||||
TIMED_OUT when timing out.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for configuring cloud services.
|
||||
"""
|
||||
"""Protocol interface for configuring cloud services."""
|
||||
|
||||
from typing import Any, Dict, Protocol, Tuple, TYPE_CHECKING, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Dict, Protocol, Tuple, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.environments.status import Status
|
||||
|
@ -14,12 +12,9 @@ if TYPE_CHECKING:
|
|||
|
||||
@runtime_checkable
|
||||
class SupportsRemoteConfig(Protocol):
|
||||
"""
|
||||
Protocol interface for configuring cloud services.
|
||||
"""
|
||||
"""Protocol interface for configuring cloud services."""
|
||||
|
||||
def configure(self, config: Dict[str, Any],
|
||||
params: Dict[str, Any]) -> Tuple["Status", dict]:
|
||||
def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple["Status", dict]:
|
||||
"""
|
||||
Update the parameters of a SaaS service in the cloud.
|
||||
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for Service types that provide helper functions to run
|
||||
scripts on a remote host OS.
|
||||
"""Protocol interface for Service types that provide helper functions to run scripts on
|
||||
a remote host OS.
|
||||
"""
|
||||
|
||||
from typing import Iterable, Tuple, Protocol, runtime_checkable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Iterable, Protocol, Tuple, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.environments.status import Status
|
||||
|
@ -15,13 +14,16 @@ if TYPE_CHECKING:
|
|||
|
||||
@runtime_checkable
|
||||
class SupportsRemoteExec(Protocol):
|
||||
"""
|
||||
Protocol interface for Service types that provide helper functions to run
|
||||
scripts on a remote host OS.
|
||||
"""Protocol interface for Service types that provide helper functions to run scripts
|
||||
on a remote host OS.
|
||||
"""
|
||||
|
||||
def remote_exec(self, script: Iterable[str], config: dict,
|
||||
env_params: dict) -> Tuple["Status", dict]:
|
||||
def remote_exec(
|
||||
self,
|
||||
script: Iterable[str],
|
||||
config: dict,
|
||||
env_params: dict,
|
||||
) -> Tuple["Status", dict]:
|
||||
"""
|
||||
Run a command on remote host OS.
|
||||
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Protocol interface for VM provisioning operations.
|
||||
"""
|
||||
"""Protocol interface for VM provisioning operations."""
|
||||
|
||||
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.environments.status import Status
|
||||
|
@ -14,9 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
@runtime_checkable
|
||||
class SupportsVMOps(Protocol):
|
||||
"""
|
||||
Protocol interface for VM provisioning operations.
|
||||
"""
|
||||
"""Protocol interface for VM provisioning operations."""
|
||||
|
||||
def vm_provision(self, params: dict) -> Tuple["Status", dict]:
|
||||
"""
|
||||
|
@ -122,8 +118,8 @@ class SupportsVMOps(Protocol):
|
|||
|
||||
def wait_vm_operation(self, params: dict) -> Tuple["Status", dict]:
|
||||
"""
|
||||
Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED.
|
||||
Return TIMED_OUT when timing out.
|
||||
Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED. Return
|
||||
TIMED_OUT when timing out.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -2,14 +2,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Interfaces to the storage backends for OS Autotune.
|
||||
"""
|
||||
"""Interfaces to the storage backends for OS Autotune."""
|
||||
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
from mlos_bench.storage.storage_factory import from_config
|
||||
|
||||
__all__ = [
|
||||
'Storage',
|
||||
'from_config',
|
||||
"Storage",
|
||||
"from_config",
|
||||
]
|
||||
|
|
|
@ -2,13 +2,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base interface for accessing the stored benchmark experiment data.
|
||||
"""
|
||||
"""Base interface for accessing the stored benchmark experiment data."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from distutils.util import strtobool # pylint: disable=deprecated-module
|
||||
from typing import Dict, Literal, Optional, Tuple, TYPE_CHECKING
|
||||
from distutils.util import strtobool # pylint: disable=deprecated-module
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple
|
||||
|
||||
import pandas
|
||||
|
||||
|
@ -16,7 +14,9 @@ from mlos_bench.storage.base_tunable_config_data import TunableConfigData
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.storage.base_trial_data import TrialData
|
||||
from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
|
||||
from mlos_bench.storage.base_tunable_config_trial_group_data import (
|
||||
TunableConfigTrialGroupData,
|
||||
)
|
||||
|
||||
|
||||
class ExperimentData(metaclass=ABCMeta):
|
||||
|
@ -30,12 +30,15 @@ class ExperimentData(metaclass=ABCMeta):
|
|||
RESULT_COLUMN_PREFIX = "result."
|
||||
CONFIG_COLUMN_PREFIX = "config."
|
||||
|
||||
def __init__(self, *,
|
||||
experiment_id: str,
|
||||
description: str,
|
||||
root_env_config: str,
|
||||
git_repo: str,
|
||||
git_commit: str):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
experiment_id: str,
|
||||
description: str,
|
||||
root_env_config: str,
|
||||
git_repo: str,
|
||||
git_commit: str,
|
||||
):
|
||||
self._experiment_id = experiment_id
|
||||
self._description = description
|
||||
self._root_env_config = root_env_config
|
||||
|
@ -44,16 +47,12 @@ class ExperimentData(metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def experiment_id(self) -> str:
|
||||
"""
|
||||
ID of the experiment.
|
||||
"""
|
||||
"""ID of the experiment."""
|
||||
return self._experiment_id
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""
|
||||
Description of the experiment.
|
||||
"""
|
||||
"""Description of the experiment."""
|
||||
return self._description
|
||||
|
||||
@property
|
||||
|
@ -123,7 +122,8 @@ class ExperimentData(metaclass=ABCMeta):
|
|||
@property
|
||||
def default_tunable_config_id(self) -> Optional[int]:
|
||||
"""
|
||||
Retrieves the (tunable) config id for the default tunable values for this experiment.
|
||||
Retrieves the (tunable) config id for the default tunable values for this
|
||||
experiment.
|
||||
|
||||
Note: this is by *default* the first trial executed for this experiment.
|
||||
However, it is currently possible that the user changed the tunables config
|
||||
|
@ -140,9 +140,9 @@ class ExperimentData(metaclass=ABCMeta):
|
|||
trials_items = sorted(self.trials.items())
|
||||
if not trials_items:
|
||||
return None
|
||||
for (_trial_id, trial) in trials_items:
|
||||
for _trial_id, trial in trials_items:
|
||||
# Take the first config id marked as "defaults" when it was instantiated.
|
||||
if strtobool(str(trial.metadata_dict.get('is_defaults', False))):
|
||||
if strtobool(str(trial.metadata_dict.get("is_defaults", False))):
|
||||
return trial.tunable_config_id
|
||||
# Fallback (min trial_id)
|
||||
return trials_items[0][1].tunable_config_id
|
||||
|
@ -157,7 +157,8 @@ class ExperimentData(metaclass=ABCMeta):
|
|||
-------
|
||||
results : pandas.DataFrame
|
||||
A DataFrame with configurations and results from all trials of the experiment.
|
||||
Has columns [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status]
|
||||
Has columns
|
||||
[trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status]
|
||||
followed by tunable config parameters (prefixed with "config.") and
|
||||
trial results (prefixed with "result."). The latter can be NULLs if the
|
||||
trial was not successful.
|
||||
|
|
|
@ -2,15 +2,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base interface for saving and restoring the benchmark data.
|
||||
"""
|
||||
"""Base interface for saving and restoring the benchmark data."""
|
||||
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import Optional, List, Tuple, Dict, Iterator, Type, Any
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mlos_bench.config.schemas import ConfigSchema
|
||||
|
@ -24,15 +23,16 @@ _LOG = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Storage(metaclass=ABCMeta):
|
||||
"""
|
||||
An abstract interface between the benchmarking framework
|
||||
and storage systems (e.g., SQLite or MLFLow).
|
||||
"""An abstract interface between the benchmarking framework and storage systems
|
||||
(e.g., SQLite or MLFLow).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
"""
|
||||
Create a new storage object.
|
||||
|
||||
|
@ -48,10 +48,9 @@ class Storage(metaclass=ABCMeta):
|
|||
self._global_config = global_config or {}
|
||||
|
||||
def _validate_json_config(self, config: dict) -> None:
|
||||
"""
|
||||
Reconstructs a basic json config that this class might have been
|
||||
instantiated from in order to validate configs provided outside the
|
||||
file loading mechanism.
|
||||
"""Reconstructs a basic json config that this class might have been instantiated
|
||||
from in order to validate configs provided outside the file loading
|
||||
mechanism.
|
||||
"""
|
||||
json_config: dict = {
|
||||
"class": self.__class__.__module__ + "." + self.__class__.__name__,
|
||||
|
@ -73,13 +72,16 @@ class Storage(metaclass=ABCMeta):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def experiment(self, *,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
root_env_config: str,
|
||||
description: str,
|
||||
tunables: TunableGroups,
|
||||
opt_targets: Dict[str, Literal['min', 'max']]) -> 'Storage.Experiment':
|
||||
def experiment(
|
||||
self,
|
||||
*,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
root_env_config: str,
|
||||
description: str,
|
||||
tunables: TunableGroups,
|
||||
opt_targets: Dict[str, Literal["min", "max"]],
|
||||
) -> "Storage.Experiment":
|
||||
"""
|
||||
Create a new experiment in the storage.
|
||||
|
||||
|
@ -112,26 +114,31 @@ class Storage(metaclass=ABCMeta):
|
|||
# pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
Base interface for storing the results of the experiment.
|
||||
|
||||
This class is instantiated in the `Storage.experiment()` method.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
tunables: TunableGroups,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
root_env_config: str,
|
||||
description: str,
|
||||
opt_targets: Dict[str, Literal['min', 'max']]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tunables: TunableGroups,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
root_env_config: str,
|
||||
description: str,
|
||||
opt_targets: Dict[str, Literal["min", "max"]],
|
||||
):
|
||||
self._tunables = tunables.copy()
|
||||
self._trial_id = trial_id
|
||||
self._experiment_id = experiment_id
|
||||
(self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config)
|
||||
(self._git_repo, self._git_commit, self._root_env_config) = get_git_info(
|
||||
root_env_config
|
||||
)
|
||||
self._description = description
|
||||
self._opt_targets = opt_targets
|
||||
self._in_context = False
|
||||
|
||||
def __enter__(self) -> 'Storage.Experiment':
|
||||
def __enter__(self) -> "Storage.Experiment":
|
||||
"""
|
||||
Enter the context of the experiment.
|
||||
|
||||
|
@ -143,9 +150,12 @@ class Storage(metaclass=ABCMeta):
|
|||
self._in_context = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType]) -> Literal[False]:
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
"""
|
||||
End the context of the experiment.
|
||||
|
||||
|
@ -156,8 +166,11 @@ class Storage(metaclass=ABCMeta):
|
|||
_LOG.debug("Finishing experiment: %s", self)
|
||||
else:
|
||||
assert exc_type and exc_val
|
||||
_LOG.warning("Finishing experiment: %s", self,
|
||||
exc_info=(exc_type, exc_val, exc_tb))
|
||||
_LOG.warning(
|
||||
"Finishing experiment: %s",
|
||||
self,
|
||||
exc_info=(exc_type, exc_val, exc_tb),
|
||||
)
|
||||
assert self._in_context
|
||||
self._teardown(is_ok)
|
||||
self._in_context = False
|
||||
|
@ -168,7 +181,8 @@ class Storage(metaclass=ABCMeta):
|
|||
|
||||
def _setup(self) -> None:
|
||||
"""
|
||||
Create a record of the new experiment or find an existing one in the storage.
|
||||
Create a record of the new experiment or find an existing one in the
|
||||
storage.
|
||||
|
||||
This method is called by `Storage.Experiment.__enter__()`.
|
||||
"""
|
||||
|
@ -187,36 +201,34 @@ class Storage(metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def experiment_id(self) -> str:
|
||||
"""Get the Experiment's ID"""
|
||||
"""Get the Experiment's ID."""
|
||||
return self._experiment_id
|
||||
|
||||
@property
|
||||
def trial_id(self) -> int:
|
||||
"""Get the current Trial ID"""
|
||||
"""Get the current Trial ID."""
|
||||
return self._trial_id
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Get the Experiment's description"""
|
||||
"""Get the Experiment's description."""
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def tunables(self) -> TunableGroups:
|
||||
"""Get the Experiment's tunables"""
|
||||
"""Get the Experiment's tunables."""
|
||||
return self._tunables
|
||||
|
||||
@property
|
||||
def opt_targets(self) -> Dict[str, Literal["min", "max"]]:
|
||||
"""
|
||||
Get the Experiment's optimization targets and directions
|
||||
"""
|
||||
"""Get the Experiment's optimization targets and directions."""
|
||||
return self._opt_targets
|
||||
|
||||
@abstractmethod
|
||||
def merge(self, experiment_ids: List[str]) -> None:
|
||||
"""
|
||||
Merge in the results of other (compatible) experiments trials.
|
||||
Used to help warm up the optimizer for this experiment.
|
||||
Merge in the results of other (compatible) experiments trials. Used to help
|
||||
warm up the optimizer for this experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -226,9 +238,7 @@ class Storage(metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def load_tunable_config(self, config_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Load tunable values for a given config ID.
|
||||
"""
|
||||
"""Load tunable values for a given config ID."""
|
||||
|
||||
@abstractmethod
|
||||
def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
|
||||
|
@ -247,8 +257,10 @@ class Storage(metaclass=ABCMeta):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, last_trial_id: int = -1,
|
||||
) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
|
||||
def load(
|
||||
self,
|
||||
last_trial_id: int = -1,
|
||||
) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
|
||||
"""
|
||||
Load (tunable values, benchmark scores, status) to warm-up the optimizer.
|
||||
|
||||
|
@ -268,10 +280,15 @@ class Storage(metaclass=ABCMeta):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']:
|
||||
def pending_trials(
|
||||
self,
|
||||
timestamp: datetime,
|
||||
*,
|
||||
running: bool,
|
||||
) -> Iterator["Storage.Trial"]:
|
||||
"""
|
||||
Return an iterator over the pending trials that are scheduled to run
|
||||
on or before the specified timestamp.
|
||||
Return an iterator over the pending trials that are scheduled to run on or
|
||||
before the specified timestamp.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -288,8 +305,12 @@ class Storage(metaclass=ABCMeta):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
|
||||
config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial':
|
||||
def new_trial(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
ts_start: Optional[datetime] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> "Storage.Trial":
|
||||
"""
|
||||
Create a new experiment run in the storage.
|
||||
|
||||
|
@ -313,13 +334,20 @@ class Storage(metaclass=ABCMeta):
|
|||
# pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
Base interface for storing the results of a single run of the experiment.
|
||||
|
||||
This class is instantiated in the `Storage.Experiment.trial()` method.
|
||||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
tunables: TunableGroups, experiment_id: str, trial_id: int,
|
||||
tunable_config_id: int, opt_targets: Dict[str, Literal['min', 'max']],
|
||||
config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tunables: TunableGroups,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
tunable_config_id: int,
|
||||
opt_targets: Dict[str, Literal["min", "max"]],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self._tunables = tunables
|
||||
self._experiment_id = experiment_id
|
||||
self._trial_id = trial_id
|
||||
|
@ -332,29 +360,23 @@ class Storage(metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def trial_id(self) -> int:
|
||||
"""
|
||||
ID of the current trial.
|
||||
"""
|
||||
"""ID of the current trial."""
|
||||
return self._trial_id
|
||||
|
||||
@property
|
||||
def tunable_config_id(self) -> int:
|
||||
"""
|
||||
ID of the current trial (tunable) configuration.
|
||||
"""
|
||||
"""ID of the current trial (tunable) configuration."""
|
||||
return self._tunable_config_id
|
||||
|
||||
@property
|
||||
def opt_targets(self) -> Dict[str, Literal["min", "max"]]:
|
||||
"""
|
||||
Get the Trial's optimization targets and directions.
|
||||
"""
|
||||
"""Get the Trial's optimization targets and directions."""
|
||||
return self._opt_targets
|
||||
|
||||
@property
|
||||
def tunables(self) -> TunableGroups:
|
||||
"""
|
||||
Tunable parameters of the current trial
|
||||
Tunable parameters of the current trial.
|
||||
|
||||
(e.g., application Environment's "config")
|
||||
"""
|
||||
|
@ -362,8 +384,8 @@ class Storage(metaclass=ABCMeta):
|
|||
|
||||
def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Produce a copy of the global configuration updated
|
||||
with the parameters of the current trial.
|
||||
Produce a copy of the global configuration updated with the parameters of
|
||||
the current trial.
|
||||
|
||||
Note: this is not the target Environment's "config" (i.e., tunable
|
||||
params), but rather the internal "config" which consists of a
|
||||
|
@ -377,9 +399,12 @@ class Storage(metaclass=ABCMeta):
|
|||
return config
|
||||
|
||||
@abstractmethod
|
||||
def update(self, status: Status, timestamp: datetime,
|
||||
metrics: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
def update(
|
||||
self,
|
||||
status: Status,
|
||||
timestamp: datetime,
|
||||
metrics: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Update the storage with the results of the experiment.
|
||||
|
||||
|
@ -403,14 +428,21 @@ class Storage(metaclass=ABCMeta):
|
|||
assert metrics is not None
|
||||
opt_targets = set(self._opt_targets.keys())
|
||||
if not opt_targets.issubset(metrics.keys()):
|
||||
_LOG.warning("Trial %s :: opt.targets missing: %s",
|
||||
self, opt_targets.difference(metrics.keys()))
|
||||
_LOG.warning(
|
||||
"Trial %s :: opt.targets missing: %s",
|
||||
self,
|
||||
opt_targets.difference(metrics.keys()),
|
||||
)
|
||||
# raise ValueError()
|
||||
return metrics
|
||||
|
||||
@abstractmethod
|
||||
def update_telemetry(self, status: Status, timestamp: datetime,
|
||||
metrics: List[Tuple[datetime, str, Any]]) -> None:
|
||||
def update_telemetry(
|
||||
self,
|
||||
status: Status,
|
||||
timestamp: datetime,
|
||||
metrics: List[Tuple[datetime, str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Save the experiment's telemetry data and intermediate status.
|
||||
|
||||
|
|
|
@ -2,40 +2,43 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base interface for accessing the stored benchmark trial data.
|
||||
"""
|
||||
"""Base interface for accessing the stored benchmark trial data."""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import pandas
|
||||
from pytz import UTC
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
from mlos_bench.storage.base_tunable_config_data import TunableConfigData
|
||||
from mlos_bench.storage.util import kv_df_to_dict
|
||||
from mlos_bench.tunables.tunable import TunableValue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
|
||||
from mlos_bench.storage.base_tunable_config_trial_group_data import (
|
||||
TunableConfigTrialGroupData,
|
||||
)
|
||||
|
||||
|
||||
class TrialData(metaclass=ABCMeta):
|
||||
"""
|
||||
Base interface for accessing the stored experiment benchmark trial data.
|
||||
|
||||
A trial is a single run of an experiment with a given configuration (e.g., set
|
||||
of tunable parameters).
|
||||
A trial is a single run of an experiment with a given configuration (e.g., set of
|
||||
tunable parameters).
|
||||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
tunable_config_id: int,
|
||||
ts_start: datetime,
|
||||
ts_end: Optional[datetime],
|
||||
status: Status):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
tunable_config_id: int,
|
||||
ts_start: datetime,
|
||||
ts_end: Optional[datetime],
|
||||
status: Status,
|
||||
):
|
||||
self._experiment_id = experiment_id
|
||||
self._trial_id = trial_id
|
||||
self._tunable_config_id = tunable_config_id
|
||||
|
@ -46,7 +49,10 @@ class TrialData(metaclass=ABCMeta):
|
|||
self._status = status
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Trial :: {self._experiment_id}:{self._trial_id} cid:{self._tunable_config_id} {self._status.name}"
|
||||
return (
|
||||
f"Trial :: {self._experiment_id}:{self._trial_id} "
|
||||
f"cid:{self._tunable_config_id} {self._status.name}"
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
|
@ -55,44 +61,32 @@ class TrialData(metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def experiment_id(self) -> str:
|
||||
"""
|
||||
ID of the experiment this trial belongs to.
|
||||
"""
|
||||
"""ID of the experiment this trial belongs to."""
|
||||
return self._experiment_id
|
||||
|
||||
@property
|
||||
def trial_id(self) -> int:
|
||||
"""
|
||||
ID of the trial.
|
||||
"""
|
||||
"""ID of the trial."""
|
||||
return self._trial_id
|
||||
|
||||
@property
|
||||
def ts_start(self) -> datetime:
|
||||
"""
|
||||
Start timestamp of the trial (UTC).
|
||||
"""
|
||||
"""Start timestamp of the trial (UTC)."""
|
||||
return self._ts_start
|
||||
|
||||
@property
|
||||
def ts_end(self) -> Optional[datetime]:
|
||||
"""
|
||||
End timestamp of the trial (UTC).
|
||||
"""
|
||||
"""End timestamp of the trial (UTC)."""
|
||||
return self._ts_end
|
||||
|
||||
@property
|
||||
def status(self) -> Status:
|
||||
"""
|
||||
Status of the trial.
|
||||
"""
|
||||
"""Status of the trial."""
|
||||
return self._status
|
||||
|
||||
@property
|
||||
def tunable_config_id(self) -> int:
|
||||
"""
|
||||
ID of the (tunable) configuration of the trial.
|
||||
"""
|
||||
"""ID of the (tunable) configuration of the trial."""
|
||||
return self._tunable_config_id
|
||||
|
||||
@property
|
||||
|
@ -112,9 +106,7 @@ class TrialData(metaclass=ABCMeta):
|
|||
@property
|
||||
@abstractmethod
|
||||
def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData":
|
||||
"""
|
||||
Retrieve the trial's (tunable) config trial group data from the storage.
|
||||
"""
|
||||
"""Retrieve the trial's (tunable) config trial group data from the storage."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base interface for accessing the stored benchmark (tunable) config data.
|
||||
"""
|
||||
"""Base interface for accessing the stored benchmark (tunable) config data."""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
@ -21,8 +19,7 @@ class TunableConfigData(metaclass=ABCMeta):
|
|||
A configuration in this context is the set of tunable parameter values.
|
||||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
tunable_config_id: int):
|
||||
def __init__(self, *, tunable_config_id: int):
|
||||
self._tunable_config_id = tunable_config_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -35,9 +32,7 @@ class TunableConfigData(metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def tunable_config_id(self) -> int:
|
||||
"""
|
||||
Unique ID of the (tunable) configuration.
|
||||
"""
|
||||
"""Unique ID of the (tunable) configuration."""
|
||||
return self._tunable_config_id
|
||||
|
||||
@property
|
||||
|
|
|
@ -2,12 +2,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Base interface for accessing the stored benchmark config trial group data.
|
||||
"""
|
||||
"""Base interface for accessing the stored benchmark config trial group data."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import pandas
|
||||
|
||||
|
@ -19,18 +17,21 @@ if TYPE_CHECKING:
|
|||
|
||||
class TunableConfigTrialGroupData(metaclass=ABCMeta):
|
||||
"""
|
||||
Base interface for accessing the stored experiment benchmark tunable config
|
||||
trial group data.
|
||||
Base interface for accessing the stored experiment benchmark tunable config trial
|
||||
group data.
|
||||
|
||||
A (tunable) config is used to define an instance of values for a set of tunable
|
||||
parameters for a given experiment and can be used by one or more trial instances
|
||||
(e.g., for repeats), which we call a (tunable) config trial group.
|
||||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
experiment_id: str,
|
||||
tunable_config_id: int,
|
||||
tunable_config_trial_group_id: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
experiment_id: str,
|
||||
tunable_config_id: int,
|
||||
tunable_config_trial_group_id: Optional[int] = None,
|
||||
):
|
||||
self._experiment_id = experiment_id
|
||||
self._tunable_config_id = tunable_config_id
|
||||
# can be lazily initialized as necessary:
|
||||
|
@ -38,23 +39,17 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def experiment_id(self) -> str:
|
||||
"""
|
||||
ID of the experiment.
|
||||
"""
|
||||
"""ID of the experiment."""
|
||||
return self._experiment_id
|
||||
|
||||
@property
|
||||
def tunable_config_id(self) -> int:
|
||||
"""
|
||||
ID of the config.
|
||||
"""
|
||||
"""ID of the config."""
|
||||
return self._tunable_config_id
|
||||
|
||||
@abstractmethod
|
||||
def _get_tunable_config_trial_group_id(self) -> int:
|
||||
"""
|
||||
Retrieve the trial's config_trial_group_id from the storage.
|
||||
"""
|
||||
"""Retrieve the trial's config_trial_group_id from the storage."""
|
||||
raise NotImplementedError("subclass must implement")
|
||||
|
||||
@property
|
||||
|
@ -77,13 +72,17 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
|
|||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return self._tunable_config_id == other._tunable_config_id and self._experiment_id == other._experiment_id
|
||||
return (
|
||||
self._tunable_config_id == other._tunable_config_id
|
||||
and self._experiment_id == other._experiment_id
|
||||
)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def tunable_config(self) -> TunableConfigData:
|
||||
"""
|
||||
Retrieve the (tunable) config data for this (tunable) config trial group from the storage.
|
||||
Retrieve the (tunable) config data for this (tunable) config trial group from
|
||||
the storage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -94,7 +93,8 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def trials(self) -> Dict[int, "TrialData"]:
|
||||
"""
|
||||
Retrieve the trials' data for this (tunable) config trial group from the storage.
|
||||
Retrieve the trials' data for this (tunable) config trial group from the
|
||||
storage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -106,7 +106,8 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def results_df(self) -> pandas.DataFrame:
|
||||
"""
|
||||
Retrieve all results for this (tunable) config trial group as a single DataFrame.
|
||||
Retrieve all results for this (tunable) config trial group as a single
|
||||
DataFrame.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Interfaces to the SQL-based storage backends for OS Autotune.
|
||||
"""
|
||||
"""Interfaces to the SQL-based storage backends for OS Autotune."""
|
||||
from mlos_bench.storage.sql.storage import SqlStorage
|
||||
|
||||
__all__ = [
|
||||
'SqlStorage',
|
||||
"SqlStorage",
|
||||
]
|
||||
|
|
|
@ -2,39 +2,45 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Common SQL methods for accessing the stored benchmark data.
|
||||
"""
|
||||
"""Common SQL methods for accessing the stored benchmark data."""
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pandas
|
||||
from sqlalchemy import Engine, Integer, func, and_, select
|
||||
from sqlalchemy import Engine, Integer, and_, func, select
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.storage.base_experiment_data import ExperimentData
|
||||
from mlos_bench.storage.base_trial_data import TrialData
|
||||
from mlos_bench.storage.sql.schema import DbSchema
|
||||
from mlos_bench.util import utcify_timestamp, utcify_nullable_timestamp
|
||||
from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp
|
||||
|
||||
|
||||
def get_trials(
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
experiment_id: str,
|
||||
tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]:
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
experiment_id: str,
|
||||
tunable_config_id: Optional[int] = None,
|
||||
) -> Dict[int, TrialData]:
|
||||
"""
|
||||
Gets TrialData for the given experiment_data and optionally additionally
|
||||
restricted by tunable_config_id.
|
||||
Gets TrialData for the given experiment_data and optionally additionally restricted
|
||||
by tunable_config_id.
|
||||
|
||||
Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
|
||||
"""
|
||||
from mlos_bench.storage.sql.trial_data import TrialSqlData # pylint: disable=import-outside-toplevel,cyclic-import
|
||||
# pylint: disable=import-outside-toplevel,cyclic-import
|
||||
from mlos_bench.storage.sql.trial_data import TrialSqlData
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Build up sql a statement for fetching trials.
|
||||
stmt = schema.trial.select().where(
|
||||
schema.trial.c.exp_id == experiment_id,
|
||||
).order_by(
|
||||
schema.trial.c.exp_id.asc(),
|
||||
schema.trial.c.trial_id.asc(),
|
||||
stmt = (
|
||||
schema.trial.select()
|
||||
.where(
|
||||
schema.trial.c.exp_id == experiment_id,
|
||||
)
|
||||
.order_by(
|
||||
schema.trial.c.exp_id.asc(),
|
||||
schema.trial.c.trial_id.asc(),
|
||||
)
|
||||
)
|
||||
# Optionally restrict to those using a particular tunable config.
|
||||
if tunable_config_id is not None:
|
||||
|
@ -58,27 +64,36 @@ def get_trials(
|
|||
|
||||
|
||||
def get_results_df(
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
experiment_id: str,
|
||||
tunable_config_id: Optional[int] = None) -> pandas.DataFrame:
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
experiment_id: str,
|
||||
tunable_config_id: Optional[int] = None,
|
||||
) -> pandas.DataFrame:
|
||||
"""
|
||||
Gets TrialData for the given experiment_data and optionally additionally
|
||||
restricted by tunable_config_id.
|
||||
Gets TrialData for the given experiment_data and optionally additionally restricted
|
||||
by tunable_config_id.
|
||||
|
||||
Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
with engine.connect() as conn:
|
||||
# Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config.
|
||||
tunable_config_group_id_stmt = schema.trial.select().with_only_columns(
|
||||
schema.trial.c.exp_id,
|
||||
schema.trial.c.config_id,
|
||||
func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'),
|
||||
).where(
|
||||
schema.trial.c.exp_id == experiment_id,
|
||||
).group_by(
|
||||
schema.trial.c.exp_id,
|
||||
schema.trial.c.config_id,
|
||||
tunable_config_group_id_stmt = (
|
||||
schema.trial.select()
|
||||
.with_only_columns(
|
||||
schema.trial.c.exp_id,
|
||||
schema.trial.c.config_id,
|
||||
func.min(schema.trial.c.trial_id)
|
||||
.cast(Integer)
|
||||
.label("tunable_config_trial_group_id"),
|
||||
)
|
||||
.where(
|
||||
schema.trial.c.exp_id == experiment_id,
|
||||
)
|
||||
.group_by(
|
||||
schema.trial.c.exp_id,
|
||||
schema.trial.c.config_id,
|
||||
)
|
||||
)
|
||||
# Optionally restrict to those using a particular tunable config.
|
||||
if tunable_config_id is not None:
|
||||
|
@ -88,18 +103,22 @@ def get_results_df(
|
|||
tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery()
|
||||
|
||||
# Get each trial's metadata.
|
||||
cur_trials_stmt = select(
|
||||
schema.trial,
|
||||
tunable_config_trial_group_id_subquery,
|
||||
).where(
|
||||
schema.trial.c.exp_id == experiment_id,
|
||||
and_(
|
||||
tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
|
||||
tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
|
||||
),
|
||||
).order_by(
|
||||
schema.trial.c.exp_id.asc(),
|
||||
schema.trial.c.trial_id.asc(),
|
||||
cur_trials_stmt = (
|
||||
select(
|
||||
schema.trial,
|
||||
tunable_config_trial_group_id_subquery,
|
||||
)
|
||||
.where(
|
||||
schema.trial.c.exp_id == experiment_id,
|
||||
and_(
|
||||
tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
|
||||
tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
|
||||
),
|
||||
)
|
||||
.order_by(
|
||||
schema.trial.c.exp_id.asc(),
|
||||
schema.trial.c.trial_id.asc(),
|
||||
)
|
||||
)
|
||||
# Optionally restrict to those using a particular tunable config.
|
||||
if tunable_config_id is not None:
|
||||
|
@ -108,39 +127,48 @@ def get_results_df(
|
|||
)
|
||||
cur_trials = conn.execute(cur_trials_stmt)
|
||||
trials_df = pandas.DataFrame(
|
||||
[(
|
||||
row.trial_id,
|
||||
utcify_timestamp(row.ts_start, origin="utc"),
|
||||
utcify_nullable_timestamp(row.ts_end, origin="utc"),
|
||||
row.config_id,
|
||||
row.tunable_config_trial_group_id,
|
||||
row.status,
|
||||
) for row in cur_trials.fetchall()],
|
||||
[
|
||||
(
|
||||
row.trial_id,
|
||||
utcify_timestamp(row.ts_start, origin="utc"),
|
||||
utcify_nullable_timestamp(row.ts_end, origin="utc"),
|
||||
row.config_id,
|
||||
row.tunable_config_trial_group_id,
|
||||
row.status,
|
||||
)
|
||||
for row in cur_trials.fetchall()
|
||||
],
|
||||
columns=[
|
||||
'trial_id',
|
||||
'ts_start',
|
||||
'ts_end',
|
||||
'tunable_config_id',
|
||||
'tunable_config_trial_group_id',
|
||||
'status',
|
||||
]
|
||||
"trial_id",
|
||||
"ts_start",
|
||||
"ts_end",
|
||||
"tunable_config_id",
|
||||
"tunable_config_trial_group_id",
|
||||
"status",
|
||||
],
|
||||
)
|
||||
|
||||
# Get each trial's config in wide format.
|
||||
configs_stmt = schema.trial.select().with_only_columns(
|
||||
schema.trial.c.trial_id,
|
||||
schema.trial.c.config_id,
|
||||
schema.config_param.c.param_id,
|
||||
schema.config_param.c.param_value,
|
||||
).where(
|
||||
schema.trial.c.exp_id == experiment_id,
|
||||
).join(
|
||||
schema.config_param,
|
||||
schema.config_param.c.config_id == schema.trial.c.config_id,
|
||||
isouter=True
|
||||
).order_by(
|
||||
schema.trial.c.trial_id,
|
||||
schema.config_param.c.param_id,
|
||||
configs_stmt = (
|
||||
schema.trial.select()
|
||||
.with_only_columns(
|
||||
schema.trial.c.trial_id,
|
||||
schema.trial.c.config_id,
|
||||
schema.config_param.c.param_id,
|
||||
schema.config_param.c.param_value,
|
||||
)
|
||||
.where(
|
||||
schema.trial.c.exp_id == experiment_id,
|
||||
)
|
||||
.join(
|
||||
schema.config_param,
|
||||
schema.config_param.c.config_id == schema.trial.c.config_id,
|
||||
isouter=True,
|
||||
)
|
||||
.order_by(
|
||||
schema.trial.c.trial_id,
|
||||
schema.config_param.c.param_id,
|
||||
)
|
||||
)
|
||||
if tunable_config_id is not None:
|
||||
configs_stmt = configs_stmt.where(
|
||||
|
@ -148,41 +176,75 @@ def get_results_df(
|
|||
)
|
||||
configs = conn.execute(configs_stmt)
|
||||
configs_df = pandas.DataFrame(
|
||||
[(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value)
|
||||
for row in configs.fetchall()],
|
||||
columns=['trial_id', 'tunable_config_id', 'param', 'value']
|
||||
[
|
||||
(
|
||||
row.trial_id,
|
||||
row.config_id,
|
||||
ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id,
|
||||
row.param_value,
|
||||
)
|
||||
for row in configs.fetchall()
|
||||
],
|
||||
columns=["trial_id", "tunable_config_id", "param", "value"],
|
||||
).pivot(
|
||||
index=["trial_id", "tunable_config_id"], columns="param", values="value",
|
||||
index=["trial_id", "tunable_config_id"],
|
||||
columns="param",
|
||||
values="value",
|
||||
)
|
||||
configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp)
|
||||
configs_df = configs_df.apply( # type: ignore[assignment] # (fp)
|
||||
pandas.to_numeric,
|
||||
errors="coerce",
|
||||
).fillna(configs_df)
|
||||
|
||||
# Get each trial's results in wide format.
|
||||
results_stmt = schema.trial_result.select().with_only_columns(
|
||||
schema.trial_result.c.trial_id,
|
||||
schema.trial_result.c.metric_id,
|
||||
schema.trial_result.c.metric_value,
|
||||
).where(
|
||||
schema.trial_result.c.exp_id == experiment_id,
|
||||
).order_by(
|
||||
schema.trial_result.c.trial_id,
|
||||
schema.trial_result.c.metric_id,
|
||||
results_stmt = (
|
||||
schema.trial_result.select()
|
||||
.with_only_columns(
|
||||
schema.trial_result.c.trial_id,
|
||||
schema.trial_result.c.metric_id,
|
||||
schema.trial_result.c.metric_value,
|
||||
)
|
||||
.where(
|
||||
schema.trial_result.c.exp_id == experiment_id,
|
||||
)
|
||||
.order_by(
|
||||
schema.trial_result.c.trial_id,
|
||||
schema.trial_result.c.metric_id,
|
||||
)
|
||||
)
|
||||
if tunable_config_id is not None:
|
||||
results_stmt = results_stmt.join(schema.trial, and_(
|
||||
schema.trial.c.exp_id == schema.trial_result.c.exp_id,
|
||||
schema.trial.c.trial_id == schema.trial_result.c.trial_id,
|
||||
schema.trial.c.config_id == tunable_config_id,
|
||||
))
|
||||
results_stmt = results_stmt.join(
|
||||
schema.trial,
|
||||
and_(
|
||||
schema.trial.c.exp_id == schema.trial_result.c.exp_id,
|
||||
schema.trial.c.trial_id == schema.trial_result.c.trial_id,
|
||||
schema.trial.c.config_id == tunable_config_id,
|
||||
),
|
||||
)
|
||||
results = conn.execute(results_stmt)
|
||||
results_df = pandas.DataFrame(
|
||||
[(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value)
|
||||
for row in results.fetchall()],
|
||||
columns=['trial_id', 'metric', 'value']
|
||||
[
|
||||
(
|
||||
row.trial_id,
|
||||
ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id,
|
||||
row.metric_value,
|
||||
)
|
||||
for row in results.fetchall()
|
||||
],
|
||||
columns=["trial_id", "metric", "value"],
|
||||
).pivot(
|
||||
index="trial_id", columns="metric", values="value",
|
||||
index="trial_id",
|
||||
columns="metric",
|
||||
values="value",
|
||||
)
|
||||
results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp)
|
||||
results_df = results_df.apply( # type: ignore[assignment] # (fp)
|
||||
pandas.to_numeric,
|
||||
errors="coerce",
|
||||
).fillna(results_df)
|
||||
|
||||
# Concat the trials, configs, and results.
|
||||
return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \
|
||||
.merge(results_df, on="trial_id", how="left")
|
||||
return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge(
|
||||
results_df,
|
||||
on="trial_id",
|
||||
how="left",
|
||||
)
|
||||
|
|
|
@ -2,43 +2,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Saving and restoring the benchmark data using SQLAlchemy.
|
||||
"""
|
||||
"""Saving and restoring the benchmark data using SQLAlchemy."""
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple, List, Literal, Dict, Iterator, Any
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple
|
||||
|
||||
from pytz import UTC
|
||||
|
||||
from sqlalchemy import Engine, Connection, CursorResult, Table, column, func, select
|
||||
from sqlalchemy import Connection, CursorResult, Engine, Table, column, func, select
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
from mlos_bench.storage.sql.schema import DbSchema
|
||||
from mlos_bench.storage.sql.trial import Trial
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.util import nullable, utcify_timestamp
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Experiment(Storage.Experiment):
|
||||
"""
|
||||
Logic for retrieving and storing the results of a single experiment.
|
||||
"""
|
||||
"""Logic for retrieving and storing the results of a single experiment."""
|
||||
|
||||
def __init__(self, *,
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
tunables: TunableGroups,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
root_env_config: str,
|
||||
description: str,
|
||||
opt_targets: Dict[str, Literal['min', 'max']]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
tunables: TunableGroups,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
root_env_config: str,
|
||||
description: str,
|
||||
opt_targets: Dict[str, Literal["min", "max"]],
|
||||
):
|
||||
super().__init__(
|
||||
tunables=tunables,
|
||||
experiment_id=experiment_id,
|
||||
|
@ -56,18 +54,22 @@ class Experiment(Storage.Experiment):
|
|||
# Get git info and the last trial ID for the experiment.
|
||||
# pylint: disable=not-callable
|
||||
exp_info = conn.execute(
|
||||
self._schema.experiment.select().with_only_columns(
|
||||
self._schema.experiment.select()
|
||||
.with_only_columns(
|
||||
self._schema.experiment.c.git_repo,
|
||||
self._schema.experiment.c.git_commit,
|
||||
self._schema.experiment.c.root_env_config,
|
||||
func.max(self._schema.trial.c.trial_id).label("trial_id"),
|
||||
).join(
|
||||
)
|
||||
.join(
|
||||
self._schema.trial,
|
||||
self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id,
|
||||
isouter=True
|
||||
).where(
|
||||
isouter=True,
|
||||
)
|
||||
.where(
|
||||
self._schema.experiment.c.exp_id == self._experiment_id,
|
||||
).group_by(
|
||||
)
|
||||
.group_by(
|
||||
self._schema.experiment.c.git_repo,
|
||||
self._schema.experiment.c.git_commit,
|
||||
self._schema.experiment.c.root_env_config,
|
||||
|
@ -76,33 +78,47 @@ class Experiment(Storage.Experiment):
|
|||
if exp_info is None:
|
||||
_LOG.info("Start new experiment: %s", self._experiment_id)
|
||||
# It's a new experiment: create a record for it in the database.
|
||||
conn.execute(self._schema.experiment.insert().values(
|
||||
exp_id=self._experiment_id,
|
||||
description=self._description,
|
||||
git_repo=self._git_repo,
|
||||
git_commit=self._git_commit,
|
||||
root_env_config=self._root_env_config,
|
||||
))
|
||||
conn.execute(self._schema.objectives.insert().values([
|
||||
{
|
||||
"exp_id": self._experiment_id,
|
||||
"optimization_target": opt_target,
|
||||
"optimization_direction": opt_dir,
|
||||
}
|
||||
for (opt_target, opt_dir) in self.opt_targets.items()
|
||||
]))
|
||||
conn.execute(
|
||||
self._schema.experiment.insert().values(
|
||||
exp_id=self._experiment_id,
|
||||
description=self._description,
|
||||
git_repo=self._git_repo,
|
||||
git_commit=self._git_commit,
|
||||
root_env_config=self._root_env_config,
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
self._schema.objectives.insert().values(
|
||||
[
|
||||
{
|
||||
"exp_id": self._experiment_id,
|
||||
"optimization_target": opt_target,
|
||||
"optimization_direction": opt_dir,
|
||||
}
|
||||
for (opt_target, opt_dir) in self.opt_targets.items()
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
if exp_info.trial_id is not None:
|
||||
self._trial_id = exp_info.trial_id + 1
|
||||
_LOG.info("Continue experiment: %s last trial: %s resume from: %d",
|
||||
self._experiment_id, exp_info.trial_id, self._trial_id)
|
||||
_LOG.info(
|
||||
"Continue experiment: %s last trial: %s resume from: %d",
|
||||
self._experiment_id,
|
||||
exp_info.trial_id,
|
||||
self._trial_id,
|
||||
)
|
||||
# TODO: Sanity check that certain critical configs (e.g.,
|
||||
# objectives) haven't changed to be incompatible such that a new
|
||||
# experiment should be started (possibly by prewarming with the
|
||||
# previous one).
|
||||
if exp_info.git_commit != self._git_commit:
|
||||
_LOG.warning("Experiment %s git expected: %s %s",
|
||||
self, exp_info.git_repo, exp_info.git_commit)
|
||||
_LOG.warning(
|
||||
"Experiment %s git expected: %s %s",
|
||||
self,
|
||||
exp_info.git_repo,
|
||||
exp_info.git_commit,
|
||||
)
|
||||
|
||||
def merge(self, experiment_ids: List[str]) -> None:
|
||||
_LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
|
||||
|
@ -115,33 +131,42 @@ class Experiment(Storage.Experiment):
|
|||
def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
|
||||
with self._engine.connect() as conn:
|
||||
cur_telemetry = conn.execute(
|
||||
self._schema.trial_telemetry.select().where(
|
||||
self._schema.trial_telemetry.select()
|
||||
.where(
|
||||
self._schema.trial_telemetry.c.exp_id == self._experiment_id,
|
||||
self._schema.trial_telemetry.c.trial_id == trial_id
|
||||
).order_by(
|
||||
self._schema.trial_telemetry.c.trial_id == trial_id,
|
||||
)
|
||||
.order_by(
|
||||
self._schema.trial_telemetry.c.ts,
|
||||
self._schema.trial_telemetry.c.metric_id,
|
||||
)
|
||||
)
|
||||
# Not all storage backends store the original zone info.
|
||||
# We try to ensure data is entered in UTC and augment it on return again here.
|
||||
return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
|
||||
for row in cur_telemetry.fetchall()]
|
||||
return [
|
||||
(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
|
||||
for row in cur_telemetry.fetchall()
|
||||
]
|
||||
|
||||
def load(self, last_trial_id: int = -1,
|
||||
) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
|
||||
def load(
|
||||
self,
|
||||
last_trial_id: int = -1,
|
||||
) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
|
||||
|
||||
with self._engine.connect() as conn:
|
||||
cur_trials = conn.execute(
|
||||
self._schema.trial.select().with_only_columns(
|
||||
self._schema.trial.select()
|
||||
.with_only_columns(
|
||||
self._schema.trial.c.trial_id,
|
||||
self._schema.trial.c.config_id,
|
||||
self._schema.trial.c.status,
|
||||
).where(
|
||||
)
|
||||
.where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
self._schema.trial.c.trial_id > last_trial_id,
|
||||
self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']),
|
||||
).order_by(
|
||||
self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]),
|
||||
)
|
||||
.order_by(
|
||||
self._schema.trial.c.trial_id.asc(),
|
||||
)
|
||||
)
|
||||
|
@ -155,12 +180,24 @@ class Experiment(Storage.Experiment):
|
|||
stat = Status[trial.status]
|
||||
status.append(stat)
|
||||
trial_ids.append(trial.trial_id)
|
||||
configs.append(self._get_key_val(
|
||||
conn, self._schema.config_param, "param", config_id=trial.config_id))
|
||||
configs.append(
|
||||
self._get_key_val(
|
||||
conn,
|
||||
self._schema.config_param,
|
||||
"param",
|
||||
config_id=trial.config_id,
|
||||
)
|
||||
)
|
||||
if stat.is_succeeded():
|
||||
scores.append(self._get_key_val(
|
||||
conn, self._schema.trial_result, "metric",
|
||||
exp_id=self._experiment_id, trial_id=trial.trial_id))
|
||||
scores.append(
|
||||
self._get_key_val(
|
||||
conn,
|
||||
self._schema.trial_result,
|
||||
"metric",
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=trial.trial_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
scores.append(None)
|
||||
|
||||
|
@ -170,55 +207,73 @@ class Experiment(Storage.Experiment):
|
|||
def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to retrieve key-value pairs from the database.
|
||||
|
||||
(E.g., configurations, results, and telemetry).
|
||||
"""
|
||||
cur_result: CursorResult[Tuple[str, Any]] = conn.execute(
|
||||
select(
|
||||
column(f"{field}_id"),
|
||||
column(f"{field}_value"),
|
||||
).select_from(table).where(
|
||||
*[column(key) == val for (key, val) in kwargs.items()]
|
||||
)
|
||||
.select_from(table)
|
||||
.where(*[column(key) == val for (key, val) in kwargs.items()])
|
||||
)
|
||||
# NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to
|
||||
# avoid naming conflicts.
|
||||
return dict(
|
||||
row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access
|
||||
)
|
||||
# NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts.
|
||||
return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access
|
||||
|
||||
@staticmethod
|
||||
def _save_params(conn: Connection, table: Table,
|
||||
params: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def _save_params(
|
||||
conn: Connection,
|
||||
table: Table,
|
||||
params: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if not params:
|
||||
return
|
||||
conn.execute(table.insert(), [
|
||||
{
|
||||
**kwargs,
|
||||
"param_id": key,
|
||||
"param_value": nullable(str, val)
|
||||
}
|
||||
for (key, val) in params.items()
|
||||
])
|
||||
conn.execute(
|
||||
table.insert(),
|
||||
[
|
||||
{**kwargs, "param_id": key, "param_value": nullable(str, val)}
|
||||
for (key, val) in params.items()
|
||||
],
|
||||
)
|
||||
|
||||
def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
|
||||
timestamp = utcify_timestamp(timestamp, origin="local")
|
||||
_LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
|
||||
if running:
|
||||
pending_status = ['PENDING', 'READY', 'RUNNING']
|
||||
pending_status = ["PENDING", "READY", "RUNNING"]
|
||||
else:
|
||||
pending_status = ['PENDING']
|
||||
pending_status = ["PENDING"]
|
||||
with self._engine.connect() as conn:
|
||||
cur_trials = conn.execute(self._schema.trial.select().where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
(self._schema.trial.c.ts_start.is_(None) |
|
||||
(self._schema.trial.c.ts_start <= timestamp)),
|
||||
self._schema.trial.c.ts_end.is_(None),
|
||||
self._schema.trial.c.status.in_(pending_status),
|
||||
))
|
||||
cur_trials = conn.execute(
|
||||
self._schema.trial.select().where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
(
|
||||
self._schema.trial.c.ts_start.is_(None)
|
||||
| (self._schema.trial.c.ts_start <= timestamp)
|
||||
),
|
||||
self._schema.trial.c.ts_end.is_(None),
|
||||
self._schema.trial.c.status.in_(pending_status),
|
||||
)
|
||||
)
|
||||
for trial in cur_trials.fetchall():
|
||||
tunables = self._get_key_val(
|
||||
conn, self._schema.config_param, "param",
|
||||
config_id=trial.config_id)
|
||||
conn,
|
||||
self._schema.config_param,
|
||||
"param",
|
||||
config_id=trial.config_id,
|
||||
)
|
||||
config = self._get_key_val(
|
||||
conn, self._schema.trial_param, "param",
|
||||
exp_id=self._experiment_id, trial_id=trial.trial_id)
|
||||
conn,
|
||||
self._schema.trial_param,
|
||||
"param",
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=trial.trial_id,
|
||||
)
|
||||
yield Trial(
|
||||
engine=self._engine,
|
||||
schema=self._schema,
|
||||
|
@ -233,47 +288,63 @@ class Experiment(Storage.Experiment):
|
|||
|
||||
def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:
|
||||
"""
|
||||
Get the config ID for the given tunables. If the config does not exist,
|
||||
create a new record for it.
|
||||
Get the config ID for the given tunables.
|
||||
|
||||
If the config does not exist, create a new record for it.
|
||||
"""
|
||||
config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest()
|
||||
cur_config = conn.execute(self._schema.config.select().where(
|
||||
self._schema.config.c.config_hash == config_hash
|
||||
)).fetchone()
|
||||
config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest()
|
||||
cur_config = conn.execute(
|
||||
self._schema.config.select().where(self._schema.config.c.config_hash == config_hash)
|
||||
).fetchone()
|
||||
if cur_config is not None:
|
||||
return int(cur_config.config_id) # mypy doesn't know it's always int
|
||||
# Config not found, create a new one:
|
||||
config_id: int = conn.execute(self._schema.config.insert().values(
|
||||
config_hash=config_hash)).inserted_primary_key[0]
|
||||
config_id: int = conn.execute(
|
||||
self._schema.config.insert().values(config_hash=config_hash)
|
||||
).inserted_primary_key[0]
|
||||
self._save_params(
|
||||
conn, self._schema.config_param,
|
||||
conn,
|
||||
self._schema.config_param,
|
||||
{tunable.name: tunable.value for (tunable, _group) in tunables},
|
||||
config_id=config_id)
|
||||
config_id=config_id,
|
||||
)
|
||||
return config_id
|
||||
|
||||
def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
|
||||
config: Optional[Dict[str, Any]] = None) -> Storage.Trial:
|
||||
def new_trial(
|
||||
self,
|
||||
tunables: TunableGroups,
|
||||
ts_start: Optional[datetime] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> Storage.Trial:
|
||||
# MySQL can round microseconds into the future causing scheduler to skip trials.
|
||||
# Truncate microseconds to avoid this issue.
|
||||
ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(microsecond=0)
|
||||
ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(
|
||||
microsecond=0
|
||||
)
|
||||
_LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
|
||||
with self._engine.begin() as conn:
|
||||
try:
|
||||
config_id = self._get_config_id(conn, tunables)
|
||||
conn.execute(self._schema.trial.insert().values(
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=self._trial_id,
|
||||
config_id=config_id,
|
||||
ts_start=ts_start,
|
||||
status='PENDING',
|
||||
))
|
||||
conn.execute(
|
||||
self._schema.trial.insert().values(
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=self._trial_id,
|
||||
config_id=config_id,
|
||||
ts_start=ts_start,
|
||||
status="PENDING",
|
||||
)
|
||||
)
|
||||
|
||||
# Note: config here is the framework config, not the target
|
||||
# environment config (i.e., tunables).
|
||||
if config is not None:
|
||||
self._save_params(
|
||||
conn, self._schema.trial_param, config,
|
||||
exp_id=self._experiment_id, trial_id=self._trial_id)
|
||||
conn,
|
||||
self._schema.trial_param,
|
||||
config,
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=self._trial_id,
|
||||
)
|
||||
|
||||
trial = Trial(
|
||||
engine=self._engine,
|
||||
|
|
|
@ -2,12 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
An interface to access the experiment benchmark data stored in SQL DB.
|
||||
"""
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
"""An interface to access the experiment benchmark data stored in SQL DB."""
|
||||
import logging
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
import pandas
|
||||
from sqlalchemy import Engine, Integer, String, func
|
||||
|
@ -15,11 +12,15 @@ from sqlalchemy import Engine, Integer, String, func
|
|||
from mlos_bench.storage.base_experiment_data import ExperimentData
|
||||
from mlos_bench.storage.base_trial_data import TrialData
|
||||
from mlos_bench.storage.base_tunable_config_data import TunableConfigData
|
||||
from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
|
||||
from mlos_bench.storage.base_tunable_config_trial_group_data import (
|
||||
TunableConfigTrialGroupData,
|
||||
)
|
||||
from mlos_bench.storage.sql import common
|
||||
from mlos_bench.storage.sql.schema import DbSchema
|
||||
from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
|
||||
from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData
|
||||
from mlos_bench.storage.sql.tunable_config_trial_group_data import (
|
||||
TunableConfigTrialGroupSqlData,
|
||||
)
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
@ -32,14 +33,17 @@ class ExperimentSqlData(ExperimentData):
|
|||
scripts and mlos_bench configuration files.
|
||||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
experiment_id: str,
|
||||
description: str,
|
||||
root_env_config: str,
|
||||
git_repo: str,
|
||||
git_commit: str):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
experiment_id: str,
|
||||
description: str,
|
||||
root_env_config: str,
|
||||
git_repo: str,
|
||||
git_commit: str,
|
||||
):
|
||||
super().__init__(
|
||||
experiment_id=experiment_id,
|
||||
description=description,
|
||||
|
@ -54,9 +58,11 @@ class ExperimentSqlData(ExperimentData):
|
|||
def objectives(self) -> Dict[str, Literal["min", "max"]]:
|
||||
with self._engine.connect() as conn:
|
||||
objectives_db_data = conn.execute(
|
||||
self._schema.objectives.select().where(
|
||||
self._schema.objectives.select()
|
||||
.where(
|
||||
self._schema.objectives.c.exp_id == self._experiment_id,
|
||||
).order_by(
|
||||
)
|
||||
.order_by(
|
||||
self._schema.objectives.c.weight.desc(),
|
||||
self._schema.objectives.c.optimization_target.asc(),
|
||||
)
|
||||
|
@ -66,7 +72,8 @@ class ExperimentSqlData(ExperimentData):
|
|||
for objective in objectives_db_data.fetchall()
|
||||
}
|
||||
|
||||
# TODO: provide a way to get individual data to avoid repeated bulk fetches where only small amounts of data is accessed.
|
||||
# TODO: provide a way to get individual data to avoid repeated bulk fetches
|
||||
# where only small amounts of data is accessed.
|
||||
# Or else make the TrialData object lazily populate.
|
||||
|
||||
@property
|
||||
|
@ -77,13 +84,17 @@ class ExperimentSqlData(ExperimentData):
|
|||
def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]:
|
||||
with self._engine.connect() as conn:
|
||||
tunable_config_trial_groups = conn.execute(
|
||||
self._schema.trial.select().with_only_columns(
|
||||
self._schema.trial.select()
|
||||
.with_only_columns(
|
||||
self._schema.trial.c.config_id,
|
||||
func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
|
||||
'tunable_config_trial_group_id'),
|
||||
).where(
|
||||
func.min(self._schema.trial.c.trial_id)
|
||||
.cast(Integer)
|
||||
.label("tunable_config_trial_group_id"), # pylint: disable=not-callable
|
||||
)
|
||||
.where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
).group_by(
|
||||
)
|
||||
.group_by(
|
||||
self._schema.trial.c.exp_id,
|
||||
self._schema.trial.c.config_id,
|
||||
)
|
||||
|
@ -94,7 +105,7 @@ class ExperimentSqlData(ExperimentData):
|
|||
schema=self._schema,
|
||||
experiment_id=self._experiment_id,
|
||||
tunable_config_id=tunable_config_trial_group.config_id,
|
||||
tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id,
|
||||
tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, # pylint:disable=line-too-long # noqa
|
||||
)
|
||||
for tunable_config_trial_group in tunable_config_trial_groups.fetchall()
|
||||
}
|
||||
|
@ -103,11 +114,14 @@ class ExperimentSqlData(ExperimentData):
|
|||
def tunable_configs(self) -> Dict[int, TunableConfigData]:
|
||||
with self._engine.connect() as conn:
|
||||
tunable_configs = conn.execute(
|
||||
self._schema.trial.select().with_only_columns(
|
||||
self._schema.trial.c.config_id.cast(Integer).label('config_id'),
|
||||
).where(
|
||||
self._schema.trial.select()
|
||||
.with_only_columns(
|
||||
self._schema.trial.c.config_id.cast(Integer).label("config_id"),
|
||||
)
|
||||
.where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
).group_by(
|
||||
)
|
||||
.group_by(
|
||||
self._schema.trial.c.exp_id,
|
||||
self._schema.trial.c.config_id,
|
||||
)
|
||||
|
@ -124,7 +138,8 @@ class ExperimentSqlData(ExperimentData):
|
|||
@property
|
||||
def default_tunable_config_id(self) -> Optional[int]:
|
||||
"""
|
||||
Retrieves the (tunable) config id for the default tunable values for this experiment.
|
||||
Retrieves the (tunable) config id for the default tunable values for this
|
||||
experiment.
|
||||
|
||||
Note: this is by *default* the first trial executed for this experiment.
|
||||
However, it is currently possible that the user changed the tunables config
|
||||
|
@ -136,20 +151,28 @@ class ExperimentSqlData(ExperimentData):
|
|||
"""
|
||||
with self._engine.connect() as conn:
|
||||
query_results = conn.execute(
|
||||
self._schema.trial.select().with_only_columns(
|
||||
self._schema.trial.c.config_id.cast(Integer).label('config_id'),
|
||||
).where(
|
||||
self._schema.trial.select()
|
||||
.with_only_columns(
|
||||
self._schema.trial.c.config_id.cast(Integer).label("config_id"),
|
||||
)
|
||||
.where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
self._schema.trial.c.trial_id.in_(
|
||||
self._schema.trial_param.select().with_only_columns(
|
||||
func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
|
||||
"first_trial_id_with_defaults"),
|
||||
).where(
|
||||
self._schema.trial_param.select()
|
||||
.with_only_columns(
|
||||
func.min(self._schema.trial_param.c.trial_id)
|
||||
.cast(Integer)
|
||||
.label("first_trial_id_with_defaults"), # pylint: disable=not-callable
|
||||
)
|
||||
.where(
|
||||
self._schema.trial_param.c.exp_id == self._experiment_id,
|
||||
self._schema.trial_param.c.param_id == "is_defaults",
|
||||
func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]),
|
||||
).scalar_subquery()
|
||||
)
|
||||
func.lower(self._schema.trial_param.c.param_value, type_=String).in_(
|
||||
["1", "true"]
|
||||
),
|
||||
)
|
||||
.scalar_subquery()
|
||||
),
|
||||
)
|
||||
)
|
||||
min_default_trial_row = query_results.fetchone()
|
||||
|
@ -158,17 +181,24 @@ class ExperimentSqlData(ExperimentData):
|
|||
return min_default_trial_row._tuple()[0]
|
||||
# fallback logic - assume minimum trial_id for experiment
|
||||
query_results = conn.execute(
|
||||
self._schema.trial.select().with_only_columns(
|
||||
self._schema.trial.c.config_id.cast(Integer).label('config_id'),
|
||||
).where(
|
||||
self._schema.trial.select()
|
||||
.with_only_columns(
|
||||
self._schema.trial.c.config_id.cast(Integer).label("config_id"),
|
||||
)
|
||||
.where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
self._schema.trial.c.trial_id.in_(
|
||||
self._schema.trial.select().with_only_columns(
|
||||
func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"),
|
||||
).where(
|
||||
self._schema.trial.select()
|
||||
.with_only_columns(
|
||||
func.min(self._schema.trial.c.trial_id)
|
||||
.cast(Integer)
|
||||
.label("first_trial_id"),
|
||||
)
|
||||
.where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
).scalar_subquery()
|
||||
)
|
||||
)
|
||||
.scalar_subquery()
|
||||
),
|
||||
)
|
||||
)
|
||||
min_trial_row = query_results.fetchone()
|
||||
|
|
|
@ -2,17 +2,26 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
DB schema definition.
|
||||
"""
|
||||
"""DB schema definition."""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
from typing import Any, List
|
||||
|
||||
from sqlalchemy import (
|
||||
Engine, MetaData, Dialect, create_mock_engine,
|
||||
Table, Column, Sequence, Integer, Float, String, DateTime,
|
||||
PrimaryKeyConstraint, ForeignKeyConstraint, UniqueConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
Dialect,
|
||||
Engine,
|
||||
Float,
|
||||
ForeignKeyConstraint,
|
||||
Integer,
|
||||
MetaData,
|
||||
PrimaryKeyConstraint,
|
||||
Sequence,
|
||||
String,
|
||||
Table,
|
||||
UniqueConstraint,
|
||||
create_mock_engine,
|
||||
)
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
@ -38,9 +47,7 @@ class _DDL:
|
|||
|
||||
|
||||
class DbSchema:
|
||||
"""
|
||||
A class to define and create the DB schema.
|
||||
"""
|
||||
"""A class to define and create the DB schema."""
|
||||
|
||||
# This class is internal to SqlStorage and is mostly a struct
|
||||
# for all DB tables, so it's ok to disable the warnings.
|
||||
|
@ -53,9 +60,7 @@ class DbSchema:
|
|||
_STATUS_LEN = 16
|
||||
|
||||
def __init__(self, engine: Engine):
|
||||
"""
|
||||
Declare the SQLAlchemy schema for the database.
|
||||
"""
|
||||
"""Declare the SQLAlchemy schema for the database."""
|
||||
_LOG.info("Create the DB schema for: %s", engine)
|
||||
self._engine = engine
|
||||
# TODO: bind for automatic schema updates? (#649)
|
||||
|
@ -69,7 +74,6 @@ class DbSchema:
|
|||
Column("root_env_config", String(1024), nullable=False),
|
||||
Column("git_repo", String(1024), nullable=False),
|
||||
Column("git_commit", String(40), nullable=False),
|
||||
|
||||
PrimaryKeyConstraint("exp_id"),
|
||||
)
|
||||
|
||||
|
@ -84,20 +88,29 @@ class DbSchema:
|
|||
# Will need to adjust the insert and return values to support this
|
||||
# eventually.
|
||||
Column("weight", Float, nullable=True),
|
||||
|
||||
PrimaryKeyConstraint("exp_id", "optimization_target"),
|
||||
ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
|
||||
)
|
||||
|
||||
# A workaround for SQLAlchemy issue with autoincrement in DuckDB:
|
||||
if engine.dialect.name == "duckdb":
|
||||
seq_config_id = Sequence('seq_config_id')
|
||||
col_config_id = Column("config_id", Integer, seq_config_id,
|
||||
server_default=seq_config_id.next_value(),
|
||||
nullable=False, primary_key=True)
|
||||
seq_config_id = Sequence("seq_config_id")
|
||||
col_config_id = Column(
|
||||
"config_id",
|
||||
Integer,
|
||||
seq_config_id,
|
||||
server_default=seq_config_id.next_value(),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
)
|
||||
else:
|
||||
col_config_id = Column("config_id", Integer, nullable=False,
|
||||
primary_key=True, autoincrement=True)
|
||||
col_config_id = Column(
|
||||
"config_id",
|
||||
Integer,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
)
|
||||
|
||||
self.config = Table(
|
||||
"config",
|
||||
|
@ -116,7 +129,6 @@ class DbSchema:
|
|||
Column("ts_end", DateTime),
|
||||
# Should match the text IDs of `mlos_bench.environments.Status` enum:
|
||||
Column("status", String(self._STATUS_LEN), nullable=False),
|
||||
|
||||
PrimaryKeyConstraint("exp_id", "trial_id"),
|
||||
ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
|
||||
ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
|
||||
|
@ -130,7 +142,6 @@ class DbSchema:
|
|||
Column("config_id", Integer, nullable=False),
|
||||
Column("param_id", String(self._ID_LEN), nullable=False),
|
||||
Column("param_value", String(self._PARAM_VALUE_LEN)),
|
||||
|
||||
PrimaryKeyConstraint("config_id", "param_id"),
|
||||
ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
|
||||
)
|
||||
|
@ -144,10 +155,11 @@ class DbSchema:
|
|||
Column("trial_id", Integer, nullable=False),
|
||||
Column("param_id", String(self._ID_LEN), nullable=False),
|
||||
Column("param_value", String(self._PARAM_VALUE_LEN)),
|
||||
|
||||
PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
|
||||
ForeignKeyConstraint(["exp_id", "trial_id"],
|
||||
[self.trial.c.exp_id, self.trial.c.trial_id]),
|
||||
ForeignKeyConstraint(
|
||||
["exp_id", "trial_id"],
|
||||
[self.trial.c.exp_id, self.trial.c.trial_id],
|
||||
),
|
||||
)
|
||||
|
||||
self.trial_status = Table(
|
||||
|
@ -157,10 +169,11 @@ class DbSchema:
|
|||
Column("trial_id", Integer, nullable=False),
|
||||
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
|
||||
Column("status", String(self._STATUS_LEN), nullable=False),
|
||||
|
||||
UniqueConstraint("exp_id", "trial_id", "ts"),
|
||||
ForeignKeyConstraint(["exp_id", "trial_id"],
|
||||
[self.trial.c.exp_id, self.trial.c.trial_id]),
|
||||
ForeignKeyConstraint(
|
||||
["exp_id", "trial_id"],
|
||||
[self.trial.c.exp_id, self.trial.c.trial_id],
|
||||
),
|
||||
)
|
||||
|
||||
self.trial_result = Table(
|
||||
|
@ -170,10 +183,11 @@ class DbSchema:
|
|||
Column("trial_id", Integer, nullable=False),
|
||||
Column("metric_id", String(self._ID_LEN), nullable=False),
|
||||
Column("metric_value", String(self._METRIC_VALUE_LEN)),
|
||||
|
||||
PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
|
||||
ForeignKeyConstraint(["exp_id", "trial_id"],
|
||||
[self.trial.c.exp_id, self.trial.c.trial_id]),
|
||||
ForeignKeyConstraint(
|
||||
["exp_id", "trial_id"],
|
||||
[self.trial.c.exp_id, self.trial.c.trial_id],
|
||||
),
|
||||
)
|
||||
|
||||
self.trial_telemetry = Table(
|
||||
|
@ -184,26 +198,25 @@ class DbSchema:
|
|||
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
|
||||
Column("metric_id", String(self._ID_LEN), nullable=False),
|
||||
Column("metric_value", String(self._METRIC_VALUE_LEN)),
|
||||
|
||||
UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
|
||||
ForeignKeyConstraint(["exp_id", "trial_id"],
|
||||
[self.trial.c.exp_id, self.trial.c.trial_id]),
|
||||
ForeignKeyConstraint(
|
||||
["exp_id", "trial_id"],
|
||||
[self.trial.c.exp_id, self.trial.c.trial_id],
|
||||
),
|
||||
)
|
||||
|
||||
_LOG.debug("Schema: %s", self._meta)
|
||||
|
||||
def create(self) -> 'DbSchema':
|
||||
"""
|
||||
Create the DB schema.
|
||||
"""
|
||||
def create(self) -> "DbSchema":
|
||||
"""Create the DB schema."""
|
||||
_LOG.info("Create the DB schema")
|
||||
self._meta.create_all(self._engine)
|
||||
return self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Produce a string with all SQL statements required to create the schema
|
||||
from scratch in current SQL dialect.
|
||||
Produce a string with all SQL statements required to create the schema from
|
||||
scratch in current SQL dialect.
|
||||
|
||||
That is, return a collection of CREATE TABLE statements and such.
|
||||
NOTE: this method is quite heavy! We use it only once at startup
|
||||
|
|
|
@ -2,35 +2,33 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Saving and restoring the benchmark data in SQL database.
|
||||
"""
|
||||
"""Saving and restoring the benchmark data in SQL database."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
from sqlalchemy import URL, create_engine
|
||||
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.services.base_service import Service
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
from mlos_bench.storage.sql.schema import DbSchema
|
||||
from mlos_bench.storage.sql.experiment import Experiment
|
||||
from mlos_bench.storage.base_experiment_data import ExperimentData
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
from mlos_bench.storage.sql.experiment import Experiment
|
||||
from mlos_bench.storage.sql.experiment_data import ExperimentSqlData
|
||||
from mlos_bench.storage.sql.schema import DbSchema
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SqlStorage(Storage):
|
||||
"""
|
||||
An implementation of the Storage interface using SQLAlchemy backend.
|
||||
"""
|
||||
"""An implementation of the Storage interface using SQLAlchemy backend."""
|
||||
|
||||
def __init__(self,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
global_config: Optional[dict] = None,
|
||||
service: Optional[Service] = None,
|
||||
):
|
||||
super().__init__(config, global_config, service)
|
||||
lazy_schema_create = self._config.pop("lazy_schema_create", False)
|
||||
self._log_sql = self._config.pop("log_sql", False)
|
||||
|
@ -47,7 +45,7 @@ class SqlStorage(Storage):
|
|||
@property
|
||||
def _schema(self) -> DbSchema:
|
||||
"""Lazily create schema upon first access."""
|
||||
if not hasattr(self, '_db_schema'):
|
||||
if not hasattr(self, "_db_schema"):
|
||||
self._db_schema = DbSchema(self._engine).create()
|
||||
if _LOG.isEnabledFor(logging.DEBUG):
|
||||
_LOG.debug("DDL statements:\n%s", self._schema)
|
||||
|
@ -56,13 +54,16 @@ class SqlStorage(Storage):
|
|||
def __repr__(self) -> str:
|
||||
return self._repr
|
||||
|
||||
def experiment(self, *,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
root_env_config: str,
|
||||
description: str,
|
||||
tunables: TunableGroups,
|
||||
opt_targets: Dict[str, Literal['min', 'max']]) -> Storage.Experiment:
|
||||
def experiment(
|
||||
self,
|
||||
*,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
root_env_config: str,
|
||||
description: str,
|
||||
tunables: TunableGroups,
|
||||
opt_targets: Dict[str, Literal["min", "max"]],
|
||||
) -> Storage.Experiment:
|
||||
return Experiment(
|
||||
engine=self._engine,
|
||||
schema=self._schema,
|
||||
|
|
|
@ -2,40 +2,39 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
"""
|
||||
Saving and updating benchmark data using SQLAlchemy backend.
|
||||
"""
|
||||
"""Saving and updating benchmark data using SQLAlchemy backend."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional, Tuple, Dict, Any
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from sqlalchemy import Engine, Connection
|
||||
from sqlalchemy import Connection, Engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from mlos_bench.environments.status import Status
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.storage.base_storage import Storage
|
||||
from mlos_bench.storage.sql.schema import DbSchema
|
||||
from mlos_bench.tunables.tunable_groups import TunableGroups
|
||||
from mlos_bench.util import nullable, utcify_timestamp
|
||||
|
||||
_LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trial(Storage.Trial):
|
||||
"""
|
||||
Store the results of a single run of the experiment in SQL database.
|
||||
"""
|
||||
"""Store the results of a single run of the experiment in SQL database."""
|
||||
|
||||
def __init__(self, *,
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
tunables: TunableGroups,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
config_id: int,
|
||||
opt_targets: Dict[str, Literal['min', 'max']],
|
||||
config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
engine: Engine,
|
||||
schema: DbSchema,
|
||||
tunables: TunableGroups,
|
||||
experiment_id: str,
|
||||
trial_id: int,
|
||||
config_id: int,
|
||||
opt_targets: Dict[str, Literal["min", "max"]],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
tunables=tunables,
|
||||
experiment_id=experiment_id,
|
||||
|
@ -47,9 +46,12 @@ class Trial(Storage.Trial):
|
|||
self._engine = engine
|
||||
self._schema = schema
|
||||
|
||||
def update(self, status: Status, timestamp: datetime,
|
||||
metrics: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
def update(
|
||||
self,
|
||||
status: Status,
|
||||
timestamp: datetime,
|
||||
metrics: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# Make sure to convert the timestamp to UTC before storing it in the database.
|
||||
timestamp = utcify_timestamp(timestamp, origin="local")
|
||||
metrics = super().update(status, timestamp, metrics)
|
||||
|
@ -59,13 +61,16 @@ class Trial(Storage.Trial):
|
|||
if status.is_completed():
|
||||
# Final update of the status and ts_end:
|
||||
cur_status = conn.execute(
|
||||
self._schema.trial.update().where(
|
||||
self._schema.trial.update()
|
||||
.where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
self._schema.trial.c.trial_id == self._trial_id,
|
||||
self._schema.trial.c.ts_end.is_(None),
|
||||
self._schema.trial.c.status.notin_(
|
||||
['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']),
|
||||
).values(
|
||||
["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
|
||||
),
|
||||
)
|
||||
.values(
|
||||
status=status.name,
|
||||
ts_end=timestamp,
|
||||
)
|
||||
|
@ -73,29 +78,37 @@ class Trial(Storage.Trial):
|
|||
if cur_status.rowcount not in {1, -1}:
|
||||
_LOG.warning("Trial %s :: update failed: %s", self, status)
|
||||
raise RuntimeError(
|
||||
f"Failed to update the status of the trial {self} to {status}." +
|
||||
f" ({cur_status.rowcount} rows)")
|
||||
f"Failed to update the status of the trial {self} to {status}. "
|
||||
f"({cur_status.rowcount} rows)"
|
||||
)
|
||||
if metrics:
|
||||
conn.execute(self._schema.trial_result.insert().values([
|
||||
{
|
||||
"exp_id": self._experiment_id,
|
||||
"trial_id": self._trial_id,
|
||||
"metric_id": key,
|
||||
"metric_value": nullable(str, val),
|
||||
}
|
||||
for (key, val) in metrics.items()
|
||||
]))
|
||||
conn.execute(
|
||||
self._schema.trial_result.insert().values(
|
||||
[
|
||||
{
|
||||
"exp_id": self._experiment_id,
|
||||
"trial_id": self._trial_id,
|
||||
"metric_id": key,
|
||||
"metric_value": nullable(str, val),
|
||||
}
|
||||
for (key, val) in metrics.items()
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Update of the status and ts_start when starting the trial:
|
||||
assert metrics is None, f"Unexpected metrics for status: {status}"
|
||||
cur_status = conn.execute(
|
||||
self._schema.trial.update().where(
|
||||
self._schema.trial.update()
|
||||
.where(
|
||||
self._schema.trial.c.exp_id == self._experiment_id,
|
||||
self._schema.trial.c.trial_id == self._trial_id,
|
||||
self._schema.trial.c.ts_end.is_(None),
|
||||
self._schema.trial.c.status.notin_(
|
||||
['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']),
|
||||
).values(
|
||||
["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
|
||||
),
|
||||
)
|
||||
.values(
|
||||
status=status.name,
|
||||
ts_start=timestamp,
|
||||
)
|
||||
|
@ -108,8 +121,12 @@ class Trial(Storage.Trial):
|
|||
raise
|
||||
return metrics
|
||||
|
||||
def update_telemetry(self, status: Status, timestamp: datetime,
|
||||
metrics: List[Tuple[datetime, str, Any]]) -> None:
|
||||
def update_telemetry(
|
||||
self,
|
||||
status: Status,
|
||||
timestamp: datetime,
|
||||
metrics: List[Tuple[datetime, str, Any]],
|
||||
) -> None:
|
||||
super().update_telemetry(status, timestamp, metrics)
|
||||
# Make sure to convert the timestamp to UTC before storing it in the database.
|
||||
timestamp = utcify_timestamp(timestamp, origin="local")
|
||||
|
@ -120,33 +137,42 @@ class Trial(Storage.Trial):
|
|||
# See Also: comments in <https://github.com/microsoft/MLOS/pull/466>
|
||||
with self._engine.begin() as conn:
|
||||
self._update_status(conn, status, timestamp)
|
||||
for (metric_ts, key, val) in metrics:
|
||||
for metric_ts, key, val in metrics:
|
||||
with self._engine.begin() as conn:
|
||||
try:
|
||||
conn.execute(self._schema.trial_telemetry.insert().values(
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=self._trial_id,
|
||||
ts=metric_ts,
|
||||
metric_id=key,
|
||||
metric_value=nullable(str, val),
|
||||
))
|
||||
conn.execute(
|
||||
self._schema.trial_telemetry.insert().values(
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=self._trial_id,
|
||||
ts=metric_ts,
|
||||
metric_id=key,
|
||||
metric_value=nullable(str, val),
|
||||
)
|
||||
)
|
||||
except IntegrityError as ex:
|
||||
_LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
|
||||
|
||||
def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
|
||||
"""
|
||||
Insert a new status record into the database.
|
||||
|
||||
This call is idempotent.
|
||||
"""
|
||||
# Make sure to convert the timestamp to UTC before storing it in the database.
|
||||
timestamp = utcify_timestamp(timestamp, origin="local")
|
||||
try:
|
||||
conn.execute(self._schema.trial_status.insert().values(
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=self._trial_id,
|
||||
ts=timestamp,
|
||||
status=status.name,
|
||||
))
|
||||
conn.execute(
|
||||
self._schema.trial_status.insert().values(
|
||||
exp_id=self._experiment_id,
|
||||
trial_id=self._trial_id,
|
||||
ts=timestamp,
|
||||
status=status.name,
|
||||
)
|
||||
)
|
||||
except IntegrityError as ex:
|
||||
_LOG.warning("Status with that timestamp already exists: %s %s :: %s",
|
||||
self, timestamp, ex)
|
||||
_LOG.warning(
|
||||
"Status with that timestamp already exists: %s %s :: %s",
|
||||
self,
|
||||
timestamp,
|
||||
ex,
|
||||
)
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче