зеркало из https://github.com/microsoft/archai.git
config.get->get_inst, collate option in report, serialize tests
This commit is contained in:
Родитель
9c9ef5ed22
Коммит
9a7d01bdbd
|
@ -33,7 +33,7 @@ _tb_writer: SummaryWriterAny = None
|
|||
_atexit_reg = False # is hook for atexit registered?
|
||||
|
||||
def get_conf()->Config:
|
||||
return Config.get()
|
||||
return Config.get_inst()
|
||||
|
||||
def get_conf_common()->Config:
|
||||
return get_conf()['common']
|
||||
|
@ -96,7 +96,7 @@ def common_init(config_filepath: Optional[str]=None,
|
|||
conf = Config(config_filepath=config_filepath,
|
||||
param_args=param_overrides,
|
||||
use_args=use_args)
|
||||
Config.set(conf)
|
||||
Config.set_inst(conf)
|
||||
|
||||
# create experiment dir
|
||||
_setup_dirs()
|
||||
|
|
|
@ -51,8 +51,6 @@ class Config(UserDict):
|
|||
resolve_redirects -- [if True then _copy commands in yaml are executed]
|
||||
"""
|
||||
super(Config, self).__init__()
|
||||
# without below Python would let static method override instance method
|
||||
self.get = super(Config, self).get
|
||||
|
||||
self.args, self.extra_args = None, []
|
||||
|
||||
|
@ -145,12 +143,12 @@ class Config(UserDict):
|
|||
return super().get(key, default_val)
|
||||
|
||||
@staticmethod
|
||||
def set(instance:'Config')->None:
|
||||
def set_inst(instance:'Config')->None:
|
||||
global _config
|
||||
_config = instance
|
||||
|
||||
@staticmethod
|
||||
def get()->'Config':
|
||||
def get_inst()->'Config':
|
||||
global _config
|
||||
return _config
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ gorilla.apply(patch)
|
|||
@ray.remote(num_gpus=torch.cuda.device_count(), max_calls=1)
|
||||
def _train_model(conf, dataroot, augment, val_ratio, val_fold, save_path=None,
|
||||
only_eval=False):
|
||||
Config.set(conf)
|
||||
Config.set_inst(conf)
|
||||
conf['autoaug']['loader']['aug'] = augment
|
||||
model_type = conf['autoaug']['model']['type']
|
||||
|
||||
|
@ -310,7 +310,7 @@ def search(conf):
|
|||
|
||||
|
||||
def _eval_tta(conf, augment, reporter):
|
||||
Config.set(conf)
|
||||
Config.set_inst(conf)
|
||||
|
||||
# region conf vars
|
||||
conf_data = conf['dataset']
|
||||
|
|
|
@ -9,5 +9,6 @@ def get_filepath(suffix):
|
|||
])
|
||||
return utils.full_path(os.path.join('$expdir' ,'somefile.txt'))
|
||||
|
||||
|
||||
print(get_filepath('search'))
|
||||
print(get_filepath('eval'))
|
||||
|
|
|
@ -22,12 +22,26 @@ from archai.common.ordereddict_logger import OrderedDictLogger
|
|||
import re
|
||||
|
||||
|
||||
"""
|
||||
Temp workaround for yaml construction recursion error:
|
||||
Search and replace following with blank:
|
||||
|
||||
get: !!python/object/apply:builtins.getattr
|
||||
- *id001
|
||||
- get
|
||||
"""
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Report creator')
|
||||
parser.add_argument('--results-dir', '-d', type=str, default=r'D:\GitHubSrc\archaiphilly\phillytools\darts_baseline_20200411',
|
||||
parser.add_argument('--results-dir', '-d', type=str,
|
||||
#default=r'D:\GitHubSrc\archaiphilly\phillytools\darts_baseline_20200411',
|
||||
default=r'D:\GitHubSrc\archaiphilly\phillytools\c10_auxs_droppaths_b96',
|
||||
help='folder with experiment results from pt')
|
||||
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports',
|
||||
help='folder to output reports')
|
||||
parser.add_argument('--collate', '-c', type=lambda x:x.lower()=='true',
|
||||
nargs='?', const=True, default=False,
|
||||
help='Collate epochs metrics from jobs, useful if jobs are for different seeds')
|
||||
args, extra_args = parser.parse_known_args()
|
||||
|
||||
# root dir where all results are stored
|
||||
|
@ -43,24 +57,33 @@ def main():
|
|||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# get list of all structured logs for each job
|
||||
logs = []
|
||||
logs = {}
|
||||
job_count = 0
|
||||
for job_dir in results_dir.iterdir():
|
||||
job_count += 1
|
||||
for subdir in job_dir.iterdir():
|
||||
# currently we expect that each job was ExperimentRunner job which should have
|
||||
# _search or _eval folders
|
||||
is_search = subdir.stem.endswith('_search')
|
||||
is_eval = subdir.stem.endswith('_eval')
|
||||
assert is_search or is_eval
|
||||
if subdir.stem.endswith('_search'):
|
||||
sub_job = 'search'
|
||||
elif subdir.stem.endswith('_eval'):
|
||||
sub_job = 'eval'
|
||||
else:
|
||||
raise RuntimeError('Sub directory "{subdir}" in job "{}" must '
|
||||
'end with either _search or _eval which '
|
||||
'should be the case if ExperimentRunner was used.')
|
||||
|
||||
logs_filepath = os.path.join(str(subdir), 'logs.yaml')
|
||||
if os.path.isfile(logs_filepath):
|
||||
with open(logs_filepath, 'r') as f:
|
||||
logs.append(yaml.load(f, Loader=yaml.Loader))
|
||||
key = job_dir.stem + ':' + sub_job
|
||||
logs[key] = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
collated_logs = collate_epoch_nodes(logs)
|
||||
# create list of epoch nodes having same path in the logs
|
||||
collated_logs = collect_epoch_nodes(logs, args.collate)
|
||||
summary_text, details_text = '', ''
|
||||
|
||||
# for each path for epochs nodes, compute stats
|
||||
for node_path, logs_epochs_nodes in collated_logs.items():
|
||||
collated_epoch_stats = get_epoch_stats(node_path, logs_epochs_nodes)
|
||||
summary_text += get_summary_text(out_dir, node_path, collated_epoch_stats)
|
||||
|
@ -70,25 +93,30 @@ def main():
|
|||
write_report('details.md', **vars())
|
||||
|
||||
def epoch_nodes(node:OrderedDict, path=[])->Iterator[Tuple[List[str], OrderedDictLogger]]:
|
||||
for k,v in node.items():
|
||||
"""Search nodes recursively for nodes named 'epochs' and return them along with their paths"""
|
||||
for k, v in node.items():
|
||||
if k == 'epochs' and isinstance(v, OrderedDict) and len(v) and '0' in v:
|
||||
yield path, v
|
||||
elif isinstance(v, OrderedDict):
|
||||
for p, en in epoch_nodes(v, path=path+[k]):
|
||||
yield p, en
|
||||
|
||||
def collate_epoch_nodes(logs:List[OrderedDict])->Dict[str, List[OrderedDict]]:
|
||||
def collect_epoch_nodes(logs:Dict[str, OrderedDict], collate:bool)->Dict[str, List[OrderedDict]]:
|
||||
"""Make list of epoch nodes in same path in each of the logs if collate=True else
|
||||
its just list of epoch nodes with jobdir and path as the key."""
|
||||
collated = OrderedDict()
|
||||
for log in logs:
|
||||
for log_key, log in logs.items():
|
||||
for path, epoch_node in epoch_nodes(log):
|
||||
# for each path get the list where we can put epoch node
|
||||
path_key = '/'.join(path)
|
||||
if not collate:
|
||||
path_key += ' : ' + log_key
|
||||
if not path_key in collated:
|
||||
collated[path_key] = []
|
||||
v = collated[path_key]
|
||||
v.append(epoch_node)
|
||||
return collated
|
||||
|
||||
|
||||
class EpochStats:
|
||||
def __init__(self) -> None:
|
||||
self.start_lr = Statistics()
|
||||
|
@ -99,8 +127,11 @@ class EpochStats:
|
|||
def update(self, epoch_node:OrderedDict)->None:
|
||||
self.start_lr.push(epoch_node['start_lr'])
|
||||
self.end_lr.push(epoch_node['train']['end_lr'])
|
||||
self.train_fold.update(epoch_node['train'])
|
||||
self.val_fold.update(epoch_node['val'])
|
||||
|
||||
if 'train' in epoch_node:
|
||||
self.train_fold.update(epoch_node['train'])
|
||||
if 'val' in epoch_node:
|
||||
self.val_fold.update(epoch_node['val'])
|
||||
|
||||
class FoldStats:
|
||||
def __init__(self) -> None:
|
||||
|
@ -161,9 +192,12 @@ def get_summary_text(out_dir:str, node_path:str, epoch_stats:List[EpochStats])->
|
|||
|
||||
lines.append(f'Train epoch time: {stat2str(train_duration)}')
|
||||
lines.append('')
|
||||
for milestone in [35-1, 200-1, 600-1, 1500-1]:
|
||||
milestones = [35, 200, 600, 1500]
|
||||
for milestone in milestones:
|
||||
if len(epoch_stats) >= milestone:
|
||||
lines.append(f'{stat2str(epoch_stats[milestone].val_fold.top1)} val top1 @ {milestone} epochs')
|
||||
lines.append(f'{stat2str(epoch_stats[milestone-1].val_fold.top1)} val top1 @ {milestone} epochs')
|
||||
if not len(epoch_stats) in milestones:
|
||||
lines.append(f'{stat2str(epoch_stats[-1].val_fold.top1)} val top1 @ {len(epoch_stats)} epochs')
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import yaml
|
||||
from archai.common.config import Config
|
||||
|
||||
def test_param_override1():
|
||||
|
@ -20,5 +21,35 @@ def test_param_override2():
|
|||
assert not conf['nas']['eval']['trainer']['apex']['distributed_enabled']
|
||||
assert not conf['nas']['eval']['loader']['apex']['distributed_enabled']
|
||||
|
||||
def test_serialize():
|
||||
conf = Config()
|
||||
conf['decay'] = 1
|
||||
s = yaml.dump(conf)
|
||||
conf2 = yaml.load(s, Loader=yaml.Loader)
|
||||
assert len(conf2)==1
|
||||
|
||||
def test_serialize_str():
|
||||
s = """
|
||||
!!python/object/apply:collections.OrderedDict
|
||||
- - - conf_optim
|
||||
- &id001 !!python/object:archai.common.config.Config
|
||||
args: null
|
||||
config_filepath: null
|
||||
data:
|
||||
decay: 0.0003
|
||||
decay_bn: .nan
|
||||
lr: 0.025
|
||||
momentum: 0.9
|
||||
nesterov: false
|
||||
type: sgd
|
||||
extra_args: []
|
||||
- - steps_per_epoch
|
||||
- 521
|
||||
"""
|
||||
o = yaml.load(s, Loader=yaml.Loader)
|
||||
assert o is not None
|
||||
|
||||
test_serialize_str()
|
||||
test_serialize()
|
||||
test_param_override1()
|
||||
test_param_override2()
|
Загрузка…
Ссылка в новой задаче