зеркало из https://github.com/microsoft/nni.git
Fix search space compatibility with JSON (#4455)
This commit is contained in:
Родитель
452e69f3f2
Коммит
31f11f5124
|
@ -8,7 +8,9 @@ Top level experiement configuration class, ``ExperimentConfig``.
|
|||
__all__ = ['ExperimentConfig']
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
@ -113,6 +115,16 @@ class ExperimentConfig(ConfigBase):
|
|||
|
||||
super()._canonicalize([self])
|
||||
|
||||
if self.search_space_file is not None:
|
||||
yaml_error = None
|
||||
try:
|
||||
self.search_space = _load_search_space_file(self.search_space_file)
|
||||
except Exception as e:
|
||||
yaml_error = repr(e)
|
||||
if yaml_error is not None: # raise it outside except block to make stack trace clear
|
||||
msg = f'ExperimentConfig: Failed to load search space file "{self.search_space_file}": {yaml_error}'
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.nni_manager_ip is None:
|
||||
# show a warning if user does not set nni_manager_ip. we have many issues caused by this
|
||||
# the simple detection logic won't work for hybrid, but advanced users should not need it
|
||||
|
@ -133,10 +145,6 @@ class ExperimentConfig(ConfigBase):
|
|||
if not self.use_annotation and space_cnt < 1:
|
||||
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
|
||||
|
||||
if self.search_space_file is not None:
|
||||
with open(self.search_space_file) as ss_file:
|
||||
self.search_space = yaml.safe_load(ss_file)
|
||||
|
||||
# to make the error message clear, ideally it should be:
|
||||
# `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')`
|
||||
# but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple
|
||||
|
@ -156,3 +164,13 @@ class ExperimentConfig(ConfigBase):
|
|||
tuner_cnt = (self.tuner is not None) + (self.advisor is not None)
|
||||
if tuner_cnt != 1:
|
||||
raise ValueError('ExperimentConfig: tuner and advisor must be set one')
|
||||
|
||||
def _load_search_space_file(search_space_path):
|
||||
# FIXME
|
||||
# we need this because PyYAML 6.0 does not support YAML 1.2,
|
||||
# which means it is not fully compatible with JSON
|
||||
content = Path(search_space_path).read_text(encoding='utf8')
|
||||
try:
|
||||
return json.loads(content)
|
||||
except Exception:
|
||||
return yaml.safe_load(content)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
pool_type:
|
||||
_type: choice
|
||||
_value:
|
||||
- max
|
||||
- min
|
||||
- avg
|
||||
学习率: # test unicode
|
||||
_type: loguniform
|
||||
_value: [ 0.0000001, 0.1 ]
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"pool_type": {
|
||||
"_type": "choice",
|
||||
"_value": [ "max", "min", "avg" ],
|
||||
},
|
||||
"学习率": {
|
||||
"_type": "loguniform",
|
||||
"_value": [ 0.0000001, 0.1 ],
|
||||
},
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"pool_type": {
|
||||
"_type": "choice",
|
||||
"_value": [ "max", "min", "avg" ]
|
||||
},
|
||||
"学习率": {
|
||||
"_type": "loguniform",
|
||||
"_value": [ 1e-7, 0.1 ]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"pool_type": {
|
||||
"_type": "choice",
|
||||
"_value": [ "max", "min", "avg" ],
|
||||
},
|
||||
"学习率": {
|
||||
"_type": "loguniform",
|
||||
"_value": [ 1e-7, 0.1 ],
|
||||
},
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
pool_type:
|
||||
_type: choice
|
||||
_value:
|
||||
- max
|
||||
- min
|
||||
- avg
|
||||
学习率: # test unicode
|
||||
_type: loguniform
|
||||
_value: [ 1e-7, 0.1 ] # test scientific notation
|
|
@ -0,0 +1,52 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from nni.experiment.config import ExperimentConfig, AlgorithmConfig, LocalConfig
|
||||
|
||||
## template ##
|
||||
|
||||
config = ExperimentConfig(
|
||||
search_space_file = '',
|
||||
trial_command = 'echo hello',
|
||||
trial_concurrency = 1,
|
||||
tuner = AlgorithmConfig(name='randomm'),
|
||||
training_service = LocalConfig()
|
||||
)
|
||||
|
||||
space_correct = {
|
||||
'pool_type': {
|
||||
'_type': 'choice',
|
||||
'_value': ['max', 'min', 'avg']
|
||||
},
|
||||
'学习率': {
|
||||
'_type': 'loguniform',
|
||||
'_value': [1e-7, 0.1]
|
||||
}
|
||||
}
|
||||
|
||||
# FIXME
|
||||
# PyYAML 6.0 (YAML 1.1) does not support tab and scientific notation
|
||||
# JSON does not support comment and extra comma
|
||||
# So some combinations will fail to load
|
||||
formats = [
|
||||
('ss_tab.json', 'JSON (tabs + scientific notation)'),
|
||||
('ss_comma.json', 'JSON with extra comma'),
|
||||
#('ss_tab_comma.json', 'JSON (tabs + scientific notation) with extra comma'),
|
||||
('ss.yaml', 'YAML'),
|
||||
#('ss_yaml12.yaml', 'YAML 1.2 with scientific notation'),
|
||||
]
|
||||
|
||||
def test_search_space():
|
||||
for space_file, description in formats:
|
||||
try:
|
||||
config.search_space_file = Path(__file__).parent / 'assets' / space_file
|
||||
space = config.json()['searchSpace']
|
||||
assert space == space_correct
|
||||
except Exception as e:
|
||||
print('Failed to load search space format: ' + description)
|
||||
raise e
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_search_space()
|
Загрузка…
Ссылка в новой задаче