зеркало из https://github.com/microsoft/nni.git
Typehint and copyright header (#4669)
This commit is contained in:
Родитель
68347c5e60
Коммит
5136a86d11
|
@ -5,6 +5,7 @@ ipython
|
|||
jupyterlab
|
||||
nbsphinx
|
||||
pylint
|
||||
pyright
|
||||
pytest
|
||||
pytest-azurepipelines
|
||||
pytest-cov
|
||||
|
|
|
@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
|
|||
scipy < 1.8 ; python_version < "3.8"
|
||||
scipy ; python_version >= "3.8"
|
||||
typeguard
|
||||
typing_extensions ; python_version < "3.8"
|
||||
typing_extensions >= 4.0.0 ; python_version < "3.8"
|
||||
websockets >= 10.1
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
Uncategorized Modules
|
||||
=====================
|
||||
|
||||
nni.typehint
|
||||
------------
|
||||
|
||||
.. automodule:: nni.typehint
|
||||
:members:
|
|
@ -1,11 +0,0 @@
|
|||
Others
|
||||
======
|
||||
|
||||
nni
|
||||
---
|
||||
|
||||
nni.common
|
||||
----------
|
||||
|
||||
nni.utils
|
||||
---------
|
|
@ -9,4 +9,4 @@ API Reference
|
|||
Model Compression <compression>
|
||||
Feature Engineering <./python_api/feature_engineering>
|
||||
Experiment <experiment>
|
||||
Others <./python_api/others>
|
||||
Others <others>
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .bohb_advisor import BOHB, BOHBClassArgsValidator
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .gp_tuner import GPTuner, GPClassArgsValidator
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .metis_tuner import MetisTuner, MetisClassArgsValidator
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .networkmorphism_tuner import NetworkMorphismTuner, NetworkMorphismClassArgsValidator
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .ppo_tuner import PPOTuner, PPOClassArgsValidator
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
|
|
|
@ -8,10 +8,13 @@ to tell whether this trial can be early stopped or not.
|
|||
See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from .recoverable import Recoverable
|
||||
from .typehint import TrialMetric
|
||||
|
||||
__all__ = ['AssessResult', 'Assessor']
|
||||
|
||||
|
@ -54,7 +57,7 @@ class Assessor(Recoverable):
|
|||
:class:`~nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor`
|
||||
"""
|
||||
|
||||
def assess_trial(self, trial_job_id, trial_history):
|
||||
def assess_trial(self, trial_job_id: str, trial_history: list[TrialMetric]) -> AssessResult:
|
||||
"""
|
||||
Abstract method for determining whether a trial should be killed. Must override.
|
||||
|
||||
|
@ -91,7 +94,7 @@ class Assessor(Recoverable):
|
|||
"""
|
||||
raise NotImplementedError('Assessor: assess_trial not implemented')
|
||||
|
||||
def trial_end(self, trial_job_id, success):
|
||||
def trial_end(self, trial_job_id: str, success: bool) -> None:
|
||||
"""
|
||||
Abstract method invoked when a trial is completed or terminated. Do nothing by default.
|
||||
|
||||
|
@ -103,22 +106,22 @@ class Assessor(Recoverable):
|
|||
True if the trial successfully completed; False if failed or terminated.
|
||||
"""
|
||||
|
||||
def load_checkpoint(self):
|
||||
def load_checkpoint(self) -> None:
|
||||
"""
|
||||
Internal API under revising, not recommended for end users.
|
||||
"""
|
||||
checkpoin_path = self.get_checkpoint_path()
|
||||
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
|
||||
|
||||
def save_checkpoint(self):
|
||||
def save_checkpoint(self) -> None:
|
||||
"""
|
||||
Internal API under revising, not recommended for end users.
|
||||
"""
|
||||
checkpoin_path = self.get_checkpoint_path()
|
||||
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
|
||||
|
||||
def _on_exit(self):
|
||||
def _on_exit(self) -> None:
|
||||
pass
|
||||
|
||||
def _on_error(self):
|
||||
def _on_error(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .serializer import trace, dump, load, is_traceable
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Helper class and functions for tuners to deal with search space.
|
||||
|
||||
This script provides a more program-friendly representation of HPO search space.
|
||||
The format is considered internal helper and is not visible to end users.
|
||||
|
||||
|
@ -9,8 +11,16 @@ You will find this useful when you want to support nested search space.
|
|||
|
||||
The random tuner is an intuitive example for this utility.
|
||||
You should check its code before reading docstrings in this file.
|
||||
|
||||
.. attention::
|
||||
|
||||
This module does not guarantee forward-compatibility.
|
||||
|
||||
If you want to use it outside official NNI repo, it is recommended to copy the script.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
'ParameterSpec',
|
||||
'deformat_parameters',
|
||||
|
@ -20,10 +30,16 @@ __all__ = [
|
|||
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||
from typing import Any, Dict, NamedTuple, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from nni.typehint import Parameters, SearchSpace
|
||||
|
||||
ParameterKey = Tuple['str | int', ...]
|
||||
FormattedParameters = Dict[ParameterKey, 'float | int']
|
||||
FormattedSearchSpace = Dict[ParameterKey, 'ParameterSpec']
|
||||
|
||||
class ParameterSpec(NamedTuple):
|
||||
"""
|
||||
Specification (aka space / range / domain) of one single parameter.
|
||||
|
@ -33,29 +49,31 @@ class ParameterSpec(NamedTuple):
|
|||
|
||||
name: str # The object key in JSON
|
||||
type: str # "_type" in JSON
|
||||
values: List[Any] # "_value" in JSON
|
||||
values: list[Any] # "_value" in JSON
|
||||
|
||||
key: Tuple[str] # The "path" of this parameter
|
||||
key: ParameterKey # The "path" of this parameter
|
||||
|
||||
categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered)
|
||||
size: int = None # If it's categorical, how many candidates it has
|
||||
size: int = cast(int, None) # If it's categorical, how many candidates it has
|
||||
|
||||
# uniform distributed
|
||||
low: float = None # Lower bound of uniform parameter
|
||||
high: float = None # Upper bound of uniform parameter
|
||||
low: float = cast(float, None) # Lower bound of uniform parameter
|
||||
high: float = cast(float, None) # Upper bound of uniform parameter
|
||||
|
||||
normal_distributed: bool = None # Whether this parameter is uniform or normal distrubuted
|
||||
mu: float = None # µ of normal parameter
|
||||
sigma: float = None # σ of normal parameter
|
||||
normal_distributed: bool = cast(bool, None)
|
||||
# Whether this parameter is uniform or normal distrubuted
|
||||
mu: float = cast(float, None) # µ of normal parameter
|
||||
sigma: float = cast(float, None)# σ of normal parameter
|
||||
|
||||
q: Optional[float] = None # If not `None`, the parameter value should be an integer multiple of this
|
||||
clip: Optional[Tuple[float, float]] = None
|
||||
q: float | None = None # If not `None`, the parameter value should be an integer multiple of this
|
||||
clip: tuple[float, float] | None = None
|
||||
# For q(log)uniform, this equals to "values[:2]"; for others this is None
|
||||
|
||||
log_distributed: bool = None # Whether this parameter is log distributed
|
||||
log_distributed: bool = cast(bool, None)
|
||||
# Whether this parameter is log distributed
|
||||
# When true, low/high/mu/sigma describes log of parameter value (like np.lognormal)
|
||||
|
||||
def is_activated_in(self, partial_parameters):
|
||||
def is_activated_in(self, partial_parameters: FormattedParameters) -> bool:
|
||||
"""
|
||||
For nested search space, check whether this parameter should be skipped for current set of paremters.
|
||||
This function must be used in a pattern similar to random tuner. Otherwise it will misbehave.
|
||||
|
@ -64,7 +82,7 @@ class ParameterSpec(NamedTuple):
|
|||
return True
|
||||
return partial_parameters[self.key[:-2]] == self.key[-2]
|
||||
|
||||
def format_search_space(search_space):
|
||||
def format_search_space(search_space: SearchSpace) -> FormattedSearchSpace:
|
||||
"""
|
||||
Convert user provided search space into a dict of ParameterSpec.
|
||||
The dict key is dict value's `ParameterSpec.key`.
|
||||
|
@ -76,7 +94,9 @@ def format_search_space(search_space):
|
|||
# Remove these comments when we drop 3.6 support.
|
||||
return {spec.key: spec for spec in formatted}
|
||||
|
||||
def deformat_parameters(formatted_parameters, formatted_search_space):
|
||||
def deformat_parameters(
|
||||
formatted_parameters: FormattedParameters,
|
||||
formatted_search_space: FormattedSearchSpace) -> Parameters:
|
||||
"""
|
||||
Convert internal format parameters to users' expected format.
|
||||
|
||||
|
@ -88,10 +108,11 @@ def deformat_parameters(formatted_parameters, formatted_search_space):
|
|||
3. For "q*", convert x to `round(x / q) * q`, then clip into range.
|
||||
4. For nested choices, convert flatten key-value pairs into nested structure.
|
||||
"""
|
||||
ret = {}
|
||||
ret: Parameters = {}
|
||||
for key, x in formatted_parameters.items():
|
||||
spec = formatted_search_space[key]
|
||||
if spec.categorical:
|
||||
x = cast(int, x)
|
||||
if spec.type == 'randint':
|
||||
lower = min(math.ceil(float(x)) for x in spec.values)
|
||||
_assign(ret, key, int(lower + x))
|
||||
|
@ -112,7 +133,7 @@ def deformat_parameters(formatted_parameters, formatted_search_space):
|
|||
_assign(ret, key, x)
|
||||
return ret
|
||||
|
||||
def format_parameters(parameters, formatted_search_space):
|
||||
def format_parameters(parameters: Parameters, formatted_search_space: FormattedSearchSpace) -> FormattedParameters:
|
||||
"""
|
||||
Convert end users' parameter format back to internal format, mainly for resuming experiments.
|
||||
|
||||
|
@ -123,7 +144,7 @@ def format_parameters(parameters, formatted_search_space):
|
|||
for key, spec in formatted_search_space.items():
|
||||
if not spec.is_activated_in(ret):
|
||||
continue
|
||||
value = parameters
|
||||
value: Any = parameters
|
||||
for name in key:
|
||||
if isinstance(name, str):
|
||||
value = value[name]
|
||||
|
@ -142,8 +163,8 @@ def format_parameters(parameters, formatted_search_space):
|
|||
ret[key] = value
|
||||
return ret
|
||||
|
||||
def _format_search_space(parent_key, space):
|
||||
formatted = []
|
||||
def _format_search_space(parent_key: ParameterKey, space: SearchSpace) -> list[ParameterSpec]:
|
||||
formatted: list[ParameterSpec] = []
|
||||
for name, spec in space.items():
|
||||
if name == '_name':
|
||||
continue
|
||||
|
@ -155,7 +176,7 @@ def _format_search_space(parent_key, space):
|
|||
formatted += _format_search_space(key, sub_space)
|
||||
return formatted
|
||||
|
||||
def _format_parameter(key, type_, values):
|
||||
def _format_parameter(key: ParameterKey, type_: str, values: list[Any]):
|
||||
spec = SimpleNamespace(
|
||||
name = key[-1],
|
||||
type = type_,
|
||||
|
@ -197,7 +218,7 @@ def _format_parameter(key, type_, values):
|
|||
|
||||
return ParameterSpec(**spec.__dict__)
|
||||
|
||||
def _is_nested_choices(values):
|
||||
def _is_nested_choices(values: list[Any]) -> bool:
|
||||
assert values # choices should not be empty
|
||||
for value in values:
|
||||
if not isinstance(value, dict):
|
||||
|
@ -206,9 +227,9 @@ def _is_nested_choices(values):
|
|||
return False
|
||||
return True
|
||||
|
||||
def _assign(params, key, x):
|
||||
def _assign(params: Parameters, key: ParameterKey, x: Any) -> None:
|
||||
if len(key) == 1:
|
||||
params[key[0]] = x
|
||||
params[cast(str, key[0])] = x
|
||||
elif isinstance(key[0], int):
|
||||
_assign(params, key[1:], x)
|
||||
else:
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class OptimizeMode(Enum):
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
common_search_space_types = [
|
||||
'choice',
|
||||
|
@ -19,7 +21,7 @@ common_search_space_types = [
|
|||
|
||||
def validate_search_space(
|
||||
search_space: Any,
|
||||
support_types: Optional[List[str]] = None,
|
||||
support_types: list[str] | None = None,
|
||||
raise_exception: bool = False # for now, in case false positive
|
||||
) -> bool:
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import abc
|
||||
import base64
|
||||
import collections.abc
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
try:
|
||||
import torch
|
||||
|
|
|
@ -47,6 +47,7 @@ class _AlgorithmConfig(ConfigBase):
|
|||
else: # custom algorithm
|
||||
assert self.name is None
|
||||
assert self.class_name
|
||||
assert self.code_directory is not None
|
||||
if not Path(self.code_directory).is_dir():
|
||||
raise ValueError(f'CustomAlgorithmConfig: code_directory "{self.code_directory}" is not a directory')
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ def to_v2(v1):
|
|||
_move_field(v1_trial, v2, 'command', 'trialCommand')
|
||||
_move_field(v1_trial, v2, 'codeDir', 'trialCodeDirectory')
|
||||
_move_field(v1_trial, v2, 'gpuNum', 'trialGpuNumber')
|
||||
else:
|
||||
v1_trial = {}
|
||||
|
||||
for algo_type in ['tuner', 'assessor', 'advisor']:
|
||||
v1_algo = v1.pop(algo_type, None)
|
||||
|
|
|
@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase):
|
|||
|
||||
if self.password is not None:
|
||||
warnings.warn('SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.')
|
||||
elif not Path(self.ssh_key_file).is_file():
|
||||
elif not Path(self.ssh_key_file).is_file(): # type: ignore
|
||||
raise ValueError(
|
||||
f'RemoteMachineConfig: You must either provide password or a valid SSH key file "{self.ssh_key_file}"'
|
||||
)
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from typing import List
|
||||
|
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
|||
import socket
|
||||
from subprocess import Popen
|
||||
import time
|
||||
from typing import Optional, Any
|
||||
from typing import Any
|
||||
|
||||
import colorama
|
||||
import psutil
|
||||
|
@ -34,8 +34,7 @@ class RunMode(Enum):
|
|||
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
|
||||
- Detach: do not stop NNI manager when Python script exits.
|
||||
|
||||
NOTE:
|
||||
This API is non-stable and is likely to get refactored in next release.
|
||||
NOTE: This API is non-stable and is likely to get refactored in upcoming release.
|
||||
"""
|
||||
# TODO:
|
||||
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
|
||||
|
@ -72,15 +71,15 @@ class Experiment:
|
|||
Web portal port. Or ``None`` if the experiment is not running.
|
||||
"""
|
||||
|
||||
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None) -> None:
|
||||
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None):
|
||||
nni.runtime.log.init_logger_for_command_line()
|
||||
|
||||
self.config: Optional[ExperimentConfig] = None
|
||||
self.config: ExperimentConfig | None = None
|
||||
self.id: str = management.generate_experiment_id()
|
||||
self.port: Optional[int] = None
|
||||
self._proc: Optional[Popen] = None
|
||||
self.action = 'create'
|
||||
self.url_prefix: Optional[str] = None
|
||||
self.port: int | None = None
|
||||
self._proc: Popen | psutil.Process | None = None
|
||||
self._action = 'create'
|
||||
self.url_prefix: str | None = None
|
||||
|
||||
if isinstance(config_or_platform, (str, list)):
|
||||
self.config = ExperimentConfig(config_or_platform)
|
||||
|
@ -101,6 +100,7 @@ class Experiment:
|
|||
debug
|
||||
Whether to start in debug mode.
|
||||
"""
|
||||
assert self.config is not None
|
||||
if run_mode is not RunMode.Detach:
|
||||
atexit.register(self.stop)
|
||||
|
||||
|
@ -114,7 +114,7 @@ class Experiment:
|
|||
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
|
||||
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
|
||||
|
||||
self._proc = launcher.start_experiment(self.action, self.id, config, port, debug, run_mode, self.url_prefix)
|
||||
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix)
|
||||
assert self._proc is not None
|
||||
|
||||
self.port = port # port will be None if start up failed
|
||||
|
@ -144,16 +144,16 @@ class Experiment:
|
|||
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
|
||||
kill_command(self._proc.pid)
|
||||
|
||||
self.id = None
|
||||
self.id = None # type: ignore
|
||||
self.port = None
|
||||
self._proc = None
|
||||
_logger.info('Experiment stopped')
|
||||
|
||||
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool:
|
||||
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
|
||||
"""
|
||||
Run the experiment.
|
||||
|
||||
If ``wait_completion`` is True, this function will block until experiment finish or error.
|
||||
If ``wait_completion`` is ``True``, this function will block until experiment finish or error.
|
||||
|
||||
Return ``True`` when experiment done; or return ``False`` when experiment failed.
|
||||
|
||||
|
@ -247,7 +247,7 @@ class Experiment:
|
|||
def _resume(exp_id, exp_dir=None):
|
||||
exp = Experiment(None)
|
||||
exp.id = exp_id
|
||||
exp.action = 'resume'
|
||||
exp._action = 'resume'
|
||||
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
|
||||
return exp
|
||||
|
||||
|
@ -255,7 +255,7 @@ class Experiment:
|
|||
def _view(exp_id, exp_dir=None):
|
||||
exp = Experiment(None)
|
||||
exp.id = exp_id
|
||||
exp.action = 'view'
|
||||
exp._action = 'view'
|
||||
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
|
||||
return exp
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from dataclasses import dataclass, fields
|
||||
from datetime import datetime
|
||||
|
@ -126,9 +128,9 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
|
|||
|
||||
return proc
|
||||
|
||||
def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
|
||||
def _start_rest_server(nni_manager_args, run_mode) -> Popen:
|
||||
import nni_node
|
||||
node_dir = Path(nni_node.__path__[0])
|
||||
node_dir = Path(nni_node.__path__[0]) # type: ignore
|
||||
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
|
||||
main_js = str(node_dir / 'main.js')
|
||||
cmd = [node, '--max-old-space-size=4096', main_js]
|
||||
|
@ -151,10 +153,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
|
|||
from subprocess import CREATE_NEW_PROCESS_GROUP
|
||||
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
|
||||
else:
|
||||
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp)
|
||||
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) # type: ignore
|
||||
|
||||
|
||||
def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen:
|
||||
def start_experiment_retiarii(exp_id, config, port, debug):
|
||||
pipe = None
|
||||
proc = None
|
||||
|
||||
|
@ -221,7 +223,7 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool
|
|||
args['dispatcher_pipe'] = pipe_path
|
||||
|
||||
import nni_node
|
||||
node_dir = Path(nni_node.__path__[0])
|
||||
node_dir = Path(nni_node.__path__[0]) # type: ignore
|
||||
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
|
||||
main_js = str(node_dir / 'main.js')
|
||||
cmd = [node, '--max-old-space-size=4096', main_js]
|
||||
|
@ -259,8 +261,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
|
|||
|
||||
|
||||
def get_stopped_experiment_config(exp_id, exp_dir=None):
|
||||
config_json = get_stopped_experiment_config_json(exp_id, exp_dir)
|
||||
config = ExperimentConfig(**config_json)
|
||||
config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore
|
||||
config = ExperimentConfig(**config_json) # type: ignore
|
||||
if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory):
|
||||
msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
|
||||
_logger.warning(msg, exp_dir, config.experiment_working_directory)
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from pathlib import Path
|
||||
import random
|
||||
import string
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from io import BufferedIOBase
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
@ -25,7 +27,7 @@ if sys.platform == 'win32':
|
|||
_winapi.NULL
|
||||
)
|
||||
|
||||
def connect(self) -> BufferedIOBase:
|
||||
def connect(self):
|
||||
_winapi.ConnectNamedPipe(self._handle, _winapi.NULL)
|
||||
fd = msvcrt.open_osfhandle(self._handle, 0)
|
||||
self.file = os.fdopen(fd, 'w+b')
|
||||
|
@ -55,7 +57,7 @@ else:
|
|||
self._socket.bind(self.path)
|
||||
self._socket.listen(1) # only accepts one connection
|
||||
|
||||
def connect(self) -> BufferedIOBase:
|
||||
def connect(self):
|
||||
conn, _ = self._socket.accept()
|
||||
self.file = conn.makefile('rwb')
|
||||
return self.file
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
class Recoverable:
|
||||
|
||||
def load_checkpoint(self):
|
||||
def load_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def save_checkpoint(self):
|
||||
def save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def get_checkpoint_path(self):
|
||||
def get_checkpoint_path(self) -> str | None:
|
||||
ckp_path = os.getenv('NNI_CHECKPOINT_DIRECTORY')
|
||||
if ckp_path is not None and os.path.isdir(ckp_path):
|
||||
return ckp_path
|
||||
|
|
|
@ -14,7 +14,7 @@ def get_config_directory() -> Path:
|
|||
Create it if not exist.
|
||||
"""
|
||||
if os.getenv('NNI_CONFIG_DIR') is not None:
|
||||
config_dir = Path(os.getenv('NNI_CONFIG_DIR'))
|
||||
config_dir = Path(os.getenv('NNI_CONFIG_DIR')) # type: ignore
|
||||
elif sys.prefix != sys.base_prefix or Path(sys.prefix, 'conda-meta').is_dir():
|
||||
config_dir = Path(sys.prefix, 'nni')
|
||||
elif sys.platform == 'win32':
|
||||
|
@ -39,4 +39,4 @@ def get_builtin_config_file(name: str) -> Path:
|
|||
"""
|
||||
Get a readonly builtin config file.
|
||||
"""
|
||||
return Path(nni.__path__[0], 'runtime/default_config', name)
|
||||
return Path(nni.__path__[0], 'runtime/default_config', name) # type: ignore
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
@ -105,7 +110,7 @@ def _init_logger_standalone() -> None:
|
|||
_register_handler(StreamHandler(sys.stdout), logging.INFO)
|
||||
|
||||
|
||||
def _prepare_log_dir(path: Optional[str]) -> Path:
|
||||
def _prepare_log_dir(path: Path | str) -> Path:
|
||||
if path is None:
|
||||
return Path()
|
||||
ret = Path(path)
|
||||
|
@ -148,7 +153,7 @@ class _LogFileWrapper(TextIOBase):
|
|||
def __init__(self, log_file: TextIOBase):
|
||||
self.file: TextIOBase = log_file
|
||||
self.line_buffer: Optional[str] = None
|
||||
self.line_start_time: Optional[datetime] = None
|
||||
self.line_start_time: datetime = datetime.fromtimestamp(0)
|
||||
|
||||
def write(self, s: str) -> int:
|
||||
cur_time = datetime.now()
|
||||
|
|
|
@ -212,6 +212,7 @@ class MsgDispatcher(MsgDispatcherBase):
|
|||
except Exception as e:
|
||||
_logger.error('Assessor error')
|
||||
_logger.exception(e)
|
||||
raise
|
||||
|
||||
if isinstance(result, bool):
|
||||
result = AssessResult.Good if result else AssessResult.Bad
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from . import proxy
|
||||
|
||||
load_jupyter_server_extension = proxy.setup
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import importlib
|
||||
import json
|
||||
|
||||
|
|
47
nni/trial.py
47
nni/trial.py
|
@ -1,13 +1,18 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .common.serializer import dump
|
||||
from .runtime.env_vars import trial_env_vars
|
||||
from .runtime import platform
|
||||
|
||||
from .typehint import Parameters, TrialMetric
|
||||
|
||||
__all__ = [
|
||||
'get_next_parameter',
|
||||
'get_next_parameters',
|
||||
'get_current_parameter',
|
||||
'report_intermediate_result',
|
||||
'report_final_result',
|
||||
|
@ -23,7 +28,7 @@ _trial_id = platform.get_trial_id()
|
|||
_sequence_id = platform.get_sequence_id()
|
||||
|
||||
|
||||
def get_next_parameter():
|
||||
def get_next_parameter() -> Parameters:
|
||||
"""
|
||||
Get the hyperparameters generated by tuner.
|
||||
|
||||
|
@ -32,7 +37,7 @@ def get_next_parameter():
|
|||
|
||||
Examples
|
||||
--------
|
||||
Assuming the search space is:
|
||||
Assuming the :doc:`search space </hpo/search_space>` is:
|
||||
|
||||
.. code-block::
|
||||
|
||||
|
@ -52,16 +57,22 @@ def get_next_parameter():
|
|||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
:class:`~nni.typehint.Parameters`
|
||||
A hyperparameter set sampled from search space.
|
||||
"""
|
||||
global _params
|
||||
_params = platform.get_next_parameter()
|
||||
if _params is None:
|
||||
return None
|
||||
return None # type: ignore
|
||||
return _params['parameters']
|
||||
|
||||
def get_current_parameter(tag=None):
|
||||
def get_next_parameters() -> Parameters:
|
||||
"""
|
||||
Alias of :func:`get_next_parameter`
|
||||
"""
|
||||
return get_next_parameter()
|
||||
|
||||
def get_current_parameter(tag: str | None = None) -> Any:
|
||||
global _params
|
||||
if _params is None:
|
||||
return None
|
||||
|
@ -94,13 +105,13 @@ def get_sequence_id() -> int:
|
|||
_intermediate_seq = 0
|
||||
|
||||
|
||||
def overwrite_intermediate_seq(value):
|
||||
def overwrite_intermediate_seq(value: int) -> None:
|
||||
assert isinstance(value, int)
|
||||
global _intermediate_seq
|
||||
_intermediate_seq = value
|
||||
|
||||
|
||||
def report_intermediate_result(metric):
|
||||
def report_intermediate_result(metric: TrialMetric | dict[str, Any]) -> None:
|
||||
"""
|
||||
Reports intermediate result to NNI.
|
||||
|
||||
|
@ -110,11 +121,16 @@ def report_intermediate_result(metric):
|
|||
and other items can be visualized with web portal.
|
||||
|
||||
Typically ``metric`` is per-epoch accuracy or loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : :class:`~nni.typehint.TrialMetric`
|
||||
The intermeidate result.
|
||||
"""
|
||||
global _intermediate_seq
|
||||
assert _params or trial_env_vars.NNI_PLATFORM is None, \
|
||||
'nni.get_next_parameter() needs to be called before report_intermediate_result'
|
||||
metric = dump({
|
||||
dumped_metric = dump({
|
||||
'parameter_id': _params['parameter_id'] if _params else None,
|
||||
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
|
||||
'type': 'PERIODICAL',
|
||||
|
@ -122,9 +138,9 @@ def report_intermediate_result(metric):
|
|||
'value': dump(metric)
|
||||
})
|
||||
_intermediate_seq += 1
|
||||
platform.send_metric(metric)
|
||||
platform.send_metric(dumped_metric)
|
||||
|
||||
def report_final_result(metric):
|
||||
def report_final_result(metric: TrialMetric | dict[str, Any]) -> None:
|
||||
"""
|
||||
Reports final result to NNI.
|
||||
|
||||
|
@ -134,14 +150,19 @@ def report_final_result(metric):
|
|||
and other items can be visualized with web portal.
|
||||
|
||||
Typically ``metric`` is the final accuracy or loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : :class:`~nni.typehint.TrialMetric`
|
||||
The final result.
|
||||
"""
|
||||
assert _params or trial_env_vars.NNI_PLATFORM is None, \
|
||||
'nni.get_next_parameter() needs to be called before report_final_result'
|
||||
metric = dump({
|
||||
dumped_metric = dump({
|
||||
'parameter_id': _params['parameter_id'] if _params else None,
|
||||
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
|
||||
'type': 'FINAL',
|
||||
'sequence': 0,
|
||||
'value': dump(metric)
|
||||
})
|
||||
platform.send_metric(metric)
|
||||
platform.send_metric(dumped_metric)
|
||||
|
|
23
nni/tuner.py
23
nni/tuner.py
|
@ -8,11 +8,14 @@ A new trial will run with this configuration.
|
|||
See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import nni
|
||||
|
||||
from .recoverable import Recoverable
|
||||
from .typehint import Parameters, SearchSpace, TrialMetric, TrialRecord
|
||||
|
||||
__all__ = ['Tuner']
|
||||
|
||||
|
@ -67,7 +70,7 @@ class Tuner(Recoverable):
|
|||
:class:`~nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner`
|
||||
"""
|
||||
|
||||
def generate_parameters(self, parameter_id, **kwargs):
|
||||
def generate_parameters(self, parameter_id: int, **kwargs) -> Parameters:
|
||||
"""
|
||||
Abstract method which provides a set of hyper-parameters.
|
||||
|
||||
|
@ -100,7 +103,7 @@ class Tuner(Recoverable):
|
|||
# we need to design a new exception for this purpose
|
||||
raise NotImplementedError('Tuner: generate_parameters not implemented')
|
||||
|
||||
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
|
||||
def generate_multiple_parameters(self, parameter_id_list: list[int], **kwargs) -> list[Parameters]:
|
||||
"""
|
||||
Callback method which provides multiple sets of hyper-parameters.
|
||||
|
||||
|
@ -135,7 +138,7 @@ class Tuner(Recoverable):
|
|||
result.append(res)
|
||||
return result
|
||||
|
||||
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
|
||||
def receive_trial_result(self, parameter_id: int, parameters: Parameters, value: TrialMetric, **kwargs) -> None:
|
||||
"""
|
||||
Abstract method invoked when a trial reports its final result. Must override.
|
||||
|
||||
|
@ -165,7 +168,7 @@ class Tuner(Recoverable):
|
|||
# pylint: disable=attribute-defined-outside-init
|
||||
self._accept_customized = accept
|
||||
|
||||
def trial_end(self, parameter_id, success, **kwargs):
|
||||
def trial_end(self, parameter_id: int, success: bool, **kwargs) -> None:
|
||||
"""
|
||||
Abstract method invoked when a trial is completed or terminated. Do nothing by default.
|
||||
|
||||
|
@ -179,7 +182,7 @@ class Tuner(Recoverable):
|
|||
Unstable parameters which should be ignored by normal users.
|
||||
"""
|
||||
|
||||
def update_search_space(self, search_space):
|
||||
def update_search_space(self, search_space: SearchSpace) -> None:
|
||||
"""
|
||||
Abstract method for updating the search space. Must override.
|
||||
|
||||
|
@ -194,21 +197,21 @@ class Tuner(Recoverable):
|
|||
"""
|
||||
raise NotImplementedError('Tuner: update_search_space not implemented')
|
||||
|
||||
def load_checkpoint(self):
|
||||
def load_checkpoint(self) -> None:
|
||||
"""
|
||||
Internal API under revising, not recommended for end users.
|
||||
"""
|
||||
checkpoin_path = self.get_checkpoint_path()
|
||||
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
|
||||
|
||||
def save_checkpoint(self):
|
||||
def save_checkpoint(self) -> None:
|
||||
"""
|
||||
Internal API under revising, not recommended for end users.
|
||||
"""
|
||||
checkpoin_path = self.get_checkpoint_path()
|
||||
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
|
||||
|
||||
def import_data(self, data):
|
||||
def import_data(self, data: list[TrialRecord]) -> None:
|
||||
"""
|
||||
Internal API under revising, not recommended for end users.
|
||||
"""
|
||||
|
@ -216,8 +219,8 @@ class Tuner(Recoverable):
|
|||
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
|
||||
pass
|
||||
|
||||
def _on_exit(self):
|
||||
def _on_exit(self) -> None:
|
||||
pass
|
||||
|
||||
def _on_error(self):
|
||||
def _on_error(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -1,10 +1,58 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sys
|
||||
import typing
|
||||
"""
|
||||
Types for static checking.
|
||||
"""
|
||||
|
||||
if typing.TYPE_CHECKING or sys.version_info >= (3, 8):
|
||||
Literal = typing.Literal
|
||||
__all__ = [
|
||||
'Literal',
|
||||
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord',
|
||||
]
|
||||
|
||||
import sys
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING or sys.version_info >= (3, 8):
|
||||
from typing import Literal, TypedDict
|
||||
else:
|
||||
Literal = typing.Any
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
Parameters = Dict[str, Any]
|
||||
"""
|
||||
Return type of :func:`nni.get_next_parameter`.
|
||||
|
||||
For built-in tuners, this is a ``dict`` whose content is defined by :doc:`search space </hpo/search_space>`.
|
||||
|
||||
Customized tuners do not need to follow the constraint and can use anything serializable.
|
||||
"""
|
||||
|
||||
class _ParameterSearchSpace(TypedDict):
|
||||
_type: Literal[
|
||||
'choice', 'randint',
|
||||
'uniform', 'loguniform', 'quniform', 'qloguniform',
|
||||
'normal', 'lognormal', 'qnormal', 'qlognormal',
|
||||
]
|
||||
_value: List[Any]
|
||||
|
||||
SearchSpace = Dict[str, _ParameterSearchSpace]
|
||||
"""
|
||||
Type of ``experiment.config.search_space``.
|
||||
|
||||
For built-in tuners, the format is detailed in :doc:`/hpo/search_space`.
|
||||
|
||||
Customized tuners do not need to follow the constraint and can use anything serializable, except ``None``.
|
||||
"""
|
||||
|
||||
TrialMetric = float
|
||||
"""
|
||||
Type of the metrics sent to :func:`nni.report_final_result` and :func:`nni.report_intermediate_result`.
|
||||
|
||||
For built-in tuners it must be a number (``float``, ``int``, ``numpy.float32``, etc).
|
||||
|
||||
Customized tuners do not need to follow this constraint and can use anything serializable.
|
||||
"""
|
||||
|
||||
class TrialRecord(TypedDict):
|
||||
parameter: Parameters
|
||||
value: TrialMetric
|
||||
|
|
|
@ -63,6 +63,9 @@ stages:
|
|||
python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
|
||||
displayName: flake8
|
||||
|
||||
- script: |
|
||||
python -m pyright nni
|
||||
|
||||
- job: typescript
|
||||
pool:
|
||||
vmImage: ubuntu-latest
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"ignore": [
|
||||
"nni/algorithms",
|
||||
"nni/common/device.py",
|
||||
"nni/common/graph_utils.py",
|
||||
"nni/common/serializer.py",
|
||||
"nni/compression",
|
||||
"nni/nas",
|
||||
"nni/retiarii",
|
||||
"nni/smartparam.py",
|
||||
"nni/tools"
|
||||
],
|
||||
"reportMissingImports": false
|
||||
}
|
Загрузка…
Ссылка в новой задаче