# Pull Request

## Title

Type hint fixups

---

## Description

A few type hint fixups.

---

## Type of Change

- 🔄 Refactor

---

## Testing

- CI
- Local

---

## Additional Notes

After #916 

---
This commit is contained in:
Brian Kroth 2025-01-09 16:42:26 -06:00 коммит произвёл GitHub
Родитель d7294fa887
Коммит 6d91add928
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
16 изменённых файлов: 41 добавлений и 27 удалений

Просмотреть файл

@ -2,12 +2,19 @@
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [pre-commit]
# Note "require_serial" actually controls whether that particular hook's files
# are partitioned and the hook executable called in parallel across them, not
# whether hooks themselves are parallelized.
# As such, some hooks (e.g., pylint) which do internal parallelism need it set
# for effeciency and correctness anyways.
repos:
#
# Formatting
#
# NOTE: checks that adjust files are marked with the special "manual" stage
# and "require_serial" so that we can easily call them via `make`
# NOTE: checks that adjust files are marked with the special "manual" stage so
# that we can easily call them via `make`.
#
#
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
@ -18,15 +25,12 @@ repos:
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
require_serial: true
stages: [pre-commit, manual]
# TODO:
#- id: pretty-format-json
# args: [--autofix, --no-sort-keys]
# require_serial: true
# stages: [pre-commit, manual]
- id: trailing-whitespace
require_serial: true
stages: [pre-commit, manual]
- repo: https://github.com/johann-petrak/licenseheaders
rev: v0.8.8
@ -34,32 +38,29 @@ repos:
- id: licenseheaders
files: '\.(sh|cmd|ps1|sql|py)$'
args: [-t, doc/mit-license.tmpl, -E, .py, .sh, .ps1, .sql, .cmd, -x, mlos_bench/setup.py, mlos_core/setup.py, mlos_viz/setup.py, -f]
require_serial: true
stages: [pre-commit, manual]
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py310-plus]
require_serial: true
stages: [pre-commit, manual]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
require_serial: true
args: ["-j", "-1"]
stages: [pre-commit, manual]
- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
- id: black
require_serial: true
stages: [pre-commit, manual]
- repo: https://github.com/PyCQA/docformatter
rev: 06907d0 # v1.7.5
hooks:
- id: docformatter
require_serial: true
stages: [pre-commit, manual]
#
# Linting
@ -69,6 +70,8 @@ repos:
hooks:
- id: pydocstyle
types: [python]
additional_dependencies:
- tomli
# Use pylint and mypy from the local (conda) environment so that vscode can reuse them too.
- repo: local
hooks:
@ -82,6 +85,7 @@ repos:
entry: pylint
language: system
types: [python]
require_serial: true
args: [
"-j0",
"--rcfile=pyproject.toml",
@ -97,6 +101,7 @@ repos:
entry: mypy
language: system
types: [python]
require_serial: true
exclude: |
(?x)^(
doc/source/conf.py|

6
.vscode/settings.json поставляемый
Просмотреть файл

@ -167,5 +167,9 @@
"--log-level=DEBUG",
"."
],
"python.testing.unittestEnabled": false
"python.testing.unittestEnabled": false,
"debugpy.debugJustMyCode": false,
"python.analysis.autoImportCompletions": true,
"python.analysis.supportRestructuredText": true,
"python.analysis.typeCheckingMode": "standard"
}

Просмотреть файл

@ -8,6 +8,7 @@ import abc
import json
import logging
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager as ContextManager
from datetime import datetime
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal
@ -28,7 +29,7 @@ if TYPE_CHECKING:
_LOG = logging.getLogger(__name__)
class Environment(metaclass=abc.ABCMeta):
class Environment(ContextManager, metaclass=abc.ABCMeta):
# pylint: disable=too-many-instance-attributes
"""An abstract base of all benchmark environments."""

Просмотреть файл

@ -91,7 +91,7 @@ class LocalFileShareEnv(LocalEnv):
def _expand(
from_to: Iterable[tuple[Template, Template]],
params: Mapping[str, TunableValue],
) -> Generator[tuple[str, str], None, None]:
) -> Generator[tuple[str, str]]:
"""
Substitute $var parameters in from/to path templates.

Просмотреть файл

@ -9,6 +9,7 @@ optimizers.
import logging
from abc import ABCMeta, abstractmethod
from collections.abc import Sequence
from contextlib import AbstractContextManager as ContextManager
from types import TracebackType
from typing import Literal
@ -25,7 +26,7 @@ from mlos_bench.util import strtobool
_LOG = logging.getLogger(__name__)
class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
class Optimizer(ContextManager, metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
"""An abstract interface between the benchmarking framework and mlos_core
optimizers.
"""

Просмотреть файл

@ -7,6 +7,7 @@
import json
import logging
from abc import ABCMeta, abstractmethod
from contextlib import AbstractContextManager as ContextManager
from datetime import datetime
from types import TracebackType
from typing import Any, Literal
@ -23,7 +24,7 @@ from mlos_bench.util import merge_parameters
_LOG = logging.getLogger(__name__)
class Scheduler(metaclass=ABCMeta):
class Scheduler(ContextManager, metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""Base class for the optimization loop scheduling policies."""

Просмотреть файл

@ -9,6 +9,7 @@ from __future__ import annotations
import json
import logging
from collections.abc import Callable
from contextlib import AbstractContextManager as ContextManager
from types import TracebackType
from typing import Any, Literal
@ -19,7 +20,7 @@ from mlos_bench.util import instantiate_from_config
_LOG = logging.getLogger(__name__)
class Service:
class Service(ContextManager):
"""An abstract base of all Environment Services and used to build up mix-ins."""
@classmethod

Просмотреть файл

@ -67,7 +67,7 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
Free-format dictionary of global parameters.
parent : Service
An optional parent service that can provide mixin functions.
methods : Union[dict[str, Callable], list[Callable], None]
methods : dict[str, Callable] | list[Callable] | None
New methods to register with the service.
"""
super().__init__(
@ -166,7 +166,7 @@ class ConfigPersistenceService(Service, SupportsConfigLoading):
Returns
-------
config : Union[dict, list[dict]]
config : dict | list[dict]
Free-format dictionary that contains the configuration.
"""
assert isinstance(json, str)

Просмотреть файл

@ -25,6 +25,7 @@ mlos_bench.storage.base_trial_data.TrialData :
import logging
from abc import ABCMeta, abstractmethod
from collections.abc import Iterator
from contextlib import AbstractContextManager as ContextManager
from datetime import datetime
from types import TracebackType
from typing import Any, Literal
@ -132,7 +133,7 @@ class Storage(metaclass=ABCMeta):
the results of the experiment and related data.
"""
class Experiment(metaclass=ABCMeta):
class Experiment(ContextManager, metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
Base interface for storing the results of the experiment.

Просмотреть файл

@ -51,7 +51,7 @@ def ssh_test_server(
ssh_test_server_hostname: str,
docker_compose_project_name: str,
locked_docker_services: DockerServices,
) -> Generator[SshTestServerInfo, None, None]:
) -> Generator[SshTestServerInfo]:
"""
Fixture for getting the ssh test server services setup via docker-compose using
pytest-docker.

Просмотреть файл

@ -23,7 +23,7 @@ from mlos_bench.util import path_join
@contextmanager
def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]:
def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper]:
"""
Provides a context manager for a temporary file that can be closed and still
unlinked.

Просмотреть файл

@ -38,7 +38,7 @@ def storage() -> SqlStorage:
def exp_storage(
storage: SqlStorage,
tunable_groups: TunableGroups,
) -> Generator[SqlStorage.Experiment, None, None]:
) -> Generator[SqlStorage.Experiment]:
"""
Test fixture for Experiment using in-memory SQLite3 storage.
@ -60,7 +60,7 @@ def exp_storage(
@pytest.fixture
def exp_no_tunables_storage(
storage: SqlStorage,
) -> Generator[SqlStorage.Experiment, None, None]:
) -> Generator[SqlStorage.Experiment]:
"""
Test fixture for Experiment using in-memory SQLite3 storage.
@ -84,7 +84,7 @@ def exp_no_tunables_storage(
def mixed_numerics_exp_storage(
storage: SqlStorage,
mixed_numerics_tunable_groups: TunableGroups,
) -> Generator[SqlStorage.Experiment, None, None]:
) -> Generator[SqlStorage.Experiment]:
"""
Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3
storage.

Просмотреть файл

@ -179,13 +179,13 @@ class TunableGroups:
self._index[name][name] = value
return self._index[name][name]
def __iter__(self) -> Generator[tuple[Tunable, CovariantTunableGroup], None, None]:
def __iter__(self) -> Generator[tuple[Tunable, CovariantTunableGroup]]:
"""
An iterator over all tunables in the group.
Returns
-------
[(tunable, group), ...] : Generator[tuple[Tunable, CovariantTunableGroup], None, None]
[(tunable, group), ...] : Generator[tuple[Tunable, CovariantTunableGroup]]
An iterator over all tunables in all groups. Each element is a 2-tuple
of an instance of the Tunable parameter and covariant group it belongs to.
"""

Просмотреть файл

@ -93,7 +93,7 @@ def test_configspace_quant_repatch() -> None:
new_meta[QUANTIZATION_BINS_META_KEY] = 21
hp.meta = new_meta
monkey_patch_hp_quantization(hp)
samples_set = set(hp.sample_value(100, seed=RandomState(SEED)))
samples_set: set[int] = set(hp.sample_value(100, seed=RandomState(SEED)))
quantized_values_new = set(range(5, 96, 10))
assert samples_set.issubset(set(range(0, 101, 5)))
assert len(samples_set - quantized_values_new) < len(samples_set)

Просмотреть файл

@ -62,7 +62,6 @@ disable = [
"consider-using-assignment-expr",
"docstring-first-line-empty",
"missing-raises-doc",
"unnecessary-default-type-args", # affects Generator type hints, but we still support python 3.8
]
[tool.pylint.string]

Просмотреть файл

@ -73,6 +73,7 @@ exclude_also =
#
[mypy]
cache_fine_grained = True
#ignore_missing_imports = True
warn_unused_configs = True
warn_unused_ignores = True