Added analysis script for ranking in nb301.

This commit is contained in:
Debadeepta Dey 2021-10-15 20:01:46 -07:00 коммит произвёл Gustavo Rosa
Родитель 524d174c21
Коммит d63d6f178f
3 изменённых файлов: 79 добавлений и 2 удалений

13
.vscode/launch.json поставляемый
Просмотреть файл

@ -698,7 +698,7 @@
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_darts_space.py",
"console": "integratedTerminal",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly",
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
},
@ -708,7 +708,7 @@
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_regular_darts_space.py",
"console": "integratedTerminal",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\dt_reg_b96_e20",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_fixed",
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
},
@ -912,6 +912,15 @@
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
"--out-dir", "F:\\archai_experiment_reports"]
},
{
"name": "Analysis Nasbench301 Ranking",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_nasbench301_ranking.py",
"console": "integratedTerminal",
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
"--out-dir", "F:\\archai_experiment_reports"]
},
{
"name": "CurrentFile",
"type": "python",

Просмотреть файл

@ -0,0 +1,58 @@
import json
import argparse
import os
from typing import List
from tqdm import tqdm
from scipy.stats import kendalltau, spearmanr
import plotly.graph_objects as go
def find_train_thresh_epochs(train_acc:List[float], train_thresh:float)->int:
for i, t in enumerate(train_acc):
if t >= train_thresh:
return i
def main():
parser = argparse.ArgumentParser(description='Nasbench301 Ranking Experiments')
parser.add_argument('--nb301-logs-dir', '-d', type=str, help='folder with nasbench301 architecture training logs')
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports', help='folder to output reports')
args, extra_args = parser.parse_known_args()
all_test_acc = []
all_train_acc_end = []
all_train_loss_end = []
all_val_acc_end = []
# collect all the json file names in the log dir recursively
for root, dir, files in os.walk(args.nb301_logs_dir):
for name in tqdm(files):
log_name = os.path.join(root, name)
with open(log_name, 'r') as f:
log_data = json.load(f)
test_acc = log_data['test_accuracy']
train_acc_end = log_data['learning_curves']['Train/train_accuracy'][-1]
train_loss_end = log_data['learning_curves']['Train/train_loss'][-1]
val_acc_end = log_data['learning_curves']['Train/val_accuracy'][-1]
all_test_acc.append(test_acc)
all_train_acc_end.append(train_acc_end)
all_train_loss_end.append(train_loss_end)
all_val_acc_end.append(val_acc_end)
# negate the training loss for ranking purposes
all_train_loss_neg_end = [-x for x in all_train_loss_end]
train_acc_spe, _ = spearmanr(all_test_acc, all_train_acc_end)
train_loss_spe, _ = spearmanr(all_test_acc, all_train_loss_neg_end)
val_loss_spe, _ = spearmanr(all_test_acc, all_val_acc_end)
print(f'Spearman training accuracy end: {train_acc_spe}')
print(f'Spearman training loss end: {train_loss_spe}')
print(f'Spearman val loss end: {val_loss_spe}')
if __name__ == '__main__':
main()

Просмотреть файл

@ -4,6 +4,16 @@ exp_folder: 'F:\\archai_experiment_reports'
darts_cifar10:
freezetrain:
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly'
<<<<<<< HEAD
=======
ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly'
ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly'
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly'
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c4_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c4_ftonly'
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly'
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly'
>>>>>>> eaaf94d5 (Added analysis script for ranking in nb301.)
shortreg:
dt_reg_b96_e5: 'dt_reg_b96_e5'