diff --git a/dependencies/develop.txt b/dependencies/develop.txt index 781f53cb8..80edb1306 100644 --- a/dependencies/develop.txt +++ b/dependencies/develop.txt @@ -5,6 +5,7 @@ ipython jupyterlab nbsphinx pylint +pyright pytest pytest-azurepipelines pytest-cov diff --git a/dependencies/required.txt b/dependencies/required.txt index e765025af..bb8d0cc8b 100644 --- a/dependencies/required.txt +++ b/dependencies/required.txt @@ -19,5 +19,5 @@ scikit-learn >= 0.24.1 scipy < 1.8 ; python_version < "3.8" scipy ; python_version >= "3.8" typeguard -typing_extensions ; python_version < "3.8" +typing_extensions >= 4.0.0 ; python_version < "3.8" websockets >= 10.1 diff --git a/docs/source/reference/others.rst b/docs/source/reference/others.rst new file mode 100644 index 000000000..435238f81 --- /dev/null +++ b/docs/source/reference/others.rst @@ -0,0 +1,8 @@ +Uncategorized Modules +===================== + +nni.typehint +------------ + +.. automodule:: nni.typehint + :members: diff --git a/docs/source/reference/python_api/others.rst b/docs/source/reference/python_api/others.rst deleted file mode 100644 index 41b58ea7d..000000000 --- a/docs/source/reference/python_api/others.rst +++ /dev/null @@ -1,11 +0,0 @@ -Others -====== - -nni ---- - -nni.common ----------- - -nni.utils ---------- diff --git a/docs/source/reference/python_api_ref.rst b/docs/source/reference/python_api_ref.rst index 76875d929..a1390c9f1 100644 --- a/docs/source/reference/python_api_ref.rst +++ b/docs/source/reference/python_api_ref.rst @@ -9,4 +9,4 @@ API Reference Model Compression Feature Engineering <./python_api/feature_engineering> Experiment - Others <./python_api/others> + Others diff --git a/nni/algorithms/hpo/bohb_advisor/__init__.py b/nni/algorithms/hpo/bohb_advisor/__init__.py index 0ebb442e5..95a7b1b99 100644 --- a/nni/algorithms/hpo/bohb_advisor/__init__.py +++ b/nni/algorithms/hpo/bohb_advisor/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from .bohb_advisor import BOHB, BOHBClassArgsValidator diff --git a/nni/algorithms/hpo/dngo_tuner.py b/nni/algorithms/hpo/dngo_tuner.py index 11ec4d595..3664169b5 100644 --- a/nni/algorithms/hpo/dngo_tuner.py +++ b/nni/algorithms/hpo/dngo_tuner.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import logging import warnings diff --git a/nni/algorithms/hpo/gp_tuner/__init__.py b/nni/algorithms/hpo/gp_tuner/__init__.py index 17bedd38f..a2b489569 100644 --- a/nni/algorithms/hpo/gp_tuner/__init__.py +++ b/nni/algorithms/hpo/gp_tuner/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from .gp_tuner import GPTuner, GPClassArgsValidator diff --git a/nni/algorithms/hpo/metis_tuner/__init__.py b/nni/algorithms/hpo/metis_tuner/__init__.py index f4f9ceba6..9d799fb5a 100644 --- a/nni/algorithms/hpo/metis_tuner/__init__.py +++ b/nni/algorithms/hpo/metis_tuner/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from .metis_tuner import MetisTuner, MetisClassArgsValidator diff --git a/nni/algorithms/hpo/networkmorphism_tuner/__init__.py b/nni/algorithms/hpo/networkmorphism_tuner/__init__.py index b60da9c38..95a7a41d2 100644 --- a/nni/algorithms/hpo/networkmorphism_tuner/__init__.py +++ b/nni/algorithms/hpo/networkmorphism_tuner/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from .networkmorphism_tuner import NetworkMorphismTuner, NetworkMorphismClassArgsValidator diff --git a/nni/algorithms/hpo/ppo_tuner/__init__.py b/nni/algorithms/hpo/ppo_tuner/__init__.py index 854090c93..c2ea27403 100644 --- a/nni/algorithms/hpo/ppo_tuner/__init__.py +++ b/nni/algorithms/hpo/ppo_tuner/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from .ppo_tuner import PPOTuner, PPOClassArgsValidator diff --git a/nni/algorithms/hpo/regularized_evolution_tuner.py b/nni/algorithms/hpo/regularized_evolution_tuner.py index ac756c338..e3139b8e4 100644 --- a/nni/algorithms/hpo/regularized_evolution_tuner.py +++ b/nni/algorithms/hpo/regularized_evolution_tuner.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import copy import logging import random diff --git a/nni/assessor.py b/nni/assessor.py index 7cd83e923..36c78bbbf 100644 --- a/nni/assessor.py +++ b/nni/assessor.py @@ -8,10 +8,13 @@ to tell whether this trial can be early stopped or not. See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details. """ +from __future__ import annotations + from enum import Enum import logging from .recoverable import Recoverable +from .typehint import TrialMetric __all__ = ['AssessResult', 'Assessor'] @@ -54,7 +57,7 @@ class Assessor(Recoverable): :class:`~nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor` """ - def assess_trial(self, trial_job_id, trial_history): + def assess_trial(self, trial_job_id: str, trial_history: list[TrialMetric]) -> AssessResult: """ Abstract method for determining whether a trial should be killed. Must override. @@ -91,7 +94,7 @@ class Assessor(Recoverable): """ raise NotImplementedError('Assessor: assess_trial not implemented') - def trial_end(self, trial_job_id, success): + def trial_end(self, trial_job_id: str, success: bool) -> None: """ Abstract method invoked when a trial is completed or terminated. Do nothing by default. @@ -103,22 +106,22 @@ class Assessor(Recoverable): True if the trial successfully completed; False if failed or terminated. """ - def load_checkpoint(self): + def load_checkpoint(self) -> None: """ Internal API under revising, not recommended for end users. """ checkpoin_path = self.get_checkpoint_path() _logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) - def save_checkpoint(self): + def save_checkpoint(self) -> None: """ Internal API under revising, not recommended for end users. """ checkpoin_path = self.get_checkpoint_path() _logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) - def _on_exit(self): + def _on_exit(self) -> None: pass - def _on_error(self): + def _on_error(self) -> None: pass diff --git a/nni/common/__init__.py b/nni/common/__init__.py index f18054727..c29323fd6 100644 --- a/nni/common/__init__.py +++ b/nni/common/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from .serializer import trace, dump, load, is_traceable diff --git a/nni/common/hpo_utils/formatting.py b/nni/common/hpo_utils/formatting.py index 2a2674b9f..b86e14038 100644 --- a/nni/common/hpo_utils/formatting.py +++ b/nni/common/hpo_utils/formatting.py @@ -2,6 +2,8 @@ # Licensed under the MIT license. """ +Helper class and functions for tuners to deal with search space. + This script provides a more program-friendly representation of HPO search space. The format is considered internal helper and is not visible to end users. @@ -9,8 +11,16 @@ You will find this useful when you want to support nested search space. The random tuner is an intuitive example for this utility. You should check its code before reading docstrings in this file. + +.. attention:: + + This module does not guarantee forward-compatibility. + + If you want to use it outside official NNI repo, it is recommended to copy the script. """ +from __future__ import annotations + __all__ = [ 'ParameterSpec', 'deformat_parameters', @@ -20,10 +30,16 @@ __all__ = [ import math from types import SimpleNamespace -from typing import Any, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, NamedTuple, Tuple, cast import numpy as np +from nni.typehint import Parameters, SearchSpace + +ParameterKey = Tuple['str | int', ...] +FormattedParameters = Dict[ParameterKey, 'float | int'] +FormattedSearchSpace = Dict[ParameterKey, 'ParameterSpec'] + class ParameterSpec(NamedTuple): """ Specification (aka space / range / domain) of one single parameter. @@ -33,29 +49,31 @@ class ParameterSpec(NamedTuple): name: str # The object key in JSON type: str # "_type" in JSON - values: List[Any] # "_value" in JSON + values: list[Any] # "_value" in JSON - key: Tuple[str] # The "path" of this parameter + key: ParameterKey # The "path" of this parameter categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered) - size: int = None # If it's categorical, how many candidates it has + size: int = cast(int, None) # If it's categorical, how many candidates it has # uniform distributed - low: float = None # Lower bound of uniform parameter - high: float = None # Upper bound of uniform parameter + low: float = cast(float, None) # Lower bound of uniform parameter + high: float = cast(float, None) # Upper bound of uniform parameter - normal_distributed: bool = None # Whether this parameter is uniform or normal distrubuted - mu: float = None # µ of normal parameter - sigma: float = None # σ of normal parameter + normal_distributed: bool = cast(bool, None) + # Whether this parameter is uniform or normal distrubuted + mu: float = cast(float, None) # µ of normal parameter + sigma: float = cast(float, None)# σ of normal parameter - q: Optional[float] = None # If not `None`, the parameter value should be an integer multiple of this - clip: Optional[Tuple[float, float]] = None + q: float | None = None # If not `None`, the parameter value should be an integer multiple of this + clip: tuple[float, float] | None = None # For q(log)uniform, this equals to "values[:2]"; for others this is None - log_distributed: bool = None # Whether this parameter is log distributed + log_distributed: bool = cast(bool, None) + # Whether this parameter is log distributed # When true, low/high/mu/sigma describes log of parameter value (like np.lognormal) - def is_activated_in(self, partial_parameters): + def is_activated_in(self, partial_parameters: FormattedParameters) -> bool: """ For nested search space, check whether this parameter should be skipped for current set of paremters. This function must be used in a pattern similar to random tuner. Otherwise it will misbehave. @@ -64,7 +82,7 @@ class ParameterSpec(NamedTuple): return True return partial_parameters[self.key[:-2]] == self.key[-2] -def format_search_space(search_space): +def format_search_space(search_space: SearchSpace) -> FormattedSearchSpace: """ Convert user provided search space into a dict of ParameterSpec. The dict key is dict value's `ParameterSpec.key`. @@ -76,7 +94,9 @@ def format_search_space(search_space): # Remove these comments when we drop 3.6 support. return {spec.key: spec for spec in formatted} -def deformat_parameters(formatted_parameters, formatted_search_space): +def deformat_parameters( + formatted_parameters: FormattedParameters, + formatted_search_space: FormattedSearchSpace) -> Parameters: """ Convert internal format parameters to users' expected format. @@ -88,10 +108,11 @@ def deformat_parameters(formatted_parameters, formatted_search_space): 3. For "q*", convert x to `round(x / q) * q`, then clip into range. 4. For nested choices, convert flatten key-value pairs into nested structure. """ - ret = {} + ret: Parameters = {} for key, x in formatted_parameters.items(): spec = formatted_search_space[key] if spec.categorical: + x = cast(int, x) if spec.type == 'randint': lower = min(math.ceil(float(x)) for x in spec.values) _assign(ret, key, int(lower + x)) @@ -112,7 +133,7 @@ def deformat_parameters(formatted_parameters, formatted_search_space): _assign(ret, key, x) return ret -def format_parameters(parameters, formatted_search_space): +def format_parameters(parameters: Parameters, formatted_search_space: FormattedSearchSpace) -> FormattedParameters: """ Convert end users' parameter format back to internal format, mainly for resuming experiments. @@ -123,7 +144,7 @@ def format_parameters(parameters, formatted_search_space): for key, spec in formatted_search_space.items(): if not spec.is_activated_in(ret): continue - value = parameters + value: Any = parameters for name in key: if isinstance(name, str): value = value[name] @@ -142,8 +163,8 @@ def format_parameters(parameters, formatted_search_space): ret[key] = value return ret -def _format_search_space(parent_key, space): - formatted = [] +def _format_search_space(parent_key: ParameterKey, space: SearchSpace) -> list[ParameterSpec]: + formatted: list[ParameterSpec] = [] for name, spec in space.items(): if name == '_name': continue @@ -155,7 +176,7 @@ def _format_search_space(parent_key, space): formatted += _format_search_space(key, sub_space) return formatted -def _format_parameter(key, type_, values): +def _format_parameter(key: ParameterKey, type_: str, values: list[Any]): spec = SimpleNamespace( name = key[-1], type = type_, @@ -197,7 +218,7 @@ def _format_parameter(key, type_, values): return ParameterSpec(**spec.__dict__) -def _is_nested_choices(values): +def _is_nested_choices(values: list[Any]) -> bool: assert values # choices should not be empty for value in values: if not isinstance(value, dict): @@ -206,9 +227,9 @@ def _is_nested_choices(values): return False return True -def _assign(params, key, x): +def _assign(params: Parameters, key: ParameterKey, x: Any) -> None: if len(key) == 1: - params[key[0]] = x + params[cast(str, key[0])] = x elif isinstance(key[0], int): _assign(params, key[1:], x) else: diff --git a/nni/common/hpo_utils/optimize_mode.py b/nni/common/hpo_utils/optimize_mode.py index 6d7a034c8..e91834d73 100644 --- a/nni/common/hpo_utils/optimize_mode.py +++ b/nni/common/hpo_utils/optimize_mode.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from enum import Enum class OptimizeMode(Enum): diff --git a/nni/common/hpo_utils/validation.py b/nni/common/hpo_utils/validation.py index 3c1ea71d5..729c0b517 100644 --- a/nni/common/hpo_utils/validation.py +++ b/nni/common/hpo_utils/validation.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import logging -from typing import Any, List, Optional +from typing import Any common_search_space_types = [ 'choice', @@ -19,7 +21,7 @@ common_search_space_types = [ def validate_search_space( search_space: Any, - support_types: Optional[List[str]] = None, + support_types: list[str] | None = None, raise_exception: bool = False # for now, in case false positive ) -> bool: diff --git a/nni/common/serializer.py b/nni/common/serializer.py index f0a0e0c18..eec7d9c41 100644 --- a/nni/common/serializer.py +++ b/nni/common/serializer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import abc import base64 import collections.abc diff --git a/nni/common/version.py b/nni/common/version.py index b8881f48a..504a52967 100644 --- a/nni/common/version.py +++ b/nni/common/version.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import logging try: import torch diff --git a/nni/experiment/config/algorithm.py b/nni/experiment/config/algorithm.py index bd79f90e5..96b8d3560 100644 --- a/nni/experiment/config/algorithm.py +++ b/nni/experiment/config/algorithm.py @@ -47,6 +47,7 @@ class _AlgorithmConfig(ConfigBase): else: # custom algorithm assert self.name is None assert self.class_name + assert self.code_directory is not None if not Path(self.code_directory).is_dir(): raise ValueError(f'CustomAlgorithmConfig: code_directory "{self.code_directory}" is not a directory') diff --git a/nni/experiment/config/convert.py b/nni/experiment/config/convert.py index 860b3c1ff..f1a98e492 100644 --- a/nni/experiment/config/convert.py +++ b/nni/experiment/config/convert.py @@ -37,6 +37,8 @@ def to_v2(v1): _move_field(v1_trial, v2, 'command', 'trialCommand') _move_field(v1_trial, v2, 'codeDir', 'trialCodeDirectory') _move_field(v1_trial, v2, 'gpuNum', 'trialGpuNumber') + else: + v1_trial = {} for algo_type in ['tuner', 'assessor', 'advisor']: v1_algo = v1.pop(algo_type, None) diff --git a/nni/experiment/config/training_services/remote.py b/nni/experiment/config/training_services/remote.py index 4a1a49d81..5b5c476d5 100644 --- a/nni/experiment/config/training_services/remote.py +++ b/nni/experiment/config/training_services/remote.py @@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase): if self.password is not None: warnings.warn('SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.') - elif not Path(self.ssh_key_file).is_file(): + elif not Path(self.ssh_key_file).is_file(): # type: ignore raise ValueError( f'RemoteMachineConfig: You must either provide password or a valid SSH key file "{self.ssh_key_file}"' ) diff --git a/nni/experiment/data.py b/nni/experiment/data.py index fd75520cc..026c9c04c 100644 --- a/nni/experiment/data.py +++ b/nni/experiment/data.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from dataclasses import dataclass import json from typing import List diff --git a/nni/experiment/experiment.py b/nni/experiment/experiment.py index 223829b08..b1ddd105e 100644 --- a/nni/experiment/experiment.py +++ b/nni/experiment/experiment.py @@ -10,7 +10,7 @@ from pathlib import Path import socket from subprocess import Popen import time -from typing import Optional, Any +from typing import Any import colorama import psutil @@ -34,8 +34,7 @@ class RunMode(Enum): - Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout. - Detach: do not stop NNI manager when Python script exits. - NOTE: - This API is non-stable and is likely to get refactored in next release. + NOTE: This API is non-stable and is likely to get refactored in upcoming release. """ # TODO: # NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose. @@ -72,15 +71,15 @@ class Experiment: Web portal port. Or ``None`` if the experiment is not running. """ - def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None) -> None: + def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None): nni.runtime.log.init_logger_for_command_line() - self.config: Optional[ExperimentConfig] = None + self.config: ExperimentConfig | None = None self.id: str = management.generate_experiment_id() - self.port: Optional[int] = None - self._proc: Optional[Popen] = None - self.action = 'create' - self.url_prefix: Optional[str] = None + self.port: int | None = None + self._proc: Popen | psutil.Process | None = None + self._action = 'create' + self.url_prefix: str | None = None if isinstance(config_or_platform, (str, list)): self.config = ExperimentConfig(config_or_platform) @@ -101,6 +100,7 @@ class Experiment: debug Whether to start in debug mode. """ + assert self.config is not None if run_mode is not RunMode.Detach: atexit.register(self.stop) @@ -114,7 +114,7 @@ class Experiment: log_dir = Path.home() / f'nni-experiments/{self.id}/log' nni.runtime.log.start_experiment_log(self.id, log_dir, debug) - self._proc = launcher.start_experiment(self.action, self.id, config, port, debug, run_mode, self.url_prefix) + self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix) assert self._proc is not None self.port = port # port will be None if start up failed @@ -144,16 +144,16 @@ class Experiment: _logger.warning('Cannot gracefully stop experiment, killing NNI process...') kill_command(self._proc.pid) - self.id = None + self.id = None # type: ignore self.port = None self._proc = None _logger.info('Experiment stopped') - def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool: + def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None: """ Run the experiment. - If ``wait_completion`` is True, this function will block until experiment finish or error. + If ``wait_completion`` is ``True``, this function will block until experiment finish or error. Return ``True`` when experiment done; or return ``False`` when experiment failed. @@ -247,7 +247,7 @@ class Experiment: def _resume(exp_id, exp_dir=None): exp = Experiment(None) exp.id = exp_id - exp.action = 'resume' + exp._action = 'resume' exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir) return exp @@ -255,7 +255,7 @@ class Experiment: def _view(exp_id, exp_dir=None): exp = Experiment(None) exp.id = exp_id - exp.action = 'view' + exp._action = 'view' exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir) return exp diff --git a/nni/experiment/launcher.py b/nni/experiment/launcher.py index f7abbf253..cebf9becb 100644 --- a/nni/experiment/launcher.py +++ b/nni/experiment/launcher.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import contextlib from dataclasses import dataclass, fields from datetime import datetime @@ -126,9 +128,9 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix): return proc -def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]: +def _start_rest_server(nni_manager_args, run_mode) -> Popen: import nni_node - node_dir = Path(nni_node.__path__[0]) + node_dir = Path(nni_node.__path__[0]) # type: ignore node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node')) main_js = str(node_dir / 'main.js') cmd = [node, '--max-old-space-size=4096', main_js] @@ -151,10 +153,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]: from subprocess import CREATE_NEW_PROCESS_GROUP return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP) else: - return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) + return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) # type: ignore -def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen: +def start_experiment_retiarii(exp_id, config, port, debug): pipe = None proc = None @@ -221,7 +223,7 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool args['dispatcher_pipe'] = pipe_path import nni_node - node_dir = Path(nni_node.__path__[0]) + node_dir = Path(nni_node.__path__[0]) # type: ignore node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node')) main_js = str(node_dir / 'main.js') cmd = [node, '--max-old-space-size=4096', main_js] @@ -259,8 +261,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int, def get_stopped_experiment_config(exp_id, exp_dir=None): - config_json = get_stopped_experiment_config_json(exp_id, exp_dir) - config = ExperimentConfig(**config_json) + config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore + config = ExperimentConfig(**config_json) # type: ignore if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory): msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)' _logger.warning(msg, exp_dir, config.experiment_working_directory) diff --git a/nni/experiment/management.py b/nni/experiment/management.py index b15c4d6d2..b74cd033b 100644 --- a/nni/experiment/management.py +++ b/nni/experiment/management.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from pathlib import Path import random import string diff --git a/nni/experiment/pipe.py b/nni/experiment/pipe.py index e59fd8270..7a29ca41e 100644 --- a/nni/experiment/pipe.py +++ b/nni/experiment/pipe.py @@ -1,4 +1,6 @@ -from io import BufferedIOBase +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import logging import os import sys @@ -25,7 +27,7 @@ if sys.platform == 'win32': _winapi.NULL ) - def connect(self) -> BufferedIOBase: + def connect(self): _winapi.ConnectNamedPipe(self._handle, _winapi.NULL) fd = msvcrt.open_osfhandle(self._handle, 0) self.file = os.fdopen(fd, 'w+b') @@ -55,7 +57,7 @@ else: self._socket.bind(self.path) self._socket.listen(1) # only accepts one connection - def connect(self) -> BufferedIOBase: + def connect(self): conn, _ = self._socket.accept() self.file = conn.makefile('rwb') return self.file diff --git a/nni/experiment/rest.py b/nni/experiment/rest.py index c2fbaa824..3ab266311 100644 --- a/nni/experiment/rest.py +++ b/nni/experiment/rest.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import logging from typing import Any, Optional diff --git a/nni/recoverable.py b/nni/recoverable.py index 70f11e634..4ff419b8f 100644 --- a/nni/recoverable.py +++ b/nni/recoverable.py @@ -1,17 +1,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import os class Recoverable: - def load_checkpoint(self): + def load_checkpoint(self) -> None: pass - def save_checkpoint(self): + def save_checkpoint(self) -> None: pass - def get_checkpoint_path(self): + def get_checkpoint_path(self) -> str | None: ckp_path = os.getenv('NNI_CHECKPOINT_DIRECTORY') if ckp_path is not None and os.path.isdir(ckp_path): return ckp_path diff --git a/nni/runtime/config.py b/nni/runtime/config.py index 8c353c8f2..10f077cb9 100644 --- a/nni/runtime/config.py +++ b/nni/runtime/config.py @@ -14,7 +14,7 @@ def get_config_directory() -> Path: Create it if not exist. """ if os.getenv('NNI_CONFIG_DIR') is not None: - config_dir = Path(os.getenv('NNI_CONFIG_DIR')) + config_dir = Path(os.getenv('NNI_CONFIG_DIR')) # type: ignore elif sys.prefix != sys.base_prefix or Path(sys.prefix, 'conda-meta').is_dir(): config_dir = Path(sys.prefix, 'nni') elif sys.platform == 'win32': @@ -39,4 +39,4 @@ def get_builtin_config_file(name: str) -> Path: """ Get a readonly builtin config file. """ - return Path(nni.__path__[0], 'runtime/default_config', name) + return Path(nni.__path__[0], 'runtime/default_config', name) # type: ignore diff --git a/nni/runtime/log.py b/nni/runtime/log.py index 4f9822207..3d6ecde9c 100644 --- a/nni/runtime/log.py +++ b/nni/runtime/log.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + import logging import sys from datetime import datetime @@ -105,7 +110,7 @@ def _init_logger_standalone() -> None: _register_handler(StreamHandler(sys.stdout), logging.INFO) -def _prepare_log_dir(path: Optional[str]) -> Path: +def _prepare_log_dir(path: Path | str) -> Path: if path is None: return Path() ret = Path(path) @@ -148,7 +153,7 @@ class _LogFileWrapper(TextIOBase): def __init__(self, log_file: TextIOBase): self.file: TextIOBase = log_file self.line_buffer: Optional[str] = None - self.line_start_time: Optional[datetime] = None + self.line_start_time: datetime = datetime.fromtimestamp(0) def write(self, s: str) -> int: cur_time = datetime.now() diff --git a/nni/runtime/msg_dispatcher.py b/nni/runtime/msg_dispatcher.py index 9f8481daf..442d50f9c 100644 --- a/nni/runtime/msg_dispatcher.py +++ b/nni/runtime/msg_dispatcher.py @@ -212,6 +212,7 @@ class MsgDispatcher(MsgDispatcherBase): except Exception as e: _logger.error('Assessor error') _logger.exception(e) + raise if isinstance(result, bool): result = AssessResult.Good if result else AssessResult.Bad diff --git a/nni/tools/jupyter_extension/__init__.py b/nni/tools/jupyter_extension/__init__.py index df86a70a8..ac2a0e6e8 100644 --- a/nni/tools/jupyter_extension/__init__.py +++ b/nni/tools/jupyter_extension/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from . import proxy load_jupyter_server_extension = proxy.setup diff --git a/nni/tools/jupyter_extension/management.py b/nni/tools/jupyter_extension/management.py index 276160b66..d3e70a3f6 100644 --- a/nni/tools/jupyter_extension/management.py +++ b/nni/tools/jupyter_extension/management.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import json from pathlib import Path import shutil diff --git a/nni/tools/jupyter_extension/proxy.py b/nni/tools/jupyter_extension/proxy.py index f5457379d..c0bd571d7 100644 --- a/nni/tools/jupyter_extension/proxy.py +++ b/nni/tools/jupyter_extension/proxy.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import json from pathlib import Path diff --git a/nni/tools/nnictl/ts_management.py b/nni/tools/nnictl/ts_management.py index 4dffe21eb..f3145e015 100644 --- a/nni/tools/nnictl/ts_management.py +++ b/nni/tools/nnictl/ts_management.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import importlib import json diff --git a/nni/trial.py b/nni/trial.py index 48fa52b08..e8afc036a 100644 --- a/nni/trial.py +++ b/nni/trial.py @@ -1,13 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + +from typing import Any + from .common.serializer import dump from .runtime.env_vars import trial_env_vars from .runtime import platform - +from .typehint import Parameters, TrialMetric __all__ = [ 'get_next_parameter', + 'get_next_parameters', 'get_current_parameter', 'report_intermediate_result', 'report_final_result', @@ -23,7 +28,7 @@ _trial_id = platform.get_trial_id() _sequence_id = platform.get_sequence_id() -def get_next_parameter(): +def get_next_parameter() -> Parameters: """ Get the hyperparameters generated by tuner. @@ -32,7 +37,7 @@ def get_next_parameter(): Examples -------- - Assuming the search space is: + Assuming the :doc:`search space ` is: .. code-block:: @@ -52,16 +57,22 @@ def get_next_parameter(): Returns ------- - dict + :class:`~nni.typehint.Parameters` A hyperparameter set sampled from search space. """ global _params _params = platform.get_next_parameter() if _params is None: - return None + return None # type: ignore return _params['parameters'] -def get_current_parameter(tag=None): +def get_next_parameters() -> Parameters: + """ + Alias of :func:`get_next_parameter` + """ + return get_next_parameter() + +def get_current_parameter(tag: str | None = None) -> Any: global _params if _params is None: return None @@ -94,13 +105,13 @@ def get_sequence_id() -> int: _intermediate_seq = 0 -def overwrite_intermediate_seq(value): +def overwrite_intermediate_seq(value: int) -> None: assert isinstance(value, int) global _intermediate_seq _intermediate_seq = value -def report_intermediate_result(metric): +def report_intermediate_result(metric: TrialMetric | dict[str, Any]) -> None: """ Reports intermediate result to NNI. @@ -110,11 +121,16 @@ def report_intermediate_result(metric): and other items can be visualized with web portal. Typically ``metric`` is per-epoch accuracy or loss. + + Parameters + ---------- + metric : :class:`~nni.typehint.TrialMetric` + The intermeidate result. """ global _intermediate_seq assert _params or trial_env_vars.NNI_PLATFORM is None, \ 'nni.get_next_parameter() needs to be called before report_intermediate_result' - metric = dump({ + dumped_metric = dump({ 'parameter_id': _params['parameter_id'] if _params else None, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'type': 'PERIODICAL', @@ -122,9 +138,9 @@ def report_intermediate_result(metric): 'value': dump(metric) }) _intermediate_seq += 1 - platform.send_metric(metric) + platform.send_metric(dumped_metric) -def report_final_result(metric): +def report_final_result(metric: TrialMetric | dict[str, Any]) -> None: """ Reports final result to NNI. @@ -134,14 +150,19 @@ def report_final_result(metric): and other items can be visualized with web portal. Typically ``metric`` is the final accuracy or loss. + + Parameters + ---------- + metric : :class:`~nni.typehint.TrialMetric` + The final result. """ assert _params or trial_env_vars.NNI_PLATFORM is None, \ 'nni.get_next_parameter() needs to be called before report_final_result' - metric = dump({ + dumped_metric = dump({ 'parameter_id': _params['parameter_id'] if _params else None, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'type': 'FINAL', 'sequence': 0, 'value': dump(metric) }) - platform.send_metric(metric) + platform.send_metric(dumped_metric) diff --git a/nni/tuner.py b/nni/tuner.py index 4fbcc011d..87b168db6 100644 --- a/nni/tuner.py +++ b/nni/tuner.py @@ -8,11 +8,14 @@ A new trial will run with this configuration. See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details. """ +from __future__ import annotations + import logging import nni from .recoverable import Recoverable +from .typehint import Parameters, SearchSpace, TrialMetric, TrialRecord __all__ = ['Tuner'] @@ -67,7 +70,7 @@ class Tuner(Recoverable): :class:`~nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner` """ - def generate_parameters(self, parameter_id, **kwargs): + def generate_parameters(self, parameter_id: int, **kwargs) -> Parameters: """ Abstract method which provides a set of hyper-parameters. @@ -100,7 +103,7 @@ class Tuner(Recoverable): # we need to design a new exception for this purpose raise NotImplementedError('Tuner: generate_parameters not implemented') - def generate_multiple_parameters(self, parameter_id_list, **kwargs): + def generate_multiple_parameters(self, parameter_id_list: list[int], **kwargs) -> list[Parameters]: """ Callback method which provides multiple sets of hyper-parameters. @@ -135,7 +138,7 @@ class Tuner(Recoverable): result.append(res) return result - def receive_trial_result(self, parameter_id, parameters, value, **kwargs): + def receive_trial_result(self, parameter_id: int, parameters: Parameters, value: TrialMetric, **kwargs) -> None: """ Abstract method invoked when a trial reports its final result. Must override. @@ -165,7 +168,7 @@ class Tuner(Recoverable): # pylint: disable=attribute-defined-outside-init self._accept_customized = accept - def trial_end(self, parameter_id, success, **kwargs): + def trial_end(self, parameter_id: int, success: bool, **kwargs) -> None: """ Abstract method invoked when a trial is completed or terminated. Do nothing by default. @@ -179,7 +182,7 @@ class Tuner(Recoverable): Unstable parameters which should be ignored by normal users. """ - def update_search_space(self, search_space): + def update_search_space(self, search_space: SearchSpace) -> None: """ Abstract method for updating the search space. Must override. @@ -194,21 +197,21 @@ class Tuner(Recoverable): """ raise NotImplementedError('Tuner: update_search_space not implemented') - def load_checkpoint(self): + def load_checkpoint(self) -> None: """ Internal API under revising, not recommended for end users. """ checkpoin_path = self.get_checkpoint_path() _logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path) - def save_checkpoint(self): + def save_checkpoint(self) -> None: """ Internal API under revising, not recommended for end users. """ checkpoin_path = self.get_checkpoint_path() _logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path) - def import_data(self, data): + def import_data(self, data: list[TrialRecord]) -> None: """ Internal API under revising, not recommended for end users. """ @@ -216,8 +219,8 @@ class Tuner(Recoverable): # data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' pass - def _on_exit(self): + def _on_exit(self) -> None: pass - def _on_error(self): + def _on_error(self) -> None: pass diff --git a/nni/typehint.py b/nni/typehint.py index 504004047..33b82a22c 100644 --- a/nni/typehint.py +++ b/nni/typehint.py @@ -1,10 +1,58 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import sys -import typing +""" +Types for static checking. +""" -if typing.TYPE_CHECKING or sys.version_info >= (3, 8): - Literal = typing.Literal +__all__ = [ + 'Literal', + 'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord', +] + +import sys +from typing import Any, Dict, List, TYPE_CHECKING + +if TYPE_CHECKING or sys.version_info >= (3, 8): + from typing import Literal, TypedDict else: - Literal = typing.Any + from typing_extensions import Literal, TypedDict + +Parameters = Dict[str, Any] +""" +Return type of :func:`nni.get_next_parameter`. + +For built-in tuners, this is a ``dict`` whose content is defined by :doc:`search space `. + +Customized tuners do not need to follow the constraint and can use anything serializable. +""" + +class _ParameterSearchSpace(TypedDict): + _type: Literal[ + 'choice', 'randint', + 'uniform', 'loguniform', 'quniform', 'qloguniform', + 'normal', 'lognormal', 'qnormal', 'qlognormal', + ] + _value: List[Any] + +SearchSpace = Dict[str, _ParameterSearchSpace] +""" +Type of ``experiment.config.search_space``. + +For built-in tuners, the format is detailed in :doc:`/hpo/search_space`. + +Customized tuners do not need to follow the constraint and can use anything serializable, except ``None``. +""" + +TrialMetric = float +""" +Type of the metrics sent to :func:`nni.report_final_result` and :func:`nni.report_intermediate_result`. + +For built-in tuners it must be a number (``float``, ``int``, ``numpy.float32``, etc). + +Customized tuners do not need to follow this constraint and can use anything serializable. +""" + +class TrialRecord(TypedDict): + parameter: Parameters + value: TrialMetric diff --git a/pipelines/fast-test.yml b/pipelines/fast-test.yml index ca4d7f6f6..657492cb8 100644 --- a/pipelines/fast-test.yml +++ b/pipelines/fast-test.yml @@ -63,6 +63,9 @@ stages: python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics displayName: flake8 + - script: | + python -m pyright nni + - job: typescript pool: vmImage: ubuntu-latest diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 000000000..ebe2e6cee --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,14 @@ +{ + "ignore": [ + "nni/algorithms", + "nni/common/device.py", + "nni/common/graph_utils.py", + "nni/common/serializer.py", + "nni/compression", + "nni/nas", + "nni/retiarii", + "nni/smartparam.py", + "nni/tools" + ], + "reportMissingImports": false +}