Support overriding by args along with redirects

This commit is contained in:
Shital Shah 2020-05-03 04:22:40 -07:00
Родитель 1904a7ab65
Коммит 6be9f214ab
2 изменённых файлов: 29 добавлений и 16 удалений

Просмотреть файл

@ -5,14 +5,16 @@ from typing import Sequence
from argparse import ArgumentError
from collections.abc import Mapping, MutableMapping
import os
import yaml
from distutils.util import strtobool
import copy
from os import stat
import yaml
from . import yaml_utils
# global config instance
_config:'Config' = None
@ -66,16 +68,20 @@ class Config(UserDict):
for filepath in config_filepath.strip().split(';'):
self._load_from_file(filepath.strip())
# replace _copy paths
# HACK: using file names to detect root config, may be there should be a flag?
# Create a copy of ourselves and do the resolution over it.
# This resolved_conf then can be used to search for overrides that
# wouldn't have existed before resolution.
resolved_conf = copy.deepcopy(self)
if resolve_redirects:
yaml_utils.resolve_all(resolved_conf)
# Let's do final overrides from args
self._update_from_args(param_args, resolved_conf) # merge from params
self._update_from_args(self.extra_args, resolved_conf) # merge from command line
if resolve_redirects:
yaml_utils.resolve_all(self)
# TODO: currently update from args is applied after _copy resolution which
# limits to only fine grain overriding and probably can be made better
self._update_from_args(param_args) # merge from params
self._update_from_args(self.extra_args) # merge from command line
self.config_filepath = config_filepath
def _load_from_file(self, filepath:Optional[str])->None:
@ -92,13 +98,13 @@ class Config(UserDict):
deep_update(self, config_yaml, lambda: Config(resolve_redirects=False))
print('config loaded from: ', filepath)
def _update_from_args(self, args:Sequence)->None:
def _update_from_args(self, args:Sequence, resolved_section:'Config')->None:
i = 0
while i < len(args)-1:
arg = args[i]
if arg.startswith(("--")):
path = arg[len("--"):].split('.')
i += Config._update_section(self, path, args[i+1])
i += Config._update_section(self, path, args[i+1], resolved_section)
else: # some other arg
i += 1
@ -106,19 +112,22 @@ class Config(UserDict):
return deep_update({}, self, lambda: dict()) # type: ignore
@staticmethod
def _update_section(section:'Config', path:List[str], val:Any)->int:
def _update_section(section:'Config', path:List[str], val:Any, resolved_section:'Config')->int:
for p in range(len(path)-1):
sub_path = path[p]
if sub_path in section:
if sub_path in resolved_section:
resolved_section = resolved_section[sub_path]
if not sub_path in section:
section[sub_path] = Config(resolve_redirects=False)
section = section[sub_path]
else:
return 1 # path not found, ignore this
key = path[-1] # final leaf node value
if key in section:
if key in resolved_section:
original_val, original_type = None, None
try:
original_val = section[key]
original_val = resolved_section[key]
original_type = type(original_val)
if original_type == bool: # bool('False') is True :(
original_type = lambda x: strtobool(x)==1

Просмотреть файл

@ -1,9 +1,13 @@
from archai.common.config import Config
def test_param_override():
conf = Config('confs/algos/darts.yaml;confs/datasets/cifar10.yaml')
assert not conf['nas']['eval']['trainer']['apex']['enabled']
assert not conf['nas']['eval']['loader']['apex']['enabled']
conf = Config('confs/algos/darts.yaml;confs/datasets/cifar10.yaml',
param_args=["--nas.eval.trainer.apex.enabled", "True"])
assert conf['nas']['eval']['trainer']['apex']['enabled']
assert conf['nas']['eval']['loader']['apex']['enabled']
test_param_override()