Added script to analyse DARTS space regular evaluation.

This commit is contained in:
Debadeepta Dey 2021-09-10 16:01:03 -07:00 коммит произвёл Gustavo Rosa
Родитель cc1f417171
Коммит 1affe0cdc8
2 изменённых файлов: 293 добавлений и 1 удалений

12
.vscode/launch.json поставляемый
Просмотреть файл

@ -698,7 +698,17 @@
"request": "launch", "request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_darts_space.py", "program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_darts_space.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6", "args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly",
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
},
{
"name": "Analysis Regular Darts Space",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_regular_darts_space.py",
"console": "integratedTerminal",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\dt_reg_b96_e20",
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file", "--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"] "F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
}, },

Просмотреть файл

@ -0,0 +1,282 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
from typing import Dict, List, Type, Iterator, Tuple
import glob
import os
import pathlib
from collections import OrderedDict, defaultdict
from scipy.stats.stats import _two_sample_transform
import yaml
from inspect import getsourcefile
import re
from tqdm import tqdm
import seaborn as sns
import math as ma
from scipy.stats import kendalltau, spearmanr, sem
from runstats import Statistics
#import matplotlib
#matplotlib.use('TkAgg')
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import multiprocessing
from multiprocessing import Pool
from archai.common import utils
from archai.common.ordereddict_logger import OrderedDictLogger
from archai.common.analysis_utils import epoch_nodes, parse_a_job, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
import re
def main():
parser = argparse.ArgumentParser(description='Report creator')
parser.add_argument('--results-dir', '-d', type=str,
default=r'~/logdir/proxynas_test_0001',
help='folder with experiment results from pt')
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports',
help='folder to output reports')
parser.add_argument('--reg-evals-file', '-r', type=str, default=None,
help='optional yaml file which contains full evaluation \
of architectures on new datasets not part of natsbench')
args, extra_args = parser.parse_known_args()
# root dir where all results are stored
results_dir = pathlib.Path(utils.full_path(args.results_dir))
print(f'results_dir: {results_dir}')
# extract experiment name which is top level directory
exp_name = results_dir.parts[-1]
# create results dir for experiment
out_dir = utils.full_path(os.path.join(args.out_dir, exp_name))
print(f'out_dir: {out_dir}')
os.makedirs(out_dir, exist_ok=True)
# if optional regular evaluation lookup file is provided
if args.reg_evals_file:
with open(args.reg_evals_file, 'r') as f:
reg_evals_data = yaml.load(f, Loader=yaml.Loader)
# get list of all structured logs for each job
logs = {}
confs = {}
job_dirs = list(results_dir.iterdir())
# # test single job parsing for debugging
# # WARNING: very slow, just use for debugging
# for job_dir in job_dirs:
# a = parse_a_job(job_dir)
# parallel parsing of yaml logs
num_workers = 6
with Pool(num_workers) as p:
a = p.map(parse_a_job, job_dirs)
for storage in a:
for key, val in storage.items():
logs[key] = val[0]
confs[key] = val[1]
# examples of accessing logs
# best_test = logs[key]['eval_arch']['eval_train']['best_test']['top1']
# best_train = logs[key]['eval_arch']['eval_train']['best_train']['top1']
# remove all search jobs
for key in list(logs.keys()):
if 'search' in key:
logs.pop(key)
# sometimes a job has not even written logs to yaml
for key in list(logs.keys()):
if not logs[key]:
print(f'arch id {key} did not finish. removing from calculations.')
logs.pop(key)
# remove all arch_ids which did not finish
for key in list(logs.keys()):
to_delete = False
# it might have died early
if 'regular_evaluate' not in list(logs[key].keys()):
to_delete = True
if to_delete:
print(f'arch id {key} did not finish. removing from calculations.')
logs.pop(key)
continue
if 'best_train' not in list(logs[key]['regular_evaluate']['eval_arch']['eval_train'].keys()):
print(f'arch id {key} did not finish. removing from calculations.')
logs.pop(key)
continue
all_arch_ids = []
all_reg_evals = []
all_short_reg_evals = []
all_short_reg_time = []
for key in logs.keys():
if 'eval' in key:
try:
# regular evaluation
# important to get this first since if it is not
# available for in the created darts benchmark we need to
# remove it from consideration
# --------------------
# lookup from the provided darts benchmark file
arch_id = confs[key]['nas']['eval']['dartsspace']['arch_index']
if arch_id not in list(reg_evals_data.keys()):
# if the dataset used is not part of the standard benchmark some of the architectures
# may not have full evaluation accuracies available. Remove them from consideration.
continue
reg_eval_top1 = reg_evals_data[arch_id]
all_reg_evals.append(reg_eval_top1)
best_train = logs[key]['regular_evaluate']['eval_arch']['eval_train']['best_train']['top1']
all_short_reg_evals.append(best_train)
# collect duration
duration = 0.0
for epoch_key in logs[key]['regular_evaluate']['eval_arch']['eval_train']['epochs']:
duration += logs[key]['regular_evaluate']['eval_arch']['eval_train']['epochs'][epoch_key]['train']['duration']
all_short_reg_time.append(duration)
# record the arch id
# --------------------
all_arch_ids.append(arch_id)
except KeyError as err:
print(f'KeyError {err} not in {key}!')
# Store some key numbers in results.txt
results_savename = os.path.join(out_dir, 'results.txt')
# Sanity check
assert len(all_reg_evals) == len(all_short_reg_evals)
assert len(all_reg_evals) == len(all_short_reg_time)
# Shortened training results
short_reg_tau, short_reg_p_value = kendalltau(all_reg_evals, all_short_reg_evals)
short_reg_spe, short_reg_sp_value = spearmanr(all_reg_evals, all_short_reg_evals)
print(f'Short reg Kendall Tau score: {short_reg_tau:3.03f}, p_value {short_reg_p_value:3.03f}')
print(f'Short reg Spearman corr: {short_reg_spe:3.03f}, p_value {short_reg_sp_value:3.03f}')
print(f'Valid archs: {len(all_reg_evals)}')
with open(results_savename, 'w') as f:
f.write(f'Short reg Kendall Tau score: {short_reg_tau:3.03f}, p_value {short_reg_p_value:3.03f} \n')
f.write(f'Short reg Spearman corr: {short_reg_spe:3.03f}, p_value {short_reg_sp_value:3.03f} \n')
plt.clf()
sns.scatterplot(x=all_reg_evals, y=all_short_reg_evals)
plt.xlabel('Test top1 at natsbench full training')
plt.ylabel('Regular training with less epochs')
plt.grid()
savename = os.path.join(out_dir, 'shortened_training.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
# Rank correlations at top n percent of architectures
reg_shortreg_evals = [(all_reg_evals[i], all_short_reg_evals[i], all_short_reg_time[i]) for i in range(len(all_reg_evals))]
# sort in descending order of accuracy of regular evaluation
reg_shortreg_evals.sort(key=lambda x: x[0], reverse=True)
top_percent_shortreg_times_avg = []
top_percent_shortreg_times_std = []
top_percent_shortreg_times_stderr = []
spe_shortreg_top_percents = []
top_percents = []
top_percent_range = range(2, 101, 2)
for top_percent in top_percent_range:
top_percents.append(top_percent)
num_to_keep = int(ma.floor(len(reg_shortreg_evals) * top_percent * 0.01))
top_percent_evals = reg_shortreg_evals[:num_to_keep]
top_percent_reg = [x[0] for x in top_percent_evals]
top_percent_shortreg = [x[1] for x in top_percent_evals]
top_percent_shortreg_times = [x[2] for x in top_percent_evals]
top_percent_shortreg_times_avg.append(np.mean(np.array(top_percent_shortreg_times)))
top_percent_shortreg_times_std.append(np.std(np.array(top_percent_shortreg_times)))
top_percent_shortreg_times_stderr.append(sem(np.array(top_percent_shortreg_times)))
spe_shortreg, _ = spearmanr(top_percent_reg, top_percent_shortreg)
spe_shortreg_top_percents.append(spe_shortreg)
plt.clf()
sns.scatterplot(top_percents, spe_shortreg_top_percents)
plt.legend(labels=['Shortened Regular Training'])
plt.ylim((-1.0, 1.0))
plt.xlim((0,100))
plt.xlabel('Top percent of architectures')
plt.ylabel('Spearman Correlation')
plt.grid()
savename = os.path.join(out_dir, f'spe_top_archs.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
plt.clf()
plt.errorbar(top_percents, top_percent_shortreg_times_avg, yerr=np.array(top_percent_shortreg_times_std)/2, marker='s', mfc='red', ms=10, mew=4)
plt.xlabel('Top percent of architectures')
plt.ylabel('Avg. time (s)')
plt.yticks(np.arange(0,600, step=50))
plt.grid()
savename = os.path.join(out_dir, f'shortreg_train_duration_top_archs.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
# time taken
avg_shortreg_runtime = np.mean(np.array(all_short_reg_time))
stderr_shortreg_runtime = np.std(np.array(all_short_reg_time)) / np.sqrt(len(all_short_reg_time))
with open(results_savename, 'a') as f:
f.write(f'Avg. Shortened Training Runtime: {avg_shortreg_runtime:.03f}, stderr {stderr_shortreg_runtime:.03f} \n')
# how much overlap in top x% of architectures between method and groundtruth
# ----------------------------------------------------------------------------
arch_id_reg_evals = [(arch_id, reg_eval) for arch_id, reg_eval in zip(all_arch_ids, all_reg_evals)]
arch_id_shortreg_evals = [(arch_id, shortreg_eval) for arch_id, shortreg_eval in zip(all_arch_ids, all_short_reg_evals)]
arch_id_reg_evals.sort(key=lambda x: x[1], reverse=True)
arch_id_shortreg_evals.sort(key=lambda x: x[1], reverse=True)
assert len(arch_id_reg_evals) == len(arch_id_shortreg_evals)
top_percents = []
shortreg_ratio_common = []
for top_percent in top_percent_range:
top_percents.append(top_percent)
num_to_keep = int(ma.floor(len(arch_id_reg_evals) * top_percent * 0.01))
top_percent_arch_id_reg_evals = arch_id_reg_evals[:num_to_keep]
top_percent_arch_id_shortreg_evals = arch_id_shortreg_evals[:num_to_keep]
# take the set of arch_ids in each method and find overlap with top archs
set_reg = set([x[0] for x in top_percent_arch_id_reg_evals])
set_ft = set([x[0] for x in top_percent_arch_id_shortreg_evals])
ft_num_common = len(set_reg.intersection(set_ft))
shortreg_ratio_common.append(ft_num_common/num_to_keep)
# save raw data for other aggregate plots over experiments
raw_data_dict = {}
raw_data_dict['top_percents'] = top_percents
raw_data_dict['spe_shortreg'] = spe_shortreg_top_percents
raw_data_dict['shortreg_times_avg'] = top_percent_shortreg_times_avg
raw_data_dict['shortreg_times_std'] = top_percent_shortreg_times_std
raw_data_dict['shortreg_times_stderr'] = top_percent_shortreg_times_stderr
raw_data_dict['shortreg_ratio_common'] = shortreg_ratio_common
savename = os.path.join(out_dir, 'raw_data.yaml')
with open(savename, 'w') as f:
yaml.dump(raw_data_dict, f)
if __name__ == '__main__':
main()