Natsbench SSS regular evaluation code added.

This commit is contained in:
Debadeepta Dey 2021-07-29 14:29:46 -07:00 коммит произвёл Gustavo Rosa
Родитель 9be602cc5d
Коммит eb0abc880c
8 изменённых файлов: 688 добавлений и 5 удалений

23
.vscode/launch.json поставляемый
Просмотреть файл

@ -384,6 +384,14 @@
"console": "integratedTerminal",
"args": ["--algos", "natsbench_regular_eval"]
},
{
"name": "Natsbench-Regular-SSS-Eval-Full",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "natsbench_sss_regular_eval", "--datasets", "cifar10"]
},
{
"name": "Nb101-Regular-Eval-Full",
"type": "python",
@ -422,7 +430,7 @@
"request": "launch",
"program": "${cwd}/scripts/main_proxynas_nb_wrapper.py",
"console": "integratedTerminal",
"args": ["--algos", "proxynas_natsbench_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"]
"args": ["--algos", "proxynas_natsbench_space", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"]
},
{
"name": "Main Proxynas Natsbench SSS Wrapper",
@ -430,7 +438,7 @@
"request": "launch",
"program": "${cwd}/scripts/main_proxynas_nb_sss_wrapper.py",
"console": "integratedTerminal",
"args": ["--algos", "proxynas_natsbench_sss_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"]
"args": ["--algos", "natsbench_sss_regular_eval", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"]
},
{
"name": "Main Proxynas Nb101 Wrapper",
@ -438,7 +446,7 @@
"request": "launch",
"program": "${cwd}/scripts/main_proxynas_nb101_wrapper.py",
"console": "integratedTerminal",
"args": ["--algos", "proxynas_nasbench101_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"]
"args": ["--algos", "proxynas_nasbench101_space", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"]
},
{
"name": "Resnet-Toy",
@ -713,6 +721,15 @@
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_fb1024_ftlr0.1_fte10_ct256_ftt0.6_scu",
"--out-dir", "F:\\archai_experiment_reports"]
},
{
"name": "Analysis Freeze Natsbench SSS Space",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_natsbench_sss.py",
"console": "integratedTerminal",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6",
"--out-dir", "F:\\archai_experiment_reports"]
},
{
"name": "Analysis Freeze Addon No Cond",
"type": "python",

Просмотреть файл

@ -0,0 +1,83 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
from typing import Optional
import importlib
import sys
import string
import os
from overrides import overrides
import torch
from torch import nn
from overrides import overrides, EnforceOverrides
from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger
from archai.datasets import data
from archai.nas.model_desc import ModelDesc
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas import nas_utils
from archai.common import ml_utils, utils
from archai.common.metrics import EpochMetrics, Metrics
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from archai.nas.evaluater import Evaluater
from nats_bench import create
from archai.algos.natsbench.lib.models import get_cell_based_tiny_net
class NatsbenchSSSRegularEvaluater(Evaluater):
@overrides
def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder,
final_desc_filename=None, full_desc_filename=None)->nn.Module:
# region conf vars
dataset_name = conf_eval['loader']['dataset']['name']
self.num_classes = conf_eval['loader']['dataset']['n_classes']
# if explicitly passed in then don't get from conf
if not final_desc_filename:
final_desc_filename = conf_eval['final_desc_filename']
arch_index = conf_eval['natsbench']['arch_index']
dataroot = utils.full_path(conf_eval['loader']['dataset']['dataroot'])
natsbench_location = os.path.join(dataroot, 'natsbench', conf_eval['natsbench']['natsbench_sss_fast'])
# endregion
assert arch_index
assert natsbench_location
return self._model_from_natsbench(arch_index, dataset_name, natsbench_location)
def _model_from_natsbench(self, arch_index:int, dataset_name:str, natsbench_location:str)->Model:
# create natsbench api
api = create(natsbench_location, 'tss', fast_mode=True, verbose=True)
if arch_index >= 32768 or arch_index < 0:
logger.warn(f'architecture id {arch_index} is invalid ')
supported_datasets = {'cifar10', 'cifar100', 'ImageNet16-120'}
# force natsbench to use cifar10 archs
# since it doesn't know about other datasets
if dataset_name not in supported_datasets:
dataset_name = 'cifar10'
config = api.get_net_config(arch_index, dataset_name)
# fill in the num classes from conf again
# to account for unsupported datasets may
# have number of classes different from cifar10
config['num_classes'] = self.num_classes
# network is a nn.Module subclass. the last few modules have names
# lastact, lastact.0, lastact.1, global_pooling, classifier
# which we can freeze train as usual
model = get_cell_based_tiny_net(config)
return model

Просмотреть файл

@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Type
from copy import deepcopy
import os
from overrides import overrides
from archai.common.config import Config
from archai.nas import nas_utils
from archai.common import utils
from archai.nas.exp_runner import ExperimentRunner
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas.evaluater import EvalResult
from archai.common.common import get_expdir, logger
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
from archai.algos.natsbench.natsbench_sss_regular_evaluater import NatsbenchSSSRegularEvaluater
from nats_bench import create
class NatsbenchSSSRegularExperimentRunner(ExperimentRunner):
"""Runs regular training on architectures from natsbench"""
@overrides
def model_desc_builder(self)->Optional[ModelDescBuilder]:
return None
@overrides
def trainer_class(self)->TArchTrainer:
return None # no search trainer
@overrides
def searcher(self)->ManualFreezeSearcher:
return ManualFreezeSearcher() # no searcher basically
@overrides
def copy_search_to_eval(self)->None:
pass
@overrides
def run_eval(self, conf_eval:Config)->EvalResult:
# regular evaluation of the architecture
# where we simply lookup the result
# --------------------------------------
dataset_name = conf_eval['loader']['dataset']['name']
logger.pushd('regular_evaluate')
if dataset_name in {'cifar10', 'cifar100', 'ImageNet16-120'}:
arch_id = conf_eval['natsbench']['arch_index']
dataroot = utils.full_path(conf_eval['loader']['dataset']['dataroot'])
natsbench_location = os.path.join(dataroot, 'natsbench', conf_eval['natsbench']['natsbench_sss_fast'])
logger.info(natsbench_location)
api = create(natsbench_location, 'tss', fast_mode=True, verbose=True)
if arch_id >= 32768 or arch_id < 0:
logger.warn(f'architecture id {arch_id} is invalid ')
info = api.get_more_info(arch_id, dataset_name, hp=90, is_random=False)
test_accuracy = info['test-accuracy']
logger.info(f'Regular training top1 test accuracy is {test_accuracy}')
logger.info({'regtrainingtop1': float(test_accuracy)})
else:
logger.info({'regtrainingtop1': -1})
logger.popd()
# regular evaluation of n epochs
evaler = NatsbenchSSSRegularEvaluater()
return evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())

Просмотреть файл

@ -0,0 +1,46 @@
__include__: 'darts.yaml' # just use darts defaults
nas:
search:
model_desc:
num_edges_to_sample: 2 # number of edges each node will take input from
eval:
natsbench:
arch_index: 9682
natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the sss fast mode folder
model_desc:
num_edges_to_sample: 2
loader:
train_batch: 256 # natsbench uses 256
aug: '' # random flip and crop are already there in default params
trainer: # matching natsbench paper closely
plotsdir: ''
apex:
_copy: '/common/apex'
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
drop_path_prob: 0.2 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
title: 'eval_train'
epochs: 5
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
lossfn:
type: 'CrossEntropyLoss'
optimizer:
type: 'sgd'
lr: 0.1 # init learning rate
decay: 5.0e-4 # pytorch default is 0.0
momentum: 0.9 # pytorch default is 0.0
nesterov: True # pytorch default is False
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
lr_schedule:
type: 'cosine'
min_lr: 0.000 # min learning rate to be set in eta_min param of scheduler
warmup: # increases LR for 0 to current in specified epochs and then hands over to main scheduler
multiplier: 1
epochs: 0 # 0 disables warmup

Просмотреть файл

@ -9,7 +9,7 @@ nas:
eval:
natsbench:
arch_index: 23000
natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the tss fast mode folder
natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the sss fast mode folder
model_desc:
num_edges_to_sample: 2
loader:

Просмотреть файл

@ -21,6 +21,7 @@ from archai.algos.proxynas.freeze_nasbench101_experiment_runner import FreezeNas
from archai.algos.proxynas.freeze_manual_experiment_runner import ManualFreezeExperimentRunner
from archai.algos.naswotrain.naswotrain_natsbench_conditional_experiment_runner import NaswotConditionalNatsbenchExperimentRunner
from archai.algos.natsbench.natsbench_regular_experiment_runner import NatsbenchRegularExperimentRunner
from archai.algos.natsbench.natsbench_sss_regular_experiment_runner import NatsbenchSSSRegularExperimentRunner
from archai.algos.nasbench101.nasbench101_exp_runner import Nb101RegularExperimentRunner
from archai.algos.proxynas.phased_freeze_natsbench_experiment_runner import PhasedFreezeNatsbenchExperimentRunner
from archai.algos.proxynas.freezeaddon_nasbench101_experiment_runner import FreezeAddonNasbench101ExperimentRunner
@ -54,6 +55,7 @@ def main():
'zerocost_conditional_natsbench_space': ZeroCostConditionalNatsbenchExperimentRunner,
'zerocost_natsbench_epochs_space': ZeroCostNatsbenchEpochsExperimentRunner,
'natsbench_regular_eval': NatsbenchRegularExperimentRunner,
'natsbench_sss_regular_eval': NatsbenchSSSRegularExperimentRunner,
'nb101_regular_eval': Nb101RegularExperimentRunner,
'phased_freezetrain_natsbench_space': PhasedFreezeNatsbenchExperimentRunner,
'freezeaddon_nasbench101_space': FreezeAddonNasbench101ExperimentRunner,
@ -85,6 +87,7 @@ def main():
zerocost_conditional_natsbench_space,
zerocost_natsbench_epochs_space,
natsbench_regular_eval,
natsbench_sss_regular_eval,
nb101_regular_eval,
phased_freezetrain_natsbench_space,
random_natsbench_tss_far,

Просмотреть файл

@ -11,7 +11,8 @@ from archai.common.utils import exec_shell_command
def main():
parser = argparse.ArgumentParser(description='Proxynas SSS Wrapper Main')
parser.add_argument('--algos', type=str, default='''
proxynas_natsbench_sss_space,
proxynas_natsbench_sss_space,
natsbench_sss_regular_eval,
''',
help='NAS algos to run, separated by comma')
parser.add_argument('--top1-acc-threshold', type=float)

Просмотреть файл

@ -0,0 +1,461 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import argparse
from typing import Dict, List, Type, Iterator, Tuple
import glob
import os
import pathlib
from collections import OrderedDict, defaultdict
from scipy.stats.stats import _two_sample_transform
import yaml
from inspect import getsourcefile
import seaborn as sns
import math as ma
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scipy.stats import kendalltau, spearmanr, sem
from runstats import Statistics
#import matplotlib
#matplotlib.use('TkAgg')
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from multiprocessing import Pool
from collections import namedtuple
from archai.common import utils
from archai.common.ordereddict_logger import OrderedDictLogger
from archai.common.analysis_utils import epoch_nodes, parse_a_job, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
import re
def main():
parser = argparse.ArgumentParser(description='Report creator')
parser.add_argument('--results-dir', '-d', type=str,
default=r'~/logdir/proxynas_test_0001',
help='folder with experiment results')
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports',
help='folder to output reports')
parser.add_argument('--reg-evals-file', '-r', type=str, default=None,
help='optional yaml file which contains full evaluation \
of architectures on new datasets not part of natsbench')
args, extra_args = parser.parse_known_args()
# root dir where all results are stored
results_dir = pathlib.Path(utils.full_path(args.results_dir))
print(f'results_dir: {results_dir}')
# extract experiment name which is top level directory
exp_name = results_dir.parts[-1]
# create results dir for experiment
out_dir = utils.full_path(os.path.join(args.out_dir, exp_name))
print(f'out_dir: {out_dir}')
os.makedirs(out_dir, exist_ok=True)
# if optional regular evaluation lookup file is provided
if args.reg_evals_file:
with open(args.reg_evals_file, 'r') as f:
reg_evals_data = yaml.load(f, Loader=yaml.Loader)
# get list of all structured logs for each job
logs = {}
confs = {}
job_dirs = list(results_dir.iterdir())
# # test single job parsing for debugging
# # WARNING: very slow, just use for debugging
# for job_dir in job_dirs:
# a = parse_a_job(job_dir)
# parallel parsing of yaml logs
num_workers = 8
with Pool(num_workers) as p:
a = p.map(parse_a_job, job_dirs)
for storage in a:
for key, val in storage.items():
logs[key] = val[0]
confs[key] = val[1]
# examples of accessing logs
# logs['proxynas_blahblah:eval']['naswotrain_evaluate']['eval_arch']['eval_train']['naswithouttraining']
# logs['proxynas_blahblah:eval']['regular_evaluate']['regtrainingtop1']
# logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']['9']['val']['top1']
# last_epoch_key = list(logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1]
# last_val_top1 = logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][last_epoch_key]['val']['top1']
# epoch_duration = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']['0']['train']['duration']
# remove all search jobs
for key in list(logs.keys()):
if 'search' in key:
logs.pop(key)
# remove all arch_ids which did not finish
for key in list(logs.keys()):
to_delete = False
# it might have died early
if 'freeze_evaluate' not in list(logs[key].keys()):
to_delete = True
if 'regular_evaluate' not in list(logs[key].keys()):
to_delete = True
if to_delete:
print(f'arch id {key} did not finish. removing from calculations.')
logs.pop(key)
continue
if 'freeze_training'not in list(logs[key]['freeze_evaluate']['eval_arch'].keys()):
print(f'arch id {key} did not finish. removing from calculations.')
logs.pop(key)
continue
# freeze train may not have finished
num_freeze_epochs = confs[key]['nas']['eval']['freeze_trainer']['epochs']
last_freeze_epoch_key = int(list(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1])
if last_freeze_epoch_key != num_freeze_epochs - 1:
print(f'arch id {key} did not finish. removing from calculations.')
logs.pop(key)
all_arch_ids = []
all_reg_evals = []
all_freeze_evals_last = []
all_cond_evals_last = []
all_freeze_flops_last = []
all_cond_flops_last = []
all_freeze_time_last = []
all_cond_time_last = []
all_partial_time_last = []
all_freeze_evals = defaultdict(list)
num_archs_unmet_cond = 0
for key in logs.keys():
if 'eval' in key:
try:
# if at the end of conditional training train accuracy has not gone above target then don't consider it
# important to get this first
last_cond_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'].keys())[-1]
use_val = confs[key]['nas']['eval']['trainer']['use_val']
threshold = confs[key]['nas']['eval']['trainer']['top1_acc_threshold']
if use_val:
val_or_train = 'val'
else:
val_or_train = 'train'
end_cond = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][last_cond_epoch_key][val_or_train]['top1']
if end_cond < threshold:
num_archs_unmet_cond += 1
continue
# regular evaluation
# important to get this first since if it is not
# available for non-benchmark datasets we need to
# remove it from consideration
# --------------------
if not args.reg_evals_file:
reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1']
else:
# lookup from the provided file since this dataset is not part of the
# benchmark and hence we have to provide the info separately
if 'natsbench' in list(confs[key]['nas']['eval'].keys()):
arch_id_in_bench = confs[key]['nas']['eval']['natsbench']['arch_index']
elif 'nasbench101' in list(confs[key]['nas']['eval'].keys()):
arch_id_in_bench = confs[key]['nas']['eval']['nasbench101']['arch_index']
if arch_id_in_bench not in list(reg_evals_data.keys()):
# if the dataset used is not part of the standard benchmark some of the architectures
# may not have full evaluation accuracies available. Remove them from consideration.
continue
reg_eval_top1 = reg_evals_data[arch_id_in_bench]
all_reg_evals.append(reg_eval_top1)
# freeze evaluation
#--------------------
# at last epoch
last_freeze_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1]
freeze_eval_top1 = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][last_freeze_epoch_key][val_or_train]['top1']
all_freeze_evals_last.append(freeze_eval_top1)
# collect evals at other epochs
for epoch in range(int(last_freeze_epoch_key)):
all_freeze_evals[epoch].append(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][str(epoch)][val_or_train]['top1'])
# collect flops used for conditional training and freeze training
freeze_mega_flops_epoch = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['total_mega_flops_epoch']
freeze_mega_flops_used = freeze_mega_flops_epoch * int(last_freeze_epoch_key)
all_freeze_flops_last.append(freeze_mega_flops_used)
last_cond_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'].keys())[-1]
cond_mega_flops_epoch = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['total_mega_flops_epoch']
cond_mega_flops_used = cond_mega_flops_epoch * int(last_cond_epoch_key)
all_cond_flops_last.append(cond_mega_flops_used)
# collect training error at end of conditional training
cond_eval_top1 = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][last_cond_epoch_key][val_or_train]['top1']
all_cond_evals_last.append(cond_eval_top1)
# collect duration for conditional training and freeze training
# NOTE: don't use val_or_train here since we are really interested in the duration of training
freeze_duration = 0.0
for epoch_key in logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']:
freeze_duration += logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][epoch_key]['train']['duration']
cond_duration = 0.0
for epoch_key in logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs']:
cond_duration += logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][epoch_key]['train']['duration']
all_freeze_time_last.append(freeze_duration + cond_duration)
all_cond_time_last.append(cond_duration)
all_partial_time_last.append(freeze_duration)
# record the arch id
# --------------------
if 'natsbench' in list(confs[key]['nas']['eval'].keys()):
all_arch_ids.append(confs[key]['nas']['eval']['natsbench']['arch_index'])
elif 'nasbench101' in list(confs[key]['nas']['eval'].keys()):
all_arch_ids.append(confs[key]['nas']['eval']['nasbench101']['arch_index'])
except KeyError as err:
print(f'KeyError {err} not in {key}!')
sys.exit()
# Store some key numbers in results.txt
results_savename = os.path.join(out_dir, 'results.txt')
with open(results_savename, 'w') as f:
f.write(f'Number of archs which did not reach condition: {num_archs_unmet_cond} \n')
f.write(f'Total valid archs processed: {len(all_reg_evals)} \n')
print(f'Number of archs which did not reach condition: {num_archs_unmet_cond}')
print(f'Total valid archs processed: {len(all_reg_evals)}')
# Sanity check
assert len(all_reg_evals) == len(all_freeze_evals_last)
assert len(all_reg_evals) == len(all_cond_evals_last)
assert len(all_reg_evals) == len(all_cond_time_last)
assert len(all_reg_evals) == len(all_freeze_flops_last)
assert len(all_reg_evals) == len(all_cond_flops_last)
assert len(all_reg_evals) == len(all_freeze_time_last)
assert len(all_reg_evals) == len(all_arch_ids)
# scatter plot between time to threshold accuracy and regular evaluation
fig = px.scatter(x=all_cond_time_last, y=all_reg_evals, labels={'x': 'Time to reach threshold train accuracy (s)', 'y': 'Final Accuracy'})
fig.update_layout(font=dict(
size=48,
))
savename = os.path.join(out_dir, 'cond_time_vs_final_acc.html')
fig.write_html(savename)
savename_pdf = os.path.join(out_dir, 'cond_time_vs_final_acc.pdf')
fig.write_image(savename_pdf, engine="kaleido", width=1500, height=1500, scale=1)
fig.show()
# histogram of training accuracies
fig = px.histogram(all_reg_evals, labels={'x': 'Test Accuracy', 'y': 'Counts'})
savename = os.path.join(out_dir, 'distribution_of_reg_evals.html')
fig.write_html(savename)
fig.show()
# Freeze training results at last epoch
freeze_tau, freeze_p_value = kendalltau(all_reg_evals, all_freeze_evals_last)
freeze_spe, freeze_sp_value = spearmanr(all_reg_evals, all_freeze_evals_last)
print(f'Freeze Kendall Tau score: {freeze_tau:3.03f}, p_value {freeze_p_value:3.03f}')
print(f'Freeze Spearman corr: {freeze_spe:3.03f}, p_value {freeze_sp_value:3.03f}')
with open(results_savename, 'a') as f:
f.write(f'Freeze Kendall Tau score: {freeze_tau:3.03f}, p_value {freeze_p_value:3.03f} \n')
f.write(f'Freeze Spearman corr: {freeze_spe:3.03f}, p_value {freeze_sp_value:3.03f} \n')
plt.clf()
sns.scatterplot(x=all_reg_evals, y=all_freeze_evals_last)
plt.xlabel('Test top1 at natsbench full training')
plt.ylabel('Freeze training')
plt.grid()
savename = os.path.join(out_dir, 'proxynas_freeze_training_epochs.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
# Conditional training results at last epoch
cond_tau, cond_p_value = kendalltau(all_reg_evals, all_cond_evals_last)
cond_spe, cond_sp_value = spearmanr(all_reg_evals, all_cond_evals_last)
print(f'Conditional Kendall Tau score: {cond_tau:3.03f}, p_value {cond_p_value:3.03f}')
print(f'Conditional Spearman corr: {cond_spe:3.03f}, p_value {cond_sp_value:3.03f}')
with open(results_savename, 'a') as f:
f.write(f'Conditional Kendall Tau score: {cond_tau:3.03f}, p_value {cond_p_value:3.03f} \n')
f.write(f'Conditional Spearman corr: {cond_spe:3.03f}, p_value {cond_sp_value:3.03f} \n')
plt.clf()
sns.scatterplot(x=all_reg_evals, y=all_cond_evals_last)
plt.xlabel('Test top1 at natsbench full training')
plt.ylabel('Conditional training')
plt.grid()
savename = os.path.join(out_dir, 'proxynas_cond_training_epochs.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
# Report average runtime and average flops consumed
total_freeze_flops = np.array(all_freeze_flops_last) + np.array(all_cond_flops_last)
avg_freeze_flops = np.mean(total_freeze_flops)
std_freeze_flops = np.std(total_freeze_flops)
stderr_freeze_flops = std_freeze_flops / np.sqrt(len(all_freeze_flops_last))
avg_freeze_runtime = np.mean(np.array(all_freeze_time_last))
std_freeze_runtime = np.std(np.array(all_freeze_time_last))
stderr_freeze_runtime = std_freeze_runtime / np.sqrt(len(all_freeze_time_last))
avg_cond_runtime = np.mean(np.array(all_cond_time_last))
std_cond_runtime = np.std(np.array(all_cond_time_last))
stderr_cond_runtime = std_cond_runtime / np.sqrt(len(all_cond_time_last))
avg_partial_runtime = np.mean(np.array(all_partial_time_last))
std_partial_runtime = np.std(np.array(all_partial_time_last))
stderr_partial_runtime = std_partial_runtime / np.sqrt(len(all_partial_time_last))
with open(results_savename, 'a') as f:
f.write(f'Avg. Freeze MFlops: {avg_freeze_flops:.03f}, std {std_freeze_flops}, stderr {stderr_freeze_flops:.03f} \n')
f.write(f'Avg. Freeze Runtime: {avg_freeze_runtime:.03f}, std {std_freeze_runtime}, stderr {stderr_freeze_runtime:.03f} \n')
f.write(f'Avg. Conditional Runtime: {avg_cond_runtime:.03f}, std {std_cond_runtime}, stderr {stderr_cond_runtime:.03f} \n')
f.write(f'Avg. Partial Runtime: {avg_partial_runtime:.03f}, std {std_partial_runtime}, stderr {stderr_partial_runtime:.03f} \n')
# Plot freeze training rank correlations if cutoff at various epochs
freeze_taus = {}
freeze_spes = {}
for epoch_key in all_freeze_evals.keys():
tau, _ = kendalltau(all_reg_evals, all_freeze_evals[epoch_key])
spe, _ = spearmanr(all_reg_evals, all_freeze_evals[epoch_key])
freeze_taus[epoch_key] = tau
freeze_spes[epoch_key] = spe
plt.clf()
for epoch_key in freeze_taus.keys():
plt.scatter(epoch_key, freeze_taus[epoch_key])
plt.xlabel('Epochs of freeze training')
plt.ylabel('Kendall Tau')
plt.ylim((-1.0, 1.0))
plt.grid()
savename = os.path.join(out_dir, 'proxynas_freeze_training_kendall_taus.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
plt.clf()
for epoch_key in freeze_taus.keys():
plt.scatter(epoch_key, freeze_spes[epoch_key])
plt.xlabel('Epochs of freeze training')
plt.ylabel('Spearman Correlation')
plt.ylim((-1.0, 1.0))
plt.grid()
savename = os.path.join(out_dir, 'proxynas_freeze_training_spearman_corrs.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
# Rank correlations at top n percent of architectures
#-----------------------------------------------------
reg_freezelast_naswot_evals = [(all_reg_evals[i], all_freeze_evals_last[i], all_freeze_time_last[i]) for i in range(len(all_reg_evals))]
# sort in descending order of accuracy of regular evaluation
reg_freezelast_naswot_evals.sort(key=lambda x: x[0], reverse=True)
top_percent_freeze_times_avg = []
top_percent_freeze_times_std = []
top_percent_freeze_times_stderr = []
spe_freeze_top_percents = []
top_percents = []
top_percent_range = range(2, 101, 2)
for top_percent in top_percent_range:
top_percents.append(top_percent)
num_to_keep = int(ma.floor(len(reg_freezelast_naswot_evals) * top_percent * 0.01))
top_percent_evals = reg_freezelast_naswot_evals[:num_to_keep]
top_percent_reg = [x[0] for x in top_percent_evals]
top_percent_freeze = [x[1] for x in top_percent_evals]
top_percent_freeze_times = [x[2] for x in top_percent_evals]
top_percent_freeze_times_avg.append(np.mean(np.array(top_percent_freeze_times)))
top_percent_freeze_times_std.append(np.std(np.array(top_percent_freeze_times)))
top_percent_freeze_times_stderr.append(sem(np.array(top_percent_freeze_times)))
spe_freeze, _ = spearmanr(top_percent_reg, top_percent_freeze)
spe_freeze_top_percents.append(spe_freeze)
plt.clf()
sns.scatterplot(top_percents, spe_freeze_top_percents)
plt.legend(labels=['Freeze Train'])
plt.ylim((-1.0, 1.0))
plt.xlabel('Top percent of architectures')
plt.ylabel('Spearman Correlation')
plt.grid()
savename = os.path.join(out_dir, f'spe_top_archs.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
plt.clf()
plt.errorbar(top_percents, top_percent_freeze_times_avg, yerr=np.array(top_percent_freeze_times_std)/2, marker='s', mfc='red', ms=10, mew=4)
plt.xlabel('Top percent of architectures')
plt.ylabel('Avg. time (s)')
plt.yticks(np.arange(0,600, step=50))
plt.grid()
savename = os.path.join(out_dir, f'freeze_train_duration_top_archs.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
# how much overlap in top x% of architectures between method and groundtruth
# ----------------------------------------------------------------------------
arch_id_reg_evals = [(arch_id, reg_eval) for arch_id, reg_eval in zip(all_arch_ids, all_reg_evals)]
arch_id_freezetrain_evals = [(arch_id, freeze_eval) for arch_id, freeze_eval in zip(all_arch_ids, all_freeze_evals_last)]
arch_id_reg_evals.sort(key=lambda x: x[1], reverse=True)
arch_id_freezetrain_evals.sort(key=lambda x: x[1], reverse=True)
assert len(arch_id_reg_evals) == len(arch_id_freezetrain_evals)
top_percents = []
freezetrain_ratio_common = []
for top_percent in top_percent_range:
top_percents.append(top_percent)
num_to_keep = int(ma.floor(len(arch_id_reg_evals) * top_percent * 0.01))
top_percent_arch_id_reg_evals = arch_id_reg_evals[:num_to_keep]
top_percent_arch_id_freezetrain_evals = arch_id_freezetrain_evals[:num_to_keep]
# take the set of arch_ids in each method and find overlap with top archs
set_reg = set([x[0] for x in top_percent_arch_id_reg_evals])
set_ft = set([x[0] for x in top_percent_arch_id_freezetrain_evals])
ft_num_common = len(set_reg.intersection(set_ft))
freezetrain_ratio_common.append(ft_num_common/num_to_keep)
# save raw data for other aggregate plots over experiments
raw_data_dict = {}
raw_data_dict['top_percents'] = top_percents
raw_data_dict['spe_freeze'] = spe_freeze_top_percents
raw_data_dict['freeze_times_avg'] = top_percent_freeze_times_avg
raw_data_dict['freeze_times_std'] = top_percent_freeze_times_std
raw_data_dict['freeze_times_stderr'] = top_percent_freeze_times_stderr
raw_data_dict['freeze_ratio_common'] = freezetrain_ratio_common
savename = os.path.join(out_dir, 'raw_data.yaml')
with open(savename, 'w') as f:
yaml.dump(raw_data_dict, f)
if __name__ == '__main__':
main()