More progress on fear simulation on nb301.

This commit is contained in:
Debadeepta Dey 2021-10-19 20:48:57 -07:00 коммит произвёл Gustavo Rosa
Родитель 41c2fe018f
Коммит bcfec6c4ac
2 изменённых файлов: 19 добавлений и 10 удалений

4
.vscode/launch.json поставляемый
Просмотреть файл

@ -928,7 +928,7 @@
"program": "${cwd}/scripts/reports/fear_analysis/simulate_fear_on_nb301.py",
"console": "integratedTerminal",
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
"--out-dir", "F:\\archai_experiment_reports"]
"--out-dir", "F:\\archai_experiment_reports", "--scorer", "train_accuracy"]
},
{
"name": "Analysis Simulate FEAR on Nasbench301 Toy",
@ -937,7 +937,7 @@
"program": "${cwd}/scripts/reports/fear_analysis/simulate_fear_on_nb301.py",
"console": "integratedTerminal",
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_toy",
"--out-dir", "F:\\archai_experiment_reports"]
"--out-dir", "F:\\archai_experiment_reports", "--scorer", "train_accuracy"]
},
{
"name": "CurrentFile",

Просмотреть файл

@ -1,4 +1,5 @@
from collections import defaultdict
from enum import Enum
import json
import argparse
import os
@ -11,6 +12,9 @@ import plotly.graph_objects as go
from scipy.stats import kendalltau, spearmanr, sem
import statistics
SCORERS = {'train_accuracy', 'train_loss', 'train_cross_entropy', 'val_accuracy'}
def plot_spearman_top_percents(results:Dict[str, list],
plotly_fig_handle,
legend_text:str,
@ -91,8 +95,15 @@ def main():
parser = argparse.ArgumentParser(description='Nasbench301 time to threshold vs. test accuracy')
parser.add_argument('--nb301-logs-dir', '-d', type=str, help='folder with nasbench301 architecture training logs')
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports', help='folder to output reports')
parser.add_argument('--scorer', '-s', type=str, default='train_accuracy',
help='one of train_accuracy, train_loss, train_cross_entropy, val_accuracy')
args, extra_args = parser.parse_known_args()
if args.scorer not in SCORERS:
raise argparse.ArgumentError
scorer_key = "Train/" + args.scorer
# TODO: make these into cmd line arguments
train_thresh = 60.0
post_thresh_epochs = 10
@ -109,10 +120,10 @@ def main():
log_name = os.path.join(root, name)
with open(log_name, 'r') as f:
log_data = json.load(f)
num_epochs = len(log_data['learning_curves']['Train/train_accuracy'])
num_epochs = len(log_data['learning_curves'][scorer_key])
test_acc = log_data['test_accuracy']
per_epoch_time = log_data['runtime'] / num_epochs
num_epochs_to_thresh = find_train_thresh_epochs(log_data['learning_curves']['Train/train_accuracy'],
num_epochs_to_thresh = find_train_thresh_epochs(log_data['learning_curves'][scorer_key],
train_thresh)
# many weak architectures will never reach threshold
if not num_epochs_to_thresh:
@ -120,7 +131,7 @@ def main():
simulated_stage2_epoch = num_epochs_to_thresh + post_thresh_epochs
fear_time = per_epoch_time * simulated_stage2_epoch
try:
train_acc_stage2 = log_data['learning_curves']['Train/train_accuracy'][simulated_stage2_epoch]
train_acc_stage2 = log_data['learning_curves'][scorer_key][simulated_stage2_epoch]
except:
continue
@ -130,15 +141,11 @@ def main():
# get training acc at all epochs for regular
# evaluation baseline
for epoch_num, train_acc in enumerate(log_data['learning_curves']['Train/train_accuracy']):
for epoch_num, train_acc in enumerate(log_data['learning_curves'][scorer_key]):
all_reg_train_acc[epoch_num].append(train_acc)
all_reg_train_time_per_epoch[epoch_num].append((epoch_num + 1) * per_epoch_time)
fear_train_acc_spe, _ = spearmanr(all_test_acc, all_fear_end_acc)
print(f'FEAR Spearman training accuracy: {fear_train_acc_spe}')
print(f'FEAR avg time: {statistics.mean(all_fear_time)}')
spes_train_acc_vs_epoch = {}
avg_time_train_acc_vs_epoch = {}
for epoch_num, train_accs_epoch in all_reg_train_acc.items():
@ -166,6 +173,8 @@ def main():
for epoch_num in all_reg_train_acc.keys():
all_reg = all_reg_train_acc[epoch_num]
all_reg_times = all_reg_train_time_per_epoch[epoch_num]
if len(all_test_acc) != len(all_reg):
continue
reg_results[epoch_num] = top_buckets_spearmans(all_reg_evals=all_test_acc,
all_proxy_evals=all_reg,
all_proxy_times=all_reg_times)