зеркало из https://github.com/microsoft/archai.git
Natsbench SSS regular evaluation code added.
This commit is contained in:
Родитель
9be602cc5d
Коммит
eb0abc880c
|
@ -384,6 +384,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "natsbench_regular_eval"]
|
||||
},
|
||||
{
|
||||
"name": "Natsbench-Regular-SSS-Eval-Full",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "natsbench_sss_regular_eval", "--datasets", "cifar10"]
|
||||
},
|
||||
{
|
||||
"name": "Nb101-Regular-Eval-Full",
|
||||
"type": "python",
|
||||
|
@ -422,7 +430,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main_proxynas_nb_wrapper.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "proxynas_natsbench_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"]
|
||||
"args": ["--algos", "proxynas_natsbench_space", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"]
|
||||
},
|
||||
{
|
||||
"name": "Main Proxynas Natsbench SSS Wrapper",
|
||||
|
@ -430,7 +438,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main_proxynas_nb_sss_wrapper.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "proxynas_natsbench_sss_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"]
|
||||
"args": ["--algos", "natsbench_sss_regular_eval", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"]
|
||||
},
|
||||
{
|
||||
"name": "Main Proxynas Nb101 Wrapper",
|
||||
|
@ -438,7 +446,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main_proxynas_nb101_wrapper.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "proxynas_nasbench101_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"]
|
||||
"args": ["--algos", "proxynas_nasbench101_space", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"]
|
||||
},
|
||||
{
|
||||
"name": "Resnet-Toy",
|
||||
|
@ -713,6 +721,15 @@
|
|||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_fb1024_ftlr0.1_fte10_ct256_ftt0.6_scu",
|
||||
"--out-dir", "F:\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
"name": "Analysis Freeze Natsbench SSS Space",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_natsbench_sss.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6",
|
||||
"--out-dir", "F:\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
"name": "Analysis Freeze Addon No Cond",
|
||||
"type": "python",
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
import importlib
|
||||
import sys
|
||||
import string
|
||||
import os
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger
|
||||
from archai.datasets import data
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.nas.evaluater import Evaluater
|
||||
|
||||
from nats_bench import create
|
||||
from archai.algos.natsbench.lib.models import get_cell_based_tiny_net
|
||||
|
||||
class NatsbenchSSSRegularEvaluater(Evaluater):
|
||||
@overrides
|
||||
def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder,
|
||||
final_desc_filename=None, full_desc_filename=None)->nn.Module:
|
||||
# region conf vars
|
||||
dataset_name = conf_eval['loader']['dataset']['name']
|
||||
self.num_classes = conf_eval['loader']['dataset']['n_classes']
|
||||
|
||||
# if explicitly passed in then don't get from conf
|
||||
if not final_desc_filename:
|
||||
final_desc_filename = conf_eval['final_desc_filename']
|
||||
arch_index = conf_eval['natsbench']['arch_index']
|
||||
|
||||
dataroot = utils.full_path(conf_eval['loader']['dataset']['dataroot'])
|
||||
natsbench_location = os.path.join(dataroot, 'natsbench', conf_eval['natsbench']['natsbench_sss_fast'])
|
||||
# endregion
|
||||
|
||||
assert arch_index
|
||||
assert natsbench_location
|
||||
|
||||
return self._model_from_natsbench(arch_index, dataset_name, natsbench_location)
|
||||
|
||||
def _model_from_natsbench(self, arch_index:int, dataset_name:str, natsbench_location:str)->Model:
|
||||
|
||||
# create natsbench api
|
||||
api = create(natsbench_location, 'tss', fast_mode=True, verbose=True)
|
||||
|
||||
if arch_index >= 32768 or arch_index < 0:
|
||||
logger.warn(f'architecture id {arch_index} is invalid ')
|
||||
|
||||
supported_datasets = {'cifar10', 'cifar100', 'ImageNet16-120'}
|
||||
|
||||
# force natsbench to use cifar10 archs
|
||||
# since it doesn't know about other datasets
|
||||
if dataset_name not in supported_datasets:
|
||||
dataset_name = 'cifar10'
|
||||
|
||||
config = api.get_net_config(arch_index, dataset_name)
|
||||
|
||||
# fill in the num classes from conf again
|
||||
# to account for unsupported datasets may
|
||||
# have number of classes different from cifar10
|
||||
config['num_classes'] = self.num_classes
|
||||
|
||||
# network is a nn.Module subclass. the last few modules have names
|
||||
# lastact, lastact.0, lastact.1, global_pooling, classifier
|
||||
# which we can freeze train as usual
|
||||
model = get_cell_based_tiny_net(config)
|
||||
|
||||
return model
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional, Type
|
||||
from copy import deepcopy
|
||||
import os
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import utils
|
||||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas.evaluater import EvalResult
|
||||
from archai.common.common import get_expdir, logger
|
||||
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
||||
from archai.algos.natsbench.natsbench_sss_regular_evaluater import NatsbenchSSSRegularEvaluater
|
||||
|
||||
from nats_bench import create
|
||||
class NatsbenchSSSRegularExperimentRunner(ExperimentRunner):
|
||||
"""Runs regular training on architectures from natsbench"""
|
||||
|
||||
@overrides
|
||||
def model_desc_builder(self)->Optional[ModelDescBuilder]:
|
||||
return None
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
return None # no search trainer
|
||||
|
||||
@overrides
|
||||
def searcher(self)->ManualFreezeSearcher:
|
||||
return ManualFreezeSearcher() # no searcher basically
|
||||
|
||||
@overrides
|
||||
def copy_search_to_eval(self)->None:
|
||||
pass
|
||||
|
||||
@overrides
|
||||
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||
|
||||
# regular evaluation of the architecture
|
||||
# where we simply lookup the result
|
||||
# --------------------------------------
|
||||
dataset_name = conf_eval['loader']['dataset']['name']
|
||||
|
||||
logger.pushd('regular_evaluate')
|
||||
if dataset_name in {'cifar10', 'cifar100', 'ImageNet16-120'}:
|
||||
arch_id = conf_eval['natsbench']['arch_index']
|
||||
dataroot = utils.full_path(conf_eval['loader']['dataset']['dataroot'])
|
||||
natsbench_location = os.path.join(dataroot, 'natsbench', conf_eval['natsbench']['natsbench_sss_fast'])
|
||||
logger.info(natsbench_location)
|
||||
|
||||
api = create(natsbench_location, 'tss', fast_mode=True, verbose=True)
|
||||
|
||||
if arch_id >= 32768 or arch_id < 0:
|
||||
logger.warn(f'architecture id {arch_id} is invalid ')
|
||||
|
||||
info = api.get_more_info(arch_id, dataset_name, hp=90, is_random=False)
|
||||
test_accuracy = info['test-accuracy']
|
||||
logger.info(f'Regular training top1 test accuracy is {test_accuracy}')
|
||||
logger.info({'regtrainingtop1': float(test_accuracy)})
|
||||
else:
|
||||
logger.info({'regtrainingtop1': -1})
|
||||
logger.popd()
|
||||
|
||||
# regular evaluation of n epochs
|
||||
evaler = NatsbenchSSSRegularEvaluater()
|
||||
return evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
nas:
|
||||
search:
|
||||
model_desc:
|
||||
num_edges_to_sample: 2 # number of edges each node will take input from
|
||||
|
||||
eval:
|
||||
natsbench:
|
||||
arch_index: 9682
|
||||
natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the sss fast mode folder
|
||||
model_desc:
|
||||
num_edges_to_sample: 2
|
||||
loader:
|
||||
train_batch: 256 # natsbench uses 256
|
||||
aug: '' # random flip and crop are already there in default params
|
||||
trainer: # matching natsbench paper closely
|
||||
plotsdir: ''
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
|
||||
title: 'eval_train'
|
||||
epochs: 5
|
||||
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.1 # init learning rate
|
||||
decay: 5.0e-4 # pytorch default is 0.0
|
||||
momentum: 0.9 # pytorch default is 0.0
|
||||
nesterov: True # pytorch default is False
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.000 # min learning rate to be set in eta_min param of scheduler
|
||||
warmup: # increases LR for 0 to current in specified epochs and then hands over to main scheduler
|
||||
multiplier: 1
|
||||
epochs: 0 # 0 disables warmup
|
||||
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ nas:
|
|||
eval:
|
||||
natsbench:
|
||||
arch_index: 23000
|
||||
natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the tss fast mode folder
|
||||
natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the sss fast mode folder
|
||||
model_desc:
|
||||
num_edges_to_sample: 2
|
||||
loader:
|
||||
|
|
|
@ -21,6 +21,7 @@ from archai.algos.proxynas.freeze_nasbench101_experiment_runner import FreezeNas
|
|||
from archai.algos.proxynas.freeze_manual_experiment_runner import ManualFreezeExperimentRunner
|
||||
from archai.algos.naswotrain.naswotrain_natsbench_conditional_experiment_runner import NaswotConditionalNatsbenchExperimentRunner
|
||||
from archai.algos.natsbench.natsbench_regular_experiment_runner import NatsbenchRegularExperimentRunner
|
||||
from archai.algos.natsbench.natsbench_sss_regular_experiment_runner import NatsbenchSSSRegularExperimentRunner
|
||||
from archai.algos.nasbench101.nasbench101_exp_runner import Nb101RegularExperimentRunner
|
||||
from archai.algos.proxynas.phased_freeze_natsbench_experiment_runner import PhasedFreezeNatsbenchExperimentRunner
|
||||
from archai.algos.proxynas.freezeaddon_nasbench101_experiment_runner import FreezeAddonNasbench101ExperimentRunner
|
||||
|
@ -54,6 +55,7 @@ def main():
|
|||
'zerocost_conditional_natsbench_space': ZeroCostConditionalNatsbenchExperimentRunner,
|
||||
'zerocost_natsbench_epochs_space': ZeroCostNatsbenchEpochsExperimentRunner,
|
||||
'natsbench_regular_eval': NatsbenchRegularExperimentRunner,
|
||||
'natsbench_sss_regular_eval': NatsbenchSSSRegularExperimentRunner,
|
||||
'nb101_regular_eval': Nb101RegularExperimentRunner,
|
||||
'phased_freezetrain_natsbench_space': PhasedFreezeNatsbenchExperimentRunner,
|
||||
'freezeaddon_nasbench101_space': FreezeAddonNasbench101ExperimentRunner,
|
||||
|
@ -85,6 +87,7 @@ def main():
|
|||
zerocost_conditional_natsbench_space,
|
||||
zerocost_natsbench_epochs_space,
|
||||
natsbench_regular_eval,
|
||||
natsbench_sss_regular_eval,
|
||||
nb101_regular_eval,
|
||||
phased_freezetrain_natsbench_space,
|
||||
random_natsbench_tss_far,
|
||||
|
|
|
@ -11,7 +11,8 @@ from archai.common.utils import exec_shell_command
|
|||
def main():
|
||||
parser = argparse.ArgumentParser(description='Proxynas SSS Wrapper Main')
|
||||
parser.add_argument('--algos', type=str, default='''
|
||||
proxynas_natsbench_sss_space,
|
||||
proxynas_natsbench_sss_space,
|
||||
natsbench_sss_regular_eval,
|
||||
''',
|
||||
help='NAS algos to run, separated by comma')
|
||||
parser.add_argument('--top1-acc-threshold', type=float)
|
||||
|
|
|
@ -0,0 +1,461 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from typing import Dict, List, Type, Iterator, Tuple
|
||||
import glob
|
||||
import os
|
||||
import pathlib
|
||||
from collections import OrderedDict, defaultdict
|
||||
from scipy.stats.stats import _two_sample_transform
|
||||
import yaml
|
||||
from inspect import getsourcefile
|
||||
import seaborn as sns
|
||||
import math as ma
|
||||
|
||||
|
||||
|
||||
import plotly.express as px
|
||||
from plotly.subplots import make_subplots
|
||||
import plotly.graph_objects as go
|
||||
|
||||
from scipy.stats import kendalltau, spearmanr, sem
|
||||
|
||||
from runstats import Statistics
|
||||
|
||||
#import matplotlib
|
||||
#matplotlib.use('TkAgg')
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from multiprocessing import Pool
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
from archai.common import utils
|
||||
from archai.common.ordereddict_logger import OrderedDictLogger
|
||||
from archai.common.analysis_utils import epoch_nodes, parse_a_job, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Report creator')
|
||||
parser.add_argument('--results-dir', '-d', type=str,
|
||||
default=r'~/logdir/proxynas_test_0001',
|
||||
help='folder with experiment results')
|
||||
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports',
|
||||
help='folder to output reports')
|
||||
parser.add_argument('--reg-evals-file', '-r', type=str, default=None,
|
||||
help='optional yaml file which contains full evaluation \
|
||||
of architectures on new datasets not part of natsbench')
|
||||
args, extra_args = parser.parse_known_args()
|
||||
|
||||
# root dir where all results are stored
|
||||
results_dir = pathlib.Path(utils.full_path(args.results_dir))
|
||||
print(f'results_dir: {results_dir}')
|
||||
|
||||
# extract experiment name which is top level directory
|
||||
exp_name = results_dir.parts[-1]
|
||||
|
||||
# create results dir for experiment
|
||||
out_dir = utils.full_path(os.path.join(args.out_dir, exp_name))
|
||||
print(f'out_dir: {out_dir}')
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# if optional regular evaluation lookup file is provided
|
||||
if args.reg_evals_file:
|
||||
with open(args.reg_evals_file, 'r') as f:
|
||||
reg_evals_data = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
# get list of all structured logs for each job
|
||||
logs = {}
|
||||
confs = {}
|
||||
job_dirs = list(results_dir.iterdir())
|
||||
|
||||
# # test single job parsing for debugging
|
||||
# # WARNING: very slow, just use for debugging
|
||||
# for job_dir in job_dirs:
|
||||
# a = parse_a_job(job_dir)
|
||||
|
||||
# parallel parsing of yaml logs
|
||||
num_workers = 8
|
||||
with Pool(num_workers) as p:
|
||||
a = p.map(parse_a_job, job_dirs)
|
||||
|
||||
for storage in a:
|
||||
for key, val in storage.items():
|
||||
logs[key] = val[0]
|
||||
confs[key] = val[1]
|
||||
|
||||
# examples of accessing logs
|
||||
# logs['proxynas_blahblah:eval']['naswotrain_evaluate']['eval_arch']['eval_train']['naswithouttraining']
|
||||
# logs['proxynas_blahblah:eval']['regular_evaluate']['regtrainingtop1']
|
||||
# logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']['9']['val']['top1']
|
||||
# last_epoch_key = list(logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1]
|
||||
# last_val_top1 = logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][last_epoch_key]['val']['top1']
|
||||
# epoch_duration = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']['0']['train']['duration']
|
||||
|
||||
# remove all search jobs
|
||||
for key in list(logs.keys()):
|
||||
if 'search' in key:
|
||||
logs.pop(key)
|
||||
|
||||
# remove all arch_ids which did not finish
|
||||
for key in list(logs.keys()):
|
||||
to_delete = False
|
||||
|
||||
# it might have died early
|
||||
if 'freeze_evaluate' not in list(logs[key].keys()):
|
||||
to_delete = True
|
||||
|
||||
if 'regular_evaluate' not in list(logs[key].keys()):
|
||||
to_delete = True
|
||||
|
||||
if to_delete:
|
||||
print(f'arch id {key} did not finish. removing from calculations.')
|
||||
logs.pop(key)
|
||||
continue
|
||||
|
||||
if 'freeze_training'not in list(logs[key]['freeze_evaluate']['eval_arch'].keys()):
|
||||
print(f'arch id {key} did not finish. removing from calculations.')
|
||||
logs.pop(key)
|
||||
continue
|
||||
|
||||
# freeze train may not have finished
|
||||
num_freeze_epochs = confs[key]['nas']['eval']['freeze_trainer']['epochs']
|
||||
last_freeze_epoch_key = int(list(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1])
|
||||
if last_freeze_epoch_key != num_freeze_epochs - 1:
|
||||
print(f'arch id {key} did not finish. removing from calculations.')
|
||||
logs.pop(key)
|
||||
|
||||
|
||||
all_arch_ids = []
|
||||
|
||||
all_reg_evals = []
|
||||
|
||||
all_freeze_evals_last = []
|
||||
all_cond_evals_last = []
|
||||
|
||||
all_freeze_flops_last = []
|
||||
all_cond_flops_last = []
|
||||
|
||||
all_freeze_time_last = []
|
||||
all_cond_time_last = []
|
||||
all_partial_time_last = []
|
||||
|
||||
all_freeze_evals = defaultdict(list)
|
||||
|
||||
num_archs_unmet_cond = 0
|
||||
|
||||
for key in logs.keys():
|
||||
if 'eval' in key:
|
||||
try:
|
||||
|
||||
# if at the end of conditional training train accuracy has not gone above target then don't consider it
|
||||
# important to get this first
|
||||
last_cond_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'].keys())[-1]
|
||||
use_val = confs[key]['nas']['eval']['trainer']['use_val']
|
||||
threshold = confs[key]['nas']['eval']['trainer']['top1_acc_threshold']
|
||||
if use_val:
|
||||
val_or_train = 'val'
|
||||
else:
|
||||
val_or_train = 'train'
|
||||
end_cond = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][last_cond_epoch_key][val_or_train]['top1']
|
||||
if end_cond < threshold:
|
||||
num_archs_unmet_cond += 1
|
||||
continue
|
||||
|
||||
# regular evaluation
|
||||
# important to get this first since if it is not
|
||||
# available for non-benchmark datasets we need to
|
||||
# remove it from consideration
|
||||
# --------------------
|
||||
if not args.reg_evals_file:
|
||||
reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1']
|
||||
else:
|
||||
# lookup from the provided file since this dataset is not part of the
|
||||
# benchmark and hence we have to provide the info separately
|
||||
if 'natsbench' in list(confs[key]['nas']['eval'].keys()):
|
||||
arch_id_in_bench = confs[key]['nas']['eval']['natsbench']['arch_index']
|
||||
elif 'nasbench101' in list(confs[key]['nas']['eval'].keys()):
|
||||
arch_id_in_bench = confs[key]['nas']['eval']['nasbench101']['arch_index']
|
||||
|
||||
if arch_id_in_bench not in list(reg_evals_data.keys()):
|
||||
# if the dataset used is not part of the standard benchmark some of the architectures
|
||||
# may not have full evaluation accuracies available. Remove them from consideration.
|
||||
continue
|
||||
reg_eval_top1 = reg_evals_data[arch_id_in_bench]
|
||||
all_reg_evals.append(reg_eval_top1)
|
||||
|
||||
# freeze evaluation
|
||||
#--------------------
|
||||
|
||||
# at last epoch
|
||||
last_freeze_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1]
|
||||
freeze_eval_top1 = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][last_freeze_epoch_key][val_or_train]['top1']
|
||||
all_freeze_evals_last.append(freeze_eval_top1)
|
||||
|
||||
# collect evals at other epochs
|
||||
for epoch in range(int(last_freeze_epoch_key)):
|
||||
all_freeze_evals[epoch].append(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][str(epoch)][val_or_train]['top1'])
|
||||
|
||||
# collect flops used for conditional training and freeze training
|
||||
freeze_mega_flops_epoch = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['total_mega_flops_epoch']
|
||||
freeze_mega_flops_used = freeze_mega_flops_epoch * int(last_freeze_epoch_key)
|
||||
all_freeze_flops_last.append(freeze_mega_flops_used)
|
||||
|
||||
last_cond_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'].keys())[-1]
|
||||
cond_mega_flops_epoch = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['total_mega_flops_epoch']
|
||||
cond_mega_flops_used = cond_mega_flops_epoch * int(last_cond_epoch_key)
|
||||
all_cond_flops_last.append(cond_mega_flops_used)
|
||||
|
||||
# collect training error at end of conditional training
|
||||
cond_eval_top1 = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][last_cond_epoch_key][val_or_train]['top1']
|
||||
all_cond_evals_last.append(cond_eval_top1)
|
||||
|
||||
# collect duration for conditional training and freeze training
|
||||
# NOTE: don't use val_or_train here since we are really interested in the duration of training
|
||||
freeze_duration = 0.0
|
||||
for epoch_key in logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']:
|
||||
freeze_duration += logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][epoch_key]['train']['duration']
|
||||
|
||||
cond_duration = 0.0
|
||||
for epoch_key in logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs']:
|
||||
cond_duration += logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][epoch_key]['train']['duration']
|
||||
|
||||
all_freeze_time_last.append(freeze_duration + cond_duration)
|
||||
all_cond_time_last.append(cond_duration)
|
||||
all_partial_time_last.append(freeze_duration)
|
||||
|
||||
# record the arch id
|
||||
# --------------------
|
||||
if 'natsbench' in list(confs[key]['nas']['eval'].keys()):
|
||||
all_arch_ids.append(confs[key]['nas']['eval']['natsbench']['arch_index'])
|
||||
elif 'nasbench101' in list(confs[key]['nas']['eval'].keys()):
|
||||
all_arch_ids.append(confs[key]['nas']['eval']['nasbench101']['arch_index'])
|
||||
|
||||
except KeyError as err:
|
||||
print(f'KeyError {err} not in {key}!')
|
||||
sys.exit()
|
||||
|
||||
|
||||
# Store some key numbers in results.txt
|
||||
results_savename = os.path.join(out_dir, 'results.txt')
|
||||
with open(results_savename, 'w') as f:
|
||||
f.write(f'Number of archs which did not reach condition: {num_archs_unmet_cond} \n')
|
||||
f.write(f'Total valid archs processed: {len(all_reg_evals)} \n')
|
||||
|
||||
print(f'Number of archs which did not reach condition: {num_archs_unmet_cond}')
|
||||
print(f'Total valid archs processed: {len(all_reg_evals)}')
|
||||
|
||||
# Sanity check
|
||||
assert len(all_reg_evals) == len(all_freeze_evals_last)
|
||||
assert len(all_reg_evals) == len(all_cond_evals_last)
|
||||
assert len(all_reg_evals) == len(all_cond_time_last)
|
||||
assert len(all_reg_evals) == len(all_freeze_flops_last)
|
||||
assert len(all_reg_evals) == len(all_cond_flops_last)
|
||||
assert len(all_reg_evals) == len(all_freeze_time_last)
|
||||
assert len(all_reg_evals) == len(all_arch_ids)
|
||||
|
||||
# scatter plot between time to threshold accuracy and regular evaluation
|
||||
fig = px.scatter(x=all_cond_time_last, y=all_reg_evals, labels={'x': 'Time to reach threshold train accuracy (s)', 'y': 'Final Accuracy'})
|
||||
fig.update_layout(font=dict(
|
||||
size=48,
|
||||
))
|
||||
|
||||
savename = os.path.join(out_dir, 'cond_time_vs_final_acc.html')
|
||||
fig.write_html(savename)
|
||||
|
||||
savename_pdf = os.path.join(out_dir, 'cond_time_vs_final_acc.pdf')
|
||||
fig.write_image(savename_pdf, engine="kaleido", width=1500, height=1500, scale=1)
|
||||
|
||||
fig.show()
|
||||
|
||||
# histogram of training accuracies
|
||||
fig = px.histogram(all_reg_evals, labels={'x': 'Test Accuracy', 'y': 'Counts'})
|
||||
savename = os.path.join(out_dir, 'distribution_of_reg_evals.html')
|
||||
fig.write_html(savename)
|
||||
fig.show()
|
||||
|
||||
# Freeze training results at last epoch
|
||||
freeze_tau, freeze_p_value = kendalltau(all_reg_evals, all_freeze_evals_last)
|
||||
freeze_spe, freeze_sp_value = spearmanr(all_reg_evals, all_freeze_evals_last)
|
||||
print(f'Freeze Kendall Tau score: {freeze_tau:3.03f}, p_value {freeze_p_value:3.03f}')
|
||||
print(f'Freeze Spearman corr: {freeze_spe:3.03f}, p_value {freeze_sp_value:3.03f}')
|
||||
with open(results_savename, 'a') as f:
|
||||
f.write(f'Freeze Kendall Tau score: {freeze_tau:3.03f}, p_value {freeze_p_value:3.03f} \n')
|
||||
f.write(f'Freeze Spearman corr: {freeze_spe:3.03f}, p_value {freeze_sp_value:3.03f} \n')
|
||||
|
||||
plt.clf()
|
||||
sns.scatterplot(x=all_reg_evals, y=all_freeze_evals_last)
|
||||
plt.xlabel('Test top1 at natsbench full training')
|
||||
plt.ylabel('Freeze training')
|
||||
plt.grid()
|
||||
savename = os.path.join(out_dir, 'proxynas_freeze_training_epochs.png')
|
||||
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
||||
|
||||
# Conditional training results at last epoch
|
||||
cond_tau, cond_p_value = kendalltau(all_reg_evals, all_cond_evals_last)
|
||||
cond_spe, cond_sp_value = spearmanr(all_reg_evals, all_cond_evals_last)
|
||||
print(f'Conditional Kendall Tau score: {cond_tau:3.03f}, p_value {cond_p_value:3.03f}')
|
||||
print(f'Conditional Spearman corr: {cond_spe:3.03f}, p_value {cond_sp_value:3.03f}')
|
||||
with open(results_savename, 'a') as f:
|
||||
f.write(f'Conditional Kendall Tau score: {cond_tau:3.03f}, p_value {cond_p_value:3.03f} \n')
|
||||
f.write(f'Conditional Spearman corr: {cond_spe:3.03f}, p_value {cond_sp_value:3.03f} \n')
|
||||
|
||||
plt.clf()
|
||||
sns.scatterplot(x=all_reg_evals, y=all_cond_evals_last)
|
||||
plt.xlabel('Test top1 at natsbench full training')
|
||||
plt.ylabel('Conditional training')
|
||||
plt.grid()
|
||||
savename = os.path.join(out_dir, 'proxynas_cond_training_epochs.png')
|
||||
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
||||
|
||||
# Report average runtime and average flops consumed
|
||||
total_freeze_flops = np.array(all_freeze_flops_last) + np.array(all_cond_flops_last)
|
||||
avg_freeze_flops = np.mean(total_freeze_flops)
|
||||
std_freeze_flops = np.std(total_freeze_flops)
|
||||
stderr_freeze_flops = std_freeze_flops / np.sqrt(len(all_freeze_flops_last))
|
||||
|
||||
avg_freeze_runtime = np.mean(np.array(all_freeze_time_last))
|
||||
std_freeze_runtime = np.std(np.array(all_freeze_time_last))
|
||||
stderr_freeze_runtime = std_freeze_runtime / np.sqrt(len(all_freeze_time_last))
|
||||
|
||||
avg_cond_runtime = np.mean(np.array(all_cond_time_last))
|
||||
std_cond_runtime = np.std(np.array(all_cond_time_last))
|
||||
stderr_cond_runtime = std_cond_runtime / np.sqrt(len(all_cond_time_last))
|
||||
|
||||
avg_partial_runtime = np.mean(np.array(all_partial_time_last))
|
||||
std_partial_runtime = np.std(np.array(all_partial_time_last))
|
||||
stderr_partial_runtime = std_partial_runtime / np.sqrt(len(all_partial_time_last))
|
||||
|
||||
with open(results_savename, 'a') as f:
|
||||
f.write(f'Avg. Freeze MFlops: {avg_freeze_flops:.03f}, std {std_freeze_flops}, stderr {stderr_freeze_flops:.03f} \n')
|
||||
f.write(f'Avg. Freeze Runtime: {avg_freeze_runtime:.03f}, std {std_freeze_runtime}, stderr {stderr_freeze_runtime:.03f} \n')
|
||||
f.write(f'Avg. Conditional Runtime: {avg_cond_runtime:.03f}, std {std_cond_runtime}, stderr {stderr_cond_runtime:.03f} \n')
|
||||
f.write(f'Avg. Partial Runtime: {avg_partial_runtime:.03f}, std {std_partial_runtime}, stderr {stderr_partial_runtime:.03f} \n')
|
||||
|
||||
# Plot freeze training rank correlations if cutoff at various epochs
|
||||
freeze_taus = {}
|
||||
freeze_spes = {}
|
||||
for epoch_key in all_freeze_evals.keys():
|
||||
tau, _ = kendalltau(all_reg_evals, all_freeze_evals[epoch_key])
|
||||
spe, _ = spearmanr(all_reg_evals, all_freeze_evals[epoch_key])
|
||||
freeze_taus[epoch_key] = tau
|
||||
freeze_spes[epoch_key] = spe
|
||||
|
||||
plt.clf()
|
||||
for epoch_key in freeze_taus.keys():
|
||||
plt.scatter(epoch_key, freeze_taus[epoch_key])
|
||||
plt.xlabel('Epochs of freeze training')
|
||||
plt.ylabel('Kendall Tau')
|
||||
plt.ylim((-1.0, 1.0))
|
||||
plt.grid()
|
||||
savename = os.path.join(out_dir, 'proxynas_freeze_training_kendall_taus.png')
|
||||
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
||||
|
||||
plt.clf()
|
||||
for epoch_key in freeze_taus.keys():
|
||||
plt.scatter(epoch_key, freeze_spes[epoch_key])
|
||||
plt.xlabel('Epochs of freeze training')
|
||||
plt.ylabel('Spearman Correlation')
|
||||
plt.ylim((-1.0, 1.0))
|
||||
plt.grid()
|
||||
savename = os.path.join(out_dir, 'proxynas_freeze_training_spearman_corrs.png')
|
||||
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
||||
|
||||
|
||||
# Rank correlations at top n percent of architectures
|
||||
#-----------------------------------------------------
|
||||
reg_freezelast_naswot_evals = [(all_reg_evals[i], all_freeze_evals_last[i], all_freeze_time_last[i]) for i in range(len(all_reg_evals))]
|
||||
|
||||
# sort in descending order of accuracy of regular evaluation
|
||||
reg_freezelast_naswot_evals.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
top_percent_freeze_times_avg = []
|
||||
top_percent_freeze_times_std = []
|
||||
top_percent_freeze_times_stderr = []
|
||||
|
||||
spe_freeze_top_percents = []
|
||||
top_percents = []
|
||||
top_percent_range = range(2, 101, 2)
|
||||
for top_percent in top_percent_range:
|
||||
top_percents.append(top_percent)
|
||||
num_to_keep = int(ma.floor(len(reg_freezelast_naswot_evals) * top_percent * 0.01))
|
||||
top_percent_evals = reg_freezelast_naswot_evals[:num_to_keep]
|
||||
top_percent_reg = [x[0] for x in top_percent_evals]
|
||||
top_percent_freeze = [x[1] for x in top_percent_evals]
|
||||
top_percent_freeze_times = [x[2] for x in top_percent_evals]
|
||||
|
||||
top_percent_freeze_times_avg.append(np.mean(np.array(top_percent_freeze_times)))
|
||||
top_percent_freeze_times_std.append(np.std(np.array(top_percent_freeze_times)))
|
||||
top_percent_freeze_times_stderr.append(sem(np.array(top_percent_freeze_times)))
|
||||
|
||||
spe_freeze, _ = spearmanr(top_percent_reg, top_percent_freeze)
|
||||
spe_freeze_top_percents.append(spe_freeze)
|
||||
|
||||
|
||||
plt.clf()
|
||||
sns.scatterplot(top_percents, spe_freeze_top_percents)
|
||||
plt.legend(labels=['Freeze Train'])
|
||||
plt.ylim((-1.0, 1.0))
|
||||
plt.xlabel('Top percent of architectures')
|
||||
plt.ylabel('Spearman Correlation')
|
||||
plt.grid()
|
||||
savename = os.path.join(out_dir, f'spe_top_archs.png')
|
||||
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
||||
|
||||
plt.clf()
|
||||
plt.errorbar(top_percents, top_percent_freeze_times_avg, yerr=np.array(top_percent_freeze_times_std)/2, marker='s', mfc='red', ms=10, mew=4)
|
||||
plt.xlabel('Top percent of architectures')
|
||||
plt.ylabel('Avg. time (s)')
|
||||
plt.yticks(np.arange(0,600, step=50))
|
||||
plt.grid()
|
||||
savename = os.path.join(out_dir, f'freeze_train_duration_top_archs.png')
|
||||
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
||||
|
||||
# how much overlap in top x% of architectures between method and groundtruth
|
||||
# ----------------------------------------------------------------------------
|
||||
arch_id_reg_evals = [(arch_id, reg_eval) for arch_id, reg_eval in zip(all_arch_ids, all_reg_evals)]
|
||||
arch_id_freezetrain_evals = [(arch_id, freeze_eval) for arch_id, freeze_eval in zip(all_arch_ids, all_freeze_evals_last)]
|
||||
|
||||
arch_id_reg_evals.sort(key=lambda x: x[1], reverse=True)
|
||||
arch_id_freezetrain_evals.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
assert len(arch_id_reg_evals) == len(arch_id_freezetrain_evals)
|
||||
|
||||
top_percents = []
|
||||
freezetrain_ratio_common = []
|
||||
for top_percent in top_percent_range:
|
||||
top_percents.append(top_percent)
|
||||
num_to_keep = int(ma.floor(len(arch_id_reg_evals) * top_percent * 0.01))
|
||||
top_percent_arch_id_reg_evals = arch_id_reg_evals[:num_to_keep]
|
||||
top_percent_arch_id_freezetrain_evals = arch_id_freezetrain_evals[:num_to_keep]
|
||||
|
||||
# take the set of arch_ids in each method and find overlap with top archs
|
||||
set_reg = set([x[0] for x in top_percent_arch_id_reg_evals])
|
||||
set_ft = set([x[0] for x in top_percent_arch_id_freezetrain_evals])
|
||||
ft_num_common = len(set_reg.intersection(set_ft))
|
||||
freezetrain_ratio_common.append(ft_num_common/num_to_keep)
|
||||
|
||||
|
||||
# save raw data for other aggregate plots over experiments
|
||||
raw_data_dict = {}
|
||||
raw_data_dict['top_percents'] = top_percents
|
||||
raw_data_dict['spe_freeze'] = spe_freeze_top_percents
|
||||
raw_data_dict['freeze_times_avg'] = top_percent_freeze_times_avg
|
||||
raw_data_dict['freeze_times_std'] = top_percent_freeze_times_std
|
||||
raw_data_dict['freeze_times_stderr'] = top_percent_freeze_times_stderr
|
||||
raw_data_dict['freeze_ratio_common'] = freezetrain_ratio_common
|
||||
|
||||
savename = os.path.join(out_dir, 'raw_data.yaml')
|
||||
with open(savename, 'w') as f:
|
||||
yaml.dump(raw_data_dict, f)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Загрузка…
Ссылка в новой задаче