diff --git a/.vscode/launch.json b/.vscode/launch.json index c7812187..a94a6d0d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -384,6 +384,14 @@ "console": "integratedTerminal", "args": ["--algos", "natsbench_regular_eval"] }, + { + "name": "Natsbench-Regular-SSS-Eval-Full", + "type": "python", + "request": "launch", + "program": "${cwd}/scripts/main.py", + "console": "integratedTerminal", + "args": ["--full", "--algos", "natsbench_sss_regular_eval", "--datasets", "cifar10"] + }, { "name": "Nb101-Regular-Eval-Full", "type": "python", @@ -422,7 +430,7 @@ "request": "launch", "program": "${cwd}/scripts/main_proxynas_nb_wrapper.py", "console": "integratedTerminal", - "args": ["--algos", "proxynas_natsbench_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"] + "args": ["--algos", "proxynas_natsbench_space", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"] }, { "name": "Main Proxynas Natsbench SSS Wrapper", @@ -430,7 +438,7 @@ "request": "launch", "program": "${cwd}/scripts/main_proxynas_nb_sss_wrapper.py", "console": "integratedTerminal", - "args": ["--algos", "proxynas_natsbench_sss_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"] + "args": ["--algos", "natsbench_sss_regular_eval", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"] }, { "name": "Main Proxynas Nb101 Wrapper", @@ -438,7 +446,7 @@ "request": "launch", "program": "${cwd}/scripts/main_proxynas_nb101_wrapper.py", "console": "integratedTerminal", - "args": ["--algos", "proxynas_nasbench101_space", "--arch-list-index", "0", "--num-archs", "20", "--top1-acc-threshold", "0.6"] + "args": ["--algos", "proxynas_nasbench101_space", "--arch-list-index", "0", "--num-archs", "50", "--top1-acc-threshold", "0.6"] }, { "name": "Resnet-Toy", @@ -713,6 +721,15 @@ "args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_fb1024_ftlr0.1_fte10_ct256_ftt0.6_scu", "--out-dir", "F:\\archai_experiment_reports"] }, + { + "name": "Analysis Freeze Natsbench SSS Space", + "type": "python", + "request": "launch", + "program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_natsbench_sss.py", + "console": "integratedTerminal", + "args": ["--results-dir", "F:\\archaiphilly\\phillytools\\nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6", + "--out-dir", "F:\\archai_experiment_reports"] + }, { "name": "Analysis Freeze Addon No Cond", "type": "python", diff --git a/archai/algos/natsbench/natsbench_sss_regular_evaluater.py b/archai/algos/natsbench/natsbench_sss_regular_evaluater.py new file mode 100644 index 00000000..0da326c8 --- /dev/null +++ b/archai/algos/natsbench/natsbench_sss_regular_evaluater.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from copy import deepcopy +from typing import Optional +import importlib +import sys +import string +import os + +from overrides import overrides + +import torch +from torch import nn + +from overrides import overrides, EnforceOverrides + +from archai.common.trainer import Trainer +from archai.common.config import Config +from archai.common.common import logger +from archai.datasets import data +from archai.nas.model_desc import ModelDesc +from archai.nas.model_desc_builder import ModelDescBuilder +from archai.nas import nas_utils +from archai.common import ml_utils, utils +from archai.common.metrics import EpochMetrics, Metrics +from archai.nas.model import Model +from archai.common.checkpoint import CheckPoint +from archai.nas.evaluater import Evaluater + +from nats_bench import create +from archai.algos.natsbench.lib.models import get_cell_based_tiny_net + +class NatsbenchSSSRegularEvaluater(Evaluater): + @overrides + def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder, + final_desc_filename=None, full_desc_filename=None)->nn.Module: + # region conf vars + dataset_name = conf_eval['loader']['dataset']['name'] + self.num_classes = conf_eval['loader']['dataset']['n_classes'] + + # if explicitly passed in then don't get from conf + if not final_desc_filename: + final_desc_filename = conf_eval['final_desc_filename'] + arch_index = conf_eval['natsbench']['arch_index'] + + dataroot = utils.full_path(conf_eval['loader']['dataset']['dataroot']) + natsbench_location = os.path.join(dataroot, 'natsbench', conf_eval['natsbench']['natsbench_sss_fast']) + # endregion + + assert arch_index + assert natsbench_location + + return self._model_from_natsbench(arch_index, dataset_name, natsbench_location) + + def _model_from_natsbench(self, arch_index:int, dataset_name:str, natsbench_location:str)->Model: + + # create natsbench api + api = create(natsbench_location, 'tss', fast_mode=True, verbose=True) + + if arch_index >= 32768 or arch_index < 0: + logger.warn(f'architecture id {arch_index} is invalid ') + + supported_datasets = {'cifar10', 'cifar100', 'ImageNet16-120'} + + # force natsbench to use cifar10 archs + # since it doesn't know about other datasets + if dataset_name not in supported_datasets: + dataset_name = 'cifar10' + + config = api.get_net_config(arch_index, dataset_name) + + # fill in the num classes from conf again + # to account for unsupported datasets may + # have number of classes different from cifar10 + config['num_classes'] = self.num_classes + + # network is a nn.Module subclass. the last few modules have names + # lastact, lastact.0, lastact.1, global_pooling, classifier + # which we can freeze train as usual + model = get_cell_based_tiny_net(config) + + return model \ No newline at end of file diff --git a/archai/algos/natsbench/natsbench_sss_regular_experiment_runner.py b/archai/algos/natsbench/natsbench_sss_regular_experiment_runner.py new file mode 100644 index 00000000..ffa4c4e8 --- /dev/null +++ b/archai/algos/natsbench/natsbench_sss_regular_experiment_runner.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Optional, Type +from copy import deepcopy +import os + +from overrides import overrides + +from archai.common.config import Config +from archai.nas import nas_utils +from archai.common import utils +from archai.nas.exp_runner import ExperimentRunner +from archai.nas.arch_trainer import ArchTrainer, TArchTrainer +from archai.nas.model_desc_builder import ModelDescBuilder +from archai.nas.evaluater import EvalResult +from archai.common.common import get_expdir, logger +from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher +from archai.algos.natsbench.natsbench_sss_regular_evaluater import NatsbenchSSSRegularEvaluater + +from nats_bench import create +class NatsbenchSSSRegularExperimentRunner(ExperimentRunner): + """Runs regular training on architectures from natsbench""" + + @overrides + def model_desc_builder(self)->Optional[ModelDescBuilder]: + return None + + @overrides + def trainer_class(self)->TArchTrainer: + return None # no search trainer + + @overrides + def searcher(self)->ManualFreezeSearcher: + return ManualFreezeSearcher() # no searcher basically + + @overrides + def copy_search_to_eval(self)->None: + pass + + @overrides + def run_eval(self, conf_eval:Config)->EvalResult: + + # regular evaluation of the architecture + # where we simply lookup the result + # -------------------------------------- + dataset_name = conf_eval['loader']['dataset']['name'] + + logger.pushd('regular_evaluate') + if dataset_name in {'cifar10', 'cifar100', 'ImageNet16-120'}: + arch_id = conf_eval['natsbench']['arch_index'] + dataroot = utils.full_path(conf_eval['loader']['dataset']['dataroot']) + natsbench_location = os.path.join(dataroot, 'natsbench', conf_eval['natsbench']['natsbench_sss_fast']) + logger.info(natsbench_location) + + api = create(natsbench_location, 'tss', fast_mode=True, verbose=True) + + if arch_id >= 32768 or arch_id < 0: + logger.warn(f'architecture id {arch_id} is invalid ') + + info = api.get_more_info(arch_id, dataset_name, hp=90, is_random=False) + test_accuracy = info['test-accuracy'] + logger.info(f'Regular training top1 test accuracy is {test_accuracy}') + logger.info({'regtrainingtop1': float(test_accuracy)}) + else: + logger.info({'regtrainingtop1': -1}) + logger.popd() + + # regular evaluation of n epochs + evaler = NatsbenchSSSRegularEvaluater() + return evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder()) + diff --git a/confs/algos/natsbench_sss_regular_eval.yaml b/confs/algos/natsbench_sss_regular_eval.yaml new file mode 100644 index 00000000..85333760 --- /dev/null +++ b/confs/algos/natsbench_sss_regular_eval.yaml @@ -0,0 +1,46 @@ +__include__: 'darts.yaml' # just use darts defaults + +nas: + search: + model_desc: + num_edges_to_sample: 2 # number of edges each node will take input from + + eval: + natsbench: + arch_index: 9682 + natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the sss fast mode folder + model_desc: + num_edges_to_sample: 2 + loader: + train_batch: 256 # natsbench uses 256 + aug: '' # random flip and crop are already there in default params + trainer: # matching natsbench paper closely + plotsdir: '' + apex: + _copy: '/common/apex' + aux_weight: '_copy: /nas/eval/model_desc/aux_weight' + drop_path_prob: 0.2 # probability that given edge will be dropped + grad_clip: 5.0 # grads above this value is clipped + l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term + logger_freq: 1000 # after every N updates dump loss and other metrics in logger + title: 'eval_train' + epochs: 5 + batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM + lossfn: + type: 'CrossEntropyLoss' + optimizer: + type: 'sgd' + lr: 0.1 # init learning rate + decay: 5.0e-4 # pytorch default is 0.0 + momentum: 0.9 # pytorch default is 0.0 + nesterov: True # pytorch default is False + decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers + lr_schedule: + type: 'cosine' + min_lr: 0.000 # min learning rate to be set in eta_min param of scheduler + warmup: # increases LR for 0 to current in specified epochs and then hands over to main scheduler + multiplier: 1 + epochs: 0 # 0 disables warmup + + + \ No newline at end of file diff --git a/confs/algos/proxynas_natsbench_sss_space.yaml b/confs/algos/proxynas_natsbench_sss_space.yaml index 0a1ed7cf..84d782c5 100644 --- a/confs/algos/proxynas_natsbench_sss_space.yaml +++ b/confs/algos/proxynas_natsbench_sss_space.yaml @@ -9,7 +9,7 @@ nas: eval: natsbench: arch_index: 23000 - natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the tss fast mode folder + natsbench_sss_fast: 'NATS-sss-v1_0-50262-simple' # folder name in dataroot/natsbench that contains the sss fast mode folder model_desc: num_edges_to_sample: 2 loader: diff --git a/scripts/main.py b/scripts/main.py index 3477b851..3e03228b 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -21,6 +21,7 @@ from archai.algos.proxynas.freeze_nasbench101_experiment_runner import FreezeNas from archai.algos.proxynas.freeze_manual_experiment_runner import ManualFreezeExperimentRunner from archai.algos.naswotrain.naswotrain_natsbench_conditional_experiment_runner import NaswotConditionalNatsbenchExperimentRunner from archai.algos.natsbench.natsbench_regular_experiment_runner import NatsbenchRegularExperimentRunner +from archai.algos.natsbench.natsbench_sss_regular_experiment_runner import NatsbenchSSSRegularExperimentRunner from archai.algos.nasbench101.nasbench101_exp_runner import Nb101RegularExperimentRunner from archai.algos.proxynas.phased_freeze_natsbench_experiment_runner import PhasedFreezeNatsbenchExperimentRunner from archai.algos.proxynas.freezeaddon_nasbench101_experiment_runner import FreezeAddonNasbench101ExperimentRunner @@ -54,6 +55,7 @@ def main(): 'zerocost_conditional_natsbench_space': ZeroCostConditionalNatsbenchExperimentRunner, 'zerocost_natsbench_epochs_space': ZeroCostNatsbenchEpochsExperimentRunner, 'natsbench_regular_eval': NatsbenchRegularExperimentRunner, + 'natsbench_sss_regular_eval': NatsbenchSSSRegularExperimentRunner, 'nb101_regular_eval': Nb101RegularExperimentRunner, 'phased_freezetrain_natsbench_space': PhasedFreezeNatsbenchExperimentRunner, 'freezeaddon_nasbench101_space': FreezeAddonNasbench101ExperimentRunner, @@ -85,6 +87,7 @@ def main(): zerocost_conditional_natsbench_space, zerocost_natsbench_epochs_space, natsbench_regular_eval, + natsbench_sss_regular_eval, nb101_regular_eval, phased_freezetrain_natsbench_space, random_natsbench_tss_far, diff --git a/scripts/main_proxynas_nb_sss_wrapper.py b/scripts/main_proxynas_nb_sss_wrapper.py index 21563754..a98d8a76 100644 --- a/scripts/main_proxynas_nb_sss_wrapper.py +++ b/scripts/main_proxynas_nb_sss_wrapper.py @@ -11,7 +11,8 @@ from archai.common.utils import exec_shell_command def main(): parser = argparse.ArgumentParser(description='Proxynas SSS Wrapper Main') parser.add_argument('--algos', type=str, default=''' - proxynas_natsbench_sss_space, + proxynas_natsbench_sss_space, + natsbench_sss_regular_eval, ''', help='NAS algos to run, separated by comma') parser.add_argument('--top1-acc-threshold', type=float) diff --git a/scripts/reports/fear_analysis/analysis_freeze_natsbench_sss.py b/scripts/reports/fear_analysis/analysis_freeze_natsbench_sss.py new file mode 100644 index 00000000..994b0f65 --- /dev/null +++ b/scripts/reports/fear_analysis/analysis_freeze_natsbench_sss.py @@ -0,0 +1,461 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import sys +import argparse +from typing import Dict, List, Type, Iterator, Tuple +import glob +import os +import pathlib +from collections import OrderedDict, defaultdict +from scipy.stats.stats import _two_sample_transform +import yaml +from inspect import getsourcefile +import seaborn as sns +import math as ma + + + +import plotly.express as px +from plotly.subplots import make_subplots +import plotly.graph_objects as go + +from scipy.stats import kendalltau, spearmanr, sem + +from runstats import Statistics + +#import matplotlib +#matplotlib.use('TkAgg') +import seaborn as sns +import numpy as np +import matplotlib.pyplot as plt +from multiprocessing import Pool +from collections import namedtuple + + +from archai.common import utils +from archai.common.ordereddict_logger import OrderedDictLogger +from archai.common.analysis_utils import epoch_nodes, parse_a_job, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report + +import re + + +def main(): + parser = argparse.ArgumentParser(description='Report creator') + parser.add_argument('--results-dir', '-d', type=str, + default=r'~/logdir/proxynas_test_0001', + help='folder with experiment results') + parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports', + help='folder to output reports') + parser.add_argument('--reg-evals-file', '-r', type=str, default=None, + help='optional yaml file which contains full evaluation \ + of architectures on new datasets not part of natsbench') + args, extra_args = parser.parse_known_args() + + # root dir where all results are stored + results_dir = pathlib.Path(utils.full_path(args.results_dir)) + print(f'results_dir: {results_dir}') + + # extract experiment name which is top level directory + exp_name = results_dir.parts[-1] + + # create results dir for experiment + out_dir = utils.full_path(os.path.join(args.out_dir, exp_name)) + print(f'out_dir: {out_dir}') + os.makedirs(out_dir, exist_ok=True) + + # if optional regular evaluation lookup file is provided + if args.reg_evals_file: + with open(args.reg_evals_file, 'r') as f: + reg_evals_data = yaml.load(f, Loader=yaml.Loader) + + # get list of all structured logs for each job + logs = {} + confs = {} + job_dirs = list(results_dir.iterdir()) + + # # test single job parsing for debugging + # # WARNING: very slow, just use for debugging + # for job_dir in job_dirs: + # a = parse_a_job(job_dir) + + # parallel parsing of yaml logs + num_workers = 8 + with Pool(num_workers) as p: + a = p.map(parse_a_job, job_dirs) + + for storage in a: + for key, val in storage.items(): + logs[key] = val[0] + confs[key] = val[1] + + # examples of accessing logs + # logs['proxynas_blahblah:eval']['naswotrain_evaluate']['eval_arch']['eval_train']['naswithouttraining'] + # logs['proxynas_blahblah:eval']['regular_evaluate']['regtrainingtop1'] + # logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']['9']['val']['top1'] + # last_epoch_key = list(logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1] + # last_val_top1 = logs['proxynas_blahblah:eval']['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][last_epoch_key]['val']['top1'] + # epoch_duration = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']['0']['train']['duration'] + + # remove all search jobs + for key in list(logs.keys()): + if 'search' in key: + logs.pop(key) + + # remove all arch_ids which did not finish + for key in list(logs.keys()): + to_delete = False + + # it might have died early + if 'freeze_evaluate' not in list(logs[key].keys()): + to_delete = True + + if 'regular_evaluate' not in list(logs[key].keys()): + to_delete = True + + if to_delete: + print(f'arch id {key} did not finish. removing from calculations.') + logs.pop(key) + continue + + if 'freeze_training'not in list(logs[key]['freeze_evaluate']['eval_arch'].keys()): + print(f'arch id {key} did not finish. removing from calculations.') + logs.pop(key) + continue + + # freeze train may not have finished + num_freeze_epochs = confs[key]['nas']['eval']['freeze_trainer']['epochs'] + last_freeze_epoch_key = int(list(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1]) + if last_freeze_epoch_key != num_freeze_epochs - 1: + print(f'arch id {key} did not finish. removing from calculations.') + logs.pop(key) + + + all_arch_ids = [] + + all_reg_evals = [] + + all_freeze_evals_last = [] + all_cond_evals_last = [] + + all_freeze_flops_last = [] + all_cond_flops_last = [] + + all_freeze_time_last = [] + all_cond_time_last = [] + all_partial_time_last = [] + + all_freeze_evals = defaultdict(list) + + num_archs_unmet_cond = 0 + + for key in logs.keys(): + if 'eval' in key: + try: + + # if at the end of conditional training train accuracy has not gone above target then don't consider it + # important to get this first + last_cond_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'].keys())[-1] + use_val = confs[key]['nas']['eval']['trainer']['use_val'] + threshold = confs[key]['nas']['eval']['trainer']['top1_acc_threshold'] + if use_val: + val_or_train = 'val' + else: + val_or_train = 'train' + end_cond = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][last_cond_epoch_key][val_or_train]['top1'] + if end_cond < threshold: + num_archs_unmet_cond += 1 + continue + + # regular evaluation + # important to get this first since if it is not + # available for non-benchmark datasets we need to + # remove it from consideration + # -------------------- + if not args.reg_evals_file: + reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1'] + else: + # lookup from the provided file since this dataset is not part of the + # benchmark and hence we have to provide the info separately + if 'natsbench' in list(confs[key]['nas']['eval'].keys()): + arch_id_in_bench = confs[key]['nas']['eval']['natsbench']['arch_index'] + elif 'nasbench101' in list(confs[key]['nas']['eval'].keys()): + arch_id_in_bench = confs[key]['nas']['eval']['nasbench101']['arch_index'] + + if arch_id_in_bench not in list(reg_evals_data.keys()): + # if the dataset used is not part of the standard benchmark some of the architectures + # may not have full evaluation accuracies available. Remove them from consideration. + continue + reg_eval_top1 = reg_evals_data[arch_id_in_bench] + all_reg_evals.append(reg_eval_top1) + + # freeze evaluation + #-------------------- + + # at last epoch + last_freeze_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'].keys())[-1] + freeze_eval_top1 = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][last_freeze_epoch_key][val_or_train]['top1'] + all_freeze_evals_last.append(freeze_eval_top1) + + # collect evals at other epochs + for epoch in range(int(last_freeze_epoch_key)): + all_freeze_evals[epoch].append(logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][str(epoch)][val_or_train]['top1']) + + # collect flops used for conditional training and freeze training + freeze_mega_flops_epoch = logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['total_mega_flops_epoch'] + freeze_mega_flops_used = freeze_mega_flops_epoch * int(last_freeze_epoch_key) + all_freeze_flops_last.append(freeze_mega_flops_used) + + last_cond_epoch_key = list(logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'].keys())[-1] + cond_mega_flops_epoch = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['total_mega_flops_epoch'] + cond_mega_flops_used = cond_mega_flops_epoch * int(last_cond_epoch_key) + all_cond_flops_last.append(cond_mega_flops_used) + + # collect training error at end of conditional training + cond_eval_top1 = logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][last_cond_epoch_key][val_or_train]['top1'] + all_cond_evals_last.append(cond_eval_top1) + + # collect duration for conditional training and freeze training + # NOTE: don't use val_or_train here since we are really interested in the duration of training + freeze_duration = 0.0 + for epoch_key in logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs']: + freeze_duration += logs[key]['freeze_evaluate']['eval_arch']['freeze_training']['eval_train']['epochs'][epoch_key]['train']['duration'] + + cond_duration = 0.0 + for epoch_key in logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs']: + cond_duration += logs[key]['freeze_evaluate']['eval_arch']['conditional_training']['eval_train']['epochs'][epoch_key]['train']['duration'] + + all_freeze_time_last.append(freeze_duration + cond_duration) + all_cond_time_last.append(cond_duration) + all_partial_time_last.append(freeze_duration) + + # record the arch id + # -------------------- + if 'natsbench' in list(confs[key]['nas']['eval'].keys()): + all_arch_ids.append(confs[key]['nas']['eval']['natsbench']['arch_index']) + elif 'nasbench101' in list(confs[key]['nas']['eval'].keys()): + all_arch_ids.append(confs[key]['nas']['eval']['nasbench101']['arch_index']) + + except KeyError as err: + print(f'KeyError {err} not in {key}!') + sys.exit() + + + # Store some key numbers in results.txt + results_savename = os.path.join(out_dir, 'results.txt') + with open(results_savename, 'w') as f: + f.write(f'Number of archs which did not reach condition: {num_archs_unmet_cond} \n') + f.write(f'Total valid archs processed: {len(all_reg_evals)} \n') + + print(f'Number of archs which did not reach condition: {num_archs_unmet_cond}') + print(f'Total valid archs processed: {len(all_reg_evals)}') + + # Sanity check + assert len(all_reg_evals) == len(all_freeze_evals_last) + assert len(all_reg_evals) == len(all_cond_evals_last) + assert len(all_reg_evals) == len(all_cond_time_last) + assert len(all_reg_evals) == len(all_freeze_flops_last) + assert len(all_reg_evals) == len(all_cond_flops_last) + assert len(all_reg_evals) == len(all_freeze_time_last) + assert len(all_reg_evals) == len(all_arch_ids) + + # scatter plot between time to threshold accuracy and regular evaluation + fig = px.scatter(x=all_cond_time_last, y=all_reg_evals, labels={'x': 'Time to reach threshold train accuracy (s)', 'y': 'Final Accuracy'}) + fig.update_layout(font=dict( + size=48, + )) + + savename = os.path.join(out_dir, 'cond_time_vs_final_acc.html') + fig.write_html(savename) + + savename_pdf = os.path.join(out_dir, 'cond_time_vs_final_acc.pdf') + fig.write_image(savename_pdf, engine="kaleido", width=1500, height=1500, scale=1) + + fig.show() + + # histogram of training accuracies + fig = px.histogram(all_reg_evals, labels={'x': 'Test Accuracy', 'y': 'Counts'}) + savename = os.path.join(out_dir, 'distribution_of_reg_evals.html') + fig.write_html(savename) + fig.show() + + # Freeze training results at last epoch + freeze_tau, freeze_p_value = kendalltau(all_reg_evals, all_freeze_evals_last) + freeze_spe, freeze_sp_value = spearmanr(all_reg_evals, all_freeze_evals_last) + print(f'Freeze Kendall Tau score: {freeze_tau:3.03f}, p_value {freeze_p_value:3.03f}') + print(f'Freeze Spearman corr: {freeze_spe:3.03f}, p_value {freeze_sp_value:3.03f}') + with open(results_savename, 'a') as f: + f.write(f'Freeze Kendall Tau score: {freeze_tau:3.03f}, p_value {freeze_p_value:3.03f} \n') + f.write(f'Freeze Spearman corr: {freeze_spe:3.03f}, p_value {freeze_sp_value:3.03f} \n') + + plt.clf() + sns.scatterplot(x=all_reg_evals, y=all_freeze_evals_last) + plt.xlabel('Test top1 at natsbench full training') + plt.ylabel('Freeze training') + plt.grid() + savename = os.path.join(out_dir, 'proxynas_freeze_training_epochs.png') + plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight') + + # Conditional training results at last epoch + cond_tau, cond_p_value = kendalltau(all_reg_evals, all_cond_evals_last) + cond_spe, cond_sp_value = spearmanr(all_reg_evals, all_cond_evals_last) + print(f'Conditional Kendall Tau score: {cond_tau:3.03f}, p_value {cond_p_value:3.03f}') + print(f'Conditional Spearman corr: {cond_spe:3.03f}, p_value {cond_sp_value:3.03f}') + with open(results_savename, 'a') as f: + f.write(f'Conditional Kendall Tau score: {cond_tau:3.03f}, p_value {cond_p_value:3.03f} \n') + f.write(f'Conditional Spearman corr: {cond_spe:3.03f}, p_value {cond_sp_value:3.03f} \n') + + plt.clf() + sns.scatterplot(x=all_reg_evals, y=all_cond_evals_last) + plt.xlabel('Test top1 at natsbench full training') + plt.ylabel('Conditional training') + plt.grid() + savename = os.path.join(out_dir, 'proxynas_cond_training_epochs.png') + plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight') + + # Report average runtime and average flops consumed + total_freeze_flops = np.array(all_freeze_flops_last) + np.array(all_cond_flops_last) + avg_freeze_flops = np.mean(total_freeze_flops) + std_freeze_flops = np.std(total_freeze_flops) + stderr_freeze_flops = std_freeze_flops / np.sqrt(len(all_freeze_flops_last)) + + avg_freeze_runtime = np.mean(np.array(all_freeze_time_last)) + std_freeze_runtime = np.std(np.array(all_freeze_time_last)) + stderr_freeze_runtime = std_freeze_runtime / np.sqrt(len(all_freeze_time_last)) + + avg_cond_runtime = np.mean(np.array(all_cond_time_last)) + std_cond_runtime = np.std(np.array(all_cond_time_last)) + stderr_cond_runtime = std_cond_runtime / np.sqrt(len(all_cond_time_last)) + + avg_partial_runtime = np.mean(np.array(all_partial_time_last)) + std_partial_runtime = np.std(np.array(all_partial_time_last)) + stderr_partial_runtime = std_partial_runtime / np.sqrt(len(all_partial_time_last)) + + with open(results_savename, 'a') as f: + f.write(f'Avg. Freeze MFlops: {avg_freeze_flops:.03f}, std {std_freeze_flops}, stderr {stderr_freeze_flops:.03f} \n') + f.write(f'Avg. Freeze Runtime: {avg_freeze_runtime:.03f}, std {std_freeze_runtime}, stderr {stderr_freeze_runtime:.03f} \n') + f.write(f'Avg. Conditional Runtime: {avg_cond_runtime:.03f}, std {std_cond_runtime}, stderr {stderr_cond_runtime:.03f} \n') + f.write(f'Avg. Partial Runtime: {avg_partial_runtime:.03f}, std {std_partial_runtime}, stderr {stderr_partial_runtime:.03f} \n') + + # Plot freeze training rank correlations if cutoff at various epochs + freeze_taus = {} + freeze_spes = {} + for epoch_key in all_freeze_evals.keys(): + tau, _ = kendalltau(all_reg_evals, all_freeze_evals[epoch_key]) + spe, _ = spearmanr(all_reg_evals, all_freeze_evals[epoch_key]) + freeze_taus[epoch_key] = tau + freeze_spes[epoch_key] = spe + + plt.clf() + for epoch_key in freeze_taus.keys(): + plt.scatter(epoch_key, freeze_taus[epoch_key]) + plt.xlabel('Epochs of freeze training') + plt.ylabel('Kendall Tau') + plt.ylim((-1.0, 1.0)) + plt.grid() + savename = os.path.join(out_dir, 'proxynas_freeze_training_kendall_taus.png') + plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight') + + plt.clf() + for epoch_key in freeze_taus.keys(): + plt.scatter(epoch_key, freeze_spes[epoch_key]) + plt.xlabel('Epochs of freeze training') + plt.ylabel('Spearman Correlation') + plt.ylim((-1.0, 1.0)) + plt.grid() + savename = os.path.join(out_dir, 'proxynas_freeze_training_spearman_corrs.png') + plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight') + + + # Rank correlations at top n percent of architectures + #----------------------------------------------------- + reg_freezelast_naswot_evals = [(all_reg_evals[i], all_freeze_evals_last[i], all_freeze_time_last[i]) for i in range(len(all_reg_evals))] + + # sort in descending order of accuracy of regular evaluation + reg_freezelast_naswot_evals.sort(key=lambda x: x[0], reverse=True) + + top_percent_freeze_times_avg = [] + top_percent_freeze_times_std = [] + top_percent_freeze_times_stderr = [] + + spe_freeze_top_percents = [] + top_percents = [] + top_percent_range = range(2, 101, 2) + for top_percent in top_percent_range: + top_percents.append(top_percent) + num_to_keep = int(ma.floor(len(reg_freezelast_naswot_evals) * top_percent * 0.01)) + top_percent_evals = reg_freezelast_naswot_evals[:num_to_keep] + top_percent_reg = [x[0] for x in top_percent_evals] + top_percent_freeze = [x[1] for x in top_percent_evals] + top_percent_freeze_times = [x[2] for x in top_percent_evals] + + top_percent_freeze_times_avg.append(np.mean(np.array(top_percent_freeze_times))) + top_percent_freeze_times_std.append(np.std(np.array(top_percent_freeze_times))) + top_percent_freeze_times_stderr.append(sem(np.array(top_percent_freeze_times))) + + spe_freeze, _ = spearmanr(top_percent_reg, top_percent_freeze) + spe_freeze_top_percents.append(spe_freeze) + + + plt.clf() + sns.scatterplot(top_percents, spe_freeze_top_percents) + plt.legend(labels=['Freeze Train']) + plt.ylim((-1.0, 1.0)) + plt.xlabel('Top percent of architectures') + plt.ylabel('Spearman Correlation') + plt.grid() + savename = os.path.join(out_dir, f'spe_top_archs.png') + plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight') + + plt.clf() + plt.errorbar(top_percents, top_percent_freeze_times_avg, yerr=np.array(top_percent_freeze_times_std)/2, marker='s', mfc='red', ms=10, mew=4) + plt.xlabel('Top percent of architectures') + plt.ylabel('Avg. time (s)') + plt.yticks(np.arange(0,600, step=50)) + plt.grid() + savename = os.path.join(out_dir, f'freeze_train_duration_top_archs.png') + plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight') + + # how much overlap in top x% of architectures between method and groundtruth + # ---------------------------------------------------------------------------- + arch_id_reg_evals = [(arch_id, reg_eval) for arch_id, reg_eval in zip(all_arch_ids, all_reg_evals)] + arch_id_freezetrain_evals = [(arch_id, freeze_eval) for arch_id, freeze_eval in zip(all_arch_ids, all_freeze_evals_last)] + + arch_id_reg_evals.sort(key=lambda x: x[1], reverse=True) + arch_id_freezetrain_evals.sort(key=lambda x: x[1], reverse=True) + + assert len(arch_id_reg_evals) == len(arch_id_freezetrain_evals) + + top_percents = [] + freezetrain_ratio_common = [] + for top_percent in top_percent_range: + top_percents.append(top_percent) + num_to_keep = int(ma.floor(len(arch_id_reg_evals) * top_percent * 0.01)) + top_percent_arch_id_reg_evals = arch_id_reg_evals[:num_to_keep] + top_percent_arch_id_freezetrain_evals = arch_id_freezetrain_evals[:num_to_keep] + + # take the set of arch_ids in each method and find overlap with top archs + set_reg = set([x[0] for x in top_percent_arch_id_reg_evals]) + set_ft = set([x[0] for x in top_percent_arch_id_freezetrain_evals]) + ft_num_common = len(set_reg.intersection(set_ft)) + freezetrain_ratio_common.append(ft_num_common/num_to_keep) + + + # save raw data for other aggregate plots over experiments + raw_data_dict = {} + raw_data_dict['top_percents'] = top_percents + raw_data_dict['spe_freeze'] = spe_freeze_top_percents + raw_data_dict['freeze_times_avg'] = top_percent_freeze_times_avg + raw_data_dict['freeze_times_std'] = top_percent_freeze_times_std + raw_data_dict['freeze_times_stderr'] = top_percent_freeze_times_stderr + raw_data_dict['freeze_ratio_common'] = freezetrain_ratio_common + + savename = os.path.join(out_dir, 'raw_data.yaml') + with open(savename, 'w') as f: + yaml.dump(raw_data_dict, f) + + + + +if __name__ == '__main__': + main() \ No newline at end of file