зеркало из https://github.com/microsoft/archai.git
Added analysis script for ranking in nb301.
This commit is contained in:
Родитель
524d174c21
Коммит
d63d6f178f
|
@ -698,7 +698,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_darts_space.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly",
|
||||
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
|
||||
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
|
||||
},
|
||||
|
@ -708,7 +708,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/reports/fear_analysis/analysis_regular_darts_space.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\dt_reg_b96_e20",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_fixed",
|
||||
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
|
||||
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
|
||||
},
|
||||
|
@ -912,6 +912,15 @@
|
|||
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
|
||||
"--out-dir", "F:\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
"name": "Analysis Nasbench301 Ranking",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/reports/fear_analysis/analysis_nasbench301_ranking.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
|
||||
"--out-dir", "F:\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
"name": "CurrentFile",
|
||||
"type": "python",
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
import json
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
from scipy.stats import kendalltau, spearmanr
|
||||
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
def find_train_thresh_epochs(train_acc:List[float], train_thresh:float)->int:
|
||||
for i, t in enumerate(train_acc):
|
||||
if t >= train_thresh:
|
||||
return i
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Nasbench301 Ranking Experiments')
|
||||
parser.add_argument('--nb301-logs-dir', '-d', type=str, help='folder with nasbench301 architecture training logs')
|
||||
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports', help='folder to output reports')
|
||||
args, extra_args = parser.parse_known_args()
|
||||
|
||||
|
||||
all_test_acc = []
|
||||
all_train_acc_end = []
|
||||
all_train_loss_end = []
|
||||
all_val_acc_end = []
|
||||
|
||||
# collect all the json file names in the log dir recursively
|
||||
for root, dir, files in os.walk(args.nb301_logs_dir):
|
||||
for name in tqdm(files):
|
||||
log_name = os.path.join(root, name)
|
||||
with open(log_name, 'r') as f:
|
||||
log_data = json.load(f)
|
||||
test_acc = log_data['test_accuracy']
|
||||
train_acc_end = log_data['learning_curves']['Train/train_accuracy'][-1]
|
||||
train_loss_end = log_data['learning_curves']['Train/train_loss'][-1]
|
||||
val_acc_end = log_data['learning_curves']['Train/val_accuracy'][-1]
|
||||
|
||||
all_test_acc.append(test_acc)
|
||||
all_train_acc_end.append(train_acc_end)
|
||||
all_train_loss_end.append(train_loss_end)
|
||||
all_val_acc_end.append(val_acc_end)
|
||||
|
||||
# negate the training loss for ranking purposes
|
||||
all_train_loss_neg_end = [-x for x in all_train_loss_end]
|
||||
|
||||
train_acc_spe, _ = spearmanr(all_test_acc, all_train_acc_end)
|
||||
train_loss_spe, _ = spearmanr(all_test_acc, all_train_loss_neg_end)
|
||||
val_loss_spe, _ = spearmanr(all_test_acc, all_val_acc_end)
|
||||
|
||||
print(f'Spearman training accuracy end: {train_acc_spe}')
|
||||
print(f'Spearman training loss end: {train_loss_spe}')
|
||||
print(f'Spearman val loss end: {val_loss_spe}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -4,6 +4,16 @@ exp_folder: 'F:\\archai_experiment_reports'
|
|||
darts_cifar10:
|
||||
freezetrain:
|
||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly'
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly'
|
||||
ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly'
|
||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly'
|
||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c4_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c4_ftonly'
|
||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly'
|
||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly'
|
||||
|
||||
>>>>>>> eaaf94d5 (Added analysis script for ranking in nb301.)
|
||||
|
||||
shortreg:
|
||||
dt_reg_b96_e5: 'dt_reg_b96_e5'
|
||||
|
|
Загрузка…
Ссылка в новой задаче