Added more plots and figures to analyzing zero-cost proxies.

This commit is contained in:
Debadeepta Dey 2022-01-05 20:59:18 -08:00 коммит произвёл Gustavo Rosa
Родитель 925232f1fa
Коммит 5d6598fe38
4 изменённых файлов: 113 добавлений и 17 удалений

12
.vscode/launch.json поставляемый
Просмотреть файл

@ -753,7 +753,7 @@
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_create_darts_space_benchmark.py",
"console": "integratedTerminal",
"args": ["--results-dir", "/home/dedey/archaiphilly/amlt/darts_constant_random_darcyflow",
"args": ["--results-dir", "/home/dedey/archaiphilly/amlt/darts_constant_random_synthetic_cifar10",
"--out-dir", "/home/dedey/archai_experiment_reports"]
},
{
@ -873,12 +873,12 @@
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_natsbench_zerocost.py",
"console": "integratedTerminal",
"args": ["--results-dir", "/home/dedey/archaiphilly/amlt/zc_synthetic_cifar10",
"args": ["--results-dir", "/home/dedey/archaiphilly/amlt/natsbench_constant_random_zerocost_scifar100",
"--out-dir", "/home/dedey/archai_experiment_reports",
"--reg-evals-file",
"/home/dedey/archai_experiment_reports/nb_reg_b256_e200_sc10/arch_id_test_accuracy.yaml",
"/home/dedey/archai_experiment_reports/natsbench_constant_random_scifar100/arch_id_test_accuracy.yaml",
"--params-flops-file",
"/home/dedey/archai_experiment_reports/nb_reg_b256_e200_sc10/arch_id_params_flops.yaml"
"/home/dedey/archai_experiment_reports/natsbench_constant_random_scifar100/arch_id_params_flops.yaml"
]
},
{
@ -887,8 +887,8 @@
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_natsbench_zerocost_epochs.py",
"console": "integratedTerminal",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\zc_eachepoch_c10",
"--out-dir", "F:\\archai_experiment_reports"]
"args": ["--results-dir", "/home/dedey/archaiphilly/amlt/zc_eachepoch_c10",
"--out-dir", "/home/dedey/archai_experiment_reports"]
},
{
"name": "Analysis Zero Cost Conditional Natsbench Experiments",

Просмотреть файл

@ -114,10 +114,22 @@ def main():
print(f'KeyError {err} not in {key}!')
sys.exit()
savename = os.path.join(out_dir, 'darts_benchmark.yaml')
# save accuracies
savename = os.path.join(out_dir, 'arch_id_test_accuracy.yaml')
with open(savename, 'w') as f:
yaml.dump(archid_testacc, f)
# save params flops
arch_id_params_flops = dict()
savename = os.path.join(out_dir, 'arch_id_params_flops.yaml')
for archid in archid_params.keys():
num_params = archid_params[archid]
num_flops = archid_flops[archid]
arch_id_params_flops[archid] = {'params': num_params, 'flops': num_flops}
with open(savename, 'w') as f:
yaml.dump(arch_id_params_flops, f)
# plot test accuracy vs. number of params
# to see how the distribution looks
testaccs = []
@ -132,14 +144,17 @@ def main():
flops.append(num_flops)
fig = go.Figure()
fig.add_trace(go.Scatter(x=params, y=testaccs, mode='markers'))
fig.update_layout(title_text="Test accuracy vs. number of params on Darts Space Samples",
xaxis_title="Params",
yaxis_title="Test Accuracy")
fig.add_trace(go.Scatter(x=testaccs, y=params, mode='markers'))
fig.update_layout(xaxis_title="Test Accuracy",
yaxis_title="Parameters")
fig.update_layout(font=dict(size=36)) # font size
fig.update_traces(marker=dict(size=20)) # marker size
savename_html = os.path.join(out_dir, 'darts_space_params_vs_test_acc.html')
fig.write_html(savename_html)
fig.show()
savename_png = os.path.join(out_dir, 'darts_space_params_vs_test_acc.png')
fig.write_image(savename_png, width=1500, height=1500, scale=1)
# compute spearman correlation of #params vs. test accuracy
param_spe, param_sp_value = spearmanr(testaccs, params)

Просмотреть файл

@ -16,7 +16,7 @@ import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scipy.stats import kendalltau, spearmanr
from scipy.stats import kendalltau, spearmanr, pearsonr
from runstats import Statistics
@ -27,6 +27,7 @@ import numpy as np
import matplotlib.pyplot as plt
from multiprocessing import Pool
from collections import namedtuple
from itertools import product
from archai.common import utils
@ -180,6 +181,9 @@ def main():
for measure in ZEROCOST_MEASURES:
assert len(all_reg_evals) == len(all_zerocost_init_evals[measure])
assert len(all_reg_evals) == len(all_arch_ids)
# if params flops is present compute spearman wrt params and flops
# also compute scatter plots for params vs synflow
if params_flops_data:
assert len(all_reg_evals) == len(all_params_evals)
assert len(all_reg_evals) == len(all_flops_evals)
@ -198,7 +202,46 @@ def main():
print(f'Spearman wrt params: {spe_params} \n')
print(f'Spearman wrt flops: {spe_flops} \n')
# Store some key numbers in results.txt
# scatter params vs. synflow
fig = go.Figure()
fig.add_trace(go.Scatter(x=all_params_evals, y=all_zerocost_init_evals['synflow'], mode='markers'))
fig.update_layout(xaxis_title="Parameters",
yaxis_title="Synflow")
fig.update_layout(font=dict(size=36)) # font size
fig.update_traces(marker=dict(size=20)) # marker size
savename_html = os.path.join(out_dir, f'params_vs_synflow.html')
savename_png = os.path.join(out_dir, f'params_vs_synflow.png')
fig.write_html(savename_html)
fig.write_image(savename_png, width=1500, height=1500, scale=1)
# create heatmap of all pairs of proxies along with params, flops, gt
ZEROCOST_MEASURES_PF = ZEROCOST_MEASURES + ['params', 'flops', 'gt']
all_zerocost_init_evals['params'] = all_params_evals
all_zerocost_init_evals['flops'] = all_flops_evals
all_zerocost_init_evals['gt'] = all_reg_evals
hm = np.zeros((len(ZEROCOST_MEASURES_PF), len(ZEROCOST_MEASURES_PF)))
for i, m1 in enumerate(ZEROCOST_MEASURES_PF):
for j, m2 in enumerate(ZEROCOST_MEASURES_PF):
# sometimes jacob_cov has a nan here and there. ignore those.
m1_scores = all_zerocost_init_evals[m1]
m2_scores = all_zerocost_init_evals[m2]
valid_scores = [x for x in zip(m1_scores, m2_scores) if not ma.isnan(x[0]) and not ma.isnan(x[1])]
m1_valid = [x[0] for x in valid_scores]
m2_valid = [x[1] for x in valid_scores]
spe, _ = spearmanr(m1_valid, m2_valid)
hm[i][j] = spe
fig = px.imshow(hm, text_auto="0.1f", x=ZEROCOST_MEASURES_PF, y=ZEROCOST_MEASURES_PF)
savename_html = os.path.join(out_dir, f'all_pairs_zc_spe.html')
savename_png = os.path.join(out_dir, f'all_pairs_zc_spe.png')
fig.write_html(savename_html)
fig.write_image(savename_png, width=1500, height=1500, scale=1)
results_savename = os.path.join(out_dir, 'results.txt')
with open(results_savename, 'a') as f:
f.write(f'Total valid archs processed: {len(all_reg_evals)} \n')
@ -233,6 +276,22 @@ def main():
assert len(top_percent_reg) == len(top_percent_init)
spe_init, _ = spearmanr(top_percent_reg, top_percent_init)
# for the entire bin of archs scatter plot
# groundtruth accuracy (x-axis) vs. measure and save
if top_percent == 100:
fig = go.Figure()
fig.add_trace(go.Scatter(x=top_percent_reg, y=top_percent_init, mode='markers'))
fig.update_layout(xaxis_title="Test Accuracy",
yaxis_title=f"{measure}")
fig.update_layout(font=dict(size=36)) # font size
fig.update_traces(marker=dict(size=20)) # marker size
savename_html = os.path.join(out_dir, f'test_accuracy_vs_{measure}.html')
savename_png = os.path.join(out_dir, f'test_accuracy_vs_{measure}.png')
fig.write_html(savename_html)
fig.write_image(savename_png, width=1500, height=1500, scale=1)
#fig.show()
spe_top_percents_init[measure].append(spe_init)
spe_top_percents_init['top_percents'] = top_percents

Просмотреть файл

@ -78,7 +78,7 @@ def main():
# a = parse_a_job(job_dir)
# parallel parsing of yaml logs
num_workers = 12
num_workers = 64
with Pool(num_workers) as p:
a = p.map(parse_a_job, job_dirs)
@ -188,7 +188,7 @@ def main():
epoch_num_spe[epoch_num] = spe
measures_res[measure] = epoch_num_spe
# plot
# plot static image first
fig = go.Figure()
for measure in ZEROCOST_MEASURES:
xs = [key for key in measures_res[measure].keys()]
@ -206,7 +206,29 @@ def main():
savename_pdf = os.path.join(out_dir, 'zerocost_epochs_vs_spe.pdf')
fig.write_image(savename_pdf, engine="kaleido", width=1500, height=750, scale=1)
fig.show()
#fig.show()
# plot images for animation
# using a particular measure to get epoch numbers
# assuming all measures have same total epochs
for epoch_num in measures_res['synflow'].keys():
fig_anim = go.Figure()
for measure in ZEROCOST_MEASURES:
xs = [i for i in range(epoch_num)]
ys = [measures_res[measure][e_num] for e_num in xs]
fig_anim.add_trace(go.Scatter(x=xs, y=ys, name=measure, mode='markers+lines', showlegend=True))
fig_anim.update_layout(xaxis_title='Epochs',
yaxis_title='Spearman Corr.')
fig_anim.update_layout(font=dict(size=48))
fig_anim.update_xaxes(range=[0,200])
fig_anim.update_yaxes(range=[-1,1])
#fig_anim.update_xaxes(type="log")
savename_png = os.path.join(out_dir, f'epoch_{epoch_num:06d}.png')
fig_anim.write_image(savename_png, width=1500, height=1000, scale=1)