Fix search space compatibility with JSON (#4455)

This commit is contained in:
liuzhe-lz 2022-01-11 01:24:22 +08:00 коммит произвёл GitHub
Родитель 452e69f3f2
Коммит 31f11f5124
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 122 добавлений и 4 удалений

Просмотреть файл

@ -8,7 +8,9 @@ Top level experiement configuration class, ``ExperimentConfig``.
__all__ = ['ExperimentConfig']
from dataclasses import dataclass
import json
import logging
from pathlib import Path
from typing import Any, List, Optional, Union
import yaml
@ -113,6 +115,16 @@ class ExperimentConfig(ConfigBase):
super()._canonicalize([self])
if self.search_space_file is not None:
yaml_error = None
try:
self.search_space = _load_search_space_file(self.search_space_file)
except Exception as e:
yaml_error = repr(e)
if yaml_error is not None: # raise it outside except block to make stack trace clear
msg = f'ExperimentConfig: Failed to load search space file "{self.search_space_file}": {yaml_error}'
raise ValueError(msg)
if self.nni_manager_ip is None:
# show a warning if user does not set nni_manager_ip. we have many issues caused by this
# the simple detection logic won't work for hybrid, but advanced users should not need it
@ -133,10 +145,6 @@ class ExperimentConfig(ConfigBase):
if not self.use_annotation and space_cnt < 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if self.search_space_file is not None:
with open(self.search_space_file) as ss_file:
self.search_space = yaml.safe_load(ss_file)
# to make the error message clear, ideally it should be:
# `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')`
# but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple
@ -156,3 +164,13 @@ class ExperimentConfig(ConfigBase):
tuner_cnt = (self.tuner is not None) + (self.advisor is not None)
if tuner_cnt != 1:
raise ValueError('ExperimentConfig: tuner and advisor must be set one')
def _load_search_space_file(search_space_path):
# FIXME
# we need this because PyYAML 6.0 does not support YAML 1.2,
# which means it is not fully compatible with JSON
content = Path(search_space_path).read_text(encoding='utf8')
try:
return json.loads(content)
except Exception:
return yaml.safe_load(content)

Просмотреть файл

@ -0,0 +1,9 @@
pool_type:
_type: choice
_value:
- max
- min
- avg
学习率: # test unicode
_type: loguniform
_value: [ 0.0000001, 0.1 ]

Просмотреть файл

@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ],
},
"学习率": {
"_type": "loguniform",
"_value": [ 0.0000001, 0.1 ],
},
}

Просмотреть файл

@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ]
},
"学习率": {
"_type": "loguniform",
"_value": [ 1e-7, 0.1 ]
}
}

Просмотреть файл

@ -0,0 +1,10 @@
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ],
},
"学习率": {
"_type": "loguniform",
"_value": [ 1e-7, 0.1 ],
},
}

Просмотреть файл

@ -0,0 +1,9 @@
pool_type:
_type: choice
_value:
- max
- min
- avg
学习率: # test unicode
_type: loguniform
_value: [ 1e-7, 0.1 ] # test scientific notation

Просмотреть файл

@ -0,0 +1,52 @@
import json
from pathlib import Path
import yaml
from nni.experiment.config import ExperimentConfig, AlgorithmConfig, LocalConfig
## template ##
config = ExperimentConfig(
search_space_file = '',
trial_command = 'echo hello',
trial_concurrency = 1,
tuner = AlgorithmConfig(name='randomm'),
training_service = LocalConfig()
)
space_correct = {
'pool_type': {
'_type': 'choice',
'_value': ['max', 'min', 'avg']
},
'学习率': {
'_type': 'loguniform',
'_value': [1e-7, 0.1]
}
}
# FIXME
# PyYAML 6.0 (YAML 1.1) does not support tab and scientific notation
# JSON does not support comment and extra comma
# So some combinations will fail to load
formats = [
('ss_tab.json', 'JSON (tabs + scientific notation)'),
('ss_comma.json', 'JSON with extra comma'),
#('ss_tab_comma.json', 'JSON (tabs + scientific notation) with extra comma'),
('ss.yaml', 'YAML'),
#('ss_yaml12.yaml', 'YAML 1.2 with scientific notation'),
]
def test_search_space():
for space_file, description in formats:
try:
config.search_space_file = Path(__file__).parent / 'assets' / space_file
space = config.json()['searchSpace']
assert space == space_correct
except Exception as e:
print('Failed to load search space format: ' + description)
raise e
if __name__ == '__main__':
test_search_space()