Enable black, isort, and doc formatters and checks (#774)

Follow on work to #766.

This enabled both formatters and applies their changes to the repo.

Additionally, since `black` does not make changes to comments nor
docstrings, we also enable `docformatter` to reformat docstrings which
better aligns with `pydocstyle` rules as well.
Without this additional change (and some manual fixups), `pycodestyle`
and `pylint` would still complain about line lengths, for instance.

Finally, we make a minor adjustment to the max line length setting it to
99 (which is also accepted and mentioned in pep8) instead of 88 to avoid
some comment (especially linter overrides) wrapping.
This commit is contained in:
Brian Kroth 2024-07-12 14:56:14 -05:00 коммит произвёл GitHub
Родитель fd9c8f9935
Коммит e40ac28317
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
268 изменённых файлов: 9101 добавлений и 7298 удалений

Просмотреть файл

@ -45,7 +45,6 @@
// Adjust the python interpreter path to point to the conda environment
"python.defaultInterpreterPath": "/opt/conda/envs/mlos/bin/python",
"python.testing.pytestPath": "/opt/conda/envs/mlos/bin/pytest",
"python.formatting.autopep8Path": "/opt/conda/envs/mlos/bin/autopep8",
"python.linting.pylintPath": "/opt/conda/envs/mlos/bin/pylint",
"pylint.path": [
"/opt/conda/envs/mlos/bin/pylint"
@ -71,9 +70,8 @@
"lextudio.restructuredtext",
"matangover.mypy",
"ms-azuretools.vscode-docker",
// TODO: Enable additional formatter extensions:
//"ms-python.black-formatter",
//"ms-python.isort",
"ms-python.black-formatter",
"ms-python.isort",
"ms-python.pylint",
"ms-python.python",
"ms-python.vscode-pylance",

Просмотреть файл

@ -12,6 +12,9 @@ charset = utf-8
# Note: this is not currently supported by all editors or their editorconfig plugins.
max_line_length = 132
[*.py]
max_line_length = 99
# Makefiles need tab indentation
[{Makefile,*.mk}]
indent_style = tab

3
.gitignore поставляемый
Просмотреть файл

@ -1,3 +1,6 @@
# Ignore git directory (ripgrep)
.git/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

Просмотреть файл

@ -35,7 +35,7 @@ load-plugins=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=132
max-line-length=99
[MESSAGE CONTROL]
disable=
@ -48,5 +48,5 @@ disable=
missing-raises-doc
[STRING]
#check-quote-consistency=yes
check-quote-consistency=yes
check-str-concat-over-line-jumps=yes

5
.vscode/extensions.json поставляемый
Просмотреть файл

@ -14,9 +14,8 @@
"lextudio.restructuredtext",
"matangover.mypy",
"ms-azuretools.vscode-docker",
// TODO: Enable additional formatter extensions:
//"ms-python.black-formatter",
//"ms-python.isort",
"ms-python.black-formatter",
"ms-python.isort",
"ms-python.pylint",
"ms-python.python",
"ms-python.vscode-pylance",

10
.vscode/settings.json поставляемый
Просмотреть файл

@ -125,14 +125,10 @@
],
"esbonio.sphinx.confDir": "${workspaceFolder}/doc/source",
"esbonio.sphinx.buildDir": "${workspaceFolder}/doc/build/",
"autopep8.args": [
"--experimental"
],
"[python]": {
// TODO: Enable black formatter
//"editor.defaultFormatter": "ms-python.black-formatter",
//"editor.formatOnSave": true,
//"editor.formatOnSaveMode": "modifications"
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.formatOnSaveMode": "modifications"
},
// See Also .vscode/launch.json for environment variable args to pytest during debug sessions.
// For the rest, see setup.cfg

103
Makefile
Просмотреть файл

@ -34,7 +34,7 @@ conda-env: build/conda-env.${CONDA_ENV_NAME}.build-stamp
MLOS_CORE_CONF_FILES := mlos_core/pyproject.toml mlos_core/setup.py mlos_core/MANIFEST.in
MLOS_BENCH_CONF_FILES := mlos_bench/pyproject.toml mlos_bench/setup.py mlos_bench/MANIFEST.in
MLOS_VIZ_CONF_FILES := mlos_viz/pyproject.toml mlos_viz/setup.py mlos_viz/MANIFEST.in
MLOS_GLOBAL_CONF_FILES := setup.cfg # pyproject.toml
MLOS_GLOBAL_CONF_FILES := setup.cfg pyproject.toml
MLOS_PKGS := mlos_core mlos_bench mlos_viz
MLOS_PKG_CONF_FILES := $(MLOS_CORE_CONF_FILES) $(MLOS_BENCH_CONF_FILES) $(MLOS_VIZ_CONF_FILES) $(MLOS_GLOBAL_CONF_FILES)
@ -69,9 +69,9 @@ ifneq (,$(filter format,$(MAKECMDGOALS)))
endif
build/format.${CONDA_ENV_NAME}.build-stamp: build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
# TODO: enable isort and black formatters
#build/format.${CONDA_ENV_NAME}.build-stamp: build/isort.${CONDA_ENV_NAME}.build-stamp
#build/format.${CONDA_ENV_NAME}.build-stamp: build/black.${CONDA_ENV_NAME}.build-stamp
build/format.${CONDA_ENV_NAME}.build-stamp: build/isort.${CONDA_ENV_NAME}.build-stamp
build/format.${CONDA_ENV_NAME}.build-stamp: build/black.${CONDA_ENV_NAME}.build-stamp
build/format.${CONDA_ENV_NAME}.build-stamp: build/docformatter.${CONDA_ENV_NAME}.build-stamp
build/format.${CONDA_ENV_NAME}.build-stamp:
touch $@
@ -111,8 +111,8 @@ build/isort.${CONDA_ENV_NAME}.build-stamp:
# NOTE: when using pattern rules (involving %) we can only add one line of
# prerequisities, so we use this pattern to compose the list as variables.
# Both isort and licenseheaders alter files, so only run one at a time, by
# making licenseheaders an order-only prerequisite.
# black, licenseheaders, isort, and docformatter all alter files, so only run
# one at a time, by adding prerequisites, but only as necessary.
ISORT_COMMON_PREREQS :=
ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
ISORT_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
@ -126,7 +126,7 @@ build/isort.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
build/isort.%.${CONDA_ENV_NAME}.build-stamp: $(ISORT_COMMON_PREREQS)
# Reformat python file imports with isort.
conda run -n ${CONDA_ENV_NAME} isort --verbose --only-modified --atomic -j0 $(filter %.py,$?)
conda run -n ${CONDA_ENV_NAME} isort --verbose --only-modified --atomic -j0 $(filter %.py,$+)
touch $@
.PHONY: black
@ -142,8 +142,8 @@ build/black.${CONDA_ENV_NAME}.build-stamp: build/black.mlos_viz.${CONDA_ENV_NAME
build/black.${CONDA_ENV_NAME}.build-stamp:
touch $@
# Both black, licenseheaders, and isort all alter files, so only run one at a time, by
# making licenseheaders and isort an order-only prerequisite.
# black, licenseheaders, isort, and docformatter all alter files, so only run
# one at a time, by adding prerequisites, but only as necessary.
BLACK_COMMON_PREREQS :=
ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
BLACK_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
@ -160,13 +160,52 @@ build/black.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
build/black.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_COMMON_PREREQS)
# Reformat python files with black.
conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$?)
conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$+)
touch $@
.PHONY: docformatter
docformatter: build/docformatter.${CONDA_ENV_NAME}.build-stamp
ifneq (,$(filter docformatter,$(MAKECMDGOALS)))
FORMAT_PREREQS += build/docformatter.${CONDA_ENV_NAME}.build-stamp
endif
build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp
build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp
build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp
build/docformatter.${CONDA_ENV_NAME}.build-stamp:
touch $@
# black, licenseheaders, isort, and docformatter all alter files, so only run
# one at a time, by adding prerequisites, but only as necessary.
DOCFORMATTER_COMMON_PREREQS :=
ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS)))
DOCFORMATTER_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
endif
ifneq (,$(filter format isort,$(MAKECMDGOALS)))
DOCFORMATTER_COMMON_PREREQS += build/isort.${CONDA_ENV_NAME}.build-stamp
endif
ifneq (,$(filter format black,$(MAKECMDGOALS)))
DOCFORMATTER_COMMON_PREREQS += build/black.${CONDA_ENV_NAME}.build-stamp
endif
DOCFORMATTER_COMMON_PREREQS += build/conda-env.${CONDA_ENV_NAME}.build-stamp
DOCFORMATTER_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES)
build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES)
build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
# docformatter returns non-zero when it changes anything so instead we ignore that
# return code and just have it recheck itself immediately
build/docformatter.%.${CONDA_ENV_NAME}.build-stamp: $(DOCFORMATTER_COMMON_PREREQS)
# Reformat python file docstrings with docformatter.
conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$+) || true
conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$+)
touch $@
.PHONY: check
check: pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
# TODO: Enable isort and black checks
#check: isort-check black-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
check: isort-check black-check docformatter-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check
.PHONY: black-check
black-check: build/black-check.mlos_core.${CONDA_ENV_NAME}.build-stamp
@ -185,7 +224,27 @@ BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
build/black-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS)
# Check for import sort order.
# Note: if this fails use "make format" or "make black" to fix it.
conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$?)
conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$+)
touch $@
.PHONY: docformatter-check
docformatter-check: build/docformatter-check.mlos_core.${CONDA_ENV_NAME}.build-stamp
docformatter-check: build/docformatter-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp
docformatter-check: build/docformatter-check.mlos_viz.${CONDA_ENV_NAME}.build-stamp
# Make sure docformatter format rules run before docformatter-check rules.
build/docformatter-check.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES)
build/docformatter-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES)
build/docformatter-check.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES)
BLACK_CHECK_COMMON_PREREQS := build/conda-env.${CONDA_ENV_NAME}.build-stamp
BLACK_CHECK_COMMON_PREREQS += $(FORMAT_PREREQS)
BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
build/docformatter-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS)
# Check for import sort order.
# Note: if this fails use "make format" or "make docformatter" to fix it.
conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$+)
touch $@
.PHONY: isort-check
@ -204,7 +263,7 @@ ISORT_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES)
build/isort-check.%.${CONDA_ENV_NAME}.build-stamp: $(ISORT_CHECK_COMMON_PREREQS)
# Note: if this fails use "make format" or "make isort" to fix it.
conda run -n ${CONDA_ENV_NAME} isort --only-modified --check --diff -j0 $(filter %.py,$?)
conda run -n ${CONDA_ENV_NAME} isort --only-modified --check --diff -j0 $(filter %.py,$+)
touch $@
.PHONY: pycodestyle
@ -723,7 +782,12 @@ clean-doc:
.PHONY: clean-format
clean-format:
# TODO: add black and isort rules
rm -f build/black.${CONDA_ENV_NAME}.build-stamp
rm -f build/black.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/docformatter.${CONDA_ENV_NAME}.build-stamp
rm -f build/docformatter.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/isort.${CONDA_ENV_NAME}.build-stamp
rm -f build/isort.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/licenseheaders.${CONDA_ENV_NAME}.build-stamp
rm -f build/licenseheaders-prereqs.${CONDA_ENV_NAME}.build-stamp
@ -733,6 +797,13 @@ clean-check:
rm -f build/pylint.${CONDA_ENV_NAME}.build-stamp
rm -f build/pylint.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/mypy.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/black-check.build-stamp
rm -f build/black-check.${CONDA_ENV_NAME}.build-stamp
rm -f build/black-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/docformatter-check.${CONDA_ENV_NAME}.build-stamp
rm -f build/docformatter-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/isort-check.${CONDA_ENV_NAME}.build-stamp
rm -f build/isort-check.mlos_*.${CONDA_ENV_NAME}.build-stamp
rm -f build/pycodestyle.build-stamp
rm -f build/pycodestyle.${CONDA_ENV_NAME}.build-stamp
rm -f build/pycodestyle.mlos_*.${CONDA_ENV_NAME}.build-stamp

Просмотреть файл

@ -25,10 +25,10 @@ dependencies:
# See comments in mlos.yml.
#- gcc_linux-64
- pip:
- autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
- docformatter
- licenseheaders
- mypy
- pandas-stubs

Просмотреть файл

@ -25,10 +25,10 @@ dependencies:
# See comments in mlos.yml.
#- gcc_linux-64
- pip:
- autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
- docformatter
- licenseheaders
- mypy
- pandas-stubs

Просмотреть файл

@ -25,10 +25,10 @@ dependencies:
# See comments in mlos.yml.
#- gcc_linux-64
- pip:
- autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
- docformatter
- licenseheaders
- mypy
- pandas-stubs

Просмотреть файл

@ -25,10 +25,10 @@ dependencies:
# See comments in mlos.yml.
#- gcc_linux-64
- pip:
- autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
- docformatter
- licenseheaders
- mypy
- pandas-stubs

Просмотреть файл

@ -28,10 +28,10 @@ dependencies:
# This also requires a more recent vs2015_runtime from conda-forge.
- pyrfr>=0.9.0
- pip:
- autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
- docformatter
- licenseheaders
- mypy
- pandas-stubs

Просмотреть файл

@ -24,10 +24,10 @@ dependencies:
# FIXME: https://github.com/microsoft/MLOS/issues/727
- python<3.12
- pip:
- autopep8>=1.7.0
- bump2version
- check-jsonschema
- isort
- docformatter
- licenseheaders
- mypy
- pandas-stubs

Просмотреть файл

@ -32,14 +32,23 @@ def pytest_configure(config: pytest.Config) -> None:
Add some additional (global) configuration steps for pytest.
"""
# Workaround some issues loading emukit in certain environments.
if os.environ.get('DISPLAY', None):
if os.environ.get("DISPLAY", None):
try:
import matplotlib # pylint: disable=import-outside-toplevel
matplotlib.rcParams['backend'] = 'agg'
if is_master(config) or dict(getattr(config, 'workerinput', {}))['workerid'] == 'gw0':
import matplotlib # pylint: disable=import-outside-toplevel
matplotlib.rcParams["backend"] = "agg"
if is_master(config) or dict(getattr(config, "workerinput", {}))["workerid"] == "gw0":
# Only warn once.
warn(UserWarning('DISPLAY environment variable is set, which can cause problems in some setups (e.g. WSL). '
+ f'Adjusting matplotlib backend to "{matplotlib.rcParams["backend"]}" to compensate.'))
warn(
UserWarning(
(
"DISPLAY environment variable is set, "
"which can cause problems in some setups (e.g. WSL). "
f"Adjusting matplotlib backend to '{matplotlib.rcParams['backend']}' "
"to compensate."
)
)
)
except ImportError:
pass

Просмотреть файл

@ -87,10 +87,11 @@ autosummary_generate = True
numpydoc_class_members_toctree = False
autodoc_default_options = {
'members': True,
'undoc-members': True,
# Don't generate documentation for some (non-private) functions that are more for internal implementation use.
'exclude-members': 'mlos_bench.util.check_required_params'
"members": True,
"undoc-members": True,
# Don't generate documentation for some (non-private) functions that are more
# for internal implementation use.
"exclude-members": "mlos_bench.util.check_required_params",
}
# Generate the plots for the gallery

Просмотреть файл

@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
mlos_bench is a framework to help automate benchmarking and and
OS/application parameter autotuning.
"""mlos_bench is a framework to help automate benchmarking and and OS/application
parameter autotuning.
"""

Просмотреть файл

@ -2,6 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
mlos_bench.config
"""
"""mlos_bench.config."""

Просмотреть файл

@ -3,41 +3,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Script for post-processing FIO results for mlos_bench.
"""
"""Script for post-processing FIO results for mlos_bench."""
import argparse
import itertools
import json
from typing import Any, Iterator, Tuple
import pandas
def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]:
"""
Flatten every dict in the hierarchy and rename the keys with the dict path.
"""
"""Flatten every dict in the hierarchy and rename the keys with the dict path."""
if isinstance(data, dict):
for (key, val) in data.items():
for key, val in data.items():
yield from _flat_dict(val, f"{path}.{key}")
else:
yield (path, data)
def _main(input_file: str, output_file: str, prefix: str) -> None:
"""
Convert FIO read data from JSON to tall CSV.
"""
with open(input_file, mode='r', encoding='utf-8') as fh_input:
"""Convert FIO read data from JSON to tall CSV."""
with open(input_file, mode="r", encoding="utf-8") as fh_input:
json_data = json.load(fh_input)
data = list(itertools.chain(
_flat_dict(json_data["jobs"][0], prefix),
_flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util")
))
data = list(
itertools.chain(
_flat_dict(json_data["jobs"][0], prefix),
_flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util"),
)
)
tall_df = pandas.DataFrame(data, columns=["metric", "value"])
tall_df.to_csv(output_file, index=False)
@ -50,12 +45,18 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Post-process FIO benchmark results.")
parser.add_argument(
"input", help="FIO benchmark results in JSON format (downloaded from a remote VM).")
"input",
help="FIO benchmark results in JSON format (downloaded from a remote VM).",
)
parser.add_argument(
"output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).")
"output",
help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).",
)
parser.add_argument(
"--prefix", default="fio",
help="Prefix of the metric IDs (default 'fio')")
"--prefix",
default="fio",
help="Prefix of the metric IDs (default 'fio')",
)
args = parser.parse_args()
_main(args.input, args.output, args.prefix)

Просмотреть файл

@ -9,22 +9,24 @@ Helper script to generate Redis config from tunable parameters JSON.
Run: `./generate_redis_config.py ./input-params.json ./output-redis.cfg`
"""
import json
import argparse
import json
def _main(fname_input: str, fname_output: str) -> None:
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
for (key, val) in json.load(fh_tunables).items():
line = f'{key} {val}'
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open(
fname_output, "wt", encoding="utf-8", newline=""
) as fh_config:
for key, val in json.load(fh_tunables).items():
line = f"{key} {val}"
fh_config.write(line + "\n")
print(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="generate Redis config from tunable parameters JSON.")
description="generate Redis config from tunable parameters JSON."
)
parser.add_argument("input", help="JSON file with tunable parameters.")
parser.add_argument("output", help="Output Redis config file.")
args = parser.parse_args()

Просмотреть файл

@ -3,9 +3,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Script for post-processing redis-benchmark results.
"""
"""Script for post-processing redis-benchmark results."""
import argparse
@ -13,26 +11,25 @@ import pandas as pd
def _main(input_file: str, output_file: str) -> None:
"""
Re-shape Redis benchmark CSV results from wide to long.
"""
"""Re-shape Redis benchmark CSV results from wide to long."""
df_wide = pd.read_csv(input_file)
# Format the results from wide to long
# The target is columns of metric and value to act as key-value pairs.
df_long = (
df_wide
.melt(id_vars=["test"])
df_wide.melt(id_vars=["test"])
.assign(metric=lambda df: df["test"] + "_" + df["variable"])
.drop(columns=["test", "variable"])
.loc[:, ["metric", "value"]]
)
# Add a default `score` metric to the end of the dataframe.
df_long = pd.concat([
df_long,
pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]})
])
df_long = pd.concat(
[
df_long,
pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}),
]
)
df_long.to_csv(output_file, index=False)
print(f"Converted: {input_file} -> {output_file}")
@ -41,8 +38,13 @@ def _main(input_file: str, output_file: str) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.")
parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).")
parser.add_argument("output", help="Converted Redis benchmark data" +
" (to be consumed by OS Autotune framework).")
parser.add_argument(
"input",
help="Redis benchmark results (downloaded from a remote VM).",
)
parser.add_argument(
"output",
help="Converted Redis benchmark data (to be consumed by OS Autotune framework).",
)
args = parser.parse_args()
_main(args.input, args.output)

Просмотреть файл

@ -6,16 +6,18 @@
"""
Python script to parse through JSON and create new config file.
This script will be run in the SCHEDULER.
NEW_CFG will need to be copied over to the VM (/etc/default/grub.d).
This script will be run in the SCHEDULER. NEW_CFG will need to be copied over to the VM
(/etc/default/grub.d).
"""
import json
JSON_CONFIG_FILE = "config-boot-time.json"
NEW_CFG = "zz-mlos-boot-params.cfg"
with open(JSON_CONFIG_FILE, 'r', encoding='UTF-8') as fh_json, \
open(NEW_CFG, 'w', encoding='UTF-8') as fh_config:
with open(JSON_CONFIG_FILE, "r", encoding="UTF-8") as fh_json, open(
NEW_CFG, "w", encoding="UTF-8"
) as fh_config:
for key, val in json.load(fh_json).items():
fh_config.write('GRUB_CMDLINE_LINUX_DEFAULT="$'
f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n')
fh_config.write(
'GRUB_CMDLINE_LINUX_DEFAULT="$' f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n'
)

Просмотреть файл

@ -9,14 +9,15 @@ Helper script to generate GRUB config from tunable parameters JSON.
Run: `./generate_grub_config.py ./input-boot-params.json ./output-grub.cfg`
"""
import json
import argparse
import json
def _main(fname_input: str, fname_output: str) -> None:
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \
open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
for (key, val) in json.load(fh_tunables).items():
with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open(
fname_output, "wt", encoding="utf-8", newline=""
) as fh_config:
for key, val in json.load(fh_tunables).items():
line = f'GRUB_CMDLINE_LINUX_DEFAULT="${{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"'
fh_config.write(line + "\n")
print(line)
@ -24,7 +25,8 @@ def _main(fname_input: str, fname_output: str) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate GRUB config from tunable parameters JSON.")
description="Generate GRUB config from tunable parameters JSON."
)
parser.add_argument("input", help="JSON file with tunable parameters.")
parser.add_argument("output", help="Output shell script to configure GRUB.")
args = parser.parse_args()

Просмотреть файл

@ -6,11 +6,14 @@
"""
Helper script to generate a script to update kernel parameters from tunables JSON.
Run: `./generate_kernel_config_script.py ./kernel-params.json ./kernel-params-meta.json ./config-kernel.sh`
Run:
./generate_kernel_config_script.py \
./kernel-params.json ./kernel-params-meta.json \
./config-kernel.sh
"""
import json
import argparse
import json
def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
@ -22,7 +25,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
tunables_meta = json.load(fh_meta)
with open(fname_output, "wt", encoding="utf-8", newline="") as fh_config:
for (key, val) in tunables_data.items():
for key, val in tunables_data.items():
meta = tunables_meta.get(key, {})
name_prefix = meta.get("name_prefix", "")
line = f'echo "{val}" > {name_prefix}{key}'
@ -33,7 +36,8 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="generate a script to update kernel parameters from tunables JSON.")
description="generate a script to update kernel parameters from tunables JSON."
)
parser.add_argument("input", help="JSON file with tunable parameters.")
parser.add_argument("meta", help="JSON file with tunable parameters metadata.")

Просмотреть файл

@ -2,14 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A module for managing config schemas and their validation.
"""
from mlos_bench.config.schemas.config_schemas import ConfigSchema, CONFIG_SCHEMA_DIR
"""A module for managing config schemas and their validation."""
from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema
__all__ = [
'ConfigSchema',
'CONFIG_SCHEMA_DIR',
"ConfigSchema",
"CONFIG_SCHEMA_DIR",
]

Просмотреть файл

@ -2,18 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A simple class for describing where to find different config schemas and validating configs against them.
"""A simple class for describing where to find different config schemas and validating
configs against them.
"""
import json # schema files are pure json - no comments
import logging
from enum import Enum
from os import path, walk, environ
from os import environ, path, walk
from typing import Dict, Iterator, Mapping
import json # schema files are pure json - no comments
import jsonschema
from referencing import Registry, Resource
from referencing.jsonschema import DRAFT202012
@ -28,16 +27,21 @@ CONFIG_SCHEMA_DIR = path_join(path.dirname(__file__), abs_path=True)
# It is used in `ConfigSchema.validate()` method below.
# NOTE: this may cause pytest to fail if it's expecting exceptions
# to be raised for invalid configs.
_VALIDATION_ENV_FLAG = 'MLOS_BENCH_SKIP_SCHEMA_VALIDATION'
_SKIP_VALIDATION = (environ.get(_VALIDATION_ENV_FLAG, 'false').lower()
in {'true', 'y', 'yes', 'on', '1'})
_VALIDATION_ENV_FLAG = "MLOS_BENCH_SKIP_SCHEMA_VALIDATION"
_SKIP_VALIDATION = environ.get(_VALIDATION_ENV_FLAG, "false").lower() in {
"true",
"y",
"yes",
"on",
"1",
}
# Note: we separate out the SchemaStore from a class method on ConfigSchema
# because of issues with mypy/pylint and non-Enum-member class members.
class SchemaStore(Mapping):
"""
A simple class for storing schemas and subschemas for the validator to reference.
"""A simple class for storing schemas and subschemas for the validator to
reference.
"""
# A class member mapping of schema id to schema object.
@ -58,7 +62,9 @@ class SchemaStore(Mapping):
@classmethod
def _load_schemas(cls) -> None:
"""Loads all schemas and subschemas into the schema store for the validator to reference."""
"""Loads all schemas and subschemas into the schema store for the validator to
reference.
"""
if cls._SCHEMA_STORE:
return
for root, _, files in walk(CONFIG_SCHEMA_DIR):
@ -78,13 +84,17 @@ class SchemaStore(Mapping):
@classmethod
def _load_registry(cls) -> None:
"""Also store them in a Registry object for referencing by recent versions of jsonschema."""
"""Also store them in a Registry object for referencing by recent versions of
jsonschema.
"""
if not cls._SCHEMA_STORE:
cls._load_schemas()
cls._REGISTRY = Registry().with_resources([
(url, Resource.from_contents(schema, default_specification=DRAFT202012))
for url, schema in cls._SCHEMA_STORE.items()
])
cls._REGISTRY = Registry().with_resources(
[
(url, Resource.from_contents(schema, default_specification=DRAFT202012))
for url, schema in cls._SCHEMA_STORE.items()
]
)
@property
def registry(self) -> Registry:
@ -98,9 +108,7 @@ SCHEMA_STORE = SchemaStore()
class ConfigSchema(Enum):
"""
An enum to help describe schema types and help validate configs against them.
"""
"""An enum to help describe schema types and help validate configs against them."""
CLI = path_join(CONFIG_SCHEMA_DIR, "cli/cli-schema.json")
GLOBALS = path_join(CONFIG_SCHEMA_DIR, "cli/globals-schema.json")

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Simple class to help with nested dictionary $var templating.
"""
"""Simple class to help with nested dictionary $var templating."""
from copy import deepcopy
from string import Template
@ -13,10 +11,8 @@ from typing import Any, Dict, Optional
from mlos_bench.os_environ import environ
class DictTemplater: # pylint: disable=too-few-public-methods
"""
Simple class to help with nested dictionary $var templating.
"""
class DictTemplater: # pylint: disable=too-few-public-methods
"""Simple class to help with nested dictionary $var templating."""
def __init__(self, source_dict: Dict[str, Any]):
"""
@ -32,9 +28,12 @@ class DictTemplater: # pylint: disable=too-few-public-methods
# The source/target dictionary to expand.
self._dict: Dict[str, Any] = {}
def expand_vars(self, *,
extra_source_dict: Optional[Dict[str, Any]] = None,
use_os_env: bool = False) -> Dict[str, Any]:
def expand_vars(
self,
*,
extra_source_dict: Optional[Dict[str, Any]] = None,
use_os_env: bool = False,
) -> Dict[str, Any]:
"""
Expand the template variables in the destination dictionary.
@ -55,10 +54,13 @@ class DictTemplater: # pylint: disable=too-few-public-methods
assert isinstance(self._dict, dict)
return self._dict
def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any:
"""
Recursively expand $var strings in the currently operating dictionary.
"""
def _expand_vars(
self,
value: Any,
extra_source_dict: Optional[Dict[str, Any]],
use_os_env: bool,
) -> Any:
"""Recursively expand $var strings in the currently operating dictionary."""
if isinstance(value, str):
# First try to expand all $vars internally.
value = Template(value).safe_substitute(self._dict)
@ -71,7 +73,7 @@ class DictTemplater: # pylint: disable=too-few-public-methods
elif isinstance(value, dict):
# Note: we use a loop instead of dict comprehension in order to
# allow secondary expansion of subsequent values immediately.
for (key, val) in value.items():
for key, val in value.items():
value[key] = self._expand_vars(val, extra_source_dict, use_os_env)
elif isinstance(value, list):
value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value]

Просмотреть файл

@ -2,26 +2,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Tunable Environments for mlos_bench.
"""
"""Tunable Environments for mlos_bench."""
from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.mock_env import MockEnv
from mlos_bench.environments.remote.remote_env import RemoteEnv
from mlos_bench.environments.composite_env import CompositeEnv
from mlos_bench.environments.local.local_env import LocalEnv
from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv
from mlos_bench.environments.composite_env import CompositeEnv
from mlos_bench.environments.mock_env import MockEnv
from mlos_bench.environments.remote.remote_env import RemoteEnv
from mlos_bench.environments.status import Status
__all__ = [
'Status',
'Environment',
'MockEnv',
'RemoteEnv',
'LocalEnv',
'LocalFileShareEnv',
'CompositeEnv',
"Status",
"Environment",
"MockEnv",
"RemoteEnv",
"LocalEnv",
"LocalFileShareEnv",
"CompositeEnv",
]

Просмотреть файл

@ -2,19 +2,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A hierarchy of benchmark environments.
"""
"""A hierarchy of benchmark environments."""
import abc
import json
import logging
from datetime import datetime
from types import TracebackType
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union
from typing_extensions import Literal
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pytz import UTC
from typing_extensions import Literal
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.dict_templater import DictTemplater
@ -32,20 +41,19 @@ _LOG = logging.getLogger(__name__)
class Environment(metaclass=abc.ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
An abstract base of all benchmark environments.
"""
"""An abstract base of all benchmark environments."""
@classmethod
def new(cls,
*,
env_name: str,
class_name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
) -> "Environment":
def new(
cls,
*,
env_name: str,
class_name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
) -> "Environment":
"""
Factory method for a new environment with a given config.
@ -83,16 +91,18 @@ class Environment(metaclass=abc.ABCMeta):
config=config,
global_config=global_config,
tunables=tunables,
service=service
service=service,
)
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment with a given config.
@ -123,24 +133,33 @@ class Environment(metaclass=abc.ABCMeta):
self._const_args: Dict[str, TunableValue] = config.get("const_args", {})
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Environment: '%s' Service: %s", name,
self._service.pprint() if self._service else None)
_LOG.debug(
"Environment: '%s' Service: %s",
name,
self._service.pprint() if self._service else None,
)
if tunables is None:
_LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name)
_LOG.warning(
(
"No tunables provided for %s. "
"Tunable inheritance across composite environments may be broken."
),
name,
)
tunables = TunableGroups()
groups = self._expand_groups(
config.get("tunable_params", []),
(global_config or {}).get("tunable_params_map", {}))
(global_config or {}).get("tunable_params_map", {}),
)
_LOG.debug("Tunable groups for: '%s' :: %s", name, groups)
self._tunable_params = tunables.subgroup(groups)
# If a parameter comes from the tunables, do not require it in the const_args or globals
req_args = (
set(config.get("required_args", [])) -
set(self._tunable_params.get_param_values().keys())
req_args = set(config.get("required_args", [])) - set(
self._tunable_params.get_param_values().keys()
)
merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args)
self._const_args = self._expand_vars(self._const_args, global_config or {})
@ -149,14 +168,12 @@ class Environment(metaclass=abc.ABCMeta):
_LOG.debug("Parameters for '%s' :: %s", name, self._params)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Config for: '%s'\n%s",
name, json.dumps(self.config, indent=2))
_LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2))
def _validate_json_config(self, config: dict, name: str) -> None:
"""
Reconstructs a basic json config that this class might have been
instantiated from in order to validate configs provided outside the
file loading mechanism.
"""Reconstructs a basic json config that this class might have been instantiated
from in order to validate configs provided outside the file loading
mechanism.
"""
json_config: dict = {
"class": self.__class__.__module__ + "." + self.__class__.__name__,
@ -168,8 +185,10 @@ class Environment(metaclass=abc.ABCMeta):
ConfigSchema.ENVIRONMENT.validate(json_config)
@staticmethod
def _expand_groups(groups: Iterable[str],
groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]:
def _expand_groups(
groups: Iterable[str],
groups_exp: Dict[str, Union[str, Sequence[str]]],
) -> List[str]:
"""
Expand `$tunable_group` into actual names of the tunable groups.
@ -191,7 +210,12 @@ class Environment(metaclass=abc.ABCMeta):
if grp[:1] == "$":
tunable_group_name = grp[1:]
if tunable_group_name not in groups_exp:
raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}")
raise KeyError(
(
f"Expected tunable group name ${tunable_group_name} "
"undefined in {groups_exp}"
)
)
add_groups = groups_exp[tunable_group_name]
res += [add_groups] if isinstance(add_groups, str) else add_groups
else:
@ -199,10 +223,11 @@ class Environment(metaclass=abc.ABCMeta):
return res
@staticmethod
def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict:
"""
Expand `$var` into actual values of the variables.
"""
def _expand_vars(
params: Dict[str, TunableValue],
global_config: Dict[str, TunableValue],
) -> dict:
"""Expand `$var` into actual values of the variables."""
return DictTemplater(params).expand_vars(extra_source_dict=global_config)
@property
@ -210,10 +235,8 @@ class Environment(metaclass=abc.ABCMeta):
assert self._service is not None
return self._service.config_loader_service
def __enter__(self) -> 'Environment':
"""
Enter the environment's benchmarking context.
"""
def __enter__(self) -> "Environment":
"""Enter the environment's benchmarking context."""
_LOG.debug("Environment START :: %s", self)
assert not self._in_context
if self._service:
@ -221,12 +244,13 @@ class Environment(metaclass=abc.ABCMeta):
self._in_context = True
return self
def __exit__(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
"""
Exit the context of the benchmarking environment.
"""
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
"""Exit the context of the benchmarking environment."""
ex_throw = None
if ex_val is None:
_LOG.debug("Environment END :: %s", self)
@ -256,8 +280,8 @@ class Environment(metaclass=abc.ABCMeta):
def pprint(self, indent: int = 4, level: int = 0) -> str:
"""
Pretty-print the environment configuration.
For composite environments, print all children environments as well.
Pretty-print the environment configuration. For composite environments, print
all children environments as well.
Parameters
----------
@ -277,8 +301,8 @@ class Environment(metaclass=abc.ABCMeta):
def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]:
"""
Plug tunable values into the base config. If the tunable group is unknown,
ignore it (it might belong to another environment). This method should
never mutate the original config or the tunables.
ignore it (it might belong to another environment). This method should never
mutate the original config or the tunables.
Parameters
----------
@ -293,7 +317,8 @@ class Environment(metaclass=abc.ABCMeta):
"""
return tunables.get_param_values(
group_names=list(self._tunable_params.get_covariant_group_names()),
into_params=self._const_args.copy())
into_params=self._const_args.copy(),
)
@property
def tunable_params(self) -> TunableGroups:
@ -310,21 +335,23 @@ class Environment(metaclass=abc.ABCMeta):
@property
def parameters(self) -> Dict[str, TunableValue]:
"""
Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`).
Note that before `.setup()` is called, all tunables will be set to None.
Key/value pairs of all environment parameters (i.e., `const_args` and
`tunable_params`). Note that before `.setup()` is called, all tunables will be
set to None.
Returns
-------
parameters : Dict[str, TunableValue]
Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`).
Key/value pairs of all environment parameters
(i.e., `const_args` and `tunable_params`).
"""
return self._params
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Set up a new benchmark environment, if necessary. This method must be
idempotent, i.e., calling it several times in a row should be
equivalent to a single call.
idempotent, i.e., calling it several times in a row should be equivalent to a
single call.
Parameters
----------
@ -353,10 +380,15 @@ class Environment(metaclass=abc.ABCMeta):
# (Derived classes still have to check `self._tunable_params.is_updated()`).
is_updated = self._tunable_params.is_updated()
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, {
name: self._tunable_params.is_updated([name])
for name in self._tunable_params.get_covariant_group_names()
})
_LOG.debug(
"Env '%s': Tunable groups reset = %s :: %s",
self,
is_updated,
{
name: self._tunable_params.is_updated([name])
for name in self._tunable_params.get_covariant_group_names()
},
)
else:
_LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated)
@ -371,9 +403,10 @@ class Environment(metaclass=abc.ABCMeta):
def teardown(self) -> None:
"""
Tear down the benchmark environment. This method must be idempotent,
i.e., calling it several times in a row should be equivalent to a
single call.
Tear down the benchmark environment.
This method must be idempotent, i.e., calling it several times in a row should
be equivalent to a single call.
"""
_LOG.info("Teardown %s", self)
# Make sure we create a context before invoking setup/run/status/teardown

Просмотреть файл

@ -2,20 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Composite benchmark environment.
"""
"""Composite benchmark environment."""
import logging
from datetime import datetime
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type
from typing_extensions import Literal
from mlos_bench.services.base_service import Service
from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
@ -23,17 +21,17 @@ _LOG = logging.getLogger(__name__)
class CompositeEnv(Environment):
"""
Composite benchmark environment.
"""
"""Composite benchmark environment."""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment with a given config.
@ -53,8 +51,13 @@ class CompositeEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config,
tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
# By default, the Environment includes only the tunables explicitly specified
# in the "tunable_params" section of the config. `CompositeEnv`, however, must
@ -70,17 +73,27 @@ class CompositeEnv(Environment):
# each CompositeEnv gets a copy of the original global config and adjusts it with
# the `const_args` specific to it.
global_config = (global_config or {}).copy()
for (key, val) in self._const_args.items():
for key, val in self._const_args.items():
global_config.setdefault(key, val)
for child_config_file in config.get("include_children", []):
for env in self._config_loader_service.load_environment_list(
child_config_file, tunables, global_config, self._const_args, self._service):
child_config_file,
tunables,
global_config,
self._const_args,
self._service,
):
self._add_child(env, tunables)
for child_config in config.get("children", []):
env = self._config_loader_service.build_environment(
child_config, tunables, global_config, self._const_args, self._service)
child_config,
tunables,
global_config,
self._const_args,
self._service,
)
self._add_child(env, tunables)
_LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params)
@ -92,9 +105,12 @@ class CompositeEnv(Environment):
self._child_contexts = [env.__enter__() for env in self._children]
return super().__enter__()
def __exit__(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
ex_throw = None
for env in reversed(self._children):
try:
@ -111,9 +127,7 @@ class CompositeEnv(Environment):
@property
def children(self) -> List[Environment]:
"""
Return the list of child environments.
"""
"""Return the list of child environments."""
return self._children
def pprint(self, indent: int = 4, level: int = 0) -> str:
@ -132,12 +146,16 @@ class CompositeEnv(Environment):
pretty : str
Pretty-printed environment configuration.
"""
return super().pprint(indent, level) + '\n' + '\n'.join(
child.pprint(indent, level + 1) for child in self._children)
return (
super().pprint(indent, level)
+ "\n"
+ "\n".join(child.pprint(indent, level + 1) for child in self._children)
)
def _add_child(self, env: Environment, tunables: TunableGroups) -> None:
"""
Add a new child environment to the composite environment.
This method is called from the constructor only.
"""
_LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params)
@ -165,14 +183,16 @@ class CompositeEnv(Environment):
"""
assert self._in_context
self._is_ready = super().setup(tunables, global_config) and all(
env_context.setup(tunables, global_config) for env_context in self._child_contexts)
env_context.setup(tunables, global_config) for env_context in self._child_contexts
)
return self._is_ready
def teardown(self) -> None:
"""
Tear down the children environments. This method is idempotent,
i.e., calling it several times is equivalent to a single call.
The environments are being torn down in the reverse order.
Tear down the children environments.
This method is idempotent, i.e., calling it several times is equivalent to a
single call. The environments are being torn down in the reverse order.
"""
assert self._in_context
for env_context in reversed(self._child_contexts):
@ -181,9 +201,9 @@ class CompositeEnv(Environment):
def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
"""
Submit a new experiment to the environment.
Return the result of the *last* child environment if successful,
or the status of the last failed environment otherwise.
Submit a new experiment to the environment. Return the result of the *last*
child environment if successful, or the status of the last failed environment
otherwise.
Returns
-------
@ -238,5 +258,6 @@ class CompositeEnv(Environment):
final_status = final_status or status
_LOG.info("Final status: %s :: %s", self, final_status)
# Return the status and the timestamp of the last child environment or the first failed child environment.
# Return the status and the timestamp of the last child environment or the
# first failed child environment.
return (final_status, timestamp, joint_telemetry)

Просмотреть файл

@ -2,14 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Local Environments for mlos_bench.
"""
"""Local Environments for mlos_bench."""
from mlos_bench.environments.local.local_env import LocalEnv
from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv
__all__ = [
'LocalEnv',
'LocalFileShareEnv',
"LocalEnv",
"LocalFileShareEnv",
]

Просмотреть файл

@ -2,27 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Scheduler-side benchmark environment to run scripts locally.
"""
"""Scheduler-side benchmark environment to run scripts locally."""
import json
import logging
import sys
from contextlib import nullcontext
from datetime import datetime
from tempfile import TemporaryDirectory
from contextlib import nullcontext
from types import TracebackType
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union
from typing_extensions import Literal
import pandas
from typing_extensions import Literal
from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.script_env import ScriptEnv
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
from mlos_bench.tunables.tunable import TunableValue
@ -34,17 +30,17 @@ _LOG = logging.getLogger(__name__)
class LocalEnv(ScriptEnv):
# pylint: disable=too-many-instance-attributes
"""
Scheduler-side Environment that runs scripts locally.
"""
"""Scheduler-side Environment that runs scripts locally."""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment for local execution.
@ -67,11 +63,17 @@ class LocalEnv(ScriptEnv):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config,
tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
"LocalEnv requires a service that supports local execution"
assert self._service is not None and isinstance(
self._service, SupportsLocalExec
), "LocalEnv requires a service that supports local execution"
self._local_exec_service: SupportsLocalExec = self._service
self._temp_dir: Optional[str] = None
@ -85,16 +87,19 @@ class LocalEnv(ScriptEnv):
def __enter__(self) -> Environment:
assert self._temp_dir is None and self._temp_dir_context is None
self._temp_dir_context = self._local_exec_service.temp_dir_context(self.config.get("temp_dir"))
self._temp_dir_context = self._local_exec_service.temp_dir_context(
self.config.get("temp_dir"),
)
self._temp_dir = self._temp_dir_context.__enter__()
return super().__enter__()
def __exit__(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
"""
Exit the context of the benchmarking environment.
"""
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
"""Exit the context of the benchmarking environment."""
assert not (self._temp_dir is None or self._temp_dir_context is None)
self._temp_dir_context.__exit__(ex_type, ex_val, ex_tb)
self._temp_dir = None
@ -103,8 +108,8 @@ class LocalEnv(ScriptEnv):
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Check if the environment is ready and set up the application
and benchmarks, if necessary.
Check if the environment is ready and set up the application and benchmarks, if
necessary.
Parameters
----------
@ -139,10 +144,14 @@ class LocalEnv(ScriptEnv):
fname = path_join(self._temp_dir, self._dump_meta_file)
_LOG.debug("Dump tunables metadata to file: %s", fname)
with open(fname, "wt", encoding="utf-8") as fh_meta:
json.dump({
tunable.name: tunable.meta
for (tunable, _group) in self._tunable_params if tunable.meta
}, fh_meta)
json.dump(
{
tunable.name: tunable.meta
for (tunable, _group) in self._tunable_params
if tunable.meta
},
fh_meta,
)
if self._script_setup:
(return_code, _output) = self._local_exec(self._script_setup, self._temp_dir)
@ -182,18 +191,28 @@ class LocalEnv(ScriptEnv):
_LOG.debug("Not reading the data at: %s", self)
return (Status.SUCCEEDED, timestamp, stdout_data)
data = self._normalize_columns(pandas.read_csv(
self._config_loader_service.resolve_path(
self._read_results_file, extra_paths=[self._temp_dir]),
index_col=False,
))
data = self._normalize_columns(
pandas.read_csv(
self._config_loader_service.resolve_path(
self._read_results_file,
extra_paths=[self._temp_dir],
),
index_col=False,
)
)
_LOG.debug("Read data:\n%s", data)
if list(data.columns) == ["metric", "value"]:
_LOG.info("Local results have (metric,value) header and %d rows: assume long format", len(data))
_LOG.info(
"Local results have (metric,value) header and %d rows: assume long format",
len(data),
)
data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list())
# Try to convert string metrics to numbers.
data = data.apply(pandas.to_numeric, errors='coerce').fillna(data) # type: ignore[assignment] # (false positive)
data = data.apply( # type: ignore[assignment] # (false positive)
pandas.to_numeric,
errors="coerce",
).fillna(data)
elif len(data) == 1:
_LOG.info("Local results have 1 row: assume wide format")
else:
@ -205,14 +224,12 @@ class LocalEnv(ScriptEnv):
@staticmethod
def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame:
"""
Strip trailing spaces from column names (Windows only).
"""
"""Strip trailing spaces from column names (Windows only)."""
# Windows cmd interpretation of > redirect symbols can leave trailing spaces in
# the final column, which leads to misnamed columns.
# For now, we simply strip trailing spaces from column names to account for that.
if sys.platform == 'win32':
data.rename(str.rstrip, axis='columns', inplace=True)
if sys.platform == "win32":
data.rename(str.rstrip, axis="columns", inplace=True)
return data
def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
@ -224,24 +241,24 @@ class LocalEnv(ScriptEnv):
assert self._temp_dir is not None
try:
fname = self._config_loader_service.resolve_path(
self._read_telemetry_file, extra_paths=[self._temp_dir])
self._read_telemetry_file,
extra_paths=[self._temp_dir],
)
# TODO: Use the timestamp of the CSV file as our status timestamp?
# FIXME: We should not be assuming that the only output file type is a CSV.
data = self._normalize_columns(
pandas.read_csv(fname, index_col=False))
data = self._normalize_columns(pandas.read_csv(fname, index_col=False))
data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
expected_col_names = ["timestamp", "metric", "value"]
if len(data.columns) != len(expected_col_names):
raise ValueError(f'Telemetry data must have columns {expected_col_names}')
raise ValueError(f"Telemetry data must have columns {expected_col_names}")
if list(data.columns) != expected_col_names:
# Assume no header - this is ok for telemetry data.
data = pandas.read_csv(
fname, index_col=False, names=expected_col_names)
data = pandas.read_csv(fname, index_col=False, names=expected_col_names)
data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
except FileNotFoundError as ex:
@ -250,15 +267,17 @@ class LocalEnv(ScriptEnv):
_LOG.debug("Read telemetry data:\n%s", data)
col_dtypes: Mapping[int, Type] = {0: datetime}
return (status, timestamp, [
(pandas.Timestamp(ts).to_pydatetime(), metric, value)
for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
])
return (
status,
timestamp,
[
(pandas.Timestamp(ts).to_pydatetime(), metric, value)
for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
],
)
def teardown(self) -> None:
"""
Clean up the local environment.
"""
"""Clean up the local environment."""
if self._script_teardown:
_LOG.info("Local teardown: %s", self)
(return_code, _output) = self._local_exec(self._script_teardown)
@ -285,7 +304,10 @@ class LocalEnv(ScriptEnv):
env_params = self._get_env_params()
_LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params)
(return_code, stdout, stderr) = self._local_exec_service.local_exec(
script, env=env_params, cwd=cwd)
script,
env=env_params,
cwd=cwd,
)
if return_code != 0:
_LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr)
return (return_code, {"stdout": stdout, "stderr": stderr})

Просмотреть файл

@ -2,22 +2,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Scheduler-side Environment to run scripts locally
and upload/download data to the shared storage.
"""Scheduler-side Environment to run scripts locally and upload/download data to the
shared storage.
"""
import logging
from datetime import datetime
from string import Template
from typing import Any, Dict, List, Generator, Iterable, Mapping, Optional, Tuple
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
from mlos_bench.environments.status import Status
from mlos_bench.environments.local.local_env import LocalEnv
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
@ -25,18 +23,19 @@ _LOG = logging.getLogger(__name__)
class LocalFileShareEnv(LocalEnv):
"""
Scheduler-side Environment that runs scripts locally
and uploads/downloads data to the shared file storage.
"""Scheduler-side Environment that runs scripts locally and uploads/downloads data
to the shared file storage.
"""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new application environment with a given config.
@ -60,34 +59,41 @@ class LocalFileShareEnv(LocalEnv):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
"LocalEnv requires a service that supports local execution"
assert self._service is not None and isinstance(
self._service, SupportsLocalExec
), "LocalEnv requires a service that supports local execution"
self._local_exec_service: SupportsLocalExec = self._service
assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \
"LocalEnv requires a service that supports file upload/download operations"
assert self._service is not None and isinstance(
self._service, SupportsFileShareOps
), "LocalEnv requires a service that supports file upload/download operations"
self._file_share_service: SupportsFileShareOps = self._service
self._upload = self._template_from_to("upload")
self._download = self._template_from_to("download")
def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]:
"""Convert a list of {"from": "...", "to": "..."} to a list of pairs of
string.Template objects so that we can plug in self._params into it later.
"""
Convert a list of {"from": "...", "to": "..."} to a list of pairs
of string.Template objects so that we can plug in self._params into it later.
"""
return [
(Template(d['from']), Template(d['to']))
for d in self.config.get(config_key, [])
]
return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])]
@staticmethod
def _expand(from_to: Iterable[Tuple[Template, Template]],
params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]:
def _expand(
from_to: Iterable[Tuple[Template, Template]],
params: Mapping[str, TunableValue],
) -> Generator[Tuple[str, str], None, None]:
"""
Substitute $var parameters in from/to path templates.
Return a generator of (str, str) pairs of paths.
"""
return (
@ -120,9 +126,15 @@ class LocalFileShareEnv(LocalEnv):
assert self._temp_dir is not None
params = self._get_env_params(restrict=False)
params["PWD"] = self._temp_dir
for (path_from, path_to) in self._expand(self._upload, params):
self._file_share_service.upload(self._params, self._config_loader_service.resolve_path(
path_from, extra_paths=[self._temp_dir]), path_to)
for path_from, path_to in self._expand(self._upload, params):
self._file_share_service.upload(
self._params,
self._config_loader_service.resolve_path(
path_from,
extra_paths=[self._temp_dir],
),
path_to,
)
return self._is_ready
def _download_files(self, ignore_missing: bool = False) -> None:
@ -138,11 +150,16 @@ class LocalFileShareEnv(LocalEnv):
assert self._temp_dir is not None
params = self._get_env_params(restrict=False)
params["PWD"] = self._temp_dir
for (path_from, path_to) in self._expand(self._download, params):
for path_from, path_to in self._expand(self._download, params):
try:
self._file_share_service.download(self._params,
path_from, self._config_loader_service.resolve_path(
path_to, extra_paths=[self._temp_dir]))
self._file_share_service.download(
self._params,
path_from,
self._config_loader_service.resolve_path(
path_to,
extra_paths=[self._temp_dir],
),
)
except FileNotFoundError as ex:
_LOG.warning("Cannot download: %s", path_from)
if not ignore_missing:
@ -153,8 +170,8 @@ class LocalFileShareEnv(LocalEnv):
def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
"""
Download benchmark results from the shared storage
and run post-processing scripts locally.
Download benchmark results from the shared storage and run post-processing
scripts locally.
Returns
-------

Просмотреть файл

@ -2,40 +2,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Scheduler-side environment to mock the benchmark results.
"""
"""Scheduler-side environment to mock the benchmark results."""
import random
import logging
import random
from datetime import datetime
from typing import Dict, Optional, Tuple
import numpy
from mlos_bench.services.base_service import Service
from mlos_bench.environments.status import Status
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.tunables import Tunable, TunableGroups, TunableValue
_LOG = logging.getLogger(__name__)
class MockEnv(Environment):
"""
Scheduler-side environment to mock the benchmark results.
"""
"""Scheduler-side environment to mock the benchmark results."""
_NOISE_VAR = 0.2
"""Variance of the Gaussian noise added to the benchmark value."""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment that produces mock benchmark data.
@ -55,8 +53,13 @@ class MockEnv(Environment):
service: Service
An optional service object. Not used by this class.
"""
super().__init__(name=name, config=config, global_config=global_config,
tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
seed = int(self.config.get("mock_env_seed", -1))
self._random = random.Random(seed or None) if seed >= 0 else None
self._range = self.config.get("mock_env_range")
@ -81,9 +84,9 @@ class MockEnv(Environment):
return result
# Simple convex function of all tunable parameters.
score = numpy.mean(numpy.square([
self._normalized(tunable) for (tunable, _group) in self._tunable_params
]))
score = numpy.mean(
numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params])
)
# Add noise and shift the benchmark value from [0, 1] to a given range.
noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0
@ -97,15 +100,16 @@ class MockEnv(Environment):
def _normalized(tunable: Tunable) -> float:
"""
Get the NORMALIZED value of a tunable.
That is, map current value to the [0, 1] range.
"""
val = None
if tunable.is_categorical:
val = (tunable.categories.index(tunable.category) /
float(len(tunable.categories) - 1))
val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1)
elif tunable.is_numerical:
val = ((tunable.numerical_value - tunable.range[0]) /
float(tunable.range[1] - tunable.range[0]))
val = (tunable.numerical_value - tunable.range[0]) / float(
tunable.range[1] - tunable.range[0]
)
else:
raise ValueError("Invalid parameter type: " + tunable.type)
# Explicitly clip the value in case of numerical errors.

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Remote Tunable Environments for mlos_bench.
"""
"""Remote Tunable Environments for mlos_bench."""
from mlos_bench.environments.remote.host_env import HostEnv
from mlos_bench.environments.remote.network_env import NetworkEnv
@ -14,10 +12,10 @@ from mlos_bench.environments.remote.saas_env import SaaSEnv
from mlos_bench.environments.remote.vm_env import VMEnv
__all__ = [
'HostEnv',
'NetworkEnv',
'OSEnv',
'RemoteEnv',
'SaaSEnv',
'VMEnv',
"HostEnv",
"NetworkEnv",
"OSEnv",
"RemoteEnv",
"SaaSEnv",
"VMEnv",
]

Просмотреть файл

@ -2,13 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Remote host Environment.
"""
from typing import Optional
"""Remote host Environment."""
import logging
from typing import Optional
from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
@ -19,17 +16,17 @@ _LOG = logging.getLogger(__name__)
class HostEnv(Environment):
"""
Remote host environment.
"""
"""Remote host environment."""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment for host operations.
@ -50,10 +47,17 @@ class HostEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM/host, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
assert self._service is not None and isinstance(self._service, SupportsHostProvisioning), \
"HostEnv requires a service that supports host provisioning operations"
assert self._service is not None and isinstance(
self._service, SupportsHostProvisioning
), "HostEnv requires a service that supports host provisioning operations"
self._host_service: SupportsHostProvisioning = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
@ -88,9 +92,7 @@ class HostEnv(Environment):
return self._is_ready
def teardown(self) -> None:
"""
Shut down the Host and release it.
"""
"""Shut down the Host and release it."""
_LOG.info("Host tear down: %s", self)
(status, params) = self._host_service.deprovision_host(self._params)
if status.is_pending():

Просмотреть файл

@ -2,17 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Network Environment.
"""
from typing import Optional
"""Network Environment."""
import logging
from typing import Optional
from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
from mlos_bench.services.types.network_provisioner_type import (
SupportsNetworkProvisioning,
)
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
@ -26,13 +25,15 @@ class NetworkEnv(Environment):
but no real tuning is expected for it ... yet.
"""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment for network operations.
@ -53,14 +54,21 @@ class NetworkEnv(Environment):
An optional service object (e.g., providing methods to
deploy a network, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
# Virtual networks can be used for more than one experiment, so by default
# we don't attempt to deprovision them.
self._deprovision_on_teardown = config.get("deprovision_on_teardown", False)
assert self._service is not None and isinstance(self._service, SupportsNetworkProvisioning), \
"NetworkEnv requires a service that supports network provisioning"
assert self._service is not None and isinstance(
self._service, SupportsNetworkProvisioning
), "NetworkEnv requires a service that supports network provisioning"
self._network_service: SupportsNetworkProvisioning = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
@ -96,15 +104,16 @@ class NetworkEnv(Environment):
return self._is_ready
def teardown(self) -> None:
"""
Shut down the Network and releases it.
"""
"""Shut down the Network and releases it."""
if not self._deprovision_on_teardown:
_LOG.info("Skipping Network deprovision: %s", self)
return
# Else
_LOG.info("Network tear down: %s", self)
(status, params) = self._network_service.deprovision_network(self._params, ignore_errors=True)
(status, params) = self._network_service.deprovision_network(
self._params,
ignore_errors=True,
)
if status.is_pending():
(status, _) = self._network_service.wait_network_deployment(params, is_setup=False)

Просмотреть файл

@ -2,13 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
OS-level remote Environment on Azure.
"""
from typing import Optional
"""OS-level remote Environment on Azure."""
import logging
from typing import Optional
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.status import Status
@ -21,17 +18,17 @@ _LOG = logging.getLogger(__name__)
class OSEnv(Environment):
"""
OS Level Environment for a host.
"""
"""OS Level Environment for a host."""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment for remote execution.
@ -54,14 +51,22 @@ class OSEnv(Environment):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
assert self._service is not None and isinstance(self._service, SupportsHostOps), \
"RemoteEnv requires a service that supports host operations"
assert self._service is not None and isinstance(
self._service, SupportsHostOps
), "RemoteEnv requires a service that supports host operations"
self._host_service: SupportsHostOps = self._service
assert self._service is not None and isinstance(self._service, SupportsOSOps), \
"RemoteEnv requires a service that supports host operations"
assert self._service is not None and isinstance(
self._service, SupportsOSOps
), "RemoteEnv requires a service that supports host operations"
self._os_service: SupportsOSOps = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
@ -98,9 +103,7 @@ class OSEnv(Environment):
return self._is_ready
def teardown(self) -> None:
"""
Clean up and shut down the host without deprovisioning it.
"""
"""Clean up and shut down the host without deprovisioning it."""
_LOG.info("OS tear down: %s", self)
(status, params) = self._os_service.shutdown(self._params)
if status.is_pending():

Просмотреть файл

@ -14,11 +14,11 @@ from typing import Dict, Iterable, Optional, Tuple
from pytz import UTC
from mlos_bench.environments.status import Status
from mlos_bench.environments.script_env import ScriptEnv
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.services.types.host_ops_type import SupportsHostOps
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
@ -32,13 +32,15 @@ class RemoteEnv(ScriptEnv):
e.g. Application Environment
"""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment for remote execution.
@ -61,24 +63,31 @@ class RemoteEnv(ScriptEnv):
An optional service object (e.g., providing methods to
deploy or reboot a Host, VM, OS, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config,
tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
self._wait_boot = self.config.get("wait_boot", False)
assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \
"RemoteEnv requires a service that supports remote execution operations"
assert self._service is not None and isinstance(
self._service, SupportsRemoteExec
), "RemoteEnv requires a service that supports remote execution operations"
self._remote_exec_service: SupportsRemoteExec = self._service
if self._wait_boot:
assert self._service is not None and isinstance(self._service, SupportsHostOps), \
"RemoteEnv requires a service that supports host operations"
assert self._service is not None and isinstance(
self._service, SupportsHostOps
), "RemoteEnv requires a service that supports host operations"
self._host_service: SupportsHostOps = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
"""
Check if the environment is ready and set up the application
and benchmarks on a remote host.
Check if the environment is ready and set up the application and benchmarks on a
remote host.
Parameters
----------
@ -143,9 +152,7 @@ class RemoteEnv(ScriptEnv):
return (status, timestamp, output)
def teardown(self) -> None:
"""
Clean up and shut down the remote environment.
"""
"""Clean up and shut down the remote environment."""
if self._script_teardown:
_LOG.info("Remote teardown: %s", self)
(status, _timestamp, _output) = self._remote_exec(self._script_teardown)
@ -170,7 +177,10 @@ class RemoteEnv(ScriptEnv):
env_params = self._get_env_params()
_LOG.debug("Submit script: %s with %s", self, env_params)
(status, output) = self._remote_exec_service.remote_exec(
script, config=self._params, env_params=env_params)
script,
config=self._params,
env_params=env_params,
)
_LOG.debug("Script submitted: %s %s :: %s", self, status, output)
if status in {Status.PENDING, Status.SUCCEEDED}:
(status, output) = self._remote_exec_service.get_remote_exec_results(output)

Просмотреть файл

@ -2,13 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Cloud-based (configurable) SaaS environment.
"""
from typing import Optional
"""Cloud-based (configurable) SaaS environment."""
import logging
from typing import Optional
from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
@ -20,17 +17,17 @@ _LOG = logging.getLogger(__name__)
class SaaSEnv(Environment):
"""
Cloud-based (configurable) SaaS environment.
"""
"""Cloud-based (configurable) SaaS environment."""
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment for (configurable) cloud-based SaaS instance.
@ -51,15 +48,22 @@ class SaaSEnv(Environment):
An optional service object
(e.g., providing methods to configure the remote service).
"""
super().__init__(name=name, config=config, global_config=global_config,
tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
assert self._service is not None and isinstance(self._service, SupportsHostOps), \
"RemoteEnv requires a service that supports host operations"
assert self._service is not None and isinstance(
self._service, SupportsHostOps
), "RemoteEnv requires a service that supports host operations"
self._host_service: SupportsHostOps = self._service
assert self._service is not None and isinstance(self._service, SupportsRemoteConfig), \
"SaaSEnv requires a service that supports remote host configuration API"
assert self._service is not None and isinstance(
self._service, SupportsRemoteConfig
), "SaaSEnv requires a service that supports remote host configuration API"
self._config_service: SupportsRemoteConfig = self._service
def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
@ -85,7 +89,9 @@ class SaaSEnv(Environment):
return False
(status, _) = self._config_service.configure(
self._params, self._tunable_params.get_param_values())
self._params,
self._tunable_params.get_param_values(),
)
if not status.is_succeeded():
return False
@ -94,7 +100,7 @@ class SaaSEnv(Environment):
return False
# Azure Flex DB instances currently require a VM reboot after reconfiguration.
if res.get('isConfigPendingRestart') or res.get('isConfigPendingReboot'):
if res.get("isConfigPendingRestart") or res.get("isConfigPendingReboot"):
_LOG.info("Restarting: %s", self)
(status, params) = self._host_service.restart_host(self._params)
if status.is_pending():

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
"Remote" VM (Host) Environment.
"""
"""Remote VM (Host) Environment."""
import logging

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base scriptable benchmark environment.
"""
"""Base scriptable benchmark environment."""
import abc
import logging
@ -15,26 +13,25 @@ from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import try_parse_val
_LOG = logging.getLogger(__name__)
class ScriptEnv(Environment, metaclass=abc.ABCMeta):
"""
Base Environment that runs scripts for setup/run/teardown.
"""
"""Base Environment that runs scripts for setup/run/teardown."""
_RE_INVALID = re.compile(r"[^a-zA-Z0-9_]")
def __init__(self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None):
def __init__(
self,
*,
name: str,
config: dict,
global_config: Optional[dict] = None,
tunables: Optional[TunableGroups] = None,
service: Optional[Service] = None,
):
"""
Create a new environment for script execution.
@ -64,19 +61,29 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta):
An optional service object (e.g., providing methods to
deploy or reboot a VM, etc.).
"""
super().__init__(name=name, config=config, global_config=global_config,
tunables=tunables, service=service)
super().__init__(
name=name,
config=config,
global_config=global_config,
tunables=tunables,
service=service,
)
self._script_setup = self.config.get("setup")
self._script_run = self.config.get("run")
self._script_teardown = self.config.get("teardown")
self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", [])
self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {})
self._shell_env_params_rename: Dict[str, str] = self.config.get(
"shell_env_params_rename", {}
)
results_stdout_pattern = self.config.get("results_stdout_pattern")
self._results_stdout_pattern: Optional[re.Pattern[str]] = \
re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None
self._results_stdout_pattern: Optional[re.Pattern[str]] = (
re.compile(results_stdout_pattern, flags=re.MULTILINE)
if results_stdout_pattern
else None
)
def _get_env_params(self, restrict: bool = True) -> Dict[str, str]:
"""
@ -117,4 +124,6 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta):
if not self._results_stdout_pattern:
return {}
_LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout)
return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)}
return {
key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)
}

Просмотреть файл

@ -2,17 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Enum for the status of the benchmark/environment.
"""
"""Enum for the status of the benchmark/environment."""
import enum
class Status(enum.Enum):
"""
Enum for the status of the benchmark/environment.
"""
"""Enum for the status of the benchmark/environment."""
UNKNOWN = 0
PENDING = 1
@ -24,9 +20,7 @@ class Status(enum.Enum):
TIMED_OUT = 7
def is_good(self) -> bool:
"""
Check if the status of the benchmark/environment is good.
"""
"""Check if the status of the benchmark/environment is good."""
return self in {
Status.PENDING,
Status.READY,
@ -35,9 +29,8 @@ class Status(enum.Enum):
}
def is_completed(self) -> bool:
"""
Check if the status of the benchmark/environment is
one of {SUCCEEDED, CANCELED, FAILED, TIMED_OUT}.
"""Check if the status of the benchmark/environment is one of {SUCCEEDED,
CANCELED, FAILED, TIMED_OUT}.
"""
return self in {
Status.SUCCEEDED,
@ -47,37 +40,25 @@ class Status(enum.Enum):
}
def is_pending(self) -> bool:
"""
Check if the status of the benchmark/environment is PENDING.
"""
"""Check if the status of the benchmark/environment is PENDING."""
return self == Status.PENDING
def is_ready(self) -> bool:
"""
Check if the status of the benchmark/environment is READY.
"""
"""Check if the status of the benchmark/environment is READY."""
return self == Status.READY
def is_succeeded(self) -> bool:
"""
Check if the status of the benchmark/environment is SUCCEEDED.
"""
"""Check if the status of the benchmark/environment is SUCCEEDED."""
return self == Status.SUCCEEDED
def is_failed(self) -> bool:
"""
Check if the status of the benchmark/environment is FAILED.
"""
"""Check if the status of the benchmark/environment is FAILED."""
return self == Status.FAILED
def is_canceled(self) -> bool:
"""
Check if the status of the benchmark/environment is CANCELED.
"""
"""Check if the status of the benchmark/environment is CANCELED."""
return self == Status.CANCELED
def is_timed_out(self) -> bool:
"""
Check if the status of the benchmark/environment is TIMED_OUT.
"""
"""Check if the status of the benchmark/environment is TIMED_OUT."""
return self == Status.FAILED

Просмотреть файл

@ -2,25 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
EventLoopContext class definition.
"""
from asyncio import AbstractEventLoop
from concurrent.futures import Future
from typing import Any, Coroutine, Optional, TypeVar
from threading import Lock as ThreadLock, Thread
"""EventLoopContext class definition."""
import asyncio
import logging
import sys
from asyncio import AbstractEventLoop
from concurrent.futures import Future
from threading import Lock as ThreadLock
from threading import Thread
from typing import Any, Coroutine, Optional, TypeVar
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name
CoroReturnType = TypeVar("CoroReturnType") # pylint: disable=invalid-name
if sys.version_info >= (3, 9):
FutureReturnType: TypeAlias = Future[CoroReturnType]
else:
@ -31,15 +29,15 @@ _LOG = logging.getLogger(__name__)
class EventLoopContext:
"""
EventLoopContext encapsulates a background thread for asyncio event
loop processing as an aid for context managers.
EventLoopContext encapsulates a background thread for asyncio event loop processing
as an aid for context managers.
There is generally only expected to be one of these, either as a base
class instance if it's specific to that functionality or for the full
mlos_bench process to support parallel trial runners, for instance.
There is generally only expected to be one of these, either as a base class instance
if it's specific to that functionality or for the full mlos_bench process to support
parallel trial runners, for instance.
It's enter() and exit() routines are expected to be called from the
caller's context manager routines (e.g., __enter__ and __exit__).
It's enter() and exit() routines are expected to be called from the caller's context
manager routines (e.g., __enter__ and __exit__).
"""
def __init__(self) -> None:
@ -49,17 +47,13 @@ class EventLoopContext:
self._event_loop_thread_refcnt: int = 0
def _run_event_loop(self) -> None:
"""
Runs the asyncio event loop in a background thread.
"""
"""Runs the asyncio event loop in a background thread."""
assert self._event_loop is not None
asyncio.set_event_loop(self._event_loop)
self._event_loop.run_forever()
def enter(self) -> None:
"""
Manages starting the background thread for event loop processing.
"""
"""Manages starting the background thread for event loop processing."""
# Start the background thread if it's not already running.
with self._event_loop_thread_lock:
if not self._event_loop_thread:
@ -74,9 +68,7 @@ class EventLoopContext:
self._event_loop_thread_refcnt += 1
def exit(self) -> None:
"""
Manages cleaning up the background thread for event loop processing.
"""
"""Manages cleaning up the background thread for event loop processing."""
with self._event_loop_thread_lock:
self._event_loop_thread_refcnt -= 1
assert self._event_loop_thread_refcnt >= 0
@ -92,8 +84,8 @@ class EventLoopContext:
def run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType:
"""
Runs the given coroutine in the background event loop thread and
returns a Future that can be used to wait for the result.
Runs the given coroutine in the background event loop thread and returns a
Future that can be used to wait for the result.
Parameters
----------

Просмотреть файл

@ -3,8 +3,8 @@
# Licensed under the MIT License.
#
"""
A helper class to load the configuration files, parse the command line parameters,
and instantiate the main components of mlos_bench system.
A helper class to load the configuration files, parse the command line parameters, and
instantiate the main components of mlos_bench system.
It is used in `mlos_bench.run` module to run the benchmark/optimizer from the
command line.
@ -13,34 +13,26 @@ command line.
import argparse
import logging
import sys
from typing import Any, Dict, Iterable, List, Optional, Tuple
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.dict_templater import DictTemplater
from mlos_bench.util import try_parse_val
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.environments.base_environment import Environment
from mlos_bench.optimizers.base_optimizer import Optimizer
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
from mlos_bench.storage.base_storage import Storage
from mlos_bench.services.base_service import Service
from mlos_bench.services.local.local_exec import LocalExecService
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.schedulers.base_scheduler import Scheduler
from mlos_bench.services.base_service import Service
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.services.local.local_exec import LocalExecService
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.storage.base_storage import Storage
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import try_parse_val
_LOG_LEVEL = logging.INFO
_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s'
_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s"
logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT)
_LOG = logging.getLogger(__name__)
@ -48,22 +40,22 @@ _LOG = logging.getLogger(__name__)
class Launcher:
# pylint: disable=too-few-public-methods,too-many-instance-attributes
"""
Command line launcher for mlos_bench and mlos_core.
"""
"""Command line launcher for mlos_bench and mlos_core."""
def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None):
# pylint: disable=too-many-statements
_LOG.info("Launch: %s", description)
epilog = """
Additional --key=value pairs can be specified to augment or override values listed in --globals.
Other required_args values can also be pulled from shell environment variables.
Additional --key=value pairs can be specified to augment or override
values listed in --globals.
Other required_args values can also be pulled from shell environment
variables.
For additional details, please see the website or the README.md files in the source tree:
For additional details, please see the website or the README.md files in
the source tree:
<https://github.com/microsoft/MLOS/tree/main/mlos_bench/>
"""
parser = argparse.ArgumentParser(description=f"{description} : {long_text}",
epilog=epilog)
parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog)
(args, args_rest) = self._parse_args(parser, argv)
# Bootstrap config loader: command line takes priority.
@ -101,16 +93,18 @@ class Launcher:
args_rest,
{key: val for (key, val) in config.items() if key not in vars(args)},
)
# experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI.
# experiment_id is generally taken from --globals files, but we also allow
# overriding it on the CLI.
# It's useful to keep it there explicitly mostly for the --help output.
if args.experiment_id:
self.global_config['experiment_id'] = args.experiment_id
# trial_config_repeat_count is a scheduler property but it's convenient to set it via command line
self.global_config["experiment_id"] = args.experiment_id
# trial_config_repeat_count is a scheduler property but it's convenient to
# set it via command line
if args.trial_config_repeat_count:
self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count
# Ensure that the trial_id is present since it gets used by some other
# configs but is typically controlled by the run optimize loop.
self.global_config.setdefault('trial_id', 1)
self.global_config.setdefault("trial_id", 1)
self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True)
assert isinstance(self.global_config, dict)
@ -118,24 +112,31 @@ class Launcher:
# --service cli args should override the config file values.
service_files: List[str] = config.get("services", []) + (args.service or [])
assert isinstance(self._parent_service, SupportsConfigLoading)
self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service)
self._parent_service = self._parent_service.load_services(
service_files,
self.global_config,
self._parent_service,
)
env_path = args.environment or config.get("environment")
if not env_path:
_LOG.error("No environment config specified.")
parser.error("At least the Environment config must be specified." +
" Run `mlos_bench --help` and consult `README.md` for more info.")
parser.error(
"At least the Environment config must be specified."
+ " Run `mlos_bench --help` and consult `README.md` for more info."
)
self.root_env_config = self._config_loader.resolve_path(env_path)
self.environment: Environment = self._config_loader.load_environment(
self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service)
self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service
)
_LOG.info("Init environment: %s", self.environment)
# NOTE: Init tunable values *after* the Environment, but *before* the Optimizer
self.tunables = self._init_tunable_values(
args.random_init or config.get("random_init", False),
config.get("random_seed") if args.random_seed is None else args.random_seed,
config.get("tunable_values", []) + (args.tunable_values or [])
config.get("tunable_values", []) + (args.tunable_values or []),
)
_LOG.info("Init tunables: %s", self.tunables)
@ -145,106 +146,170 @@ class Launcher:
self.storage = self._load_storage(args.storage or config.get("storage"))
_LOG.info("Init storage: %s", self.storage)
self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True))
self.teardown: bool = (
bool(args.teardown)
if args.teardown is not None
else bool(config.get("teardown", True))
)
self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler"))
_LOG.info("Init scheduler: %s", self.scheduler)
@property
def config_loader(self) -> ConfigPersistenceService:
"""
Get the config loader service.
"""
"""Get the config loader service."""
return self._config_loader
@property
def service(self) -> Service:
"""
Get the parent service.
"""
"""Get the parent service."""
return self._parent_service
@staticmethod
def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]:
"""
Parse the command line arguments.
"""
def _parse_args(
parser: argparse.ArgumentParser,
argv: Optional[List[str]],
) -> Tuple[argparse.Namespace, List[str]]:
"""Parse the command line arguments."""
parser.add_argument(
'--config', required=False,
help='Main JSON5 configuration file. Its keys are the same as the' +
' command line options and can be overridden by the latter.\n' +
'\n' +
' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' +
' for additional config examples for this and other arguments.')
"--config",
required=False,
help="Main JSON5 configuration file. Its keys are the same as the"
+ " command line options and can be overridden by the latter.\n"
+ "\n"
+ " See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ "
+ " for additional config examples for this and other arguments.",
)
parser.add_argument(
'--log_file', '--log-file', required=False,
help='Path to the log file. Use stdout if omitted.')
"--log_file",
"--log-file",
required=False,
help="Path to the log file. Use stdout if omitted.",
)
parser.add_argument(
'--log_level', '--log-level', required=False, type=str,
help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' +
' Set to DEBUG for debug, WARNING for warnings only.')
"--log_level",
"--log-level",
required=False,
type=str,
help=f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}."
+ " Set to DEBUG for debug, WARNING for warnings only.",
)
parser.add_argument(
'--config_path', '--config-path', '--config-paths', '--config_paths',
nargs="+", action='extend', required=False,
help='One or more locations of JSON config files.')
"--config_path",
"--config-path",
"--config-paths",
"--config_paths",
nargs="+",
action="extend",
required=False,
help="One or more locations of JSON config files.",
)
parser.add_argument(
'--service', '--services',
nargs='+', action='extend', required=False,
help='Path to JSON file with the configuration of the service(s) for environment(s) to use.')
"--service",
"--services",
nargs="+",
action="extend",
required=False,
help=(
"Path to JSON file with the configuration "
"of the service(s) for environment(s) to use."
),
)
parser.add_argument(
'--environment', required=False,
help='Path to JSON file with the configuration of the benchmarking environment(s).')
"--environment",
required=False,
help="Path to JSON file with the configuration of the benchmarking environment(s).",
)
parser.add_argument(
'--optimizer', required=False,
help='Path to the optimizer configuration file. If omitted, run' +
' a single trial with default (or specified in --tunable_values).')
"--optimizer",
required=False,
help="Path to the optimizer configuration file. If omitted, run"
+ " a single trial with default (or specified in --tunable_values).",
)
parser.add_argument(
'--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int,
help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.')
"--trial_config_repeat_count",
"--trial-config-repeat-count",
required=False,
type=int,
help=(
"Number of times to repeat each config. "
"Default is 1 trial per config, though more may be advised."
),
)
parser.add_argument(
'--scheduler', required=False,
help='Path to the scheduler configuration file. By default, use' +
' a single worker synchronous scheduler.')
"--scheduler",
required=False,
help="Path to the scheduler configuration file. By default, use"
+ " a single worker synchronous scheduler.",
)
parser.add_argument(
'--storage', required=False,
help='Path to the storage configuration file.' +
' If omitted, use the ephemeral in-memory SQL storage.')
"--storage",
required=False,
help="Path to the storage configuration file."
+ " If omitted, use the ephemeral in-memory SQL storage.",
)
parser.add_argument(
'--random_init', '--random-init', required=False, default=False,
dest='random_init', action='store_true',
help='Initialize tunables with random values. (Before applying --tunable_values).')
"--random_init",
"--random-init",
required=False,
default=False,
dest="random_init",
action="store_true",
help="Initialize tunables with random values. (Before applying --tunable_values).",
)
parser.add_argument(
'--random_seed', '--random-seed', required=False, type=int,
help='Seed to use with --random_init')
"--random_seed",
"--random-seed",
required=False,
type=int,
help="Seed to use with --random_init",
)
parser.add_argument(
'--tunable_values', '--tunable-values', nargs="+", action='extend', required=False,
help='Path to one or more JSON files that contain values of the tunable' +
' parameters. This can be used for a single trial (when no --optimizer' +
' is specified) or as default values for the first run in optimization.')
"--tunable_values",
"--tunable-values",
nargs="+",
action="extend",
required=False,
help="Path to one or more JSON files that contain values of the tunable"
+ " parameters. This can be used for a single trial (when no --optimizer"
+ " is specified) or as default values for the first run in optimization.",
)
parser.add_argument(
'--globals', nargs="+", action='extend', required=False,
help='Path to one or more JSON files that contain additional' +
' [private] parameters of the benchmarking environment.')
"--globals",
nargs="+",
action="extend",
required=False,
help="Path to one or more JSON files that contain additional"
+ " [private] parameters of the benchmarking environment.",
)
parser.add_argument(
'--no_teardown', '--no-teardown', required=False, default=None,
dest='teardown', action='store_false',
help='Disable teardown of the environment after the benchmark.')
"--no_teardown",
"--no-teardown",
required=False,
default=None,
dest="teardown",
action="store_false",
help="Disable teardown of the environment after the benchmark.",
)
parser.add_argument(
'--experiment_id', '--experiment-id', required=False, default=None,
"--experiment_id",
"--experiment-id",
required=False,
default=None,
help="""
Experiment ID to use for the benchmark.
If omitted, the value from the --cli config or --globals is used.
@ -254,7 +319,7 @@ class Launcher:
changes are made to config files, scripts, versions, etc.
This is left as a manual operation as detection of what is
"incompatible" is not easily automatable across systems.
"""
""",
)
# By default we use the command line arguments, but allow the caller to
@ -267,9 +332,7 @@ class Launcher:
@staticmethod
def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]:
"""
Helper function to parse global key/value pairs from the command line.
"""
"""Helper function to parse global key/value pairs from the command line."""
_LOG.debug("Extra args: %s", cmdline)
config: Dict[str, TunableValue] = {}
@ -296,16 +359,17 @@ class Launcher:
_LOG.debug("Parsed config: %s", config)
return config
def _load_config(self,
args_globals: Iterable[str],
config_path: Iterable[str],
args_rest: Iterable[str],
global_config: Dict[str, Any]) -> Dict[str, Any]:
def _load_config(
self,
args_globals: Iterable[str],
config_path: Iterable[str],
args_rest: Iterable[str],
global_config: Dict[str, Any],
) -> Dict[str, Any]:
"""Get key/value pairs of the global configuration parameters from the specified
config files (if any) and command line arguments.
"""
Get key/value pairs of the global configuration parameters
from the specified config files (if any) and command line arguments.
"""
for config_file in (args_globals or []):
for config_file in args_globals or []:
conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS)
assert isinstance(conf, dict)
global_config.update(conf)
@ -314,19 +378,24 @@ class Launcher:
global_config["config_path"] = config_path
return global_config
def _init_tunable_values(self, random_init: bool, seed: Optional[int],
args_tunables: Optional[str]) -> TunableGroups:
"""
Initialize the tunables and load key/value pairs of the tunable values
from given JSON files, if specified.
def _init_tunable_values(
self,
random_init: bool,
seed: Optional[int],
args_tunables: Optional[str],
) -> TunableGroups:
"""Initialize the tunables and load key/value pairs of the tunable values from
given JSON files, if specified.
"""
tunables = self.environment.tunable_params
_LOG.debug("Init tunables: default = %s", tunables)
if random_init:
tunables = MockOptimizer(
tunables=tunables, service=None,
config={"start_with_defaults": False, "seed": seed}).suggest()
tunables=tunables,
service=None,
config={"start_with_defaults": False, "seed": seed},
).suggest()
_LOG.debug("Init tunables: random = %s", tunables)
if args_tunables is not None:
@ -340,50 +409,64 @@ class Launcher:
def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer:
"""
Instantiate the Optimizer object from JSON config file, if specified
in the --optimizer command line option. If config file not specified,
create a one-shot optimizer to run a single benchmark trial.
Instantiate the Optimizer object from JSON config file, if specified in the
--optimizer command line option.
If config file not specified, create a one-shot optimizer to run a single
benchmark trial.
"""
if args_optimizer is None:
# global_config may contain additional properties, so we need to
# strip those out before instantiating the basic oneshot optimizer.
config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS}
return OneShotOptimizer(
self.tunables, config=config, service=self._parent_service)
config = {
key: val
for key, val in self.global_config.items()
if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS
}
return OneShotOptimizer(self.tunables, config=config, service=self._parent_service)
class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER)
assert isinstance(class_config, Dict)
optimizer = self._config_loader.build_optimizer(tunables=self.tunables,
service=self._parent_service,
config=class_config,
global_config=self.global_config)
optimizer = self._config_loader.build_optimizer(
tunables=self.tunables,
service=self._parent_service,
config=class_config,
global_config=self.global_config,
)
return optimizer
def _load_storage(self, args_storage: Optional[str]) -> Storage:
"""
Instantiate the Storage object from JSON file provided in the --storage
command line parameter. If omitted, create an ephemeral in-memory SQL
storage instead.
Instantiate the Storage object from JSON file provided in the --storage command
line parameter.
If omitted, create an ephemeral in-memory SQL storage instead.
"""
if args_storage is None:
# pylint: disable=import-outside-toplevel
from mlos_bench.storage.sql.storage import SqlStorage
return SqlStorage(service=self._parent_service,
config={
"drivername": "sqlite",
"database": ":memory:",
"lazy_schema_create": True,
})
return SqlStorage(
service=self._parent_service,
config={
"drivername": "sqlite",
"database": ":memory:",
"lazy_schema_create": True,
},
)
class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE)
assert isinstance(class_config, Dict)
storage = self._config_loader.build_storage(service=self._parent_service,
config=class_config,
global_config=self.global_config)
storage = self._config_loader.build_storage(
service=self._parent_service,
config=class_config,
global_config=self.global_config,
)
return storage
def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler:
"""
Instantiate the Scheduler object from JSON file provided in the --scheduler
command line parameter.
Create a simple synchronous single-threaded scheduler if omitted.
"""
# Set `teardown` for scheduler only to prevent conflicts with other configs.
@ -392,6 +475,7 @@ class Launcher:
if args_scheduler is None:
# pylint: disable=import-outside-toplevel
from mlos_bench.schedulers.sync_scheduler import SyncScheduler
return SyncScheduler(
# All config values can be overridden from global config
config={

Просмотреть файл

@ -2,18 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Interfaces and wrapper classes for optimizers to be used in Autotune.
"""
"""Interfaces and wrapper classes for optimizers to be used in Autotune."""
from mlos_bench.optimizers.base_optimizer import Optimizer
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
__all__ = [
'Optimizer',
'MockOptimizer',
'OneShotOptimizer',
'MlosCoreOptimizer',
"Optimizer",
"MockOptimizer",
"OneShotOptimizer",
"MlosCoreOptimizer",
]

Просмотреть файл

@ -2,34 +2,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base class for an interface between the benchmarking framework
and mlos_core optimizers.
"""Base class for an interface between the benchmarking framework and mlos_core
optimizers.
"""
import logging
from abc import ABCMeta, abstractmethod
from distutils.util import strtobool # pylint: disable=deprecated-module
from distutils.util import strtobool # pylint: disable=deprecated-module
from types import TracebackType
from typing import Dict, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Literal
from ConfigSpace import ConfigurationSpace
from typing_extensions import Literal
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.base_service import Service
from mlos_bench.environments.status import Status
from mlos_bench.optimizers.convert_configspace import tunable_groups_to_configspace
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.convert_configspace import tunable_groups_to_configspace
_LOG = logging.getLogger(__name__)
class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
"""
An abstract interface between the benchmarking framework and mlos_core optimizers.
class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
"""An abstract interface between the benchmarking framework and mlos_core
optimizers.
"""
# See Also: mlos_bench/mlos_bench/config/schemas/optimizers/optimizer-schema.json
@ -40,13 +38,16 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
"start_with_defaults",
}
def __init__(self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None):
def __init__(
self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None,
):
"""
Create a new optimizer for the given configuration space defined by the tunables.
Create a new optimizer for the given configuration space defined by the
tunables.
Parameters
----------
@ -68,19 +69,20 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
self._seed = int(config.get("seed", 42))
self._in_context = False
experiment_id = self._global_config.get('experiment_id')
experiment_id = self._global_config.get("experiment_id")
self.experiment_id = str(experiment_id).strip() if experiment_id else None
self._iter = 0
# If False, use the optimizer to suggest the initial configuration;
# if True (default), use the already initialized values for the first iteration.
self._start_with_defaults: bool = bool(
strtobool(str(self._config.pop('start_with_defaults', True))))
self._max_iter = int(self._config.pop('max_suggestions', 100))
strtobool(str(self._config.pop("start_with_defaults", True)))
)
self._max_iter = int(self._config.pop("max_suggestions", 100))
opt_targets: Dict[str, str] = self._config.pop('optimization_targets', {'score': 'min'})
opt_targets: Dict[str, str] = self._config.pop("optimization_targets", {"score": "min"})
self._opt_targets: Dict[str, Literal[1, -1]] = {}
for (opt_target, opt_dir) in opt_targets.items():
for opt_target, opt_dir in opt_targets.items():
if opt_dir == "min":
self._opt_targets[opt_target] = 1
elif opt_dir == "max":
@ -89,10 +91,9 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
raise ValueError(f"Invalid optimization direction: {opt_dir} for {opt_target}")
def _validate_json_config(self, config: dict) -> None:
"""
Reconstructs a basic json config that this class might have been
instantiated from in order to validate configs provided outside the
file loading mechanism.
"""Reconstructs a basic json config that this class might have been instantiated
from in order to validate configs provided outside the file loading
mechanism.
"""
json_config: dict = {
"class": self.__class__.__module__ + "." + self.__class__.__name__,
@ -108,21 +109,20 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
)
return f"{self.name}({opt_targets},config={self._config})"
def __enter__(self) -> 'Optimizer':
"""
Enter the optimizer's context.
"""
def __enter__(self) -> "Optimizer":
"""Enter the optimizer's context."""
_LOG.debug("Optimizer START :: %s", self)
assert not self._in_context
self._in_context = True
return self
def __exit__(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
"""
Exit the context of the optimizer.
"""
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
"""Exit the context of the optimizer."""
if ex_val is None:
_LOG.debug("Optimizer END :: %s", self)
else:
@ -154,15 +154,14 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
@property
def seed(self) -> int:
"""
The random seed for the optimizer.
"""
"""The random seed for the optimizer."""
return self._seed
@property
def start_with_defaults(self) -> bool:
"""
Return True if the optimizer should start with the default values.
Note: This parameter is mutable and will be reset to False after the
defaults are first suggested.
"""
@ -198,16 +197,16 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
@property
def name(self) -> str:
"""
The name of the optimizer. We save this information in
mlos_bench storage to track the source of each configuration.
The name of the optimizer.
We save this information in mlos_bench storage to track the source of each
configuration.
"""
return self.__class__.__name__
@property
def targets(self) -> Dict[str, Literal['min', 'max']]:
"""
A dictionary of {target: direction} of optimization targets.
"""
def targets(self) -> Dict[str, Literal["min", "max"]]:
"""A dictionary of {target: direction} of optimization targets."""
return {
opt_target: "min" if opt_dir == 1 else "max"
for (opt_target, opt_dir) in self._opt_targets.items()
@ -215,16 +214,18 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
@property
def supports_preload(self) -> bool:
"""
Return True if the optimizer supports pre-loading the data from previous experiments.
"""Return True if the optimizer supports pre-loading the data from previous
experiments.
"""
return True
@abstractmethod
def bulk_register(self,
configs: Sequence[dict],
scores: Sequence[Optional[Dict[str, TunableValue]]],
status: Optional[Sequence[Status]] = None) -> bool:
def bulk_register(
self,
configs: Sequence[dict],
scores: Sequence[Optional[Dict[str, TunableValue]]],
status: Optional[Sequence[Status]] = None,
) -> bool:
"""
Pre-load the optimizer with the bulk data from previous experiments.
@ -242,8 +243,12 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
is_not_empty : bool
True if there is data to register, false otherwise.
"""
_LOG.info("Update the optimizer with: %d configs, %d scores, %d status values",
len(configs or []), len(scores or []), len(status or []))
_LOG.info(
"Update the optimizer with: %d configs, %d scores, %d status values",
len(configs or []),
len(scores or []),
len(status or []),
)
if len(configs or []) != len(scores or []):
raise ValueError("Numbers of configs and scores do not match.")
if status is not None and len(configs or []) != len(status or []):
@ -256,9 +261,8 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
def suggest(self) -> TunableGroups:
"""
Generate the next suggestion.
Base class' implementation increments the iteration count
and returns the current values of the tunables.
Generate the next suggestion. Base class' implementation increments the
iteration count and returns the current values of the tunables.
Returns
-------
@ -272,8 +276,12 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
return self._tunables.copy()
@abstractmethod
def register(self, tunables: TunableGroups, status: Status,
score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
def register(
self,
tunables: TunableGroups,
status: Status,
score: Optional[Dict[str, TunableValue]] = None,
) -> Optional[Dict[str, float]]:
"""
Register the observation for the given configuration.
@ -294,18 +302,25 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
Benchmark scores extracted (and possibly transformed)
from the dataframe that's being MINIMIZED.
"""
_LOG.info("Iteration %d :: Register: %s = %s score: %s",
self._iter, tunables, status, score)
_LOG.info(
"Iteration %d :: Register: %s = %s score: %s",
self._iter,
tunables,
status,
score,
)
if status.is_succeeded() == (score is None): # XOR
raise ValueError("Status and score must be consistent.")
return self._get_scores(status, score)
def _get_scores(self, status: Status,
scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]]
) -> Optional[Dict[str, float]]:
def _get_scores(
self,
status: Status,
scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]],
) -> Optional[Dict[str, float]]:
"""
Extract a scalar benchmark score from the dataframe.
Change the sign if we are maximizing.
Extract a scalar benchmark score from the dataframe. Change the sign if we are
maximizing.
Parameters
----------
@ -331,7 +346,7 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
assert scores is not None
target_metrics: Dict[str, float] = {}
for (opt_target, opt_dir) in self._opt_targets.items():
for opt_target, opt_dir in self._opt_targets.items():
val = scores[opt_target]
assert val is not None
target_metrics[opt_target] = float(val) * opt_dir
@ -341,12 +356,15 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr
def not_converged(self) -> bool:
"""
Return True if not converged, False otherwise.
Base implementation just checks the iteration count.
"""
return self._iter < self._max_iter
@abstractmethod
def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
def get_best_observation(
self,
) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
"""
Get the best observation so far.

Просмотреть файл

@ -2,12 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Functions to convert TunableGroups to ConfigSpace for use with the mlos_core optimizers.
"""Functions to convert TunableGroups to ConfigSpace for use with the mlos_core
optimizers.
"""
import logging
from typing import Dict, List, Optional, Tuple, Union
from ConfigSpace import (
@ -21,9 +20,10 @@ from ConfigSpace import (
Normal,
Uniform,
)
from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import try_parse_val, nullable
from mlos_bench.util import nullable, try_parse_val
_LOG = logging.getLogger(__name__)
@ -31,6 +31,7 @@ _LOG = logging.getLogger(__name__)
class TunableValueKind:
"""
Enum for the kind of the tunable value (special or not).
It is not a true enum because ConfigSpace wants string values.
"""
@ -40,15 +41,16 @@ class TunableValueKind:
def _normalize_weights(weights: List[float]) -> List[float]:
"""
Helper function for normalizing weights to probabilities.
"""
"""Helper function for normalizing weights to probabilities."""
total = sum(weights)
return [w / total for w in weights]
def _tunable_to_configspace(
tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> ConfigurationSpace:
tunable: Tunable,
group_name: Optional[str] = None,
cost: int = 0,
) -> ConfigurationSpace:
"""
Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects,
wrapped in a ConfigurationSpace for composability.
@ -71,14 +73,17 @@ def _tunable_to_configspace(
meta = {"group": group_name, "cost": cost} # {"scaling": ""}
if tunable.type == "categorical":
return ConfigurationSpace({
tunable.name: CategoricalHyperparameter(
name=tunable.name,
choices=tunable.categories,
weights=_normalize_weights(tunable.weights) if tunable.weights else None,
default_value=tunable.default,
meta=meta)
})
return ConfigurationSpace(
{
tunable.name: CategoricalHyperparameter(
name=tunable.name,
choices=tunable.categories,
weights=_normalize_weights(tunable.weights) if tunable.weights else None,
default_value=tunable.default,
meta=meta,
)
}
)
distribution: Union[Uniform, Normal, Beta, None] = None
if tunable.distribution == "uniform":
@ -86,12 +91,12 @@ def _tunable_to_configspace(
elif tunable.distribution == "normal":
distribution = Normal(
mu=tunable.distribution_params["mu"],
sigma=tunable.distribution_params["sigma"]
sigma=tunable.distribution_params["sigma"],
)
elif tunable.distribution == "beta":
distribution = Beta(
alpha=tunable.distribution_params["alpha"],
beta=tunable.distribution_params["beta"]
beta=tunable.distribution_params["beta"],
)
elif tunable.distribution is not None:
raise TypeError(f"Invalid Distribution Type: {tunable.distribution}")
@ -103,22 +108,26 @@ def _tunable_to_configspace(
log=bool(tunable.is_log),
q=nullable(int, tunable.quantization),
distribution=distribution,
default=(int(tunable.default)
if tunable.in_range(tunable.default) and tunable.default is not None
else None),
meta=meta
default=(
int(tunable.default)
if tunable.in_range(tunable.default) and tunable.default is not None
else None
),
meta=meta,
)
elif tunable.type == "float":
range_hp = Float(
name=tunable.name,
bounds=tunable.range,
log=bool(tunable.is_log),
q=tunable.quantization, # type: ignore[arg-type]
q=tunable.quantization, # type: ignore[arg-type]
distribution=distribution, # type: ignore[arg-type]
default=(float(tunable.default)
if tunable.in_range(tunable.default) and tunable.default is not None
else None),
meta=meta
default=(
float(tunable.default)
if tunable.in_range(tunable.default) and tunable.default is not None
else None
),
meta=meta,
)
else:
raise TypeError(f"Invalid Parameter Type: {tunable.type}")
@ -136,31 +145,38 @@ def _tunable_to_configspace(
# Create three hyperparameters: one for regular values,
# one for special values, and one to choose between the two.
(special_name, type_name) = special_param_names(tunable.name)
conf_space = ConfigurationSpace({
tunable.name: range_hp,
special_name: CategoricalHyperparameter(
name=special_name,
choices=tunable.special,
weights=special_weights,
default_value=tunable.default if tunable.default in tunable.special else None,
meta=meta
),
type_name: CategoricalHyperparameter(
name=type_name,
choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE],
weights=switch_weights,
default_value=TunableValueKind.SPECIAL,
),
})
conf_space.add_condition(EqualsCondition(
conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL))
conf_space.add_condition(EqualsCondition(
conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE))
conf_space = ConfigurationSpace(
{
tunable.name: range_hp,
special_name: CategoricalHyperparameter(
name=special_name,
choices=tunable.special,
weights=special_weights,
default_value=tunable.default if tunable.default in tunable.special else None,
meta=meta,
),
type_name: CategoricalHyperparameter(
name=type_name,
choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE],
weights=switch_weights,
default_value=TunableValueKind.SPECIAL,
),
}
)
conf_space.add_condition(
EqualsCondition(conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL)
)
conf_space.add_condition(
EqualsCondition(conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE)
)
return conf_space
def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = None) -> ConfigurationSpace:
def tunable_groups_to_configspace(
tunables: TunableGroups,
seed: Optional[int] = None,
) -> ConfigurationSpace:
"""
Convert TunableGroups to hyperparameters in ConfigurationSpace.
@ -178,11 +194,16 @@ def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] =
A new ConfigurationSpace instance that corresponds to the input TunableGroups.
"""
space = ConfigurationSpace(seed=seed)
for (tunable, group) in tunables:
for tunable, group in tunables:
space.add_configuration_space(
prefix="", delimiter="",
prefix="",
delimiter="",
configuration_space=_tunable_to_configspace(
tunable, group.name, group.get_current_cost()))
tunable,
group.name,
group.get_current_cost(),
),
)
return space
@ -201,7 +222,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration:
A ConfigSpace Configuration.
"""
values: Dict[str, TunableValue] = {}
for (tunable, _group) in tunables:
for tunable, _group in tunables:
if tunable.special:
(special_name, type_name) = special_param_names(tunable.name)
if tunable.value in tunable.special:
@ -219,13 +240,11 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration:
def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]:
"""
Remove the fields that correspond to special values in ConfigSpace.
In particular, remove and keys suffixes added by `special_param_names`.
"""
data = data.copy()
specials = [
special_param_name_strip(k)
for k in data.keys() if special_param_name_is_temp(k)
]
specials = [special_param_name_strip(k) for k in data.keys() if special_param_name_is_temp(k)]
for k in specials:
(special_name, type_name) = special_param_names(k)
if data[type_name] == TunableValueKind.SPECIAL:
@ -240,8 +259,8 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]:
def special_param_names(name: str) -> Tuple[str, str]:
"""
Generate the names of the auxiliary hyperparameters that correspond
to a tunable that can have special values.
Generate the names of the auxiliary hyperparameters that correspond to a tunable
that can have special values.
NOTE: `!` characters are currently disallowed in Tunable names in order handle this logic.

Просмотреть файл

@ -2,38 +2,35 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Grid search optimizer for mlos_bench.
"""
"""Grid search optimizer for mlos_bench."""
import logging
from typing import Dict, Iterable, Optional, Sequence, Set, Tuple
from typing import Dict, Iterable, Set, Optional, Sequence, Tuple
import numpy as np
import ConfigSpace
import numpy as np
from ConfigSpace.util import generate_grid
from mlos_bench.environments.status import Status
from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
from mlos_bench.services.base_service import Service
_LOG = logging.getLogger(__name__)
class GridSearchOptimizer(TrackBestOptimizer):
"""
Grid search optimizer.
"""
"""Grid search optimizer."""
def __init__(self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None):
def __init__(
self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None,
):
super().__init__(tunables, config, global_config, service)
# Track the grid as a set of tuples of tunable values and reconstruct the
@ -52,11 +49,21 @@ class GridSearchOptimizer(TrackBestOptimizer):
def _sanity_check(self) -> None:
size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables])
if size == np.inf:
raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}")
raise ValueError(
f"Unquantized tunables are not supported for grid search: {self._tunables}"
)
if size > 10000:
_LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables)
_LOG.warning(
"Large number %d of config points requested for grid search: %s",
size,
self._tunables,
)
if size > self._max_iter:
_LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter)
_LOG.warning(
"Grid search size %d, is greater than max iterations %d",
size,
self._max_iter,
)
def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]:
"""
@ -69,12 +76,14 @@ class GridSearchOptimizer(TrackBestOptimizer):
# names instead of the order given by TunableGroups.
configs = [
configspace_data_to_tunable_values(dict(config))
for config in
generate_grid(self.config_space, {
tunable.name: int(tunable.cardinality)
for (tunable, _group) in self._tunables
if tunable.quantization or tunable.type == "int"
})
for config in generate_grid(
self.config_space,
{
tunable.name: int(tunable.cardinality)
for (tunable, _group) in self._tunables
if tunable.quantization or tunable.type == "int"
},
)
]
names = set(tuple(configs.keys()) for configs in configs)
assert len(names) == 1
@ -104,15 +113,17 @@ class GridSearchOptimizer(TrackBestOptimizer):
# See NOTEs above.
return (dict(zip(self._config_keys, config)) for config in self._suggested_configs)
def bulk_register(self,
configs: Sequence[dict],
scores: Sequence[Optional[Dict[str, TunableValue]]],
status: Optional[Sequence[Status]] = None) -> bool:
def bulk_register(
self,
configs: Sequence[dict],
scores: Sequence[Optional[Dict[str, TunableValue]]],
status: Optional[Sequence[Status]] = None,
) -> bool:
if not super().bulk_register(configs, scores, status):
return False
if status is None:
status = [Status.SUCCEEDED] * len(configs)
for (params, score, trial_status) in zip(configs, scores, status):
for params, score, trial_status in zip(configs, scores, status):
tunables = self._tunables.copy().assign(params)
self.register(tunables, trial_status, score)
if _LOG.isEnabledFor(logging.DEBUG):
@ -121,9 +132,7 @@ class GridSearchOptimizer(TrackBestOptimizer):
return True
def suggest(self) -> TunableGroups:
"""
Generate the next grid search suggestion.
"""
"""Generate the next grid search suggestion."""
tunables = super().suggest()
if self._start_with_defaults:
_LOG.info("Use default values for the first trial")
@ -153,20 +162,35 @@ class GridSearchOptimizer(TrackBestOptimizer):
_LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
return tunables
def register(self, tunables: TunableGroups, status: Status,
score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
def register(
self,
tunables: TunableGroups,
status: Status,
score: Optional[Dict[str, TunableValue]] = None,
) -> Optional[Dict[str, float]]:
registered_score = super().register(tunables, status, score)
try:
config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values()))
config = dict(
ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())
)
self._suggested_configs.remove(tuple(config.values()))
except KeyError:
_LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables)
_LOG.warning(
(
"Attempted to remove missing config "
"(previously registered?) from suggested set: %s"
),
tunables,
)
return registered_score
def not_converged(self) -> bool:
if self._iter > self._max_iter:
if bool(self._pending_configs):
_LOG.warning("Exceeded max iterations, but still have %d pending configs: %s",
len(self._pending_configs), list(self._pending_configs.keys()))
_LOG.warning(
"Exceeded max iterations, but still have %d pending configs: %s",
len(self._pending_configs),
list(self._pending_configs.keys()),
)
return False
return bool(self._pending_configs)

Просмотреть файл

@ -2,72 +2,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A wrapper for mlos_core optimizers for mlos_bench.
"""
"""A wrapper for mlos_core optimizers for mlos_bench."""
import logging
import os
from types import TracebackType
from typing import Dict, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Literal
import pandas as pd
from mlos_core.optimizers import (
BaseOptimizer, OptimizerType, OptimizerFactory, SpaceAdapterType, DEFAULT_OPTIMIZER_TYPE
)
from typing_extensions import Literal
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.base_optimizer import Optimizer
from mlos_bench.optimizers.convert_configspace import (
TunableValueKind,
configspace_data_to_tunable_values,
special_param_names,
)
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_core.optimizers import (
DEFAULT_OPTIMIZER_TYPE,
BaseOptimizer,
OptimizerFactory,
OptimizerType,
SpaceAdapterType,
)
_LOG = logging.getLogger(__name__)
class MlosCoreOptimizer(Optimizer):
"""
A wrapper class for the mlos_core optimizers.
"""
"""A wrapper class for the mlos_core optimizers."""
def __init__(self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None):
def __init__(
self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None,
):
super().__init__(tunables, config, global_config, service)
opt_type = getattr(OptimizerType, self._config.pop(
'optimizer_type', DEFAULT_OPTIMIZER_TYPE.name))
opt_type = getattr(
OptimizerType, self._config.pop("optimizer_type", DEFAULT_OPTIMIZER_TYPE.name)
)
if opt_type == OptimizerType.SMAC:
output_directory = self._config.get('output_directory')
output_directory = self._config.get("output_directory")
if output_directory is not None:
# If output_directory is specified, turn it into an absolute path.
self._config['output_directory'] = os.path.abspath(output_directory)
self._config["output_directory"] = os.path.abspath(output_directory)
else:
_LOG.warning("SMAC optimizer output_directory was null. SMAC will use a temporary directory.")
_LOG.warning(
(
"SMAC optimizer output_directory was null. "
"SMAC will use a temporary directory."
)
)
# Make sure max_trials >= max_iterations.
if 'max_trials' not in self._config:
self._config['max_trials'] = self._max_iter
assert int(self._config['max_trials']) >= self._max_iter, \
f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}"
if "max_trials" not in self._config:
self._config["max_trials"] = self._max_iter
assert (
int(self._config["max_trials"]) >= self._max_iter
), f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}"
if 'run_name' not in self._config and self.experiment_id:
self._config['run_name'] = self.experiment_id
if "run_name" not in self._config and self.experiment_id:
self._config["run_name"] = self.experiment_id
space_adapter_type = self._config.pop('space_adapter_type', None)
space_adapter_config = self._config.pop('space_adapter_config', {})
space_adapter_type = self._config.pop("space_adapter_type", None)
space_adapter_config = self._config.pop("space_adapter_config", {})
if space_adapter_type is not None:
space_adapter_type = getattr(SpaceAdapterType, space_adapter_type)
@ -81,9 +87,12 @@ class MlosCoreOptimizer(Optimizer):
space_adapter_kwargs=space_adapter_config,
)
def __exit__(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
self._opt.cleanup()
return super().__exit__(ex_type, ex_val, ex_tb)
@ -91,10 +100,12 @@ class MlosCoreOptimizer(Optimizer):
def name(self) -> str:
return f"{self.__class__.__name__}:{self._opt.__class__.__name__}"
def bulk_register(self,
configs: Sequence[dict],
scores: Sequence[Optional[Dict[str, TunableValue]]],
status: Optional[Sequence[Status]] = None) -> bool:
def bulk_register(
self,
configs: Sequence[dict],
scores: Sequence[Optional[Dict[str, TunableValue]]],
status: Optional[Sequence[Status]] = None,
) -> bool:
if not super().bulk_register(configs, scores, status):
return False
@ -102,7 +113,8 @@ class MlosCoreOptimizer(Optimizer):
df_configs = self._to_df(configs) # Impute missing values, if necessary
df_scores = self._adjust_signs_df(
pd.DataFrame([{} if score is None else score for score in scores]))
pd.DataFrame([{} if score is None else score for score in scores])
)
opt_targets = list(self._opt_targets)
if status is not None:
@ -126,17 +138,15 @@ class MlosCoreOptimizer(Optimizer):
return True
def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame:
"""
In-place adjust the signs of the scores for MINIMIZATION problem.
"""
for (opt_target, opt_dir) in self._opt_targets.items():
"""In-place adjust the signs of the scores for MINIMIZATION problem."""
for opt_target, opt_dir in self._opt_targets.items():
df_scores[opt_target] *= opt_dir
return df_scores
def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame:
"""
Select from past trials only the columns required in this experiment and
impute default values for the tunables that are missing in the dataframe.
Select from past trials only the columns required in this experiment and impute
default values for the tunables that are missing in the dataframe.
Parameters
----------
@ -151,7 +161,7 @@ class MlosCoreOptimizer(Optimizer):
df_configs = pd.DataFrame(configs)
tunables_names = list(self._tunables.get_param_values().keys())
missing_cols = set(tunables_names).difference(df_configs.columns)
for (tunable, _group) in self._tunables:
for tunable, _group in self._tunables:
if tunable.name in missing_cols:
df_configs[tunable.name] = tunable.default
else:
@ -183,22 +193,34 @@ class MlosCoreOptimizer(Optimizer):
df_config, _metadata = self._opt.suggest(defaults=self._start_with_defaults)
self._start_with_defaults = False
_LOG.info("Iteration %d :: Suggest:\n%s", self._iter, df_config)
return tunables.assign(
configspace_data_to_tunable_values(df_config.loc[0].to_dict()))
return tunables.assign(configspace_data_to_tunable_values(df_config.loc[0].to_dict()))
def register(self, tunables: TunableGroups, status: Status,
score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
registered_score = super().register(tunables, status, score) # Sign-adjusted for MINIMIZATION
def register(
self,
tunables: TunableGroups,
status: Status,
score: Optional[Dict[str, TunableValue]] = None,
) -> Optional[Dict[str, float]]:
registered_score = super().register(
tunables,
status,
score,
) # Sign-adjusted for MINIMIZATION
if status.is_completed():
assert registered_score is not None
df_config = self._to_df([tunables.get_param_values()])
_LOG.debug("Score: %s Dataframe:\n%s", registered_score, df_config)
# TODO: Specify (in the config) which metrics to pass to the optimizer.
# Issue: https://github.com/microsoft/MLOS/issues/745
self._opt.register(configs=df_config, scores=pd.DataFrame([registered_score], dtype=float))
self._opt.register(
configs=df_config,
scores=pd.DataFrame([registered_score], dtype=float),
)
return registered_score
def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
def get_best_observation(
self,
) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
(df_config, df_score, _df_context) = self._opt.get_best_observations()
if len(df_config) == 0:
return (None, None)

Просмотреть файл

@ -2,35 +2,31 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Mock optimizer for mlos_bench.
"""
"""Mock optimizer for mlos_bench."""
import random
import logging
import random
from typing import Callable, Dict, Optional, Sequence
from mlos_bench.environments.status import Status
from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import Tunable, TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
class MockOptimizer(TrackBestOptimizer):
"""
Mock optimizer to test the Environment API.
"""
"""Mock optimizer to test the Environment API."""
def __init__(self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None):
def __init__(
self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None,
):
super().__init__(tunables, config, global_config, service)
rnd = random.Random(self.seed)
self._random: Dict[str, Callable[[Tunable], TunableValue]] = {
@ -39,15 +35,17 @@ class MockOptimizer(TrackBestOptimizer):
"int": lambda tunable: rnd.randint(*tunable.range),
}
def bulk_register(self,
configs: Sequence[dict],
scores: Sequence[Optional[Dict[str, TunableValue]]],
status: Optional[Sequence[Status]] = None) -> bool:
def bulk_register(
self,
configs: Sequence[dict],
scores: Sequence[Optional[Dict[str, TunableValue]]],
status: Optional[Sequence[Status]] = None,
) -> bool:
if not super().bulk_register(configs, scores, status):
return False
if status is None:
status = [Status.SUCCEEDED] * len(configs)
for (params, score, trial_status) in zip(configs, scores, status):
for params, score, trial_status in zip(configs, scores, status):
tunables = self._tunables.copy().assign(params)
self.register(tunables, trial_status, score)
if _LOG.isEnabledFor(logging.DEBUG):
@ -56,15 +54,13 @@ class MockOptimizer(TrackBestOptimizer):
return True
def suggest(self) -> TunableGroups:
"""
Generate the next (random) suggestion.
"""
"""Generate the next (random) suggestion."""
tunables = super().suggest()
if self._start_with_defaults:
_LOG.info("Use default tunable values")
self._start_with_defaults = False
else:
for (tunable, _group) in tunables:
for tunable, _group in tunables:
tunable.value = self._random[tunable.type](tunable)
_LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
return tunables

Просмотреть файл

@ -2,16 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
No-op optimizer for mlos_bench that proposes a single configuration.
"""
"""No-op optimizer for mlos_bench that proposes a single configuration."""
import logging
from typing import Optional
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
_LOG = logging.getLogger(__name__)
@ -19,24 +17,25 @@ _LOG = logging.getLogger(__name__)
class OneShotOptimizer(MockOptimizer):
"""
No-op optimizer that proposes a single configuration and returns.
Explicit configs (partial or full) are possible using configuration files.
"""
# TODO: Add support for multiple explicit configs (i.e., FewShot or Manual Optimizer) - #344
def __init__(self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None):
def __init__(
self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None,
):
super().__init__(tunables, config, global_config, service)
_LOG.info("Run a single iteration for: %s", self._tunables)
self._max_iter = 1 # Always run for just one iteration.
def suggest(self) -> TunableGroups:
"""
Always produce the same (initial) suggestion.
"""
"""Always produce the same (initial) suggestion."""
tunables = super().suggest()
self._start_with_defaults = True
return tunables

Просмотреть файл

@ -2,40 +2,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Mock optimizer for mlos_bench.
"""
"""Mock optimizer for mlos_bench."""
import logging
from abc import ABCMeta
from typing import Dict, Optional, Tuple, Union
from mlos_bench.environments.status import Status
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.optimizers.base_optimizer import Optimizer
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
"""
Base Optimizer class that keeps track of the best score and configuration.
"""
"""Base Optimizer class that keeps track of the best score and configuration."""
def __init__(self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None):
def __init__(
self,
tunables: TunableGroups,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None,
):
super().__init__(tunables, config, global_config, service)
self._best_config: Optional[TunableGroups] = None
self._best_score: Optional[Dict[str, float]] = None
def register(self, tunables: TunableGroups, status: Status,
score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]:
def register(
self,
tunables: TunableGroups,
status: Status,
score: Optional[Dict[str, TunableValue]] = None,
) -> Optional[Dict[str, float]]:
registered_score = super().register(tunables, status, score)
if status.is_succeeded() and self._is_better(registered_score):
self._best_score = registered_score
@ -43,13 +44,11 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
return registered_score
def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool:
"""
Compare the optimization scores to the best ones so far lexicographically.
"""
"""Compare the optimization scores to the best ones so far lexicographically."""
if self._best_score is None:
return True
assert registered_score is not None
for (opt_target, best_score) in self._best_score.items():
for opt_target, best_score in self._best_score.items():
score = registered_score[opt_target]
if score < best_score:
return True
@ -57,7 +56,9 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
return False
return False
def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
def get_best_observation(
self,
) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
if self._best_score is None:
return (None, None)
score = self._get_scores(Status.SUCCEEDED, self._best_score)

Просмотреть файл

@ -3,8 +3,8 @@
# Licensed under the MIT License.
#
"""
Simple platform agnostic abstraction for the OS environment variables.
Meant as a replacement for os.environ vs nt.environ.
Simple platform agnostic abstraction for the OS environment variables. Meant as a
replacement for os.environ vs nt.environ.
Example
-------
@ -22,16 +22,18 @@ else:
from typing_extensions import TypeAlias
if sys.version_info >= (3, 9):
EnvironType: TypeAlias = os._Environ[str] # pylint: disable=protected-access,disable=unsubscriptable-object
# pylint: disable=protected-access,disable=unsubscriptable-object
EnvironType: TypeAlias = os._Environ[str]
else:
EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access
EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access
# Handle case sensitivity differences between platforms.
# https://stackoverflow.com/a/19023293
if sys.platform == 'win32':
import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8)
if sys.platform == "win32":
import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8)
environ: EnvironType = nt.environ
else:
environ: EnvironType = os.environ
__all__ = ['environ']
__all__ = ["environ"]

Просмотреть файл

@ -20,8 +20,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
def _main(argv: Optional[List[str]] = None
) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
def _main(
argv: Optional[List[str]] = None,
) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv)

Просмотреть файл

@ -2,14 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Interfaces and implementations of the optimization loop scheduling policies.
"""
"""Interfaces and implementations of the optimization loop scheduling policies."""
from mlos_bench.schedulers.base_scheduler import Scheduler
from mlos_bench.schedulers.sync_scheduler import SyncScheduler
__all__ = [
'Scheduler',
'SyncScheduler',
"Scheduler",
"SyncScheduler",
]

Просмотреть файл

@ -2,20 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base class for the optimization loop scheduling policies.
"""
"""Base class for the optimization loop scheduling policies."""
import json
import logging
from datetime import datetime
from abc import ABCMeta, abstractmethod
from datetime import datetime
from types import TracebackType
from typing import Any, Dict, Optional, Tuple, Type
from typing_extensions import Literal
from pytz import UTC
from typing_extensions import Literal
from mlos_bench.environments.base_environment import Environment
from mlos_bench.optimizers.base_optimizer import Optimizer
@ -28,22 +25,23 @@ _LOG = logging.getLogger(__name__)
class Scheduler(metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
Base class for the optimization loop scheduling policies.
"""
"""Base class for the optimization loop scheduling policies."""
def __init__(self, *,
config: Dict[str, Any],
global_config: Dict[str, Any],
environment: Environment,
optimizer: Optimizer,
storage: Storage,
root_env_config: str):
def __init__(
self,
*,
config: Dict[str, Any],
global_config: Dict[str, Any],
environment: Environment,
optimizer: Optimizer,
storage: Storage,
root_env_config: str,
):
"""
Create a new instance of the scheduler. The constructor of this
and the derived classes is called by the persistence service
after reading the class JSON configuration. Other objects like
the Environment and Optimizer are provided by the Launcher.
Create a new instance of the scheduler. The constructor of this and the derived
classes is called by the persistence service after reading the class JSON
configuration. Other objects like the Environment and Optimizer are provided by
the Launcher.
Parameters
----------
@ -61,8 +59,11 @@ class Scheduler(metaclass=ABCMeta):
Path to the root environment configuration.
"""
self.global_config = global_config
config = merge_parameters(dest=config.copy(), source=global_config,
required_keys=["experiment_id", "trial_id"])
config = merge_parameters(
dest=config.copy(),
source=global_config,
required_keys=["experiment_id", "trial_id"],
)
self._experiment_id = config["experiment_id"].strip()
self._trial_id = int(config["trial_id"])
@ -72,7 +73,9 @@ class Scheduler(metaclass=ABCMeta):
self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1))
if self._trial_config_repeat_count <= 0:
raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}")
raise ValueError(
f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}"
)
self._do_teardown = bool(config.get("teardown", True))
@ -96,10 +99,8 @@ class Scheduler(metaclass=ABCMeta):
"""
return self.__class__.__name__
def __enter__(self) -> 'Scheduler':
"""
Enter the scheduler's context.
"""
def __enter__(self) -> "Scheduler":
"""Enter the scheduler's context."""
_LOG.debug("Scheduler START :: %s", self)
assert self.experiment is None
self.environment.__enter__()
@ -118,13 +119,13 @@ class Scheduler(metaclass=ABCMeta):
).__enter__()
return self
def __exit__(self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
"""
Exit the context of the scheduler.
"""
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
"""Exit the context of the scheduler."""
if ex_val is None:
_LOG.debug("Scheduler END :: %s", self)
else:
@ -139,12 +140,14 @@ class Scheduler(metaclass=ABCMeta):
@abstractmethod
def start(self) -> None:
"""
Start the optimization loop.
"""
"""Start the optimization loop."""
assert self.experiment is not None
_LOG.info("START: Experiment: %s Env: %s Optimizer: %s",
self.experiment, self.environment, self.optimizer)
_LOG.info(
"START: Experiment: %s Env: %s Optimizer: %s",
self.experiment,
self.environment,
self.optimizer,
)
if _LOG.isEnabledFor(logging.INFO):
_LOG.info("Root Environment:\n%s", self.environment.pprint())
@ -155,6 +158,7 @@ class Scheduler(metaclass=ABCMeta):
def teardown(self) -> None:
"""
Tear down the environment.
Call it after the completion of the `.start()` in the scheduler context.
"""
assert self.experiment is not None
@ -162,17 +166,13 @@ class Scheduler(metaclass=ABCMeta):
self.environment.teardown()
def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
"""
Get the best observation from the optimizer.
"""
"""Get the best observation from the optimizer."""
(best_score, best_config) = self.optimizer.get_best_observation()
_LOG.info("Env: %s best score: %s", self.environment, best_score)
return (best_score, best_config)
def load_config(self, config_id: int) -> TunableGroups:
"""
Load the existing tunable configuration from the storage.
"""
"""Load the existing tunable configuration from the storage."""
assert self.experiment is not None
tunable_values = self.experiment.load_tunable_config(config_id)
tunables = self.environment.tunable_params.assign(tunable_values)
@ -183,9 +183,11 @@ class Scheduler(metaclass=ABCMeta):
def _schedule_new_optimizer_suggestions(self) -> bool:
"""
Optimizer part of the loop. Load the results of the executed trials
into the optimizer, suggest new configurations, and add them to the queue.
Return True if optimization is not over, False otherwise.
Optimizer part of the loop.
Load the results of the executed trials into the optimizer, suggest new
configurations, and add them to the queue. Return True if optimization is not
over, False otherwise.
"""
assert self.experiment is not None
(trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id)
@ -201,33 +203,38 @@ class Scheduler(metaclass=ABCMeta):
return not_done
def schedule_trial(self, tunables: TunableGroups) -> None:
"""
Add a configuration to the queue of trials.
"""
"""Add a configuration to the queue of trials."""
for repeat_i in range(1, self._trial_config_repeat_count + 1):
self._add_trial_to_queue(tunables, config={
# Add some additional metadata to track for the trial such as the
# optimizer config used.
# Note: these values are unfortunately mutable at the moment.
# Consider them as hints of what the config was the trial *started*.
# It is possible that the experiment configs were changed
# between resuming the experiment (since that is not currently
# prevented).
"optimizer": self.optimizer.name,
"repeat_i": repeat_i,
"is_defaults": tunables.is_defaults(),
**{
f"opt_{key}_{i}": val
for (i, opt_target) in enumerate(self.optimizer.targets.items())
for (key, val) in zip(["target", "direction"], opt_target)
}
})
self._add_trial_to_queue(
tunables,
config={
# Add some additional metadata to track for the trial such as the
# optimizer config used.
# Note: these values are unfortunately mutable at the moment.
# Consider them as hints of what the config was the trial *started*.
# It is possible that the experiment configs were changed
# between resuming the experiment (since that is not currently
# prevented).
"optimizer": self.optimizer.name,
"repeat_i": repeat_i,
"is_defaults": tunables.is_defaults(),
**{
f"opt_{key}_{i}": val
for (i, opt_target) in enumerate(self.optimizer.targets.items())
for (key, val) in zip(["target", "direction"], opt_target)
},
},
)
def _add_trial_to_queue(self, tunables: TunableGroups,
ts_start: Optional[datetime] = None,
config: Optional[Dict[str, Any]] = None) -> None:
def _add_trial_to_queue(
self,
tunables: TunableGroups,
ts_start: Optional[datetime] = None,
config: Optional[Dict[str, Any]] = None,
) -> None:
"""
Add a configuration to the queue of trials.
A wrapper for the `Experiment.new_trial` method.
"""
assert self.experiment is not None
@ -236,7 +243,9 @@ class Scheduler(metaclass=ABCMeta):
def _run_schedule(self, running: bool = False) -> None:
"""
Scheduler part of the loop. Check for pending trials in the queue and run them.
Scheduler part of the loop.
Check for pending trials in the queue and run them.
"""
assert self.experiment is not None
for trial in self.experiment.pending_trials(datetime.now(UTC), running=running):
@ -245,6 +254,7 @@ class Scheduler(metaclass=ABCMeta):
def not_done(self) -> bool:
"""
Check the stopping conditions.
By default, stop when the optimizer converges or max limit of trials reached.
"""
return self.optimizer.not_converged() and (
@ -254,7 +264,9 @@ class Scheduler(metaclass=ABCMeta):
@abstractmethod
def run_trial(self, trial: Storage.Trial) -> None:
"""
Set up and run a single trial. Save the results in the storage.
Set up and run a single trial.
Save the results in the storage.
"""
assert self.experiment is not None
self._trial_count += 1

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A simple single-threaded synchronous optimization loop implementation.
"""
"""A simple single-threaded synchronous optimization loop implementation."""
import logging
from datetime import datetime
@ -19,14 +17,10 @@ _LOG = logging.getLogger(__name__)
class SyncScheduler(Scheduler):
"""
A simple single-threaded synchronous optimization loop implementation.
"""
"""A simple single-threaded synchronous optimization loop implementation."""
def start(self) -> None:
"""
Start the optimization loop.
"""
"""Start the optimization loop."""
super().start()
is_warm_up = self.optimizer.supports_preload
@ -42,7 +36,9 @@ class SyncScheduler(Scheduler):
def run_trial(self, trial: Storage.Trial) -> None:
"""
Set up and run a single trial. Save the results in the storage.
Set up and run a single trial.
Save the results in the storage.
"""
super().run_trial(trial)
@ -53,7 +49,8 @@ class SyncScheduler(Scheduler):
trial.update(Status.FAILED, datetime.now(UTC))
return
(status, timestamp, results) = self.environment.run() # Block and wait for the final result.
# Block and wait for the final result.
(status, timestamp, results) = self.environment.run()
_LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results)
# In async mode (TODO), poll the environment for status and telemetry

Просмотреть файл

@ -2,17 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Services for implementing Environments for mlos_bench.
"""
"""Services for implementing Environments for mlos_bench."""
from mlos_bench.services.base_service import Service
from mlos_bench.services.base_fileshare import FileShareService
from mlos_bench.services.base_service import Service
from mlos_bench.services.local.local_exec import LocalExecService
__all__ = [
'Service',
'FileShareService',
'LocalExecService',
"Service",
"FileShareService",
"LocalExecService",
]

Просмотреть файл

@ -2,12 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base class for remote file shares.
"""
"""Base class for remote file shares."""
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
@ -18,14 +15,15 @@ _LOG = logging.getLogger(__name__)
class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
"""
An abstract base of all file shares.
"""
"""An abstract base of all file shares."""
def __init__(self, config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new file share with a given config.
@ -43,12 +41,20 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [self.upload, self.download])
config,
global_config,
parent,
self.merge_methods(methods, [self.upload, self.download]),
)
@abstractmethod
def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
def download(
self,
params: dict,
remote_path: str,
local_path: str,
recursive: bool = True,
) -> None:
"""
Downloads contents from a remote share path to a local path.
@ -66,11 +72,22 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
if True (the default), download the entire directory tree.
"""
params = params or {}
_LOG.info("Download from File Share %s recursively: %s -> %s (%s)",
"" if recursive else "non", remote_path, local_path, params)
_LOG.info(
"Download from File Share %s recursively: %s -> %s (%s)",
"" if recursive else "non",
remote_path,
local_path,
params,
)
@abstractmethod
def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
def upload(
self,
params: dict,
local_path: str,
remote_path: str,
recursive: bool = True,
) -> None:
"""
Uploads contents from a local path to remote share path.
@ -87,5 +104,10 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta):
if True (the default), upload the entire directory tree.
"""
params = params or {}
_LOG.info("Upload to File Share %s recursively: %s -> %s (%s)",
"" if recursive else "non", local_path, remote_path, params)
_LOG.info(
"Upload to File Share %s recursively: %s -> %s (%s)",
"" if recursive else "non",
local_path,
remote_path,
params,
)

Просмотреть файл

@ -2,15 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base class for the service mix-ins.
"""
"""Base class for the service mix-ins."""
import json
import logging
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from typing_extensions import Literal
from mlos_bench.config.schemas import ConfigSchema
@ -21,16 +19,16 @@ _LOG = logging.getLogger(__name__)
class Service:
"""
An abstract base of all Environment Services and used to build up mix-ins.
"""
"""An abstract base of all Environment Services and used to build up mix-ins."""
@classmethod
def new(cls,
class_name: str,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional["Service"] = None) -> "Service":
def new(
cls,
class_name: str,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional["Service"] = None,
) -> "Service":
"""
Factory method for a new service with a given config.
@ -57,11 +55,13 @@ class Service:
assert issubclass(cls, Service)
return instantiate_from_config(cls, class_name, config, global_config, parent)
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional["Service"] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional["Service"] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new service with a given config.
@ -101,12 +101,15 @@ class Service:
_LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
@staticmethod
def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None],
local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]:
def merge_methods(
ext_methods: Union[Dict[str, Callable], List[Callable], None],
local_methods: Union[Dict[str, Callable], List[Callable]],
) -> Dict[str, Callable]:
"""
Merge methods from the external caller with the local ones.
This function is usually called by the derived class constructor
just before invoking the constructor of the base class.
This function is usually called by the derived class constructor just before
invoking the constructor of the base class.
"""
if isinstance(local_methods, dict):
local_methods = local_methods.copy()
@ -138,9 +141,12 @@ class Service:
self._in_context = True
return self
def __exit__(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
"""
Exit the Service mix-in context.
@ -170,21 +176,24 @@ class Service:
"""
Enters the context for this particular Service instance.
Called by the base __enter__ method of the Service class so it can be
used with mix-ins and overridden by subclasses.
Called by the base __enter__ method of the Service class so it can be used with
mix-ins and overridden by subclasses.
"""
assert not self._in_context
self._in_context = True
return self
def _exit_context(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
def _exit_context(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
"""
Exits the context for this particular Service instance.
Called by the base __enter__ method of the Service class so it can be
used with mix-ins and overridden by subclasses.
Called by the base __enter__ method of the Service class so it can be used with
mix-ins and overridden by subclasses.
"""
# pylint: disable=unused-argument
assert self._in_context
@ -192,13 +201,13 @@ class Service:
return False
def _validate_json_config(self, config: dict) -> None:
"""
Reconstructs a basic json config that this class might have been
instantiated from in order to validate configs provided outside the
file loading mechanism.
"""Reconstructs a basic json config that this class might have been instantiated
from in order to validate configs provided outside the file loading
mechanism.
"""
if self.__class__ == Service:
# Skip over the case where instantiate a bare base Service class in order to build up a mix-in.
# Skip over the case where instantiate a bare base Service class in
# order to build up a mix-in.
assert config == {}
return
json_config: dict = {
@ -212,9 +221,7 @@ class Service:
return f"{self.__class__.__name__}@{hex(id(self))}"
def pprint(self) -> str:
"""
Produce a human-readable string listing all public methods of the service.
"""
"""Produce a human-readable string listing all public methods of the service."""
return f"{self} ::\n" + "\n".join(
f' "{key}": {getattr(val, "__self__", "stand-alone")}'
for (key, val) in self._service_methods.items()
@ -265,10 +272,11 @@ class Service:
# Unfortunately, by creating a set, we may destroy the ability to
# preserve the context enter/exit order, but hopefully it doesn't
# matter.
svc_method.__self__ for _, svc_method in self._service_methods.items()
svc_method.__self__
for _, svc_method in self._service_methods.items()
# Note: some methods are actually stand alone functions, so we need
# to filter them out.
if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service)
if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service)
}
def export(self) -> Dict[str, Callable]:

Просмотреть файл

@ -2,22 +2,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Helper functions to load, instantiate, and serialize Python objects
that encapsulate benchmark environments, tunable parameters, and
service functions.
"""Helper functions to load, instantiate, and serialize Python objects that encapsulate
benchmark environments, tunable parameters, and service functions.
"""
import json # For logging only
import logging
import os
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
import json # For logging only
import logging
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, TYPE_CHECKING
import json5 # To read configs with comments and other JSON5 syntax features
from jsonschema import ValidationError, SchemaError
import json5 # To read configs with comments and other JSON5 syntax features
from jsonschema import SchemaError, ValidationError
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.environments.base_environment import Environment
@ -26,7 +32,12 @@ from mlos_bench.services.base_service import Service
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import instantiate_from_config, merge_parameters, path_join, preprocess_dynamic_configs
from mlos_bench.util import (
instantiate_from_config,
merge_parameters,
path_join,
preprocess_dynamic_configs,
)
if sys.version_info < (3, 10):
from importlib_resources import files
@ -34,25 +45,27 @@ else:
from importlib.resources import files
if TYPE_CHECKING:
from mlos_bench.storage.base_storage import Storage
from mlos_bench.schedulers.base_scheduler import Scheduler
from mlos_bench.storage.base_storage import Storage
_LOG = logging.getLogger(__name__)
class ConfigPersistenceService(Service, SupportsConfigLoading):
"""
Collection of methods to deserialize the Environment, Service, and TunableGroups objects.
"""Collection of methods to deserialize the Environment, Service, and TunableGroups
objects.
"""
BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/")
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of config persistence service.
@ -69,17 +82,22 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [
self.resolve_path,
self.load_config,
self.prepare_class_load,
self.build_service,
self.build_environment,
self.load_services,
self.load_environment,
self.load_environment_list,
])
config,
global_config,
parent,
self.merge_methods(
methods,
[
self.resolve_path,
self.load_config,
self.prepare_class_load,
self.build_service,
self.build_environment,
self.load_services,
self.load_environment,
self.load_environment_list,
],
),
)
self._config_loader_service = self
@ -107,11 +125,10 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
"""
return list(self._config_path) # make a copy to avoid modifications
def resolve_path(self, file_path: str,
extra_paths: Optional[Iterable[str]] = None) -> str:
def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str:
"""
Prepend the suitable `_config_path` to `path` if the latter is not absolute.
If `_config_path` is `None` or `path` is absolute, return `path` as is.
Prepend the suitable `_config_path` to `path` if the latter is not absolute. If
`_config_path` is `None` or `path` is absolute, return `path` as is.
Parameters
----------
@ -138,14 +155,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
_LOG.debug("Path not resolved: %s", file_path)
return file_path
def load_config(self,
json_file_name: str,
schema_type: Optional[ConfigSchema],
) -> Dict[str, Any]:
def load_config(
self,
json_file_name: str,
schema_type: Optional[ConfigSchema],
) -> Dict[str, Any]:
"""
Load JSON config file. Search for a file relative to `_config_path`
if the input path is not absolute.
This method is exported to be used as a service.
Load JSON config file. Search for a file relative to `_config_path` if the input
path is not absolute. This method is exported to be used as a service.
Parameters
----------
@ -161,16 +178,22 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
"""
json_file_name = self.resolve_path(json_file_name)
_LOG.info("Load config: %s", json_file_name)
with open(json_file_name, mode='r', encoding='utf-8') as fh_json:
with open(json_file_name, mode="r", encoding="utf-8") as fh_json:
config = json5.load(fh_json)
if schema_type is not None:
try:
schema_type.validate(config)
except (ValidationError, SchemaError) as ex:
_LOG.error("Failed to validate config %s against schema type %s at %s",
json_file_name, schema_type.name, schema_type.value)
raise ValueError(f"Failed to validate config {json_file_name} against " +
f"schema type {schema_type.name} at {schema_type.value}") from ex
_LOG.error(
"Failed to validate config %s against schema type %s at %s",
json_file_name,
schema_type.name,
schema_type.value,
)
raise ValueError(
f"Failed to validate config {json_file_name} against "
f"schema type {schema_type.name} at {schema_type.value}"
) from ex
if isinstance(config, dict) and config.get("$schema"):
# Remove $schema attributes from the config after we've validated
# them to avoid passing them on to other objects
@ -181,15 +204,17 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
del config["$schema"]
else:
_LOG.warning("Config %s is not validated against a schema.", json_file_name)
return config # type: ignore[no-any-return]
return config # type: ignore[no-any-return]
def prepare_class_load(self, config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None,
parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]:
def prepare_class_load(
self,
config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
) -> Tuple[str, Dict[str, Any]]:
"""
Extract the class instantiation parameters from the configuration.
Mix-in the global parameters and resolve the local file system paths,
where it is required.
Extract the class instantiation parameters from the configuration. Mix-in the
global parameters and resolve the local file system paths, where it is required.
Parameters
----------
@ -228,19 +253,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
raise ValueError(f"Parameter {key} must be a string or a list")
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Instantiating: %s with config:\n%s",
class_name, json.dumps(class_config, indent=2))
_LOG.debug(
"Instantiating: %s with config:\n%s",
class_name,
json.dumps(class_config, indent=2),
)
return (class_name, class_config)
def build_optimizer(self, *,
tunables: TunableGroups,
service: Service,
config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None) -> Optimizer:
def build_optimizer(
self,
*,
tunables: TunableGroups,
service: Service,
config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None,
) -> Optimizer:
"""
Instantiation of mlos_bench Optimizer
that depend on Service and TunableGroups.
Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups.
A class *MUST* have a constructor that takes four named arguments:
(tunables, config, global_config, service)
@ -266,18 +296,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
if tunables_path is not None:
tunables = self._load_tunables(tunables_path, tunables)
(class_name, class_config) = self.prepare_class_load(config, global_config)
inst = instantiate_from_config(Optimizer, class_name, # type: ignore[type-abstract]
tunables=tunables,
config=class_config,
global_config=global_config,
service=service)
inst = instantiate_from_config(
Optimizer, # type: ignore[type-abstract]
class_name,
tunables=tunables,
config=class_config,
global_config=global_config,
service=service,
)
_LOG.info("Created: Optimizer %s", inst)
return inst
def build_storage(self, *,
service: Service,
config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None) -> "Storage":
def build_storage(
self,
*,
service: Service,
config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None,
) -> "Storage":
"""
Instantiation of mlos_bench Storage objects.
@ -296,21 +332,29 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
A new instance of the Storage class.
"""
(class_name, class_config) = self.prepare_class_load(config, global_config)
from mlos_bench.storage.base_storage import Storage # pylint: disable=import-outside-toplevel
inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract]
config=class_config,
global_config=global_config,
service=service)
# pylint: disable=import-outside-toplevel
from mlos_bench.storage.base_storage import Storage
inst = instantiate_from_config(
Storage, # type: ignore[type-abstract]
class_name,
config=class_config,
global_config=global_config,
service=service,
)
_LOG.info("Created: Storage %s", inst)
return inst
def build_scheduler(self, *,
config: Dict[str, Any],
global_config: Dict[str, Any],
environment: Environment,
optimizer: Optimizer,
storage: "Storage",
root_env_config: str) -> "Scheduler":
def build_scheduler(
self,
*,
config: Dict[str, Any],
global_config: Dict[str, Any],
environment: Environment,
optimizer: Optimizer,
storage: "Storage",
root_env_config: str,
) -> "Scheduler":
"""
Instantiation of mlos_bench Scheduler.
@ -335,23 +379,30 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
A new instance of the Scheduler.
"""
(class_name, class_config) = self.prepare_class_load(config, global_config)
from mlos_bench.schedulers.base_scheduler import Scheduler # pylint: disable=import-outside-toplevel
inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract]
config=class_config,
global_config=global_config,
environment=environment,
optimizer=optimizer,
storage=storage,
root_env_config=root_env_config)
# pylint: disable=import-outside-toplevel
from mlos_bench.schedulers.base_scheduler import Scheduler
inst = instantiate_from_config(
Scheduler, # type: ignore[type-abstract]
class_name,
config=class_config,
global_config=global_config,
environment=environment,
optimizer=optimizer,
storage=storage,
root_env_config=root_env_config,
)
_LOG.info("Created: Scheduler %s", inst)
return inst
def build_environment(self, # pylint: disable=too-many-arguments
config: Dict[str, Any],
tunables: TunableGroups,
global_config: Optional[Dict[str, Any]] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional[Service] = None) -> Environment:
def build_environment(
self, # pylint: disable=too-many-arguments
config: Dict[str, Any],
tunables: TunableGroups,
global_config: Optional[Dict[str, Any]] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional[Service] = None,
) -> Environment:
"""
Factory method for a new environment with a given config.
@ -391,16 +442,24 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
tunables = self._load_tunables(env_tunables_path, tunables)
_LOG.debug("Creating env: %s :: %s", env_name, env_class)
env = Environment.new(env_name=env_name, class_name=env_class,
config=env_config, global_config=global_config,
tunables=tunables, service=service)
env = Environment.new(
env_name=env_name,
class_name=env_class,
config=env_config,
global_config=global_config,
tunables=tunables,
service=service,
)
_LOG.info("Created env: %s :: %s", env_name, env)
return env
def _build_standalone_service(self, config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None) -> Service:
def _build_standalone_service(
self,
config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
) -> Service:
"""
Factory method for a new service with a given config.
@ -425,9 +484,12 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
_LOG.info("Created service: %s", service)
return service
def _build_composite_service(self, config_list: Iterable[Dict[str, Any]],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None) -> Service:
def _build_composite_service(
self,
config_list: Iterable[Dict[str, Any]],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
) -> Service:
"""
Factory method for a new service with a given config.
@ -453,18 +515,21 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
service.register(parent.export())
for config in config_list:
service.register(self._build_standalone_service(
config, global_config, service).export())
service.register(
self._build_standalone_service(config, global_config, service).export()
)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Created mix-in service: %s", service)
return service
def build_service(self,
config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None) -> Service:
def build_service(
self,
config: Dict[str, Any],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
) -> Service:
"""
Factory method for a new service with a given config.
@ -486,8 +551,7 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
services from the list plus the parent mix-in.
"""
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Build service from config:\n%s",
json.dumps(config, indent=2))
_LOG.debug("Build service from config:\n%s", json.dumps(config, indent=2))
assert isinstance(config, dict)
config_list: List[Dict[str, Any]]
@ -502,12 +566,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
return self._build_composite_service(config_list, global_config, parent)
def load_environment(self, # pylint: disable=too-many-arguments
json_file_name: str,
tunables: TunableGroups,
global_config: Optional[Dict[str, Any]] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional[Service] = None) -> Environment:
def load_environment(
self, # pylint: disable=too-many-arguments
json_file_name: str,
tunables: TunableGroups,
global_config: Optional[Dict[str, Any]] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional[Service] = None,
) -> Environment:
"""
Load and build new environment from the config file.
@ -534,12 +600,14 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
assert isinstance(config, dict)
return self.build_environment(config, tunables, global_config, parent_args, service)
def load_environment_list(self, # pylint: disable=too-many-arguments
json_file_name: str,
tunables: TunableGroups,
global_config: Optional[Dict[str, Any]] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional[Service] = None) -> List[Environment]:
def load_environment_list(
self, # pylint: disable=too-many-arguments
json_file_name: str,
tunables: TunableGroups,
global_config: Optional[Dict[str, Any]] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional[Service] = None,
) -> List[Environment]:
"""
Load and build a list of environments from the config file.
@ -564,16 +632,17 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
A list of new benchmarking environments.
"""
config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT)
return [
self.build_environment(config, tunables, global_config, parent_args, service)
]
return [self.build_environment(config, tunables, global_config, parent_args, service)]
def load_services(self, json_file_names: Iterable[str],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None) -> Service:
def load_services(
self,
json_file_names: Iterable[str],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
) -> Service:
"""
Read the configuration files and bundle all service methods
from those configs into a single Service object.
Read the configuration files and bundle all service methods from those configs
into a single Service object.
Parameters
----------
@ -589,16 +658,18 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
service : Service
A collection of service methods.
"""
_LOG.info("Load services: %s parent: %s",
json_file_names, parent.__class__.__name__)
_LOG.info("Load services: %s parent: %s", json_file_names, parent.__class__.__name__)
service = Service({}, global_config, parent)
for fname in json_file_names:
config = self.load_config(fname, ConfigSchema.SERVICE)
service.register(self.build_service(config, global_config, service).export())
return service
def _load_tunables(self, json_file_names: Iterable[str],
parent: TunableGroups) -> TunableGroups:
def _load_tunables(
self,
json_file_names: Iterable[str],
parent: TunableGroups,
) -> TunableGroups:
"""
Load a collection of tunable parameters from JSON files into the parent
TunableGroup.

Просмотреть файл

@ -2,13 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Local scheduler side Services for mlos_bench.
"""
"""Local scheduler side Services for mlos_bench."""
from mlos_bench.services.local.local_exec import LocalExecService
__all__ = [
'LocalExecService',
"LocalExecService",
]

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Helper functions to run scripts and commands locally on the scheduler side.
"""
"""Helper functions to run scripts and commands locally on the scheduler side."""
import errno
import logging
@ -12,10 +10,18 @@ import os
import shlex
import subprocess
import sys
from string import Template
from typing import (
Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
from mlos_bench.os_environ import environ
@ -31,9 +37,9 @@ _LOG = logging.getLogger(__name__)
def split_cmdline(cmdline: str) -> Iterable[List[str]]:
"""
A single command line may contain multiple commands separated by
special characters (e.g., &&, ||, etc.) so further split the
commandline into an array of subcommand arrays.
A single command line may contain multiple commands separated by special characters
(e.g., &&, ||, etc.) so further split the commandline into an array of subcommand
arrays.
Parameters
----------
@ -66,16 +72,20 @@ def split_cmdline(cmdline: str) -> Iterable[List[str]]:
class LocalExecService(TempDirContextService, SupportsLocalExec):
"""
Collection of methods to run scripts and commands in an external process
on the node acting as the scheduler. Can be useful for data processing
due to reduced dependency management complications vs the target environment.
Collection of methods to run scripts and commands in an external process on the node
acting as the scheduler.
Can be useful for data processing due to reduced dependency management complications
vs the target environment.
"""
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of a service to run scripts locally.
@ -92,14 +102,19 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [self.local_exec])
config,
global_config,
parent,
self.merge_methods(methods, [self.local_exec]),
)
self.abort_on_error = self.config.get("abort_on_error", True)
def local_exec(self, script_lines: Iterable[str],
env: Optional[Mapping[str, "TunableValue"]] = None,
cwd: Optional[str] = None) -> Tuple[int, str, str]:
def local_exec(
self,
script_lines: Iterable[str],
env: Optional[Mapping[str, "TunableValue"]] = None,
cwd: Optional[str] = None,
) -> Tuple[int, str, str]:
"""
Execute the script lines from `script_lines` in a local process.
@ -141,8 +156,8 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]:
"""
Resolves local script path (first token) in the (sub)command line
tokens to its full path.
Resolves local script path (first token) in the (sub)command line tokens to its
full path.
Parameters
----------
@ -167,9 +182,12 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
subcmd_tokens.insert(0, sys.executable)
return subcmd_tokens
def _local_exec_script(self, script_line: str,
env_params: Optional[Mapping[str, "TunableValue"]],
cwd: str) -> Tuple[int, str, str]:
def _local_exec_script(
self,
script_line: str,
env_params: Optional[Mapping[str, "TunableValue"]],
cwd: str,
) -> Tuple[int, str, str]:
"""
Execute the script from `script_path` in a local process.
@ -198,7 +216,7 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
if env_params:
env = {key: str(val) for (key, val) in env_params.items()}
if sys.platform == 'win32':
if sys.platform == "win32":
# A hack to run Python on Windows with env variables set:
env_copy = environ.copy()
env_copy["PYTHONPATH"] = ""
@ -206,7 +224,7 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
env = env_copy
try:
if sys.platform != 'win32':
if sys.platform != "win32":
cmd = [" ".join(cmd)]
_LOG.info("Run: %s", cmd)
@ -214,8 +232,15 @@ class LocalExecService(TempDirContextService, SupportsLocalExec):
_LOG.debug("Expands to: %s", Template(" ".join(cmd)).safe_substitute(env))
_LOG.debug("Current working dir: %s", cwd)
proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True,
text=True, check=False, capture_output=True)
proc = subprocess.run(
cmd,
env=env or None,
cwd=cwd,
shell=True,
text=True,
check=False,
capture_output=True,
)
_LOG.debug("Run: return code = %d", proc.returncode)
return (proc.returncode, proc.stdout, proc.stderr)

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Helper functions to work with temp files locally on the scheduler side.
"""
"""Helper functions to work with temp files locally on the scheduler side."""
import abc
import logging
@ -21,21 +19,23 @@ _LOG = logging.getLogger(__name__)
class TempDirContextService(Service, metaclass=abc.ABCMeta):
"""
A *base* service class that provides a method to create a temporary
directory context for local scripts.
A *base* service class that provides a method to create a temporary directory
context for local scripts.
It is inherited by LocalExecService and MockLocalExecService.
This class is not supposed to be used as a standalone service.
It is inherited by LocalExecService and MockLocalExecService. This class is not
supposed to be used as a standalone service.
"""
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of a service that provides temporary directory context
for local exec service.
Create a new instance of a service that provides temporary directory context for
local exec service.
Parameters
----------
@ -50,8 +50,10 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta):
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [self.temp_dir_context])
config,
global_config,
parent,
self.merge_methods(methods, [self.temp_dir_context]),
)
self._temp_dir = self.config.get("temp_dir")
if self._temp_dir:
@ -61,7 +63,10 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta):
self._temp_dir = self._config_loader_service.resolve_path(self._temp_dir)
_LOG.info("%s: temp dir: %s", self, self._temp_dir)
def temp_dir_context(self, path: Optional[str] = None) -> Union[TemporaryDirectory, nullcontext]:
def temp_dir_context(
self,
path: Optional[str] = None,
) -> Union[TemporaryDirectory, nullcontext]:
"""
Create a temp directory or use the provided path.

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Azure-specific benchmark environments for mlos_bench.
"""
"""Azure-specific benchmark environments for mlos_bench."""
from mlos_bench.services.remote.azure.azure_auth import AzureAuthService
from mlos_bench.services.remote.azure.azure_fileshare import AzureFileShareService
@ -12,11 +10,10 @@ from mlos_bench.services.remote.azure.azure_network_services import AzureNetwork
from mlos_bench.services.remote.azure.azure_saas import AzureSaaSConfigService
from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService
__all__ = [
'AzureAuthService',
'AzureFileShareService',
'AzureNetworkService',
'AzureSaaSConfigService',
'AzureVMService',
"AzureAuthService",
"AzureFileShareService",
"AzureNetworkService",
"AzureSaaSConfigService",
"AzureVMService",
]

Просмотреть файл

@ -2,19 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A collection Service functions for managing VMs on Azure.
"""
"""A collection Service functions for managing VMs on Azure."""
import logging
from base64 import b64decode
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union
from pytz import UTC
import azure.identity as azure_id
from azure.keyvault.secrets import SecretClient
from pytz import UTC
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.authenticator_type import SupportsAuth
@ -24,17 +21,17 @@ _LOG = logging.getLogger(__name__)
class AzureAuthService(Service, SupportsAuth):
"""
Helper methods to get access to Azure services.
"""
"""Helper methods to get access to Azure services."""
_REQ_INTERVAL = 300 # = 5 min
_REQ_INTERVAL = 300 # = 5 min
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of Azure authentication services proxy.
@ -51,11 +48,16 @@ class AzureAuthService(Service, SupportsAuth):
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [
self.get_access_token,
self.get_auth_headers,
])
config,
global_config,
parent,
self.merge_methods(
methods,
[
self.get_access_token,
self.get_auth_headers,
],
),
)
# This parameter can come from command line as strings, so conversion is needed.
@ -71,12 +73,13 @@ class AzureAuthService(Service, SupportsAuth):
# Verify info required for SP auth early
if "spClientId" in self.config:
check_required_params(
self.config, {
self.config,
{
"spClientId",
"keyVaultName",
"certName",
"tenant",
}
},
)
def _init_sp(self) -> None:
@ -105,12 +108,14 @@ class AzureAuthService(Service, SupportsAuth):
cert_bytes = b64decode(secret.value)
# Reauthenticate as the service principal.
self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes)
self._cred = azure_id.CertificateCredential(
tenant_id=tenant_id,
client_id=sp_client_id,
certificate_data=cert_bytes,
)
def get_access_token(self) -> str:
"""
Get the access token from Azure CLI, if expired.
"""
"""Get the access token from Azure CLI, if expired."""
# Ensure we are logged as the Service Principal, if provided
if "spClientId" in self.config:
self._init_sp()
@ -126,7 +131,5 @@ class AzureAuthService(Service, SupportsAuth):
return self._access_token
def get_auth_headers(self) -> dict:
"""
Get the authorization part of HTTP headers for REST API calls.
"""
"""Get the authorization part of HTTP headers for REST API calls."""
return {"Authorization": "Bearer " + self.get_access_token()}

Просмотреть файл

@ -2,15 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base class for certain Azure Services classes that do deployments.
"""
"""Base class for certain Azure Services classes that do deployments."""
import abc
import json
import time
import logging
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import requests
@ -26,33 +23,34 @@ _LOG = logging.getLogger(__name__)
class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"""
Helper methods to manage and deploy Azure resources via REST APIs.
"""
"""Helper methods to manage and deploy Azure resources via REST APIs."""
_POLL_INTERVAL = 4 # seconds
_POLL_TIMEOUT = 300 # seconds
_REQUEST_TIMEOUT = 5 # seconds
_POLL_INTERVAL = 4 # seconds
_POLL_TIMEOUT = 300 # seconds
_REQUEST_TIMEOUT = 5 # seconds
_REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request
_REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
# Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
_REQUEST_RETRY_BACKOFF_FACTOR = 0.3
# Azure Resources Deployment REST API as described in
# https://docs.microsoft.com/en-us/rest/api/resources/deployments
_URL_DEPLOY = (
"https://management.azure.com" +
"/subscriptions/{subscription}" +
"/resourceGroups/{resource_group}" +
"/providers/Microsoft.Resources" +
"/deployments/{deployment_name}" +
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Resources"
"/deployments/{deployment_name}"
"?api-version=2022-05-01"
)
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of an Azure Services proxy.
@ -70,38 +68,49 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"""
super().__init__(config, global_config, parent, methods)
check_required_params(self.config, [
"subscription",
"resourceGroup",
])
check_required_params(
self.config,
[
"subscription",
"resourceGroup",
],
)
# These parameters can come from command line as strings, so conversion is needed.
self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL))
self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT))
self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES))
self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR))
self._total_retries = int(
self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)
)
self._backoff_factor = float(
self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)
)
self._deploy_template = {}
self._deploy_params = {}
if self.config.get("deploymentTemplatePath") is not None:
# TODO: Provide external schema validation?
template = self.config_loader_service.load_config(
self.config['deploymentTemplatePath'], schema_type=None)
self.config["deploymentTemplatePath"],
schema_type=None,
)
assert template is not None and isinstance(template, dict)
self._deploy_template = template
# Allow for recursive variable expansion as we do with global params and const_args.
deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config)
deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars(
extra_source_dict=global_config
)
self._deploy_params = merge_parameters(dest=deploy_params, source=global_config)
else:
_LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.")
_LOG.info(
"No deploymentTemplatePath provided. Deployment services will be unavailable.",
)
@property
def deploy_params(self) -> dict:
"""
Get the deployment parameters.
"""
"""Get the deployment parameters."""
return self._deploy_params
@abc.abstractmethod
@ -122,24 +131,24 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
raise NotImplementedError("Should be overridden by subclass.")
def _get_session(self, params: dict) -> requests.Session:
"""
Get a session object that includes automatic retries and headers for REST API calls.
"""Get a session object that includes automatic retries and headers for REST API
calls.
"""
total_retries = params.get("requestTotalRetries", self._total_retries)
backoff_factor = params.get("requestBackoffFactor", self._backoff_factor)
session = requests.Session()
session.mount(
"https://",
HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)))
HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)),
)
session.headers.update(self._get_headers())
return session
def _get_headers(self) -> dict:
"""
Get the headers for the REST API calls.
"""
assert self._parent is not None and isinstance(self._parent, SupportsAuth), \
"Authorization service not provided. Include service-auth.jsonc?"
"""Get the headers for the REST API calls."""
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
return self._parent.get_auth_headers()
@staticmethod
@ -235,9 +244,11 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
return (Status.FAILED, {})
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Response: %s\n%s", response,
json.dumps(response.json(), indent=2)
if response.content else "")
_LOG.debug(
"Response: %s\n%s",
response,
json.dumps(response.json(), indent=2) if response.content else "",
)
if response.status_code == 200:
output = response.json()
@ -252,15 +263,16 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
"""
Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or FAILED.
Return TIMED_OUT when timing out.
Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or
FAILED. Return TIMED_OUT when timing out.
Parameters
----------
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
is_setup : bool
If True, wait for resource being deployed; otherwise, wait for successful deprovisioning.
If True, wait for resource being deployed; otherwise, wait for
successful deprovisioning.
Returns
-------
@ -270,15 +282,22 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
params = self._set_default_params(params)
_LOG.info("Wait for %s to %s", params.get("deploymentName"),
"provision" if is_setup else "deprovision")
_LOG.info(
"Wait for %s to %s",
params.get("deploymentName"),
"provision" if is_setup else "deprovision",
)
return self._wait_while(self._check_deployment, Status.PENDING, params)
def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]],
loop_status: Status, params: dict) -> Tuple[Status, dict]:
def _wait_while(
self,
func: Callable[[dict], Tuple[Status, dict]],
loop_status: Status,
params: dict,
) -> Tuple[Status, dict]:
"""
Invoke `func` periodically while the status is equal to `loop_status`.
Return TIMED_OUT when timing out.
Invoke `func` periodically while the status is equal to `loop_status`. Return
TIMED_OUT when timing out.
Parameters
----------
@ -297,12 +316,20 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"""
params = self._set_default_params(params)
config = merge_parameters(
dest=self.config.copy(), source=params, required_keys=["deploymentName"])
dest=self.config.copy(),
source=params,
required_keys=["deploymentName"],
)
poll_period = params.get("pollInterval", self._poll_interval)
_LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s",
config["deploymentName"], loop_status, poll_period, self._poll_timeout)
_LOG.debug(
"Wait for %s status %s :: poll %.2f timeout %d s",
config["deploymentName"],
loop_status,
poll_period,
self._poll_timeout,
)
ts_timeout = time.time() + self._poll_timeout
poll_delay = poll_period
@ -326,10 +353,10 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
_LOG.warning("Request timed out: %s", params)
return (Status.TIMED_OUT, {})
def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements
def _check_deployment(self, params: dict) -> Tuple[Status, dict]:
# pylint: disable=too-many-return-statements
"""
Check if Azure deployment exists.
Return SUCCEEDED if true, PENDING otherwise.
Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise.
Parameters
----------
@ -352,7 +379,7 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"subscription",
"resourceGroup",
"deploymentName",
]
],
)
_LOG.info("Check deployment: %s", config["deploymentName"])
@ -413,13 +440,20 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
if not self._deploy_template:
raise ValueError(f"Missing deployment template: {self}")
params = self._set_default_params(params)
config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"])
config = merge_parameters(
dest=self.config.copy(),
source=params,
required_keys=["deploymentName"],
)
_LOG.info("Deploy: %s :: %s", config["deploymentName"], params)
params = merge_parameters(dest=self._deploy_params.copy(), source=params)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Deploy: %s merged params ::\n%s",
config["deploymentName"], json.dumps(params, indent=2))
_LOG.debug(
"Deploy: %s merged params ::\n%s",
config["deploymentName"],
json.dumps(params, indent=2),
)
url = self._URL_DEPLOY.format(
subscription=config["subscription"],
@ -432,22 +466,29 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
"mode": "Incremental",
"template": self._deploy_template,
"parameters": {
key: {"value": val} for (key, val) in params.items()
key: {"value": val}
for (key, val) in params.items()
if key in self._deploy_template.get("parameters", {})
}
},
}
}
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2))
response = requests.put(url, json=json_req,
headers=self._get_headers(), timeout=self._request_timeout)
response = requests.put(
url,
json=json_req,
headers=self._get_headers(),
timeout=self._request_timeout,
)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Response: %s\n%s", response,
json.dumps(response.json(), indent=2)
if response.content else "")
_LOG.debug(
"Response: %s\n%s",
response,
json.dumps(response.json(), indent=2) if response.content else "",
)
else:
_LOG.info("Response: %s", response)

Просмотреть файл

@ -2,37 +2,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A collection FileShare functions for interacting with Azure File Shares.
"""
"""A collection FileShare functions for interacting with Azure File Shares."""
import os
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Set, Union
from azure.storage.fileshare import ShareClient
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.fileshare import ShareClient
from mlos_bench.services.base_service import Service
from mlos_bench.services.base_fileshare import FileShareService
from mlos_bench.services.base_service import Service
from mlos_bench.util import check_required_params
_LOG = logging.getLogger(__name__)
class AzureFileShareService(FileShareService):
"""
Helper methods for interacting with Azure File Share
"""
"""Helper methods for interacting with Azure File Share."""
_SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}"
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new file share Service for Azure environments with a given config.
@ -50,16 +47,19 @@ class AzureFileShareService(FileShareService):
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [self.upload, self.download])
config,
global_config,
parent,
self.merge_methods(methods, [self.upload, self.download]),
)
check_required_params(
self.config, {
self.config,
{
"storageAccountName",
"storageFileShareName",
"storageAccountKey",
}
},
)
self._share_client = ShareClient.from_share_url(
@ -70,7 +70,13 @@ class AzureFileShareService(FileShareService):
credential=self.config["storageAccountKey"],
)
def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
def download(
self,
params: dict,
remote_path: str,
local_path: str,
recursive: bool = True,
) -> None:
super().download(params, remote_path, local_path, recursive)
dir_client = self._share_client.get_directory_client(remote_path)
if dir_client.exists():
@ -95,16 +101,21 @@ class AzureFileShareService(FileShareService):
# Translate into non-Azure exception:
raise FileNotFoundError(f"Cannot download: {remote_path}") from ex
def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
def upload(
self,
params: dict,
local_path: str,
remote_path: str,
recursive: bool = True,
) -> None:
super().upload(params, local_path, remote_path, recursive)
self._upload(local_path, remote_path, recursive, set())
def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None:
"""
Upload contents from a local path to an Azure file share.
This method is called from `.upload()` above. We need it to avoid exposing
the `seen` parameter and to make `.upload()` match the base class' virtual
method.
Upload contents from a local path to an Azure file share. This method is called
from `.upload()` above. We need it to avoid exposing the `seen` parameter and to
make `.upload()` match the base class' virtual method.
Parameters
----------
@ -143,8 +154,8 @@ class AzureFileShareService(FileShareService):
def _remote_makedirs(self, remote_path: str) -> None:
"""
Create remote directories for the entire path.
Succeeds even some or all directories along the path already exist.
Create remote directories for the entire path. Succeeds even some or all
directories along the path already exist.
Parameters
----------

Просмотреть файл

@ -2,47 +2,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A collection Service functions for managing virtual networks on Azure.
"""
"""A collection Service functions for managing virtual networks on Azure."""
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService
from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
from mlos_bench.services.remote.azure.azure_deployment_services import (
AzureDeploymentService,
)
from mlos_bench.services.types.network_provisioner_type import (
SupportsNetworkProvisioning,
)
from mlos_bench.util import merge_parameters
_LOG = logging.getLogger(__name__)
class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
"""
Helper methods to manage Virtual Networks on Azure.
"""
"""Helper methods to manage Virtual Networks on Azure."""
# Azure Compute REST API calls as described in
# https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01
# From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01
# From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 # pylint: disable=line-too-long # noqa
_URL_DEPROVISION = (
"https://management.azure.com" +
"/subscriptions/{subscription}" +
"/resourceGroups/{resource_group}" +
"/providers/Microsoft.Network" +
"/virtualNetwork/{vnet_name}" +
"/delete" +
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Network"
"/virtualNetwork/{vnet_name}"
"/delete"
"?api-version=2023-05-01"
)
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of Azure Network services proxy.
@ -59,25 +60,35 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [
# SupportsNetworkProvisioning
self.provision_network,
self.deprovision_network,
self.wait_network_deployment,
])
config,
global_config,
parent,
self.merge_methods(
methods,
[
# SupportsNetworkProvisioning
self.provision_network,
self.deprovision_network,
self.wait_network_deployment,
],
),
)
if not self._deploy_template:
raise ValueError("AzureNetworkService requires a deployment template:\n"
+ f"config={config}\nglobal_config={global_config}")
raise ValueError(
"AzureNetworkService requires a deployment template:\n"
+ f"config={config}\nglobal_config={global_config}"
)
def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
# Try and provide a semi sane default for the deploymentName if not provided
# since this is a common way to set the deploymentName and can same some
# config work for the caller.
if "vnetName" in params and "deploymentName" not in params:
params["deploymentName"] = f"{params['vnetName']}-deployment"
_LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"])
_LOG.info(
"deploymentName missing from params. Defaulting to '%s'.",
params["deploymentName"],
)
return params
def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
@ -148,15 +159,18 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning):
"resourceGroup",
"deploymentName",
"vnetName",
]
],
)
_LOG.info("Deprovision Network: %s", config["vnetName"])
_LOG.info("Deprovision deployment: %s", config["deploymentName"])
(status, results) = self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vnet_name=config["vnetName"],
))
(status, results) = self._azure_rest_api_post_helper(
config,
self._URL_DEPROVISION.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vnet_name=config["vnetName"],
),
)
if ignore_errors and status == Status.FAILED:
_LOG.warning("Ignoring error: %s", results)
status = Status.SUCCEEDED

Просмотреть файл

@ -2,11 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A collection Service functions for configuring SaaS instances on Azure.
"""
"""A collection Service functions for configuring SaaS instances on Azure."""
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import requests
@ -21,9 +18,7 @@ _LOG = logging.getLogger(__name__)
class AzureSaaSConfigService(Service, SupportsRemoteConfig):
"""
Helper methods to configure Azure Flex services.
"""
"""Helper methods to configure Azure Flex services."""
_REQUEST_TIMEOUT = 5 # seconds
@ -33,20 +28,22 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
# https://learn.microsoft.com/en-us/rest/api/mariadb/configurations
_URL_CONFIGURE = (
"https://management.azure.com" +
"/subscriptions/{subscription}" +
"/resourceGroups/{resource_group}" +
"/providers/{provider}" +
"/{server_type}/{vm_name}" +
"/{update}" +
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/{provider}"
"/{server_type}/{vm_name}"
"/{update}"
"?api-version={api_version}"
)
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of Azure services proxy.
@ -63,18 +60,20 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [
self.configure,
self.is_config_pending
])
config,
global_config,
parent,
self.merge_methods(methods, [self.configure, self.is_config_pending]),
)
check_required_params(self.config, {
"subscription",
"resourceGroup",
"provider",
})
check_required_params(
self.config,
{
"subscription",
"resourceGroup",
"provider",
},
)
# Provide sane defaults for known DB providers.
provider = self.config.get("provider")
@ -118,8 +117,7 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
# These parameters can come from command line as strings, so conversion is needed.
self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
def configure(self, config: Dict[str, Any],
params: Dict[str, Any]) -> Tuple[Status, dict]:
def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
"""
Update the parameters of an Azure DB service.
@ -157,33 +155,39 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
If "isConfigPendingReboot" is set to True, rebooting a VM is necessary.
Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED}
"""
config = merge_parameters(
dest=self.config.copy(), source=config, required_keys=["vmName"])
config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
url = self._url_config_get.format(vm_name=config["vmName"])
_LOG.debug("Request: GET %s", url)
response = requests.put(
url, headers=self._get_headers(), timeout=self._request_timeout)
response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout)
_LOG.debug("Response: %s :: %s", response, response.text)
if response.status_code == 504:
return (Status.TIMED_OUT, {})
if response.status_code != 200:
return (Status.FAILED, {})
# Currently, Azure Flex servers require a VM reboot.
return (Status.SUCCEEDED, {"isConfigPendingReboot": any(
{'False': False, 'True': True}[val['properties']['isConfigPendingRestart']]
for val in response.json()['value']
)})
return (
Status.SUCCEEDED,
{
"isConfigPendingReboot": any(
{"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]]
for val in response.json()["value"]
)
},
)
def _get_headers(self) -> dict:
"""
Get the headers for the REST API calls.
"""
assert self._parent is not None and isinstance(self._parent, SupportsAuth), \
"Authorization service not provided. Include service-auth.jsonc?"
"""Get the headers for the REST API calls."""
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
return self._parent.get_auth_headers()
def _config_one(self, config: Dict[str, Any],
param_name: str, param_value: Any) -> Tuple[Status, dict]:
def _config_one(
self,
config: Dict[str, Any],
param_name: str,
param_value: Any,
) -> Tuple[Status, dict]:
"""
Update a single parameter of the Azure DB service.
@ -202,13 +206,15 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
config = merge_parameters(
dest=self.config.copy(), source=config, required_keys=["vmName"])
config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name)
_LOG.debug("Request: PUT %s", url)
response = requests.put(url, headers=self._get_headers(),
json={"properties": {"value": str(param_value)}},
timeout=self._request_timeout)
response = requests.put(
url,
headers=self._get_headers(),
json={"properties": {"value": str(param_value)}},
timeout=self._request_timeout,
)
_LOG.debug("Response: %s :: %s", response, response.text)
if response.status_code == 504:
return (Status.TIMED_OUT, {})
@ -216,11 +222,10 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
return (Status.SUCCEEDED, {})
return (Status.FAILED, {})
def _config_many(self, config: Dict[str, Any],
params: Dict[str, Any]) -> Tuple[Status, dict]:
def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
"""
Update the parameters of an Azure DB service one-by-one.
(If batch API is not available for it).
Update the parameters of an Azure DB service one-by-one. (If batch API is not
available for it).
Parameters
----------
@ -235,14 +240,13 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
for (param_name, param_value) in params.items():
for param_name, param_value in params.items():
(status, result) = self._config_one(config, param_name, param_value)
if not status.is_succeeded():
return (status, result)
return (Status.SUCCEEDED, {})
def _config_batch(self, config: Dict[str, Any],
params: Dict[str, Any]) -> Tuple[Status, dict]:
def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
"""
Batch update the parameters of an Azure DB service.
@ -259,19 +263,21 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig):
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
config = merge_parameters(
dest=self.config.copy(), source=config, required_keys=["vmName"])
config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
url = self._url_config_set.format(vm_name=config["vmName"])
json_req = {
"value": [
{"name": key, "properties": {"value": str(val)}}
for (key, val) in params.items()
{"name": key, "properties": {"value": str(val)}} for (key, val) in params.items()
],
# "resetAllToDefault": "True"
}
_LOG.debug("Request: POST %s", url)
response = requests.post(url, headers=self._get_headers(),
json=json_req, timeout=self._request_timeout)
response = requests.post(
url,
headers=self._get_headers(),
json=json_req,
timeout=self._request_timeout,
)
_LOG.debug("Response: %s :: %s", response, response.text)
if response.status_code == 504:
return (Status.TIMED_OUT, {})

Просмотреть файл

@ -2,33 +2,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A collection Service functions for managing VMs on Azure.
"""
"""A collection Service functions for managing VMs on Azure."""
import json
import logging
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import requests
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
from mlos_bench.services.remote.azure.azure_deployment_services import (
AzureDeploymentService,
)
from mlos_bench.services.types.host_ops_type import SupportsHostOps
from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
from mlos_bench.services.types.os_ops_type import SupportsOSOps
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.util import merge_parameters
_LOG = logging.getLogger(__name__)
class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec):
"""
Helper methods to manage VMs on Azure.
"""
class AzureVMService(
AzureDeploymentService,
SupportsHostProvisioning,
SupportsHostOps,
SupportsOSOps,
SupportsRemoteExec,
):
"""Helper methods to manage VMs on Azure."""
# pylint: disable=too-many-ancestors
@ -37,34 +40,34 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start
_URL_START = (
"https://management.azure.com" +
"/subscriptions/{subscription}" +
"/resourceGroups/{resource_group}" +
"/providers/Microsoft.Compute" +
"/virtualMachines/{vm_name}" +
"/start" +
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Compute"
"/virtualMachines/{vm_name}"
"/start"
"?api-version=2022-03-01"
)
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off
_URL_STOP = (
"https://management.azure.com" +
"/subscriptions/{subscription}" +
"/resourceGroups/{resource_group}" +
"/providers/Microsoft.Compute" +
"/virtualMachines/{vm_name}" +
"/powerOff" +
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Compute"
"/virtualMachines/{vm_name}"
"/powerOff"
"?api-version=2022-03-01"
)
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate
_URL_DEALLOCATE = (
"https://management.azure.com" +
"/subscriptions/{subscription}" +
"/resourceGroups/{resource_group}" +
"/providers/Microsoft.Compute" +
"/virtualMachines/{vm_name}" +
"/deallocate" +
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Compute"
"/virtualMachines/{vm_name}"
"/deallocate"
"?api-version=2022-03-01"
)
@ -76,42 +79,44 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/delete
# _URL_DEPROVISION = (
# "https://management.azure.com" +
# "/subscriptions/{subscription}" +
# "/resourceGroups/{resource_group}" +
# "/providers/Microsoft.Compute" +
# "/virtualMachines/{vm_name}" +
# "/delete" +
# "https://management.azure.com"
# "/subscriptions/{subscription}"
# "/resourceGroups/{resource_group}"
# "/providers/Microsoft.Compute"
# "/virtualMachines/{vm_name}"
# "/delete"
# "?api-version=2022-03-01"
# )
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart
_URL_REBOOT = (
"https://management.azure.com" +
"/subscriptions/{subscription}" +
"/resourceGroups/{resource_group}" +
"/providers/Microsoft.Compute" +
"/virtualMachines/{vm_name}" +
"/restart" +
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Compute"
"/virtualMachines/{vm_name}"
"/restart"
"?api-version=2022-03-01"
)
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command
_URL_REXEC_RUN = (
"https://management.azure.com" +
"/subscriptions/{subscription}" +
"/resourceGroups/{resource_group}" +
"/providers/Microsoft.Compute" +
"/virtualMachines/{vm_name}" +
"/runCommand" +
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Compute"
"/virtualMachines/{vm_name}"
"/runCommand"
"?api-version=2022-03-01"
)
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of Azure VM services proxy.
@ -128,26 +133,31 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
New methods to register with the service.
"""
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [
# SupportsHostProvisioning
self.provision_host,
self.deprovision_host,
self.deallocate_host,
self.wait_host_deployment,
# SupportsHostOps
self.start_host,
self.stop_host,
self.restart_host,
self.wait_host_operation,
# SupportsOSOps
self.shutdown,
self.reboot,
self.wait_os_operation,
# SupportsRemoteExec
self.remote_exec,
self.get_remote_exec_results,
])
config,
global_config,
parent,
self.merge_methods(
methods,
[
# SupportsHostProvisioning
self.provision_host,
self.deprovision_host,
self.deallocate_host,
self.wait_host_deployment,
# SupportsHostOps
self.start_host,
self.stop_host,
self.restart_host,
self.wait_host_operation,
# SupportsOSOps
self.shutdown,
self.reboot,
self.wait_os_operation,
# SupportsRemoteExec
self.remote_exec,
self.get_remote_exec_results,
],
),
)
# As a convenience, allow reading customData out of a file, rather than
@ -156,19 +166,24 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
# can be done using the `base64()` string function inside the ARM template.
self._custom_data_file = self.config.get("customDataFile", None)
if self._custom_data_file:
if self._deploy_params.get('customData', None):
if self._deploy_params.get("customData", None):
raise ValueError("Both customDataFile and customData are specified.")
self._custom_data_file = self.config_loader_service.resolve_path(self._custom_data_file)
with open(self._custom_data_file, 'r', encoding='utf-8') as custom_data_fh:
self._custom_data_file = self.config_loader_service.resolve_path(
self._custom_data_file
)
with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh:
self._deploy_params["customData"] = custom_data_fh.read()
def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use
# Try and provide a semi sane default for the deploymentName if not provided
# since this is a common way to set the deploymentName and can same some
# config work for the caller.
if "vmName" in params and "deploymentName" not in params:
params["deploymentName"] = f"{params['vmName']}-deployment"
_LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"])
_LOG.info(
"deploymentName missing from params. Defaulting to '%s'.",
params["deploymentName"],
)
return params
def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
@ -263,20 +278,24 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"resourceGroup",
"deploymentName",
"vmName",
]
],
)
_LOG.info("Deprovision VM: %s", config["vmName"])
_LOG.info("Deprovision deployment: %s", config["deploymentName"])
# TODO: Properly deprovision *all* resources specified in the ARM template.
return self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
))
return self._azure_rest_api_post_helper(
config,
self._URL_DEPROVISION.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
),
)
def deallocate_host(self, params: dict) -> Tuple[Status, dict]:
"""
Deallocates the VM on Azure by shutting it down then releasing the compute resources.
Deallocates the VM on Azure by shutting it down then releasing the compute
resources.
Note: This can cause the VM to arrive on a new host node when its
restarted, which may have different performance characteristics.
@ -300,14 +319,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
]
],
)
_LOG.info("Deallocate VM: %s", config["vmName"])
return self._azure_rest_api_post_helper(config, self._URL_DEALLOCATE.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
))
return self._azure_rest_api_post_helper(
config,
self._URL_DEALLOCATE.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
),
)
def start_host(self, params: dict) -> Tuple[Status, dict]:
"""
@ -332,14 +354,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
]
],
)
_LOG.info("Start VM: %s :: %s", config["vmName"], params)
return self._azure_rest_api_post_helper(config, self._URL_START.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
))
return self._azure_rest_api_post_helper(
config,
self._URL_START.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
),
)
def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]:
"""
@ -366,14 +391,17 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
]
],
)
_LOG.info("Stop VM: %s", config["vmName"])
return self._azure_rest_api_post_helper(config, self._URL_STOP.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
))
return self._azure_rest_api_post_helper(
config,
self._URL_STOP.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
),
)
def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
return self.stop_host(params, force)
@ -403,20 +431,27 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
]
],
)
_LOG.info("Reboot VM: %s", config["vmName"])
return self._azure_rest_api_post_helper(config, self._URL_REBOOT.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
))
return self._azure_rest_api_post_helper(
config,
self._URL_REBOOT.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
),
)
def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
return self.restart_host(params, force)
def remote_exec(self, script: Iterable[str], config: dict,
env_params: dict) -> Tuple[Status, dict]:
def remote_exec(
self,
script: Iterable[str],
config: dict,
env_params: dict,
) -> Tuple[Status, dict]:
"""
Run a command on Azure VM.
@ -446,7 +481,7 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
"subscription",
"resourceGroup",
"vmName",
]
],
)
if _LOG.isEnabledFor(logging.INFO):
@ -455,7 +490,7 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
json_req = {
"commandId": "RunShellScript",
"script": list(script),
"parameters": [{"name": key, "value": val} for (key, val) in env_params.items()]
"parameters": [{"name": key, "value": val} for (key, val) in env_params.items()],
}
url = self._URL_REXEC_RUN.format(
@ -468,12 +503,18 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
_LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2))
response = requests.post(
url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout)
url,
json=json_req,
headers=self._get_headers(),
timeout=self._request_timeout,
)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Response: %s\n%s", response,
json.dumps(response.json(), indent=2)
if response.content else "")
_LOG.debug(
"Response: %s\n%s",
response,
json.dumps(response.json(), indent=2) if response.content else "",
)
else:
_LOG.info("Response: %s", response)
@ -481,10 +522,10 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH
# TODO: extract the results from JSON response
return (Status.SUCCEEDED, config)
elif response.status_code == 202:
return (Status.PENDING, {
**config,
"asyncResultsUrl": response.headers.get("Azure-AsyncOperation")
})
return (
Status.PENDING,
{**config, "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")},
)
else:
_LOG.error("Response: %s :: %s", response, response.text)
# _LOG.error("Bad Request:\n%s", response.request.body)

Просмотреть файл

@ -4,8 +4,8 @@
#
"""SSH remote service."""
from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService
from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService
__all__ = [
"SshHostService",

Просмотреть файл

@ -2,16 +2,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A collection functions for interacting with SSH servers as file shares.
"""
"""A collection functions for interacting with SSH servers as file shares."""
import logging
from enum import Enum
from typing import Tuple, Union
import logging
from asyncssh import scp, SFTPError, SFTPNoSuchFile, SFTPFailure, SSHClientConnection
from asyncssh import SFTPError, SFTPFailure, SFTPNoSuchFile, SSHClientConnection, scp
from mlos_bench.services.base_fileshare import FileShareService
from mlos_bench.services.remote.ssh.ssh_service import SshService
@ -21,9 +18,7 @@ _LOG = logging.getLogger(__name__)
class CopyMode(Enum):
"""
Copy mode enum.
"""
"""Copy mode enum."""
DOWNLOAD = 1
UPLOAD = 2
@ -32,17 +27,23 @@ class CopyMode(Enum):
class SshFileShareService(FileShareService, SshService):
"""A collection of functions for interacting with SSH servers as file shares."""
async def _start_file_copy(self, params: dict, mode: CopyMode,
local_path: str, remote_path: str,
recursive: bool = True) -> None:
async def _start_file_copy(
self,
params: dict,
mode: CopyMode,
local_path: str,
remote_path: str,
recursive: bool = True,
) -> None:
# pylint: disable=too-many-arguments
"""
Starts a file copy operation
Starts a file copy operation.
Parameters
----------
params : dict
Flat dictionary of (key, value) pairs of parameters (used for establishing the connection).
Flat dictionary of (key, value) pairs of parameters (used for
establishing the connection).
mode : CopyMode
Whether to download or upload the file.
local_path : str
@ -74,40 +75,70 @@ class SshFileShareService(FileShareService, SshService):
raise ValueError(f"Unknown copy mode: {mode}")
return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True)
def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
def download(
self,
params: dict,
remote_path: str,
local_path: str,
recursive: bool = True,
) -> None:
params = merge_parameters(
dest=self.config.copy(),
source=params,
required_keys=[
"ssh_hostname",
]
],
)
super().download(params, remote_path, local_path, recursive)
file_copy_future = self._run_coroutine(
self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive))
self._start_file_copy(
params,
CopyMode.DOWNLOAD,
local_path,
remote_path,
recursive,
)
)
try:
file_copy_future.result()
except (OSError, SFTPError) as ex:
_LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex)
_LOG.error(
"Failed to download %s to %s from %s: %s",
remote_path,
local_path,
params,
ex,
)
if isinstance(ex, SFTPNoSuchFile) or (
isinstance(ex, SFTPFailure) and ex.code == 4
and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory"))
isinstance(ex, SFTPFailure)
and ex.code == 4
and any(
msg.lower() in ex.reason.lower()
for msg in ("File not found", "No such file or directory")
)
):
_LOG.warning("File %s does not exist on %s", remote_path, params)
raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex
raise ex
def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
def upload(
self,
params: dict,
local_path: str,
remote_path: str,
recursive: bool = True,
) -> None:
params = merge_parameters(
dest=self.config.copy(),
source=params,
required_keys=[
"ssh_hostname",
]
],
)
super().upload(params, local_path, remote_path, recursive)
file_copy_future = self._run_coroutine(
self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive))
self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)
)
try:
file_copy_future.result()
except (OSError, SFTPError) as ex:

Просмотреть файл

@ -2,39 +2,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A collection Service functions for managing hosts via SSH.
"""
"""A collection Service functions for managing hosts via SSH."""
import logging
from concurrent.futures import Future
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import logging
from asyncssh import SSHCompletedProcess, ConnectionLost, DisconnectError, ProcessError
from asyncssh import ConnectionLost, DisconnectError, ProcessError, SSHCompletedProcess
from mlos_bench.environments.status import Status
from mlos_bench.services.base_service import Service
from mlos_bench.services.remote.ssh.ssh_service import SshService
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.services.types.os_ops_type import SupportsOSOps
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
from mlos_bench.util import merge_parameters
_LOG = logging.getLogger(__name__)
class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
"""
Helper methods to manage machines via SSH.
"""
"""Helper methods to manage machines via SSH."""
# pylint: disable=too-many-instance-attributes
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
"""
Create a new instance of an SSH Service.
@ -53,24 +50,36 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
# Same methods are also provided by the AzureVMService class
# pylint: disable=duplicate-code
super().__init__(
config, global_config, parent,
self.merge_methods(methods, [
self.shutdown,
self.reboot,
self.wait_os_operation,
self.remote_exec,
self.get_remote_exec_results,
]))
config,
global_config,
parent,
self.merge_methods(
methods,
[
self.shutdown,
self.reboot,
self.wait_os_operation,
self.remote_exec,
self.get_remote_exec_results,
],
),
)
self._shell = self.config.get("ssh_shell", "/bin/bash")
async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess:
async def _run_cmd(
self,
params: dict,
script: Iterable[str],
env_params: dict,
) -> SSHCompletedProcess:
"""
Runs a command asynchronously on a host via SSH.
Parameters
----------
params : dict
Flat dictionary of (key, value) pairs of parameters (used for establishing the connection).
Flat dictionary of (key, value) pairs of parameters (used for
establishing the connection).
cmd : str
Command(s) to run via shell.
@ -83,19 +92,29 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
# Script should be an iterable of lines, not an iterable string.
script = [script]
connection, _ = await self._get_client_connection(params)
# Note: passing environment variables to SSH servers is typically restricted to just some LC_* values.
# Note: passing environment variables to SSH servers is typically restricted
# to just some LC_* values.
# Handle transferring environment variables by making a script to set them.
env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()]
script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()]
script_lines = env_script_lines + [
line_split for line in script for line_split in line.splitlines()
]
# Note: connection.run() uses "exec" with a shell by default.
script_str = '\n'.join(script_lines)
script_str = "\n".join(script_lines)
_LOG.debug("Running script on %s:\n%s", connection, script_str)
return await connection.run(script_str,
check=False,
timeout=self._request_timeout,
env=env_params)
return await connection.run(
script_str,
check=False,
timeout=self._request_timeout,
env=env_params,
)
def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]:
def remote_exec(
self,
script: Iterable[str],
config: dict,
env_params: dict,
) -> Tuple["Status", dict]:
"""
Start running a command on remote host OS.
@ -122,9 +141,15 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
source=config,
required_keys=[
"ssh_hostname",
]
],
)
config["asyncRemoteExecResultsFuture"] = self._run_coroutine(
self._run_cmd(
config,
script,
env_params,
)
)
config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params))
return (Status.PENDING, config)
def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]:
@ -155,7 +180,11 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout
stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr
return (
Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED,
(
Status.SUCCEEDED
if result.exit_status == 0 and result.returncode == 0
else Status.FAILED
),
{
"stdout": stdout,
"stderr": stderr,
@ -167,7 +196,8 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
return (Status.FAILED, {"result": result})
def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]:
"""_summary_
"""
_summary_
Parameters
----------
@ -187,9 +217,9 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
source=params,
required_keys=[
"ssh_hostname",
]
],
)
cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list])
cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list])
script = rf"""
if [[ $EUID -ne 0 ]]; then
sudo=$(command -v sudo)
@ -224,10 +254,10 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
cmd_opts_list = [
'shutdown -h now',
'poweroff',
'halt -p',
'systemctl poweroff',
"shutdown -h now",
"poweroff",
"halt -p",
"systemctl poweroff",
]
return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params)
@ -249,18 +279,18 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
cmd_opts_list = [
'shutdown -r now',
'reboot',
'halt --reboot',
'systemctl reboot',
'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1',
"shutdown -r now",
"reboot",
"halt --reboot",
"systemctl reboot",
"kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1",
]
return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params)
def wait_os_operation(self, params: dict) -> Tuple[Status, dict]:
"""
Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED.
Return TIMED_OUT when timing out.
Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return
TIMED_OUT when timing out.
Parameters
----------

Просмотреть файл

@ -2,25 +2,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
A collection functions for interacting with SSH servers as file shares.
"""
from abc import ABCMeta
from asyncio import Event as CoroEvent, Lock as CoroLock
from warnings import warn
from types import TracebackType
from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, Union
from threading import current_thread
"""A collection functions for interacting with SSH servers as file shares."""
import logging
import os
from abc import ABCMeta
from asyncio import Event as CoroEvent
from asyncio import Lock as CoroLock
from threading import current_thread
from types import TracebackType
from typing import (
Any,
Callable,
Coroutine,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
from warnings import warn
import asyncssh
from asyncssh.connection import SSHClientConnection
from mlos_bench.event_loop_context import (
CoroReturnType,
EventLoopContext,
FutureReturnType,
)
from mlos_bench.services.base_service import Service
from mlos_bench.event_loop_context import EventLoopContext, CoroReturnType, FutureReturnType
from mlos_bench.util import nullable
_LOG = logging.getLogger(__name__)
@ -30,13 +43,13 @@ class SshClient(asyncssh.SSHClient):
"""
Wrapper around SSHClient to help provide connection caching and reconnect logic.
Used by the SshService to try and maintain a single connection to hosts,
handle reconnects if possible, and use that to run commands rather than
reconnect for each command.
Used by the SshService to try and maintain a single connection to hosts, handle
reconnects if possible, and use that to run commands rather than reconnect for each
command.
"""
_CONNECTION_PENDING = 'INIT'
_CONNECTION_LOST = 'LOST'
_CONNECTION_PENDING = "INIT"
_CONNECTION_LOST = "LOST"
def __init__(self, *args: tuple, **kwargs: dict):
self._connection_id: str = SshClient._CONNECTION_PENDING
@ -50,12 +63,16 @@ class SshClient(asyncssh.SSHClient):
@staticmethod
def id_from_connection(connection: SSHClientConnection) -> str:
"""Gets a unique id repr for the connection."""
return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access
# pylint: disable=protected-access
return f"{connection._username}@{connection._host}:{connection._port}"
@staticmethod
def id_from_params(connect_params: dict) -> str:
"""Gets a unique id repr for the connection."""
return f"{connect_params.get('username')}@{connect_params['host']}:{connect_params.get('port')}"
return (
f"{connect_params.get('username')}@{connect_params['host']}"
f":{connect_params.get('port')}"
)
def connection_made(self, conn: SSHClientConnection) -> None:
"""
@ -64,8 +81,12 @@ class SshClient(asyncssh.SSHClient):
Changes the connection_id from _CONNECTION_PENDING to a unique id repr.
"""
self._conn_event.clear()
_LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \
# pylint: disable=protected-access
_LOG.debug(
"%s: Connection made by %s: %s",
current_thread().name,
conn._options.env, # pylint: disable=protected-access
conn,
)
self._connection_id = SshClient.id_from_connection(conn)
self._connection = conn
self._conn_event.set()
@ -75,18 +96,26 @@ class SshClient(asyncssh.SSHClient):
self._conn_event.clear()
_LOG.debug("%s: %s", current_thread().name, "connection_lost")
if exc is None:
_LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc)
_LOG.debug(
"%s: gracefully disconnected ssh from %s: %s",
current_thread().name,
self._connection_id,
exc,
)
else:
_LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc)
_LOG.debug(
"%s: ssh connection lost on %s: %s",
current_thread().name,
self._connection_id,
exc,
)
self._connection_id = SshClient._CONNECTION_LOST
self._connection = None
self._conn_event.set()
return super().connection_lost(exc)
async def connection(self) -> Optional[SSHClientConnection]:
"""
Waits for and returns the SSHClientConnection to be established or lost.
"""
"""Waits for and returns the SSHClientConnection to be established or lost."""
_LOG.debug("%s: Waiting for connection to be available.", current_thread().name)
await self._conn_event.wait()
_LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id)
@ -96,6 +125,7 @@ class SshClient(asyncssh.SSHClient):
class SshClientCache:
"""
Manages a cache of SshClient connections.
Note: Only one per event loop thread supported.
See additional details in SshService comments.
"""
@ -114,6 +144,7 @@ class SshClientCache:
def enter(self) -> None:
"""
Manages the cache lifecycle with reference counting.
To be used in the __enter__ method of a caller's context manager.
"""
self._refcnt += 1
@ -121,6 +152,7 @@ class SshClientCache:
def exit(self) -> None:
"""
Manages the cache lifecycle with reference counting.
To be used in the __exit__ method of a caller's context manager.
"""
self._refcnt -= 1
@ -130,7 +162,10 @@ class SshClientCache:
warn(RuntimeWarning("SshClientCache lock was still held on exit."))
self._cache_lock.release()
async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]:
async def get_client_connection(
self,
connect_params: dict,
) -> Tuple[SSHClientConnection, SshClient]:
"""
Gets a (possibly cached) client connection.
@ -153,13 +188,21 @@ class SshClientCache:
_LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id)
connection = await client.connection()
if not connection:
_LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id)
_LOG.debug(
"%s: Removing stale client connection %s from cache.",
current_thread().name,
connection_id,
)
self._cache.pop(connection_id)
# Try to reconnect next.
else:
_LOG.debug("%s: Using cached client %s", current_thread().name, connection_id)
if connection_id not in self._cache:
_LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id)
_LOG.debug(
"%s: Establishing client connection to %s",
current_thread().name,
connection_id,
)
connection, client = await asyncssh.create_connection(SshClient, **connect_params)
assert isinstance(client, SshClient)
self._cache[connection_id] = (connection, client)
@ -167,18 +210,14 @@ class SshClientCache:
return self._cache[connection_id]
def cleanup(self) -> None:
"""
Closes all cached connections.
"""
for (connection, _) in self._cache.values():
"""Closes all cached connections."""
for connection, _ in self._cache.values():
connection.close()
self._cache = {}
class SshService(Service, metaclass=ABCMeta):
"""
Base class for SSH services.
"""
"""Base class for SSH services."""
# AsyncSSH requires an asyncio event loop to be running to work.
# However, running that event loop blocks the main thread.
@ -210,21 +249,23 @@ class SshService(Service, metaclass=ABCMeta):
_REQUEST_TIMEOUT: Optional[float] = None # seconds
def __init__(self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None):
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
global_config: Optional[Dict[str, Any]] = None,
parent: Optional[Service] = None,
methods: Union[Dict[str, Callable], List[Callable], None] = None,
):
super().__init__(config, global_config, parent, methods)
# Make sure that the value we allow overriding on a per-connection
# basis are present in the config so merge_parameters can do its thing.
self.config.setdefault('ssh_port', None)
assert isinstance(self.config['ssh_port'], (int, type(None)))
self.config.setdefault('ssh_username', None)
assert isinstance(self.config['ssh_username'], (str, type(None)))
self.config.setdefault('ssh_priv_key_path', None)
assert isinstance(self.config['ssh_priv_key_path'], (str, type(None)))
self.config.setdefault("ssh_port", None)
assert isinstance(self.config["ssh_port"], (int, type(None)))
self.config.setdefault("ssh_username", None)
assert isinstance(self.config["ssh_username"], (str, type(None)))
self.config.setdefault("ssh_priv_key_path", None)
assert isinstance(self.config["ssh_priv_key_path"], (str, type(None)))
# None can be used to disable the request timeout.
self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT)
@ -235,24 +276,25 @@ class SshService(Service, metaclass=ABCMeta):
# In general scripted commands shouldn't need a pty and having one
# available can confuse some commands, though we may need to make
# this configurable in the future.
'request_pty': False,
# By default disable known_hosts checking (since most VMs expected to be dynamically created).
'known_hosts': None,
"request_pty": False,
# By default disable known_hosts checking (since most VMs expected to be
# dynamically created).
"known_hosts": None,
}
if 'ssh_known_hosts_file' in self.config:
self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None)
if isinstance(self._connect_params['known_hosts'], str):
known_hosts_file = os.path.expanduser(self._connect_params['known_hosts'])
if "ssh_known_hosts_file" in self.config:
self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None)
if isinstance(self._connect_params["known_hosts"], str):
known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"])
if not os.path.exists(known_hosts_file):
raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist")
self._connect_params['known_hosts'] = known_hosts_file
if self._connect_params['known_hosts'] is None:
self._connect_params["known_hosts"] = known_hosts_file
if self._connect_params["known_hosts"] is None:
_LOG.info("%s known_hosts checking is disabled per config.", self)
if 'ssh_keepalive_interval' in self.config:
keepalive_internal = self.config.get('ssh_keepalive_interval')
self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal)
if "ssh_keepalive_interval" in self.config:
keepalive_internal = self.config.get("ssh_keepalive_interval")
self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal)
def _enter_context(self) -> "SshService":
# Start the background thread if it's not already running.
@ -262,9 +304,12 @@ class SshService(Service, metaclass=ABCMeta):
super()._enter_context()
return self
def _exit_context(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
def _exit_context(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> Literal[False]:
# Stop the background thread if it's not needed anymore and potentially
# cleanup the cache as well.
assert self._in_context
@ -276,6 +321,7 @@ class SshService(Service, metaclass=ABCMeta):
def clear_client_cache(cls) -> None:
"""
Clears the cache of client connections.
Note: This may cause in flight operations to fail.
"""
cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup()
@ -319,24 +365,27 @@ class SshService(Service, metaclass=ABCMeta):
# Start with the base config params.
connect_params = self._connect_params.copy()
connect_params['host'] = params['ssh_hostname'] # required
connect_params["host"] = params["ssh_hostname"] # required
if params.get('ssh_port'):
connect_params['port'] = int(params.pop('ssh_port'))
elif self.config['ssh_port']:
connect_params['port'] = int(self.config['ssh_port'])
if params.get("ssh_port"):
connect_params["port"] = int(params.pop("ssh_port"))
elif self.config["ssh_port"]:
connect_params["port"] = int(self.config["ssh_port"])
if 'ssh_username' in params:
connect_params['username'] = str(params.pop('ssh_username'))
elif self.config['ssh_username']:
connect_params['username'] = str(self.config['ssh_username'])
if "ssh_username" in params:
connect_params["username"] = str(params.pop("ssh_username"))
elif self.config["ssh_username"]:
connect_params["username"] = str(self.config["ssh_username"])
priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path'])
priv_key_file: Optional[str] = params.get(
"ssh_priv_key_path",
self.config["ssh_priv_key_path"],
)
if priv_key_file:
priv_key_file = os.path.expanduser(priv_key_file)
if not os.path.exists(priv_key_file):
raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist")
connect_params['client_keys'] = [priv_key_file]
connect_params["client_keys"] = [priv_key_file]
return connect_params
@ -355,4 +404,6 @@ class SshService(Service, metaclass=ABCMeta):
The connection and client objects.
"""
assert self._in_context
return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params))
return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(
self._get_connect_params(params)
)

Просмотреть файл

@ -2,8 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Service types for implementing declaring Service behavior for Environments to use in mlos_bench.
"""Service types for implementing declaring Service behavior for Environments to use in
mlos_bench.
"""
from mlos_bench.services.types.authenticator_type import SupportsAuth
@ -11,18 +11,19 @@ from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
from mlos_bench.services.types.local_exec_type import SupportsLocalExec
from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning
from mlos_bench.services.types.network_provisioner_type import (
SupportsNetworkProvisioning,
)
from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig
from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec
__all__ = [
'SupportsAuth',
'SupportsConfigLoading',
'SupportsFileShareOps',
'SupportsHostProvisioning',
'SupportsLocalExec',
'SupportsNetworkProvisioning',
'SupportsRemoteConfig',
'SupportsRemoteExec',
"SupportsAuth",
"SupportsConfigLoading",
"SupportsFileShareOps",
"SupportsHostProvisioning",
"SupportsLocalExec",
"SupportsNetworkProvisioning",
"SupportsRemoteConfig",
"SupportsRemoteExec",
]

Просмотреть файл

@ -2,18 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for authentication for the cloud services.
"""
"""Protocol interface for authentication for the cloud services."""
from typing import Protocol, runtime_checkable
@runtime_checkable
class SupportsAuth(Protocol):
"""
Protocol interface for authentication for the cloud services.
"""
"""Protocol interface for authentication for the cloud services."""
def get_access_token(self) -> str:
"""

Просмотреть файл

@ -2,34 +2,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for helper functions to lookup and load configs.
"""
"""Protocol interface for helper functions to lookup and load configs."""
from typing import Any, Dict, List, Iterable, Optional, Union, Protocol, runtime_checkable, TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Protocol,
Union,
runtime_checkable,
)
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.tunables.tunable import TunableValue
# Avoid's circular import issues.
if TYPE_CHECKING:
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.services.base_service import Service
from mlos_bench.environments.base_environment import Environment
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable_groups import TunableGroups
@runtime_checkable
class SupportsConfigLoading(Protocol):
"""
Protocol interface for helper functions to lookup and load configs.
"""
"""Protocol interface for helper functions to lookup and load configs."""
def resolve_path(self, file_path: str,
extra_paths: Optional[Iterable[str]] = None) -> str:
def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str:
"""
Prepend the suitable `_config_path` to `path` if the latter is not absolute.
If `_config_path` is `None` or `path` is absolute, return `path` as is.
Prepend the suitable `_config_path` to `path` if the latter is not absolute. If
`_config_path` is `None` or `path` is absolute, return `path` as is.
Parameters
----------
@ -44,11 +48,14 @@ class SupportsConfigLoading(Protocol):
An actual path to the config or script.
"""
def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]:
def load_config(
self,
json_file_name: str,
schema_type: Optional[ConfigSchema],
) -> Union[dict, List[dict]]:
"""
Load JSON config file. Search for a file relative to `_config_path`
if the input path is not absolute.
This method is exported to be used as a service.
Load JSON config file. Search for a file relative to `_config_path` if the input
path is not absolute. This method is exported to be used as a service.
Parameters
----------
@ -63,12 +70,14 @@ class SupportsConfigLoading(Protocol):
Free-format dictionary that contains the configuration.
"""
def build_environment(self, # pylint: disable=too-many-arguments
config: dict,
tunables: "TunableGroups",
global_config: Optional[dict] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional["Service"] = None) -> "Environment":
def build_environment(
self, # pylint: disable=too-many-arguments
config: dict,
tunables: "TunableGroups",
global_config: Optional[dict] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional["Service"] = None,
) -> "Environment":
"""
Factory method for a new environment with a given config.
@ -98,12 +107,13 @@ class SupportsConfigLoading(Protocol):
"""
def load_environment_list( # pylint: disable=too-many-arguments
self,
json_file_name: str,
tunables: "TunableGroups",
global_config: Optional[dict] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional["Service"] = None) -> List["Environment"]:
self,
json_file_name: str,
tunables: "TunableGroups",
global_config: Optional[dict] = None,
parent_args: Optional[Dict[str, TunableValue]] = None,
service: Optional["Service"] = None,
) -> List["Environment"]:
"""
Load and build a list of environments from the config file.
@ -128,12 +138,15 @@ class SupportsConfigLoading(Protocol):
A list of new benchmarking environments.
"""
def load_services(self, json_file_names: Iterable[str],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional["Service"] = None) -> "Service":
def load_services(
self,
json_file_names: Iterable[str],
global_config: Optional[Dict[str, Any]] = None,
parent: Optional["Service"] = None,
) -> "Service":
"""
Read the configuration files and bundle all service methods
from those configs into a single Service object.
Read the configuration files and bundle all service methods from those configs
into a single Service object.
Parameters
----------

Просмотреть файл

@ -2,20 +2,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for file share operations.
"""
"""Protocol interface for file share operations."""
from typing import Protocol, runtime_checkable
@runtime_checkable
class SupportsFileShareOps(Protocol):
"""
Protocol interface for file share operations.
"""
"""Protocol interface for file share operations."""
def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None:
def download(
self,
params: dict,
remote_path: str,
local_path: str,
recursive: bool = True,
) -> None:
"""
Downloads contents from a remote share path to a local path.
@ -33,7 +35,13 @@ class SupportsFileShareOps(Protocol):
if True (the default), download the entire directory tree.
"""
def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None:
def upload(
self,
params: dict,
local_path: str,
remote_path: str,
recursive: bool = True,
) -> None:
"""
Uploads contents from a local path to remote share path.

Просмотреть файл

@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for Host/VM boot operations.
"""
"""Protocol interface for Host/VM boot operations."""
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsHostOps(Protocol):
"""
Protocol interface for Host/VM boot operations.
"""
"""Protocol interface for Host/VM boot operations."""
def start_host(self, params: dict) -> Tuple["Status", dict]:
"""

Просмотреть файл

@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for Host/VM provisioning operations.
"""
"""Protocol interface for Host/VM provisioning operations."""
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsHostProvisioning(Protocol):
"""
Protocol interface for Host/VM provisioning operations.
"""
"""Protocol interface for Host/VM provisioning operations."""
def provision_host(self, params: dict) -> Tuple["Status", dict]:
"""
@ -46,7 +42,8 @@ class SupportsHostProvisioning(Protocol):
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
is_setup : bool
If True, wait for Host/VM being deployed; otherwise, wait for successful deprovisioning.
If True, wait for Host/VM being deployed; otherwise, wait for successful
deprovisioning.
Returns
-------
@ -74,7 +71,8 @@ class SupportsHostProvisioning(Protocol):
def deallocate_host(self, params: dict) -> Tuple["Status", dict]:
"""
Deallocates the Host/VM by shutting it down then releasing the compute resources.
Deallocates the Host/VM by shutting it down then releasing the compute
resources.
Note: This can cause the VM to arrive on a new host node when its
restarted, which may have different performance characteristics.

Просмотреть файл

@ -2,15 +2,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for Service types that provide helper functions to run
scripts and commands locally on the scheduler side.
"""Protocol interface for Service types that provide helper functions to run scripts and
commands locally on the scheduler side.
"""
from typing import Iterable, Mapping, Optional, Tuple, Union, Protocol, runtime_checkable
import tempfile
import contextlib
import tempfile
from typing import (
Iterable,
Mapping,
Optional,
Protocol,
Tuple,
Union,
runtime_checkable,
)
from mlos_bench.tunables.tunable import TunableValue
@ -18,16 +24,19 @@ from mlos_bench.tunables.tunable import TunableValue
@runtime_checkable
class SupportsLocalExec(Protocol):
"""
Protocol interface for a collection of methods to run scripts and commands
in an external process on the node acting as the scheduler. Can be useful
for data processing due to reduced dependency management complications vs
the target environment.
Used in LocalEnv and provided by LocalExecService.
Protocol interface for a collection of methods to run scripts and commands in an
external process on the node acting as the scheduler.
Can be useful for data processing due to reduced dependency management complications
vs the target environment. Used in LocalEnv and provided by LocalExecService.
"""
def local_exec(self, script_lines: Iterable[str],
env: Optional[Mapping[str, TunableValue]] = None,
cwd: Optional[str] = None) -> Tuple[int, str, str]:
def local_exec(
self,
script_lines: Iterable[str],
env: Optional[Mapping[str, TunableValue]] = None,
cwd: Optional[str] = None,
) -> Tuple[int, str, str]:
"""
Execute the script lines from `script_lines` in a local process.
@ -48,7 +57,10 @@ class SupportsLocalExec(Protocol):
A 3-tuple of return code, stdout, and stderr of the script process.
"""
def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
def temp_dir_context(
self,
path: Optional[str] = None,
) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]:
"""
Create a temp directory or use the provided path.

Просмотреть файл

@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for Network provisioning operations.
"""
"""Protocol interface for Network provisioning operations."""
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsNetworkProvisioning(Protocol):
"""
Protocol interface for Network provisioning operations.
"""
"""Protocol interface for Network provisioning operations."""
def provision_network(self, params: dict) -> Tuple["Status", dict]:
"""
@ -46,7 +42,8 @@ class SupportsNetworkProvisioning(Protocol):
params : dict
Flat dictionary of (key, value) pairs of tunable parameters.
is_setup : bool
If True, wait for Network being deployed; otherwise, wait for successful deprovisioning.
If True, wait for Network being deployed; otherwise, wait for successful
deprovisioning.
Returns
-------
@ -56,7 +53,11 @@ class SupportsNetworkProvisioning(Protocol):
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple["Status", dict]:
def deprovision_network(
self,
params: dict,
ignore_errors: bool = True,
) -> Tuple["Status", dict]:
"""
Deprovisions the Network by deleting it.

Просмотреть файл

@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for Host/OS operations.
"""
"""Protocol interface for Host/OS operations."""
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsOSOps(Protocol):
"""
Protocol interface for Host/OS operations.
"""
"""Protocol interface for Host/OS operations."""
def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]:
"""
@ -56,8 +52,8 @@ class SupportsOSOps(Protocol):
def wait_os_operation(self, params: dict) -> Tuple["Status", dict]:
"""
Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED.
Return TIMED_OUT when timing out.
Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return
TIMED_OUT when timing out.
Parameters
----------

Просмотреть файл

@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for configuring cloud services.
"""
"""Protocol interface for configuring cloud services."""
from typing import Any, Dict, Protocol, Tuple, TYPE_CHECKING, runtime_checkable
from typing import TYPE_CHECKING, Any, Dict, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@ -14,12 +12,9 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsRemoteConfig(Protocol):
"""
Protocol interface for configuring cloud services.
"""
"""Protocol interface for configuring cloud services."""
def configure(self, config: Dict[str, Any],
params: Dict[str, Any]) -> Tuple["Status", dict]:
def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple["Status", dict]:
"""
Update the parameters of a SaaS service in the cloud.

Просмотреть файл

@ -2,12 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for Service types that provide helper functions to run
scripts on a remote host OS.
"""Protocol interface for Service types that provide helper functions to run scripts on
a remote host OS.
"""
from typing import Iterable, Tuple, Protocol, runtime_checkable, TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@ -15,13 +14,16 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsRemoteExec(Protocol):
"""
Protocol interface for Service types that provide helper functions to run
scripts on a remote host OS.
"""Protocol interface for Service types that provide helper functions to run scripts
on a remote host OS.
"""
def remote_exec(self, script: Iterable[str], config: dict,
env_params: dict) -> Tuple["Status", dict]:
def remote_exec(
self,
script: Iterable[str],
config: dict,
env_params: dict,
) -> Tuple["Status", dict]:
"""
Run a command on remote host OS.

Просмотреть файл

@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Protocol interface for VM provisioning operations.
"""
"""Protocol interface for VM provisioning operations."""
from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable
if TYPE_CHECKING:
from mlos_bench.environments.status import Status
@ -14,9 +12,7 @@ if TYPE_CHECKING:
@runtime_checkable
class SupportsVMOps(Protocol):
"""
Protocol interface for VM provisioning operations.
"""
"""Protocol interface for VM provisioning operations."""
def vm_provision(self, params: dict) -> Tuple["Status", dict]:
"""
@ -122,8 +118,8 @@ class SupportsVMOps(Protocol):
def wait_vm_operation(self, params: dict) -> Tuple["Status", dict]:
"""
Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED.
Return TIMED_OUT when timing out.
Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED. Return
TIMED_OUT when timing out.
Parameters
----------

Просмотреть файл

@ -2,14 +2,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Interfaces to the storage backends for OS Autotune.
"""
"""Interfaces to the storage backends for OS Autotune."""
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.storage_factory import from_config
__all__ = [
'Storage',
'from_config',
"Storage",
"from_config",
]

Просмотреть файл

@ -2,13 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base interface for accessing the stored benchmark experiment data.
"""
"""Base interface for accessing the stored benchmark experiment data."""
from abc import ABCMeta, abstractmethod
from distutils.util import strtobool # pylint: disable=deprecated-module
from typing import Dict, Literal, Optional, Tuple, TYPE_CHECKING
from distutils.util import strtobool # pylint: disable=deprecated-module
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple
import pandas
@ -16,7 +14,9 @@ from mlos_bench.storage.base_tunable_config_data import TunableConfigData
if TYPE_CHECKING:
from mlos_bench.storage.base_trial_data import TrialData
from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
from mlos_bench.storage.base_tunable_config_trial_group_data import (
TunableConfigTrialGroupData,
)
class ExperimentData(metaclass=ABCMeta):
@ -30,12 +30,15 @@ class ExperimentData(metaclass=ABCMeta):
RESULT_COLUMN_PREFIX = "result."
CONFIG_COLUMN_PREFIX = "config."
def __init__(self, *,
experiment_id: str,
description: str,
root_env_config: str,
git_repo: str,
git_commit: str):
def __init__(
self,
*,
experiment_id: str,
description: str,
root_env_config: str,
git_repo: str,
git_commit: str,
):
self._experiment_id = experiment_id
self._description = description
self._root_env_config = root_env_config
@ -44,16 +47,12 @@ class ExperimentData(metaclass=ABCMeta):
@property
def experiment_id(self) -> str:
"""
ID of the experiment.
"""
"""ID of the experiment."""
return self._experiment_id
@property
def description(self) -> str:
"""
Description of the experiment.
"""
"""Description of the experiment."""
return self._description
@property
@ -123,7 +122,8 @@ class ExperimentData(metaclass=ABCMeta):
@property
def default_tunable_config_id(self) -> Optional[int]:
"""
Retrieves the (tunable) config id for the default tunable values for this experiment.
Retrieves the (tunable) config id for the default tunable values for this
experiment.
Note: this is by *default* the first trial executed for this experiment.
However, it is currently possible that the user changed the tunables config
@ -140,9 +140,9 @@ class ExperimentData(metaclass=ABCMeta):
trials_items = sorted(self.trials.items())
if not trials_items:
return None
for (_trial_id, trial) in trials_items:
for _trial_id, trial in trials_items:
# Take the first config id marked as "defaults" when it was instantiated.
if strtobool(str(trial.metadata_dict.get('is_defaults', False))):
if strtobool(str(trial.metadata_dict.get("is_defaults", False))):
return trial.tunable_config_id
# Fallback (min trial_id)
return trials_items[0][1].tunable_config_id
@ -157,7 +157,8 @@ class ExperimentData(metaclass=ABCMeta):
-------
results : pandas.DataFrame
A DataFrame with configurations and results from all trials of the experiment.
Has columns [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status]
Has columns
[trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status]
followed by tunable config parameters (prefixed with "config.") and
trial results (prefixed with "result."). The latter can be NULLs if the
trial was not successful.

Просмотреть файл

@ -2,15 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base interface for saving and restoring the benchmark data.
"""
"""Base interface for saving and restoring the benchmark data."""
import logging
from abc import ABCMeta, abstractmethod
from datetime import datetime
from types import TracebackType
from typing import Optional, List, Tuple, Dict, Iterator, Type, Any
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type
from typing_extensions import Literal
from mlos_bench.config.schemas import ConfigSchema
@ -24,15 +23,16 @@ _LOG = logging.getLogger(__name__)
class Storage(metaclass=ABCMeta):
"""
An abstract interface between the benchmarking framework
and storage systems (e.g., SQLite or MLFLow).
"""An abstract interface between the benchmarking framework and storage systems
(e.g., SQLite or MLFLow).
"""
def __init__(self,
config: Dict[str, Any],
global_config: Optional[dict] = None,
service: Optional[Service] = None):
def __init__(
self,
config: Dict[str, Any],
global_config: Optional[dict] = None,
service: Optional[Service] = None,
):
"""
Create a new storage object.
@ -48,10 +48,9 @@ class Storage(metaclass=ABCMeta):
self._global_config = global_config or {}
def _validate_json_config(self, config: dict) -> None:
"""
Reconstructs a basic json config that this class might have been
instantiated from in order to validate configs provided outside the
file loading mechanism.
"""Reconstructs a basic json config that this class might have been instantiated
from in order to validate configs provided outside the file loading
mechanism.
"""
json_config: dict = {
"class": self.__class__.__module__ + "." + self.__class__.__name__,
@ -73,13 +72,16 @@ class Storage(metaclass=ABCMeta):
"""
@abstractmethod
def experiment(self, *,
experiment_id: str,
trial_id: int,
root_env_config: str,
description: str,
tunables: TunableGroups,
opt_targets: Dict[str, Literal['min', 'max']]) -> 'Storage.Experiment':
def experiment(
self,
*,
experiment_id: str,
trial_id: int,
root_env_config: str,
description: str,
tunables: TunableGroups,
opt_targets: Dict[str, Literal["min", "max"]],
) -> "Storage.Experiment":
"""
Create a new experiment in the storage.
@ -112,26 +114,31 @@ class Storage(metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
Base interface for storing the results of the experiment.
This class is instantiated in the `Storage.experiment()` method.
"""
def __init__(self,
*,
tunables: TunableGroups,
experiment_id: str,
trial_id: int,
root_env_config: str,
description: str,
opt_targets: Dict[str, Literal['min', 'max']]):
def __init__(
self,
*,
tunables: TunableGroups,
experiment_id: str,
trial_id: int,
root_env_config: str,
description: str,
opt_targets: Dict[str, Literal["min", "max"]],
):
self._tunables = tunables.copy()
self._trial_id = trial_id
self._experiment_id = experiment_id
(self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config)
(self._git_repo, self._git_commit, self._root_env_config) = get_git_info(
root_env_config
)
self._description = description
self._opt_targets = opt_targets
self._in_context = False
def __enter__(self) -> 'Storage.Experiment':
def __enter__(self) -> "Storage.Experiment":
"""
Enter the context of the experiment.
@ -143,9 +150,12 @@ class Storage(metaclass=ABCMeta):
self._in_context = True
return self
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
"""
End the context of the experiment.
@ -156,8 +166,11 @@ class Storage(metaclass=ABCMeta):
_LOG.debug("Finishing experiment: %s", self)
else:
assert exc_type and exc_val
_LOG.warning("Finishing experiment: %s", self,
exc_info=(exc_type, exc_val, exc_tb))
_LOG.warning(
"Finishing experiment: %s",
self,
exc_info=(exc_type, exc_val, exc_tb),
)
assert self._in_context
self._teardown(is_ok)
self._in_context = False
@ -168,7 +181,8 @@ class Storage(metaclass=ABCMeta):
def _setup(self) -> None:
"""
Create a record of the new experiment or find an existing one in the storage.
Create a record of the new experiment or find an existing one in the
storage.
This method is called by `Storage.Experiment.__enter__()`.
"""
@ -187,36 +201,34 @@ class Storage(metaclass=ABCMeta):
@property
def experiment_id(self) -> str:
"""Get the Experiment's ID"""
"""Get the Experiment's ID."""
return self._experiment_id
@property
def trial_id(self) -> int:
"""Get the current Trial ID"""
"""Get the current Trial ID."""
return self._trial_id
@property
def description(self) -> str:
"""Get the Experiment's description"""
"""Get the Experiment's description."""
return self._description
@property
def tunables(self) -> TunableGroups:
"""Get the Experiment's tunables"""
"""Get the Experiment's tunables."""
return self._tunables
@property
def opt_targets(self) -> Dict[str, Literal["min", "max"]]:
"""
Get the Experiment's optimization targets and directions
"""
"""Get the Experiment's optimization targets and directions."""
return self._opt_targets
@abstractmethod
def merge(self, experiment_ids: List[str]) -> None:
"""
Merge in the results of other (compatible) experiments trials.
Used to help warm up the optimizer for this experiment.
Merge in the results of other (compatible) experiments trials. Used to help
warm up the optimizer for this experiment.
Parameters
----------
@ -226,9 +238,7 @@ class Storage(metaclass=ABCMeta):
@abstractmethod
def load_tunable_config(self, config_id: int) -> Dict[str, Any]:
"""
Load tunable values for a given config ID.
"""
"""Load tunable values for a given config ID."""
@abstractmethod
def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
@ -247,8 +257,10 @@ class Storage(metaclass=ABCMeta):
"""
@abstractmethod
def load(self, last_trial_id: int = -1,
) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
def load(
self,
last_trial_id: int = -1,
) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
"""
Load (tunable values, benchmark scores, status) to warm-up the optimizer.
@ -268,10 +280,15 @@ class Storage(metaclass=ABCMeta):
"""
@abstractmethod
def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']:
def pending_trials(
self,
timestamp: datetime,
*,
running: bool,
) -> Iterator["Storage.Trial"]:
"""
Return an iterator over the pending trials that are scheduled to run
on or before the specified timestamp.
Return an iterator over the pending trials that are scheduled to run on or
before the specified timestamp.
Parameters
----------
@ -288,8 +305,12 @@ class Storage(metaclass=ABCMeta):
"""
@abstractmethod
def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial':
def new_trial(
self,
tunables: TunableGroups,
ts_start: Optional[datetime] = None,
config: Optional[Dict[str, Any]] = None,
) -> "Storage.Trial":
"""
Create a new experiment run in the storage.
@ -313,13 +334,20 @@ class Storage(metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
Base interface for storing the results of a single run of the experiment.
This class is instantiated in the `Storage.Experiment.trial()` method.
"""
def __init__(self, *,
tunables: TunableGroups, experiment_id: str, trial_id: int,
tunable_config_id: int, opt_targets: Dict[str, Literal['min', 'max']],
config: Optional[Dict[str, Any]] = None):
def __init__(
self,
*,
tunables: TunableGroups,
experiment_id: str,
trial_id: int,
tunable_config_id: int,
opt_targets: Dict[str, Literal["min", "max"]],
config: Optional[Dict[str, Any]] = None,
):
self._tunables = tunables
self._experiment_id = experiment_id
self._trial_id = trial_id
@ -332,29 +360,23 @@ class Storage(metaclass=ABCMeta):
@property
def trial_id(self) -> int:
"""
ID of the current trial.
"""
"""ID of the current trial."""
return self._trial_id
@property
def tunable_config_id(self) -> int:
"""
ID of the current trial (tunable) configuration.
"""
"""ID of the current trial (tunable) configuration."""
return self._tunable_config_id
@property
def opt_targets(self) -> Dict[str, Literal["min", "max"]]:
"""
Get the Trial's optimization targets and directions.
"""
"""Get the Trial's optimization targets and directions."""
return self._opt_targets
@property
def tunables(self) -> TunableGroups:
"""
Tunable parameters of the current trial
Tunable parameters of the current trial.
(e.g., application Environment's "config")
"""
@ -362,8 +384,8 @@ class Storage(metaclass=ABCMeta):
def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Produce a copy of the global configuration updated
with the parameters of the current trial.
Produce a copy of the global configuration updated with the parameters of
the current trial.
Note: this is not the target Environment's "config" (i.e., tunable
params), but rather the internal "config" which consists of a
@ -377,9 +399,12 @@ class Storage(metaclass=ABCMeta):
return config
@abstractmethod
def update(self, status: Status, timestamp: datetime,
metrics: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
def update(
self,
status: Status,
timestamp: datetime,
metrics: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""
Update the storage with the results of the experiment.
@ -403,14 +428,21 @@ class Storage(metaclass=ABCMeta):
assert metrics is not None
opt_targets = set(self._opt_targets.keys())
if not opt_targets.issubset(metrics.keys()):
_LOG.warning("Trial %s :: opt.targets missing: %s",
self, opt_targets.difference(metrics.keys()))
_LOG.warning(
"Trial %s :: opt.targets missing: %s",
self,
opt_targets.difference(metrics.keys()),
)
# raise ValueError()
return metrics
@abstractmethod
def update_telemetry(self, status: Status, timestamp: datetime,
metrics: List[Tuple[datetime, str, Any]]) -> None:
def update_telemetry(
self,
status: Status,
timestamp: datetime,
metrics: List[Tuple[datetime, str, Any]],
) -> None:
"""
Save the experiment's telemetry data and intermediate status.

Просмотреть файл

@ -2,40 +2,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base interface for accessing the stored benchmark trial data.
"""
"""Base interface for accessing the stored benchmark trial data."""
from abc import ABCMeta, abstractmethod
from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional
import pandas
from pytz import UTC
from mlos_bench.environments.status import Status
from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.storage.base_tunable_config_data import TunableConfigData
from mlos_bench.storage.util import kv_df_to_dict
from mlos_bench.tunables.tunable import TunableValue
if TYPE_CHECKING:
from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
from mlos_bench.storage.base_tunable_config_trial_group_data import (
TunableConfigTrialGroupData,
)
class TrialData(metaclass=ABCMeta):
"""
Base interface for accessing the stored experiment benchmark trial data.
A trial is a single run of an experiment with a given configuration (e.g., set
of tunable parameters).
A trial is a single run of an experiment with a given configuration (e.g., set of
tunable parameters).
"""
def __init__(self, *,
experiment_id: str,
trial_id: int,
tunable_config_id: int,
ts_start: datetime,
ts_end: Optional[datetime],
status: Status):
def __init__(
self,
*,
experiment_id: str,
trial_id: int,
tunable_config_id: int,
ts_start: datetime,
ts_end: Optional[datetime],
status: Status,
):
self._experiment_id = experiment_id
self._trial_id = trial_id
self._tunable_config_id = tunable_config_id
@ -46,7 +49,10 @@ class TrialData(metaclass=ABCMeta):
self._status = status
def __repr__(self) -> str:
return f"Trial :: {self._experiment_id}:{self._trial_id} cid:{self._tunable_config_id} {self._status.name}"
return (
f"Trial :: {self._experiment_id}:{self._trial_id} "
f"cid:{self._tunable_config_id} {self._status.name}"
)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
@ -55,44 +61,32 @@ class TrialData(metaclass=ABCMeta):
@property
def experiment_id(self) -> str:
"""
ID of the experiment this trial belongs to.
"""
"""ID of the experiment this trial belongs to."""
return self._experiment_id
@property
def trial_id(self) -> int:
"""
ID of the trial.
"""
"""ID of the trial."""
return self._trial_id
@property
def ts_start(self) -> datetime:
"""
Start timestamp of the trial (UTC).
"""
"""Start timestamp of the trial (UTC)."""
return self._ts_start
@property
def ts_end(self) -> Optional[datetime]:
"""
End timestamp of the trial (UTC).
"""
"""End timestamp of the trial (UTC)."""
return self._ts_end
@property
def status(self) -> Status:
"""
Status of the trial.
"""
"""Status of the trial."""
return self._status
@property
def tunable_config_id(self) -> int:
"""
ID of the (tunable) configuration of the trial.
"""
"""ID of the (tunable) configuration of the trial."""
return self._tunable_config_id
@property
@ -112,9 +106,7 @@ class TrialData(metaclass=ABCMeta):
@property
@abstractmethod
def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData":
"""
Retrieve the trial's (tunable) config trial group data from the storage.
"""
"""Retrieve the trial's (tunable) config trial group data from the storage."""
@property
@abstractmethod

Просмотреть файл

@ -2,9 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base interface for accessing the stored benchmark (tunable) config data.
"""
"""Base interface for accessing the stored benchmark (tunable) config data."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional
@ -21,8 +19,7 @@ class TunableConfigData(metaclass=ABCMeta):
A configuration in this context is the set of tunable parameter values.
"""
def __init__(self, *,
tunable_config_id: int):
def __init__(self, *, tunable_config_id: int):
self._tunable_config_id = tunable_config_id
def __repr__(self) -> str:
@ -35,9 +32,7 @@ class TunableConfigData(metaclass=ABCMeta):
@property
def tunable_config_id(self) -> int:
"""
Unique ID of the (tunable) configuration.
"""
"""Unique ID of the (tunable) configuration."""
return self._tunable_config_id
@property

Просмотреть файл

@ -2,12 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Base interface for accessing the stored benchmark config trial group data.
"""
"""Base interface for accessing the stored benchmark config trial group data."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional
import pandas
@ -19,18 +17,21 @@ if TYPE_CHECKING:
class TunableConfigTrialGroupData(metaclass=ABCMeta):
"""
Base interface for accessing the stored experiment benchmark tunable config
trial group data.
Base interface for accessing the stored experiment benchmark tunable config trial
group data.
A (tunable) config is used to define an instance of values for a set of tunable
parameters for a given experiment and can be used by one or more trial instances
(e.g., for repeats), which we call a (tunable) config trial group.
"""
def __init__(self, *,
experiment_id: str,
tunable_config_id: int,
tunable_config_trial_group_id: Optional[int] = None):
def __init__(
self,
*,
experiment_id: str,
tunable_config_id: int,
tunable_config_trial_group_id: Optional[int] = None,
):
self._experiment_id = experiment_id
self._tunable_config_id = tunable_config_id
# can be lazily initialized as necessary:
@ -38,23 +39,17 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
@property
def experiment_id(self) -> str:
"""
ID of the experiment.
"""
"""ID of the experiment."""
return self._experiment_id
@property
def tunable_config_id(self) -> int:
"""
ID of the config.
"""
"""ID of the config."""
return self._tunable_config_id
@abstractmethod
def _get_tunable_config_trial_group_id(self) -> int:
"""
Retrieve the trial's config_trial_group_id from the storage.
"""
"""Retrieve the trial's config_trial_group_id from the storage."""
raise NotImplementedError("subclass must implement")
@property
@ -77,13 +72,17 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
return self._tunable_config_id == other._tunable_config_id and self._experiment_id == other._experiment_id
return (
self._tunable_config_id == other._tunable_config_id
and self._experiment_id == other._experiment_id
)
@property
@abstractmethod
def tunable_config(self) -> TunableConfigData:
"""
Retrieve the (tunable) config data for this (tunable) config trial group from the storage.
Retrieve the (tunable) config data for this (tunable) config trial group from
the storage.
Returns
-------
@ -94,7 +93,8 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
@abstractmethod
def trials(self) -> Dict[int, "TrialData"]:
"""
Retrieve the trials' data for this (tunable) config trial group from the storage.
Retrieve the trials' data for this (tunable) config trial group from the
storage.
Returns
-------
@ -106,7 +106,8 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta):
@abstractmethod
def results_df(self) -> pandas.DataFrame:
"""
Retrieve all results for this (tunable) config trial group as a single DataFrame.
Retrieve all results for this (tunable) config trial group as a single
DataFrame.
Returns
-------

Просмотреть файл

@ -2,11 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Interfaces to the SQL-based storage backends for OS Autotune.
"""
"""Interfaces to the SQL-based storage backends for OS Autotune."""
from mlos_bench.storage.sql.storage import SqlStorage
__all__ = [
'SqlStorage',
"SqlStorage",
]

Просмотреть файл

@ -2,39 +2,45 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Common SQL methods for accessing the stored benchmark data.
"""
"""Common SQL methods for accessing the stored benchmark data."""
from typing import Dict, Optional
import pandas
from sqlalchemy import Engine, Integer, func, and_, select
from sqlalchemy import Engine, Integer, and_, func, select
from mlos_bench.environments.status import Status
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.storage.base_trial_data import TrialData
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.util import utcify_timestamp, utcify_nullable_timestamp
from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp
def get_trials(
engine: Engine,
schema: DbSchema,
experiment_id: str,
tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]:
engine: Engine,
schema: DbSchema,
experiment_id: str,
tunable_config_id: Optional[int] = None,
) -> Dict[int, TrialData]:
"""
Gets TrialData for the given experiment_data and optionally additionally
restricted by tunable_config_id.
Gets TrialData for the given experiment_data and optionally additionally restricted
by tunable_config_id.
Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
"""
from mlos_bench.storage.sql.trial_data import TrialSqlData # pylint: disable=import-outside-toplevel,cyclic-import
# pylint: disable=import-outside-toplevel,cyclic-import
from mlos_bench.storage.sql.trial_data import TrialSqlData
with engine.connect() as conn:
# Build up sql a statement for fetching trials.
stmt = schema.trial.select().where(
schema.trial.c.exp_id == experiment_id,
).order_by(
schema.trial.c.exp_id.asc(),
schema.trial.c.trial_id.asc(),
stmt = (
schema.trial.select()
.where(
schema.trial.c.exp_id == experiment_id,
)
.order_by(
schema.trial.c.exp_id.asc(),
schema.trial.c.trial_id.asc(),
)
)
# Optionally restrict to those using a particular tunable config.
if tunable_config_id is not None:
@ -58,27 +64,36 @@ def get_trials(
def get_results_df(
engine: Engine,
schema: DbSchema,
experiment_id: str,
tunable_config_id: Optional[int] = None) -> pandas.DataFrame:
engine: Engine,
schema: DbSchema,
experiment_id: str,
tunable_config_id: Optional[int] = None,
) -> pandas.DataFrame:
"""
Gets TrialData for the given experiment_data and optionally additionally
restricted by tunable_config_id.
Gets TrialData for the given experiment_data and optionally additionally restricted
by tunable_config_id.
Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
"""
# pylint: disable=too-many-locals
with engine.connect() as conn:
# Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config.
tunable_config_group_id_stmt = schema.trial.select().with_only_columns(
schema.trial.c.exp_id,
schema.trial.c.config_id,
func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'),
).where(
schema.trial.c.exp_id == experiment_id,
).group_by(
schema.trial.c.exp_id,
schema.trial.c.config_id,
tunable_config_group_id_stmt = (
schema.trial.select()
.with_only_columns(
schema.trial.c.exp_id,
schema.trial.c.config_id,
func.min(schema.trial.c.trial_id)
.cast(Integer)
.label("tunable_config_trial_group_id"),
)
.where(
schema.trial.c.exp_id == experiment_id,
)
.group_by(
schema.trial.c.exp_id,
schema.trial.c.config_id,
)
)
# Optionally restrict to those using a particular tunable config.
if tunable_config_id is not None:
@ -88,18 +103,22 @@ def get_results_df(
tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery()
# Get each trial's metadata.
cur_trials_stmt = select(
schema.trial,
tunable_config_trial_group_id_subquery,
).where(
schema.trial.c.exp_id == experiment_id,
and_(
tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
),
).order_by(
schema.trial.c.exp_id.asc(),
schema.trial.c.trial_id.asc(),
cur_trials_stmt = (
select(
schema.trial,
tunable_config_trial_group_id_subquery,
)
.where(
schema.trial.c.exp_id == experiment_id,
and_(
tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
),
)
.order_by(
schema.trial.c.exp_id.asc(),
schema.trial.c.trial_id.asc(),
)
)
# Optionally restrict to those using a particular tunable config.
if tunable_config_id is not None:
@ -108,39 +127,48 @@ def get_results_df(
)
cur_trials = conn.execute(cur_trials_stmt)
trials_df = pandas.DataFrame(
[(
row.trial_id,
utcify_timestamp(row.ts_start, origin="utc"),
utcify_nullable_timestamp(row.ts_end, origin="utc"),
row.config_id,
row.tunable_config_trial_group_id,
row.status,
) for row in cur_trials.fetchall()],
[
(
row.trial_id,
utcify_timestamp(row.ts_start, origin="utc"),
utcify_nullable_timestamp(row.ts_end, origin="utc"),
row.config_id,
row.tunable_config_trial_group_id,
row.status,
)
for row in cur_trials.fetchall()
],
columns=[
'trial_id',
'ts_start',
'ts_end',
'tunable_config_id',
'tunable_config_trial_group_id',
'status',
]
"trial_id",
"ts_start",
"ts_end",
"tunable_config_id",
"tunable_config_trial_group_id",
"status",
],
)
# Get each trial's config in wide format.
configs_stmt = schema.trial.select().with_only_columns(
schema.trial.c.trial_id,
schema.trial.c.config_id,
schema.config_param.c.param_id,
schema.config_param.c.param_value,
).where(
schema.trial.c.exp_id == experiment_id,
).join(
schema.config_param,
schema.config_param.c.config_id == schema.trial.c.config_id,
isouter=True
).order_by(
schema.trial.c.trial_id,
schema.config_param.c.param_id,
configs_stmt = (
schema.trial.select()
.with_only_columns(
schema.trial.c.trial_id,
schema.trial.c.config_id,
schema.config_param.c.param_id,
schema.config_param.c.param_value,
)
.where(
schema.trial.c.exp_id == experiment_id,
)
.join(
schema.config_param,
schema.config_param.c.config_id == schema.trial.c.config_id,
isouter=True,
)
.order_by(
schema.trial.c.trial_id,
schema.config_param.c.param_id,
)
)
if tunable_config_id is not None:
configs_stmt = configs_stmt.where(
@ -148,41 +176,75 @@ def get_results_df(
)
configs = conn.execute(configs_stmt)
configs_df = pandas.DataFrame(
[(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value)
for row in configs.fetchall()],
columns=['trial_id', 'tunable_config_id', 'param', 'value']
[
(
row.trial_id,
row.config_id,
ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id,
row.param_value,
)
for row in configs.fetchall()
],
columns=["trial_id", "tunable_config_id", "param", "value"],
).pivot(
index=["trial_id", "tunable_config_id"], columns="param", values="value",
index=["trial_id", "tunable_config_id"],
columns="param",
values="value",
)
configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp)
configs_df = configs_df.apply( # type: ignore[assignment] # (fp)
pandas.to_numeric,
errors="coerce",
).fillna(configs_df)
# Get each trial's results in wide format.
results_stmt = schema.trial_result.select().with_only_columns(
schema.trial_result.c.trial_id,
schema.trial_result.c.metric_id,
schema.trial_result.c.metric_value,
).where(
schema.trial_result.c.exp_id == experiment_id,
).order_by(
schema.trial_result.c.trial_id,
schema.trial_result.c.metric_id,
results_stmt = (
schema.trial_result.select()
.with_only_columns(
schema.trial_result.c.trial_id,
schema.trial_result.c.metric_id,
schema.trial_result.c.metric_value,
)
.where(
schema.trial_result.c.exp_id == experiment_id,
)
.order_by(
schema.trial_result.c.trial_id,
schema.trial_result.c.metric_id,
)
)
if tunable_config_id is not None:
results_stmt = results_stmt.join(schema.trial, and_(
schema.trial.c.exp_id == schema.trial_result.c.exp_id,
schema.trial.c.trial_id == schema.trial_result.c.trial_id,
schema.trial.c.config_id == tunable_config_id,
))
results_stmt = results_stmt.join(
schema.trial,
and_(
schema.trial.c.exp_id == schema.trial_result.c.exp_id,
schema.trial.c.trial_id == schema.trial_result.c.trial_id,
schema.trial.c.config_id == tunable_config_id,
),
)
results = conn.execute(results_stmt)
results_df = pandas.DataFrame(
[(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value)
for row in results.fetchall()],
columns=['trial_id', 'metric', 'value']
[
(
row.trial_id,
ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id,
row.metric_value,
)
for row in results.fetchall()
],
columns=["trial_id", "metric", "value"],
).pivot(
index="trial_id", columns="metric", values="value",
index="trial_id",
columns="metric",
values="value",
)
results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp)
results_df = results_df.apply( # type: ignore[assignment] # (fp)
pandas.to_numeric,
errors="coerce",
).fillna(results_df)
# Concat the trials, configs, and results.
return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \
.merge(results_df, on="trial_id", how="left")
return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge(
results_df,
on="trial_id",
how="left",
)

Просмотреть файл

@ -2,43 +2,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Saving and restoring the benchmark data using SQLAlchemy.
"""
"""Saving and restoring the benchmark data using SQLAlchemy."""
import logging
import hashlib
import logging
from datetime import datetime
from typing import Optional, Tuple, List, Literal, Dict, Iterator, Any
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple
from pytz import UTC
from sqlalchemy import Engine, Connection, CursorResult, Table, column, func, select
from sqlalchemy import Connection, CursorResult, Engine, Table, column, func, select
from mlos_bench.environments.status import Status
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.trial import Trial
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import nullable, utcify_timestamp
_LOG = logging.getLogger(__name__)
class Experiment(Storage.Experiment):
"""
Logic for retrieving and storing the results of a single experiment.
"""
"""Logic for retrieving and storing the results of a single experiment."""
def __init__(self, *,
engine: Engine,
schema: DbSchema,
tunables: TunableGroups,
experiment_id: str,
trial_id: int,
root_env_config: str,
description: str,
opt_targets: Dict[str, Literal['min', 'max']]):
def __init__(
self,
*,
engine: Engine,
schema: DbSchema,
tunables: TunableGroups,
experiment_id: str,
trial_id: int,
root_env_config: str,
description: str,
opt_targets: Dict[str, Literal["min", "max"]],
):
super().__init__(
tunables=tunables,
experiment_id=experiment_id,
@ -56,18 +54,22 @@ class Experiment(Storage.Experiment):
# Get git info and the last trial ID for the experiment.
# pylint: disable=not-callable
exp_info = conn.execute(
self._schema.experiment.select().with_only_columns(
self._schema.experiment.select()
.with_only_columns(
self._schema.experiment.c.git_repo,
self._schema.experiment.c.git_commit,
self._schema.experiment.c.root_env_config,
func.max(self._schema.trial.c.trial_id).label("trial_id"),
).join(
)
.join(
self._schema.trial,
self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id,
isouter=True
).where(
isouter=True,
)
.where(
self._schema.experiment.c.exp_id == self._experiment_id,
).group_by(
)
.group_by(
self._schema.experiment.c.git_repo,
self._schema.experiment.c.git_commit,
self._schema.experiment.c.root_env_config,
@ -76,33 +78,47 @@ class Experiment(Storage.Experiment):
if exp_info is None:
_LOG.info("Start new experiment: %s", self._experiment_id)
# It's a new experiment: create a record for it in the database.
conn.execute(self._schema.experiment.insert().values(
exp_id=self._experiment_id,
description=self._description,
git_repo=self._git_repo,
git_commit=self._git_commit,
root_env_config=self._root_env_config,
))
conn.execute(self._schema.objectives.insert().values([
{
"exp_id": self._experiment_id,
"optimization_target": opt_target,
"optimization_direction": opt_dir,
}
for (opt_target, opt_dir) in self.opt_targets.items()
]))
conn.execute(
self._schema.experiment.insert().values(
exp_id=self._experiment_id,
description=self._description,
git_repo=self._git_repo,
git_commit=self._git_commit,
root_env_config=self._root_env_config,
)
)
conn.execute(
self._schema.objectives.insert().values(
[
{
"exp_id": self._experiment_id,
"optimization_target": opt_target,
"optimization_direction": opt_dir,
}
for (opt_target, opt_dir) in self.opt_targets.items()
]
)
)
else:
if exp_info.trial_id is not None:
self._trial_id = exp_info.trial_id + 1
_LOG.info("Continue experiment: %s last trial: %s resume from: %d",
self._experiment_id, exp_info.trial_id, self._trial_id)
_LOG.info(
"Continue experiment: %s last trial: %s resume from: %d",
self._experiment_id,
exp_info.trial_id,
self._trial_id,
)
# TODO: Sanity check that certain critical configs (e.g.,
# objectives) haven't changed to be incompatible such that a new
# experiment should be started (possibly by prewarming with the
# previous one).
if exp_info.git_commit != self._git_commit:
_LOG.warning("Experiment %s git expected: %s %s",
self, exp_info.git_repo, exp_info.git_commit)
_LOG.warning(
"Experiment %s git expected: %s %s",
self,
exp_info.git_repo,
exp_info.git_commit,
)
def merge(self, experiment_ids: List[str]) -> None:
_LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
@ -115,33 +131,42 @@ class Experiment(Storage.Experiment):
def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
with self._engine.connect() as conn:
cur_telemetry = conn.execute(
self._schema.trial_telemetry.select().where(
self._schema.trial_telemetry.select()
.where(
self._schema.trial_telemetry.c.exp_id == self._experiment_id,
self._schema.trial_telemetry.c.trial_id == trial_id
).order_by(
self._schema.trial_telemetry.c.trial_id == trial_id,
)
.order_by(
self._schema.trial_telemetry.c.ts,
self._schema.trial_telemetry.c.metric_id,
)
)
# Not all storage backends store the original zone info.
# We try to ensure data is entered in UTC and augment it on return again here.
return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
for row in cur_telemetry.fetchall()]
return [
(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
for row in cur_telemetry.fetchall()
]
def load(self, last_trial_id: int = -1,
) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
def load(
self,
last_trial_id: int = -1,
) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
with self._engine.connect() as conn:
cur_trials = conn.execute(
self._schema.trial.select().with_only_columns(
self._schema.trial.select()
.with_only_columns(
self._schema.trial.c.trial_id,
self._schema.trial.c.config_id,
self._schema.trial.c.status,
).where(
)
.where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id > last_trial_id,
self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']),
).order_by(
self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]),
)
.order_by(
self._schema.trial.c.trial_id.asc(),
)
)
@ -155,12 +180,24 @@ class Experiment(Storage.Experiment):
stat = Status[trial.status]
status.append(stat)
trial_ids.append(trial.trial_id)
configs.append(self._get_key_val(
conn, self._schema.config_param, "param", config_id=trial.config_id))
configs.append(
self._get_key_val(
conn,
self._schema.config_param,
"param",
config_id=trial.config_id,
)
)
if stat.is_succeeded():
scores.append(self._get_key_val(
conn, self._schema.trial_result, "metric",
exp_id=self._experiment_id, trial_id=trial.trial_id))
scores.append(
self._get_key_val(
conn,
self._schema.trial_result,
"metric",
exp_id=self._experiment_id,
trial_id=trial.trial_id,
)
)
else:
scores.append(None)
@ -170,55 +207,73 @@ class Experiment(Storage.Experiment):
def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]:
"""
Helper method to retrieve key-value pairs from the database.
(E.g., configurations, results, and telemetry).
"""
cur_result: CursorResult[Tuple[str, Any]] = conn.execute(
select(
column(f"{field}_id"),
column(f"{field}_value"),
).select_from(table).where(
*[column(key) == val for (key, val) in kwargs.items()]
)
.select_from(table)
.where(*[column(key) == val for (key, val) in kwargs.items()])
)
# NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to
# avoid naming conflicts.
return dict(
row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access
)
# NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts.
return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access
@staticmethod
def _save_params(conn: Connection, table: Table,
params: Dict[str, Any], **kwargs: Any) -> None:
def _save_params(
conn: Connection,
table: Table,
params: Dict[str, Any],
**kwargs: Any,
) -> None:
if not params:
return
conn.execute(table.insert(), [
{
**kwargs,
"param_id": key,
"param_value": nullable(str, val)
}
for (key, val) in params.items()
])
conn.execute(
table.insert(),
[
{**kwargs, "param_id": key, "param_value": nullable(str, val)}
for (key, val) in params.items()
],
)
def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
timestamp = utcify_timestamp(timestamp, origin="local")
_LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
if running:
pending_status = ['PENDING', 'READY', 'RUNNING']
pending_status = ["PENDING", "READY", "RUNNING"]
else:
pending_status = ['PENDING']
pending_status = ["PENDING"]
with self._engine.connect() as conn:
cur_trials = conn.execute(self._schema.trial.select().where(
self._schema.trial.c.exp_id == self._experiment_id,
(self._schema.trial.c.ts_start.is_(None) |
(self._schema.trial.c.ts_start <= timestamp)),
self._schema.trial.c.ts_end.is_(None),
self._schema.trial.c.status.in_(pending_status),
))
cur_trials = conn.execute(
self._schema.trial.select().where(
self._schema.trial.c.exp_id == self._experiment_id,
(
self._schema.trial.c.ts_start.is_(None)
| (self._schema.trial.c.ts_start <= timestamp)
),
self._schema.trial.c.ts_end.is_(None),
self._schema.trial.c.status.in_(pending_status),
)
)
for trial in cur_trials.fetchall():
tunables = self._get_key_val(
conn, self._schema.config_param, "param",
config_id=trial.config_id)
conn,
self._schema.config_param,
"param",
config_id=trial.config_id,
)
config = self._get_key_val(
conn, self._schema.trial_param, "param",
exp_id=self._experiment_id, trial_id=trial.trial_id)
conn,
self._schema.trial_param,
"param",
exp_id=self._experiment_id,
trial_id=trial.trial_id,
)
yield Trial(
engine=self._engine,
schema=self._schema,
@ -233,47 +288,63 @@ class Experiment(Storage.Experiment):
def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:
"""
Get the config ID for the given tunables. If the config does not exist,
create a new record for it.
Get the config ID for the given tunables.
If the config does not exist, create a new record for it.
"""
config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest()
cur_config = conn.execute(self._schema.config.select().where(
self._schema.config.c.config_hash == config_hash
)).fetchone()
config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest()
cur_config = conn.execute(
self._schema.config.select().where(self._schema.config.c.config_hash == config_hash)
).fetchone()
if cur_config is not None:
return int(cur_config.config_id) # mypy doesn't know it's always int
# Config not found, create a new one:
config_id: int = conn.execute(self._schema.config.insert().values(
config_hash=config_hash)).inserted_primary_key[0]
config_id: int = conn.execute(
self._schema.config.insert().values(config_hash=config_hash)
).inserted_primary_key[0]
self._save_params(
conn, self._schema.config_param,
conn,
self._schema.config_param,
{tunable.name: tunable.value for (tunable, _group) in tunables},
config_id=config_id)
config_id=config_id,
)
return config_id
def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
config: Optional[Dict[str, Any]] = None) -> Storage.Trial:
def new_trial(
self,
tunables: TunableGroups,
ts_start: Optional[datetime] = None,
config: Optional[Dict[str, Any]] = None,
) -> Storage.Trial:
# MySQL can round microseconds into the future causing scheduler to skip trials.
# Truncate microseconds to avoid this issue.
ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(microsecond=0)
ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(
microsecond=0
)
_LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
with self._engine.begin() as conn:
try:
config_id = self._get_config_id(conn, tunables)
conn.execute(self._schema.trial.insert().values(
exp_id=self._experiment_id,
trial_id=self._trial_id,
config_id=config_id,
ts_start=ts_start,
status='PENDING',
))
conn.execute(
self._schema.trial.insert().values(
exp_id=self._experiment_id,
trial_id=self._trial_id,
config_id=config_id,
ts_start=ts_start,
status="PENDING",
)
)
# Note: config here is the framework config, not the target
# environment config (i.e., tunables).
if config is not None:
self._save_params(
conn, self._schema.trial_param, config,
exp_id=self._experiment_id, trial_id=self._trial_id)
conn,
self._schema.trial_param,
config,
exp_id=self._experiment_id,
trial_id=self._trial_id,
)
trial = Trial(
engine=self._engine,

Просмотреть файл

@ -2,12 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
An interface to access the experiment benchmark data stored in SQL DB.
"""
from typing import Dict, Literal, Optional
"""An interface to access the experiment benchmark data stored in SQL DB."""
import logging
from typing import Dict, Literal, Optional
import pandas
from sqlalchemy import Engine, Integer, String, func
@ -15,11 +12,15 @@ from sqlalchemy import Engine, Integer, String, func
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.storage.base_trial_data import TrialData
from mlos_bench.storage.base_tunable_config_data import TunableConfigData
from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData
from mlos_bench.storage.base_tunable_config_trial_group_data import (
TunableConfigTrialGroupData,
)
from mlos_bench.storage.sql import common
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData
from mlos_bench.storage.sql.tunable_config_trial_group_data import (
TunableConfigTrialGroupSqlData,
)
_LOG = logging.getLogger(__name__)
@ -32,14 +33,17 @@ class ExperimentSqlData(ExperimentData):
scripts and mlos_bench configuration files.
"""
def __init__(self, *,
engine: Engine,
schema: DbSchema,
experiment_id: str,
description: str,
root_env_config: str,
git_repo: str,
git_commit: str):
def __init__(
self,
*,
engine: Engine,
schema: DbSchema,
experiment_id: str,
description: str,
root_env_config: str,
git_repo: str,
git_commit: str,
):
super().__init__(
experiment_id=experiment_id,
description=description,
@ -54,9 +58,11 @@ class ExperimentSqlData(ExperimentData):
def objectives(self) -> Dict[str, Literal["min", "max"]]:
with self._engine.connect() as conn:
objectives_db_data = conn.execute(
self._schema.objectives.select().where(
self._schema.objectives.select()
.where(
self._schema.objectives.c.exp_id == self._experiment_id,
).order_by(
)
.order_by(
self._schema.objectives.c.weight.desc(),
self._schema.objectives.c.optimization_target.asc(),
)
@ -66,7 +72,8 @@ class ExperimentSqlData(ExperimentData):
for objective in objectives_db_data.fetchall()
}
# TODO: provide a way to get individual data to avoid repeated bulk fetches where only small amounts of data is accessed.
# TODO: provide a way to get individual data to avoid repeated bulk fetches
# where only small amounts of data is accessed.
# Or else make the TrialData object lazily populate.
@property
@ -77,13 +84,17 @@ class ExperimentSqlData(ExperimentData):
def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]:
with self._engine.connect() as conn:
tunable_config_trial_groups = conn.execute(
self._schema.trial.select().with_only_columns(
self._schema.trial.select()
.with_only_columns(
self._schema.trial.c.config_id,
func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
'tunable_config_trial_group_id'),
).where(
func.min(self._schema.trial.c.trial_id)
.cast(Integer)
.label("tunable_config_trial_group_id"), # pylint: disable=not-callable
)
.where(
self._schema.trial.c.exp_id == self._experiment_id,
).group_by(
)
.group_by(
self._schema.trial.c.exp_id,
self._schema.trial.c.config_id,
)
@ -94,7 +105,7 @@ class ExperimentSqlData(ExperimentData):
schema=self._schema,
experiment_id=self._experiment_id,
tunable_config_id=tunable_config_trial_group.config_id,
tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id,
tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, # pylint:disable=line-too-long # noqa
)
for tunable_config_trial_group in tunable_config_trial_groups.fetchall()
}
@ -103,11 +114,14 @@ class ExperimentSqlData(ExperimentData):
def tunable_configs(self) -> Dict[int, TunableConfigData]:
with self._engine.connect() as conn:
tunable_configs = conn.execute(
self._schema.trial.select().with_only_columns(
self._schema.trial.c.config_id.cast(Integer).label('config_id'),
).where(
self._schema.trial.select()
.with_only_columns(
self._schema.trial.c.config_id.cast(Integer).label("config_id"),
)
.where(
self._schema.trial.c.exp_id == self._experiment_id,
).group_by(
)
.group_by(
self._schema.trial.c.exp_id,
self._schema.trial.c.config_id,
)
@ -124,7 +138,8 @@ class ExperimentSqlData(ExperimentData):
@property
def default_tunable_config_id(self) -> Optional[int]:
"""
Retrieves the (tunable) config id for the default tunable values for this experiment.
Retrieves the (tunable) config id for the default tunable values for this
experiment.
Note: this is by *default* the first trial executed for this experiment.
However, it is currently possible that the user changed the tunables config
@ -136,20 +151,28 @@ class ExperimentSqlData(ExperimentData):
"""
with self._engine.connect() as conn:
query_results = conn.execute(
self._schema.trial.select().with_only_columns(
self._schema.trial.c.config_id.cast(Integer).label('config_id'),
).where(
self._schema.trial.select()
.with_only_columns(
self._schema.trial.c.config_id.cast(Integer).label("config_id"),
)
.where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id.in_(
self._schema.trial_param.select().with_only_columns(
func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable
"first_trial_id_with_defaults"),
).where(
self._schema.trial_param.select()
.with_only_columns(
func.min(self._schema.trial_param.c.trial_id)
.cast(Integer)
.label("first_trial_id_with_defaults"), # pylint: disable=not-callable
)
.where(
self._schema.trial_param.c.exp_id == self._experiment_id,
self._schema.trial_param.c.param_id == "is_defaults",
func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]),
).scalar_subquery()
)
func.lower(self._schema.trial_param.c.param_value, type_=String).in_(
["1", "true"]
),
)
.scalar_subquery()
),
)
)
min_default_trial_row = query_results.fetchone()
@ -158,17 +181,24 @@ class ExperimentSqlData(ExperimentData):
return min_default_trial_row._tuple()[0]
# fallback logic - assume minimum trial_id for experiment
query_results = conn.execute(
self._schema.trial.select().with_only_columns(
self._schema.trial.c.config_id.cast(Integer).label('config_id'),
).where(
self._schema.trial.select()
.with_only_columns(
self._schema.trial.c.config_id.cast(Integer).label("config_id"),
)
.where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id.in_(
self._schema.trial.select().with_only_columns(
func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"),
).where(
self._schema.trial.select()
.with_only_columns(
func.min(self._schema.trial.c.trial_id)
.cast(Integer)
.label("first_trial_id"),
)
.where(
self._schema.trial.c.exp_id == self._experiment_id,
).scalar_subquery()
)
)
.scalar_subquery()
),
)
)
min_trial_row = query_results.fetchone()

Просмотреть файл

@ -2,17 +2,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
DB schema definition.
"""
"""DB schema definition."""
import logging
from typing import List, Any
from typing import Any, List
from sqlalchemy import (
Engine, MetaData, Dialect, create_mock_engine,
Table, Column, Sequence, Integer, Float, String, DateTime,
PrimaryKeyConstraint, ForeignKeyConstraint, UniqueConstraint,
Column,
DateTime,
Dialect,
Engine,
Float,
ForeignKeyConstraint,
Integer,
MetaData,
PrimaryKeyConstraint,
Sequence,
String,
Table,
UniqueConstraint,
create_mock_engine,
)
_LOG = logging.getLogger(__name__)
@ -38,9 +47,7 @@ class _DDL:
class DbSchema:
"""
A class to define and create the DB schema.
"""
"""A class to define and create the DB schema."""
# This class is internal to SqlStorage and is mostly a struct
# for all DB tables, so it's ok to disable the warnings.
@ -53,9 +60,7 @@ class DbSchema:
_STATUS_LEN = 16
def __init__(self, engine: Engine):
"""
Declare the SQLAlchemy schema for the database.
"""
"""Declare the SQLAlchemy schema for the database."""
_LOG.info("Create the DB schema for: %s", engine)
self._engine = engine
# TODO: bind for automatic schema updates? (#649)
@ -69,7 +74,6 @@ class DbSchema:
Column("root_env_config", String(1024), nullable=False),
Column("git_repo", String(1024), nullable=False),
Column("git_commit", String(40), nullable=False),
PrimaryKeyConstraint("exp_id"),
)
@ -84,20 +88,29 @@ class DbSchema:
# Will need to adjust the insert and return values to support this
# eventually.
Column("weight", Float, nullable=True),
PrimaryKeyConstraint("exp_id", "optimization_target"),
ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
)
# A workaround for SQLAlchemy issue with autoincrement in DuckDB:
if engine.dialect.name == "duckdb":
seq_config_id = Sequence('seq_config_id')
col_config_id = Column("config_id", Integer, seq_config_id,
server_default=seq_config_id.next_value(),
nullable=False, primary_key=True)
seq_config_id = Sequence("seq_config_id")
col_config_id = Column(
"config_id",
Integer,
seq_config_id,
server_default=seq_config_id.next_value(),
nullable=False,
primary_key=True,
)
else:
col_config_id = Column("config_id", Integer, nullable=False,
primary_key=True, autoincrement=True)
col_config_id = Column(
"config_id",
Integer,
nullable=False,
primary_key=True,
autoincrement=True,
)
self.config = Table(
"config",
@ -116,7 +129,6 @@ class DbSchema:
Column("ts_end", DateTime),
# Should match the text IDs of `mlos_bench.environments.Status` enum:
Column("status", String(self._STATUS_LEN), nullable=False),
PrimaryKeyConstraint("exp_id", "trial_id"),
ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
@ -130,7 +142,6 @@ class DbSchema:
Column("config_id", Integer, nullable=False),
Column("param_id", String(self._ID_LEN), nullable=False),
Column("param_value", String(self._PARAM_VALUE_LEN)),
PrimaryKeyConstraint("config_id", "param_id"),
ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
)
@ -144,10 +155,11 @@ class DbSchema:
Column("trial_id", Integer, nullable=False),
Column("param_id", String(self._ID_LEN), nullable=False),
Column("param_value", String(self._PARAM_VALUE_LEN)),
PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
ForeignKeyConstraint(["exp_id", "trial_id"],
[self.trial.c.exp_id, self.trial.c.trial_id]),
ForeignKeyConstraint(
["exp_id", "trial_id"],
[self.trial.c.exp_id, self.trial.c.trial_id],
),
)
self.trial_status = Table(
@ -157,10 +169,11 @@ class DbSchema:
Column("trial_id", Integer, nullable=False),
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
Column("status", String(self._STATUS_LEN), nullable=False),
UniqueConstraint("exp_id", "trial_id", "ts"),
ForeignKeyConstraint(["exp_id", "trial_id"],
[self.trial.c.exp_id, self.trial.c.trial_id]),
ForeignKeyConstraint(
["exp_id", "trial_id"],
[self.trial.c.exp_id, self.trial.c.trial_id],
),
)
self.trial_result = Table(
@ -170,10 +183,11 @@ class DbSchema:
Column("trial_id", Integer, nullable=False),
Column("metric_id", String(self._ID_LEN), nullable=False),
Column("metric_value", String(self._METRIC_VALUE_LEN)),
PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
ForeignKeyConstraint(["exp_id", "trial_id"],
[self.trial.c.exp_id, self.trial.c.trial_id]),
ForeignKeyConstraint(
["exp_id", "trial_id"],
[self.trial.c.exp_id, self.trial.c.trial_id],
),
)
self.trial_telemetry = Table(
@ -184,26 +198,25 @@ class DbSchema:
Column("ts", DateTime(timezone=True), nullable=False, default="now"),
Column("metric_id", String(self._ID_LEN), nullable=False),
Column("metric_value", String(self._METRIC_VALUE_LEN)),
UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
ForeignKeyConstraint(["exp_id", "trial_id"],
[self.trial.c.exp_id, self.trial.c.trial_id]),
ForeignKeyConstraint(
["exp_id", "trial_id"],
[self.trial.c.exp_id, self.trial.c.trial_id],
),
)
_LOG.debug("Schema: %s", self._meta)
def create(self) -> 'DbSchema':
"""
Create the DB schema.
"""
def create(self) -> "DbSchema":
"""Create the DB schema."""
_LOG.info("Create the DB schema")
self._meta.create_all(self._engine)
return self
def __repr__(self) -> str:
"""
Produce a string with all SQL statements required to create the schema
from scratch in current SQL dialect.
Produce a string with all SQL statements required to create the schema from
scratch in current SQL dialect.
That is, return a collection of CREATE TABLE statements and such.
NOTE: this method is quite heavy! We use it only once at startup

Просмотреть файл

@ -2,35 +2,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Saving and restoring the benchmark data in SQL database.
"""
"""Saving and restoring the benchmark data in SQL database."""
import logging
from typing import Dict, Literal, Optional
from sqlalchemy import URL, create_engine
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.services.base_service import Service
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.storage.sql.experiment import Experiment
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.experiment import Experiment
from mlos_bench.storage.sql.experiment_data import ExperimentSqlData
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.tunables.tunable_groups import TunableGroups
_LOG = logging.getLogger(__name__)
class SqlStorage(Storage):
"""
An implementation of the Storage interface using SQLAlchemy backend.
"""
"""An implementation of the Storage interface using SQLAlchemy backend."""
def __init__(self,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None):
def __init__(
self,
config: dict,
global_config: Optional[dict] = None,
service: Optional[Service] = None,
):
super().__init__(config, global_config, service)
lazy_schema_create = self._config.pop("lazy_schema_create", False)
self._log_sql = self._config.pop("log_sql", False)
@ -47,7 +45,7 @@ class SqlStorage(Storage):
@property
def _schema(self) -> DbSchema:
"""Lazily create schema upon first access."""
if not hasattr(self, '_db_schema'):
if not hasattr(self, "_db_schema"):
self._db_schema = DbSchema(self._engine).create()
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("DDL statements:\n%s", self._schema)
@ -56,13 +54,16 @@ class SqlStorage(Storage):
def __repr__(self) -> str:
return self._repr
def experiment(self, *,
experiment_id: str,
trial_id: int,
root_env_config: str,
description: str,
tunables: TunableGroups,
opt_targets: Dict[str, Literal['min', 'max']]) -> Storage.Experiment:
def experiment(
self,
*,
experiment_id: str,
trial_id: int,
root_env_config: str,
description: str,
tunables: TunableGroups,
opt_targets: Dict[str, Literal["min", "max"]],
) -> Storage.Experiment:
return Experiment(
engine=self._engine,
schema=self._schema,

Просмотреть файл

@ -2,40 +2,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Saving and updating benchmark data using SQLAlchemy backend.
"""
"""Saving and updating benchmark data using SQLAlchemy backend."""
import logging
from datetime import datetime
from typing import List, Literal, Optional, Tuple, Dict, Any
from typing import Any, Dict, List, Literal, Optional, Tuple
from sqlalchemy import Engine, Connection
from sqlalchemy import Connection, Engine
from sqlalchemy.exc import IntegrityError
from mlos_bench.environments.status import Status
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.storage.base_storage import Storage
from mlos_bench.storage.sql.schema import DbSchema
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import nullable, utcify_timestamp
_LOG = logging.getLogger(__name__)
class Trial(Storage.Trial):
"""
Store the results of a single run of the experiment in SQL database.
"""
"""Store the results of a single run of the experiment in SQL database."""
def __init__(self, *,
engine: Engine,
schema: DbSchema,
tunables: TunableGroups,
experiment_id: str,
trial_id: int,
config_id: int,
opt_targets: Dict[str, Literal['min', 'max']],
config: Optional[Dict[str, Any]] = None):
def __init__(
self,
*,
engine: Engine,
schema: DbSchema,
tunables: TunableGroups,
experiment_id: str,
trial_id: int,
config_id: int,
opt_targets: Dict[str, Literal["min", "max"]],
config: Optional[Dict[str, Any]] = None,
):
super().__init__(
tunables=tunables,
experiment_id=experiment_id,
@ -47,9 +46,12 @@ class Trial(Storage.Trial):
self._engine = engine
self._schema = schema
def update(self, status: Status, timestamp: datetime,
metrics: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
def update(
self,
status: Status,
timestamp: datetime,
metrics: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
metrics = super().update(status, timestamp, metrics)
@ -59,13 +61,16 @@ class Trial(Storage.Trial):
if status.is_completed():
# Final update of the status and ts_end:
cur_status = conn.execute(
self._schema.trial.update().where(
self._schema.trial.update()
.where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id == self._trial_id,
self._schema.trial.c.ts_end.is_(None),
self._schema.trial.c.status.notin_(
['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']),
).values(
["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
),
)
.values(
status=status.name,
ts_end=timestamp,
)
@ -73,29 +78,37 @@ class Trial(Storage.Trial):
if cur_status.rowcount not in {1, -1}:
_LOG.warning("Trial %s :: update failed: %s", self, status)
raise RuntimeError(
f"Failed to update the status of the trial {self} to {status}." +
f" ({cur_status.rowcount} rows)")
f"Failed to update the status of the trial {self} to {status}. "
f"({cur_status.rowcount} rows)"
)
if metrics:
conn.execute(self._schema.trial_result.insert().values([
{
"exp_id": self._experiment_id,
"trial_id": self._trial_id,
"metric_id": key,
"metric_value": nullable(str, val),
}
for (key, val) in metrics.items()
]))
conn.execute(
self._schema.trial_result.insert().values(
[
{
"exp_id": self._experiment_id,
"trial_id": self._trial_id,
"metric_id": key,
"metric_value": nullable(str, val),
}
for (key, val) in metrics.items()
]
)
)
else:
# Update of the status and ts_start when starting the trial:
assert metrics is None, f"Unexpected metrics for status: {status}"
cur_status = conn.execute(
self._schema.trial.update().where(
self._schema.trial.update()
.where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id == self._trial_id,
self._schema.trial.c.ts_end.is_(None),
self._schema.trial.c.status.notin_(
['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']),
).values(
["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
),
)
.values(
status=status.name,
ts_start=timestamp,
)
@ -108,8 +121,12 @@ class Trial(Storage.Trial):
raise
return metrics
def update_telemetry(self, status: Status, timestamp: datetime,
metrics: List[Tuple[datetime, str, Any]]) -> None:
def update_telemetry(
self,
status: Status,
timestamp: datetime,
metrics: List[Tuple[datetime, str, Any]],
) -> None:
super().update_telemetry(status, timestamp, metrics)
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
@ -120,33 +137,42 @@ class Trial(Storage.Trial):
# See Also: comments in <https://github.com/microsoft/MLOS/pull/466>
with self._engine.begin() as conn:
self._update_status(conn, status, timestamp)
for (metric_ts, key, val) in metrics:
for metric_ts, key, val in metrics:
with self._engine.begin() as conn:
try:
conn.execute(self._schema.trial_telemetry.insert().values(
exp_id=self._experiment_id,
trial_id=self._trial_id,
ts=metric_ts,
metric_id=key,
metric_value=nullable(str, val),
))
conn.execute(
self._schema.trial_telemetry.insert().values(
exp_id=self._experiment_id,
trial_id=self._trial_id,
ts=metric_ts,
metric_id=key,
metric_value=nullable(str, val),
)
)
except IntegrityError as ex:
_LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
"""
Insert a new status record into the database.
This call is idempotent.
"""
# Make sure to convert the timestamp to UTC before storing it in the database.
timestamp = utcify_timestamp(timestamp, origin="local")
try:
conn.execute(self._schema.trial_status.insert().values(
exp_id=self._experiment_id,
trial_id=self._trial_id,
ts=timestamp,
status=status.name,
))
conn.execute(
self._schema.trial_status.insert().values(
exp_id=self._experiment_id,
trial_id=self._trial_id,
ts=timestamp,
status=status.name,
)
)
except IntegrityError as ex:
_LOG.warning("Status with that timestamp already exists: %s %s :: %s",
self, timestamp, ex)
_LOG.warning(
"Status with that timestamp already exists: %s %s :: %s",
self,
timestamp,
ex,
)

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше