This commit is contained in:
liuzhe-lz 2022-03-24 23:37:05 +08:00 коммит произвёл GitHub
Родитель 68347c5e60
Коммит 5136a86d11
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
41 изменённых файлов: 288 добавлений и 106 удалений

1
dependencies/develop.txt поставляемый
Просмотреть файл

@ -5,6 +5,7 @@ ipython
jupyterlab
nbsphinx
pylint
pyright
pytest
pytest-azurepipelines
pytest-cov

2
dependencies/required.txt поставляемый
Просмотреть файл

@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
scipy < 1.8 ; python_version < "3.8"
scipy ; python_version >= "3.8"
typeguard
typing_extensions ; python_version < "3.8"
typing_extensions >= 4.0.0 ; python_version < "3.8"
websockets >= 10.1

Просмотреть файл

@ -0,0 +1,8 @@
Uncategorized Modules
=====================
nni.typehint
------------
.. automodule:: nni.typehint
:members:

Просмотреть файл

@ -1,11 +0,0 @@
Others
======
nni
---
nni.common
----------
nni.utils
---------

Просмотреть файл

@ -9,4 +9,4 @@ API Reference
Model Compression <compression>
Feature Engineering <./python_api/feature_engineering>
Experiment <experiment>
Others <./python_api/others>
Others <others>

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .bohb_advisor import BOHB, BOHBClassArgsValidator

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import warnings

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .gp_tuner import GPTuner, GPClassArgsValidator

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .metis_tuner import MetisTuner, MetisClassArgsValidator

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .networkmorphism_tuner import NetworkMorphismTuner, NetworkMorphismClassArgsValidator

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .ppo_tuner import PPOTuner, PPOClassArgsValidator

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import random

Просмотреть файл

@ -8,10 +8,13 @@ to tell whether this trial can be early stopped or not.
See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details.
"""
from __future__ import annotations
from enum import Enum
import logging
from .recoverable import Recoverable
from .typehint import TrialMetric
__all__ = ['AssessResult', 'Assessor']
@ -54,7 +57,7 @@ class Assessor(Recoverable):
:class:`~nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor`
"""
def assess_trial(self, trial_job_id, trial_history):
def assess_trial(self, trial_job_id: str, trial_history: list[TrialMetric]) -> AssessResult:
"""
Abstract method for determining whether a trial should be killed. Must override.
@ -91,7 +94,7 @@ class Assessor(Recoverable):
"""
raise NotImplementedError('Assessor: assess_trial not implemented')
def trial_end(self, trial_job_id, success):
def trial_end(self, trial_job_id: str, success: bool) -> None:
"""
Abstract method invoked when a trial is completed or terminated. Do nothing by default.
@ -103,22 +106,22 @@ class Assessor(Recoverable):
True if the trial successfully completed; False if failed or terminated.
"""
def load_checkpoint(self):
def load_checkpoint(self) -> None:
"""
Internal API under revising, not recommended for end users.
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self):
def save_checkpoint(self) -> None:
"""
Internal API under revising, not recommended for end users.
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
def _on_exit(self):
def _on_exit(self) -> None:
pass
def _on_error(self):
def _on_error(self) -> None:
pass

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .serializer import trace, dump, load, is_traceable

Просмотреть файл

@ -2,6 +2,8 @@
# Licensed under the MIT license.
"""
Helper class and functions for tuners to deal with search space.
This script provides a more program-friendly representation of HPO search space.
The format is considered internal helper and is not visible to end users.
@ -9,8 +11,16 @@ You will find this useful when you want to support nested search space.
The random tuner is an intuitive example for this utility.
You should check its code before reading docstrings in this file.
.. attention::
This module does not guarantee forward-compatibility.
If you want to use it outside official NNI repo, it is recommended to copy the script.
"""
from __future__ import annotations
__all__ = [
'ParameterSpec',
'deformat_parameters',
@ -20,10 +30,16 @@ __all__ = [
import math
from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Tuple
from typing import Any, Dict, NamedTuple, Tuple, cast
import numpy as np
from nni.typehint import Parameters, SearchSpace
ParameterKey = Tuple['str | int', ...]
FormattedParameters = Dict[ParameterKey, 'float | int']
FormattedSearchSpace = Dict[ParameterKey, 'ParameterSpec']
class ParameterSpec(NamedTuple):
"""
Specification (aka space / range / domain) of one single parameter.
@ -33,29 +49,31 @@ class ParameterSpec(NamedTuple):
name: str # The object key in JSON
type: str # "_type" in JSON
values: List[Any] # "_value" in JSON
values: list[Any] # "_value" in JSON
key: Tuple[str] # The "path" of this parameter
key: ParameterKey # The "path" of this parameter
categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered)
size: int = None # If it's categorical, how many candidates it has
size: int = cast(int, None) # If it's categorical, how many candidates it has
# uniform distributed
low: float = None # Lower bound of uniform parameter
high: float = None # Upper bound of uniform parameter
low: float = cast(float, None) # Lower bound of uniform parameter
high: float = cast(float, None) # Upper bound of uniform parameter
normal_distributed: bool = None # Whether this parameter is uniform or normal distrubuted
mu: float = None # µ of normal parameter
sigma: float = None # σ of normal parameter
normal_distributed: bool = cast(bool, None)
# Whether this parameter is uniform or normal distrubuted
mu: float = cast(float, None) # µ of normal parameter
sigma: float = cast(float, None)# σ of normal parameter
q: Optional[float] = None # If not `None`, the parameter value should be an integer multiple of this
clip: Optional[Tuple[float, float]] = None
q: float | None = None # If not `None`, the parameter value should be an integer multiple of this
clip: tuple[float, float] | None = None
# For q(log)uniform, this equals to "values[:2]"; for others this is None
log_distributed: bool = None # Whether this parameter is log distributed
log_distributed: bool = cast(bool, None)
# Whether this parameter is log distributed
# When true, low/high/mu/sigma describes log of parameter value (like np.lognormal)
def is_activated_in(self, partial_parameters):
def is_activated_in(self, partial_parameters: FormattedParameters) -> bool:
"""
For nested search space, check whether this parameter should be skipped for current set of paremters.
This function must be used in a pattern similar to random tuner. Otherwise it will misbehave.
@ -64,7 +82,7 @@ class ParameterSpec(NamedTuple):
return True
return partial_parameters[self.key[:-2]] == self.key[-2]
def format_search_space(search_space):
def format_search_space(search_space: SearchSpace) -> FormattedSearchSpace:
"""
Convert user provided search space into a dict of ParameterSpec.
The dict key is dict value's `ParameterSpec.key`.
@ -76,7 +94,9 @@ def format_search_space(search_space):
# Remove these comments when we drop 3.6 support.
return {spec.key: spec for spec in formatted}
def deformat_parameters(formatted_parameters, formatted_search_space):
def deformat_parameters(
formatted_parameters: FormattedParameters,
formatted_search_space: FormattedSearchSpace) -> Parameters:
"""
Convert internal format parameters to users' expected format.
@ -88,10 +108,11 @@ def deformat_parameters(formatted_parameters, formatted_search_space):
3. For "q*", convert x to `round(x / q) * q`, then clip into range.
4. For nested choices, convert flatten key-value pairs into nested structure.
"""
ret = {}
ret: Parameters = {}
for key, x in formatted_parameters.items():
spec = formatted_search_space[key]
if spec.categorical:
x = cast(int, x)
if spec.type == 'randint':
lower = min(math.ceil(float(x)) for x in spec.values)
_assign(ret, key, int(lower + x))
@ -112,7 +133,7 @@ def deformat_parameters(formatted_parameters, formatted_search_space):
_assign(ret, key, x)
return ret
def format_parameters(parameters, formatted_search_space):
def format_parameters(parameters: Parameters, formatted_search_space: FormattedSearchSpace) -> FormattedParameters:
"""
Convert end users' parameter format back to internal format, mainly for resuming experiments.
@ -123,7 +144,7 @@ def format_parameters(parameters, formatted_search_space):
for key, spec in formatted_search_space.items():
if not spec.is_activated_in(ret):
continue
value = parameters
value: Any = parameters
for name in key:
if isinstance(name, str):
value = value[name]
@ -142,8 +163,8 @@ def format_parameters(parameters, formatted_search_space):
ret[key] = value
return ret
def _format_search_space(parent_key, space):
formatted = []
def _format_search_space(parent_key: ParameterKey, space: SearchSpace) -> list[ParameterSpec]:
formatted: list[ParameterSpec] = []
for name, spec in space.items():
if name == '_name':
continue
@ -155,7 +176,7 @@ def _format_search_space(parent_key, space):
formatted += _format_search_space(key, sub_space)
return formatted
def _format_parameter(key, type_, values):
def _format_parameter(key: ParameterKey, type_: str, values: list[Any]):
spec = SimpleNamespace(
name = key[-1],
type = type_,
@ -197,7 +218,7 @@ def _format_parameter(key, type_, values):
return ParameterSpec(**spec.__dict__)
def _is_nested_choices(values):
def _is_nested_choices(values: list[Any]) -> bool:
assert values # choices should not be empty
for value in values:
if not isinstance(value, dict):
@ -206,9 +227,9 @@ def _is_nested_choices(values):
return False
return True
def _assign(params, key, x):
def _assign(params: Parameters, key: ParameterKey, x: Any) -> None:
if len(key) == 1:
params[key[0]] = x
params[cast(str, key[0])] = x
elif isinstance(key[0], int):
_assign(params, key[1:], x)
else:

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum
class OptimizeMode(Enum):

Просмотреть файл

@ -1,8 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
from typing import Any, List, Optional
from typing import Any
common_search_space_types = [
'choice',
@ -19,7 +21,7 @@ common_search_space_types = [
def validate_search_space(
search_space: Any,
support_types: Optional[List[str]] = None,
support_types: list[str] | None = None,
raise_exception: bool = False # for now, in case false positive
) -> bool:

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
import base64
import collections.abc

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
try:
import torch

Просмотреть файл

@ -47,6 +47,7 @@ class _AlgorithmConfig(ConfigBase):
else: # custom algorithm
assert self.name is None
assert self.class_name
assert self.code_directory is not None
if not Path(self.code_directory).is_dir():
raise ValueError(f'CustomAlgorithmConfig: code_directory "{self.code_directory}" is not a directory')

Просмотреть файл

@ -37,6 +37,8 @@ def to_v2(v1):
_move_field(v1_trial, v2, 'command', 'trialCommand')
_move_field(v1_trial, v2, 'codeDir', 'trialCodeDirectory')
_move_field(v1_trial, v2, 'gpuNum', 'trialGpuNumber')
else:
v1_trial = {}
for algo_type in ['tuner', 'assessor', 'advisor']:
v1_algo = v1.pop(algo_type, None)

Просмотреть файл

@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase):
if self.password is not None:
warnings.warn('SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.')
elif not Path(self.ssh_key_file).is_file():
elif not Path(self.ssh_key_file).is_file(): # type: ignore
raise ValueError(
f'RemoteMachineConfig: You must either provide password or a valid SSH key file "{self.ssh_key_file}"'
)

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
import json
from typing import List

Просмотреть файл

@ -10,7 +10,7 @@ from pathlib import Path
import socket
from subprocess import Popen
import time
from typing import Optional, Any
from typing import Any
import colorama
import psutil
@ -34,8 +34,7 @@ class RunMode(Enum):
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
NOTE:
This API is non-stable and is likely to get refactored in next release.
NOTE: This API is non-stable and is likely to get refactored in upcoming release.
"""
# TODO:
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
@ -72,15 +71,15 @@ class Experiment:
Web portal port. Or ``None`` if the experiment is not running.
"""
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None) -> None:
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None):
nni.runtime.log.init_logger_for_command_line()
self.config: Optional[ExperimentConfig] = None
self.config: ExperimentConfig | None = None
self.id: str = management.generate_experiment_id()
self.port: Optional[int] = None
self._proc: Optional[Popen] = None
self.action = 'create'
self.url_prefix: Optional[str] = None
self.port: int | None = None
self._proc: Popen | psutil.Process | None = None
self._action = 'create'
self.url_prefix: str | None = None
if isinstance(config_or_platform, (str, list)):
self.config = ExperimentConfig(config_or_platform)
@ -101,6 +100,7 @@ class Experiment:
debug
Whether to start in debug mode.
"""
assert self.config is not None
if run_mode is not RunMode.Detach:
atexit.register(self.stop)
@ -114,7 +114,7 @@ class Experiment:
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc = launcher.start_experiment(self.action, self.id, config, port, debug, run_mode, self.url_prefix)
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix)
assert self._proc is not None
self.port = port # port will be None if start up failed
@ -144,16 +144,16 @@ class Experiment:
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid)
self.id = None
self.id = None # type: ignore
self.port = None
self._proc = None
_logger.info('Experiment stopped')
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool:
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
"""
Run the experiment.
If ``wait_completion`` is True, this function will block until experiment finish or error.
If ``wait_completion`` is ``True``, this function will block until experiment finish or error.
Return ``True`` when experiment done; or return ``False`` when experiment failed.
@ -247,7 +247,7 @@ class Experiment:
def _resume(exp_id, exp_dir=None):
exp = Experiment(None)
exp.id = exp_id
exp.action = 'resume'
exp._action = 'resume'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp
@ -255,7 +255,7 @@ class Experiment:
def _view(exp_id, exp_dir=None):
exp = Experiment(None)
exp.id = exp_id
exp.action = 'view'
exp._action = 'view'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import contextlib
from dataclasses import dataclass, fields
from datetime import datetime
@ -126,9 +128,9 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
return proc
def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
def _start_rest_server(nni_manager_args, run_mode) -> Popen:
import nni_node
node_dir = Path(nni_node.__path__[0])
node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
@ -151,10 +153,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
from subprocess import CREATE_NEW_PROCESS_GROUP
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp)
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) # type: ignore
def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen:
def start_experiment_retiarii(exp_id, config, port, debug):
pipe = None
proc = None
@ -221,7 +223,7 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool
args['dispatcher_pipe'] = pipe_path
import nni_node
node_dir = Path(nni_node.__path__[0])
node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
@ -259,8 +261,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
def get_stopped_experiment_config(exp_id, exp_dir=None):
config_json = get_stopped_experiment_config_json(exp_id, exp_dir)
config = ExperimentConfig(**config_json)
config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore
config = ExperimentConfig(**config_json) # type: ignore
if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory):
msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger.warning(msg, exp_dir, config.experiment_working_directory)

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
import random
import string

Просмотреть файл

@ -1,4 +1,6 @@
from io import BufferedIOBase
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import sys
@ -25,7 +27,7 @@ if sys.platform == 'win32':
_winapi.NULL
)
def connect(self) -> BufferedIOBase:
def connect(self):
_winapi.ConnectNamedPipe(self._handle, _winapi.NULL)
fd = msvcrt.open_osfhandle(self._handle, 0)
self.file = os.fdopen(fd, 'w+b')
@ -55,7 +57,7 @@ else:
self._socket.bind(self.path)
self._socket.listen(1) # only accepts one connection
def connect(self) -> BufferedIOBase:
def connect(self):
conn, _ = self._socket.accept()
self.file = conn.makefile('rwb')
return self.file

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Any, Optional

Просмотреть файл

@ -1,17 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import os
class Recoverable:
def load_checkpoint(self):
def load_checkpoint(self) -> None:
pass
def save_checkpoint(self):
def save_checkpoint(self) -> None:
pass
def get_checkpoint_path(self):
def get_checkpoint_path(self) -> str | None:
ckp_path = os.getenv('NNI_CHECKPOINT_DIRECTORY')
if ckp_path is not None and os.path.isdir(ckp_path):
return ckp_path

Просмотреть файл

@ -14,7 +14,7 @@ def get_config_directory() -> Path:
Create it if not exist.
"""
if os.getenv('NNI_CONFIG_DIR') is not None:
config_dir = Path(os.getenv('NNI_CONFIG_DIR'))
config_dir = Path(os.getenv('NNI_CONFIG_DIR')) # type: ignore
elif sys.prefix != sys.base_prefix or Path(sys.prefix, 'conda-meta').is_dir():
config_dir = Path(sys.prefix, 'nni')
elif sys.platform == 'win32':
@ -39,4 +39,4 @@ def get_builtin_config_file(name: str) -> Path:
"""
Get a readonly builtin config file.
"""
return Path(nni.__path__[0], 'runtime/default_config', name)
return Path(nni.__path__[0], 'runtime/default_config', name) # type: ignore

Просмотреть файл

@ -1,3 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import sys
from datetime import datetime
@ -105,7 +110,7 @@ def _init_logger_standalone() -> None:
_register_handler(StreamHandler(sys.stdout), logging.INFO)
def _prepare_log_dir(path: Optional[str]) -> Path:
def _prepare_log_dir(path: Path | str) -> Path:
if path is None:
return Path()
ret = Path(path)
@ -148,7 +153,7 @@ class _LogFileWrapper(TextIOBase):
def __init__(self, log_file: TextIOBase):
self.file: TextIOBase = log_file
self.line_buffer: Optional[str] = None
self.line_start_time: Optional[datetime] = None
self.line_start_time: datetime = datetime.fromtimestamp(0)
def write(self, s: str) -> int:
cur_time = datetime.now()

Просмотреть файл

@ -212,6 +212,7 @@ class MsgDispatcher(MsgDispatcherBase):
except Exception as e:
_logger.error('Assessor error')
_logger.exception(e)
raise
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from . import proxy
load_jupyter_server_extension = proxy.setup

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from pathlib import Path
import shutil

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from pathlib import Path

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib
import json

Просмотреть файл

@ -1,13 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import Any
from .common.serializer import dump
from .runtime.env_vars import trial_env_vars
from .runtime import platform
from .typehint import Parameters, TrialMetric
__all__ = [
'get_next_parameter',
'get_next_parameters',
'get_current_parameter',
'report_intermediate_result',
'report_final_result',
@ -23,7 +28,7 @@ _trial_id = platform.get_trial_id()
_sequence_id = platform.get_sequence_id()
def get_next_parameter():
def get_next_parameter() -> Parameters:
"""
Get the hyperparameters generated by tuner.
@ -32,7 +37,7 @@ def get_next_parameter():
Examples
--------
Assuming the search space is:
Assuming the :doc:`search space </hpo/search_space>` is:
.. code-block::
@ -52,16 +57,22 @@ def get_next_parameter():
Returns
-------
dict
:class:`~nni.typehint.Parameters`
A hyperparameter set sampled from search space.
"""
global _params
_params = platform.get_next_parameter()
if _params is None:
return None
return None # type: ignore
return _params['parameters']
def get_current_parameter(tag=None):
def get_next_parameters() -> Parameters:
"""
Alias of :func:`get_next_parameter`
"""
return get_next_parameter()
def get_current_parameter(tag: str | None = None) -> Any:
global _params
if _params is None:
return None
@ -94,13 +105,13 @@ def get_sequence_id() -> int:
_intermediate_seq = 0
def overwrite_intermediate_seq(value):
def overwrite_intermediate_seq(value: int) -> None:
assert isinstance(value, int)
global _intermediate_seq
_intermediate_seq = value
def report_intermediate_result(metric):
def report_intermediate_result(metric: TrialMetric | dict[str, Any]) -> None:
"""
Reports intermediate result to NNI.
@ -110,11 +121,16 @@ def report_intermediate_result(metric):
and other items can be visualized with web portal.
Typically ``metric`` is per-epoch accuracy or loss.
Parameters
----------
metric : :class:`~nni.typehint.TrialMetric`
The intermeidate result.
"""
global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = dump({
dumped_metric = dump({
'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL',
@ -122,9 +138,9 @@ def report_intermediate_result(metric):
'value': dump(metric)
})
_intermediate_seq += 1
platform.send_metric(metric)
platform.send_metric(dumped_metric)
def report_final_result(metric):
def report_final_result(metric: TrialMetric | dict[str, Any]) -> None:
"""
Reports final result to NNI.
@ -134,14 +150,19 @@ def report_final_result(metric):
and other items can be visualized with web portal.
Typically ``metric`` is the final accuracy or loss.
Parameters
----------
metric : :class:`~nni.typehint.TrialMetric`
The final result.
"""
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result'
metric = dump({
dumped_metric = dump({
'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL',
'sequence': 0,
'value': dump(metric)
})
platform.send_metric(metric)
platform.send_metric(dumped_metric)

Просмотреть файл

@ -8,11 +8,14 @@ A new trial will run with this configuration.
See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details.
"""
from __future__ import annotations
import logging
import nni
from .recoverable import Recoverable
from .typehint import Parameters, SearchSpace, TrialMetric, TrialRecord
__all__ = ['Tuner']
@ -67,7 +70,7 @@ class Tuner(Recoverable):
:class:`~nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner`
"""
def generate_parameters(self, parameter_id, **kwargs):
def generate_parameters(self, parameter_id: int, **kwargs) -> Parameters:
"""
Abstract method which provides a set of hyper-parameters.
@ -100,7 +103,7 @@ class Tuner(Recoverable):
# we need to design a new exception for this purpose
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
def generate_multiple_parameters(self, parameter_id_list: list[int], **kwargs) -> list[Parameters]:
"""
Callback method which provides multiple sets of hyper-parameters.
@ -135,7 +138,7 @@ class Tuner(Recoverable):
result.append(res)
return result
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
def receive_trial_result(self, parameter_id: int, parameters: Parameters, value: TrialMetric, **kwargs) -> None:
"""
Abstract method invoked when a trial reports its final result. Must override.
@ -165,7 +168,7 @@ class Tuner(Recoverable):
# pylint: disable=attribute-defined-outside-init
self._accept_customized = accept
def trial_end(self, parameter_id, success, **kwargs):
def trial_end(self, parameter_id: int, success: bool, **kwargs) -> None:
"""
Abstract method invoked when a trial is completed or terminated. Do nothing by default.
@ -179,7 +182,7 @@ class Tuner(Recoverable):
Unstable parameters which should be ignored by normal users.
"""
def update_search_space(self, search_space):
def update_search_space(self, search_space: SearchSpace) -> None:
"""
Abstract method for updating the search space. Must override.
@ -194,21 +197,21 @@ class Tuner(Recoverable):
"""
raise NotImplementedError('Tuner: update_search_space not implemented')
def load_checkpoint(self):
def load_checkpoint(self) -> None:
"""
Internal API under revising, not recommended for end users.
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self):
def save_checkpoint(self) -> None:
"""
Internal API under revising, not recommended for end users.
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def import_data(self, data):
def import_data(self, data: list[TrialRecord]) -> None:
"""
Internal API under revising, not recommended for end users.
"""
@ -216,8 +219,8 @@ class Tuner(Recoverable):
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass
def _on_exit(self):
def _on_exit(self) -> None:
pass
def _on_error(self):
def _on_error(self) -> None:
pass

Просмотреть файл

@ -1,10 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import typing
"""
Types for static checking.
"""
if typing.TYPE_CHECKING or sys.version_info >= (3, 8):
Literal = typing.Literal
__all__ = [
'Literal',
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord',
]
import sys
from typing import Any, Dict, List, TYPE_CHECKING
if TYPE_CHECKING or sys.version_info >= (3, 8):
from typing import Literal, TypedDict
else:
Literal = typing.Any
from typing_extensions import Literal, TypedDict
Parameters = Dict[str, Any]
"""
Return type of :func:`nni.get_next_parameter`.
For built-in tuners, this is a ``dict`` whose content is defined by :doc:`search space </hpo/search_space>`.
Customized tuners do not need to follow the constraint and can use anything serializable.
"""
class _ParameterSearchSpace(TypedDict):
_type: Literal[
'choice', 'randint',
'uniform', 'loguniform', 'quniform', 'qloguniform',
'normal', 'lognormal', 'qnormal', 'qlognormal',
]
_value: List[Any]
SearchSpace = Dict[str, _ParameterSearchSpace]
"""
Type of ``experiment.config.search_space``.
For built-in tuners, the format is detailed in :doc:`/hpo/search_space`.
Customized tuners do not need to follow the constraint and can use anything serializable, except ``None``.
"""
TrialMetric = float
"""
Type of the metrics sent to :func:`nni.report_final_result` and :func:`nni.report_intermediate_result`.
For built-in tuners it must be a number (``float``, ``int``, ``numpy.float32``, etc).
Customized tuners do not need to follow this constraint and can use anything serializable.
"""
class TrialRecord(TypedDict):
parameter: Parameters
value: TrialMetric

Просмотреть файл

@ -63,6 +63,9 @@ stages:
python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
displayName: flake8
- script: |
python -m pyright nni
- job: typescript
pool:
vmImage: ubuntu-latest

14
pyrightconfig.json Normal file
Просмотреть файл

@ -0,0 +1,14 @@
{
"ignore": [
"nni/algorithms",
"nni/common/device.py",
"nni/common/graph_utils.py",
"nni/common/serializer.py",
"nni/compression",
"nni/nas",
"nni/retiarii",
"nni/smartparam.py",
"nni/tools"
],
"reportMissingImports": false
}