зеркало из https://github.com/microsoft/archai.git
Added analysis script for ranking in nb301.
This commit is contained in:
Родитель
524d174c21
Коммит
d63d6f178f
|
@ -698,7 +698,7 @@
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_darts_space.py",
|
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_darts_space.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly",
|
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly",
|
||||||
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
|
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
|
||||||
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
|
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
|
||||||
},
|
},
|
||||||
|
@ -708,7 +708,7 @@
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${cwd}/scripts/reports/fear_analysis/analysis_regular_darts_space.py",
|
"program": "${cwd}/scripts/reports/fear_analysis/analysis_regular_darts_space.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\dt_reg_b96_e20",
|
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_fixed",
|
||||||
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
|
"--out-dir", "F:\\archai_experiment_reports", "--reg-evals-file",
|
||||||
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
|
"F:\\archai_experiment_reports\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6\\darts_benchmark.yaml"]
|
||||||
},
|
},
|
||||||
|
@ -912,6 +912,15 @@
|
||||||
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
|
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
|
||||||
"--out-dir", "F:\\archai_experiment_reports"]
|
"--out-dir", "F:\\archai_experiment_reports"]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "Analysis Nasbench301 Ranking",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${cwd}/scripts/reports/fear_analysis/analysis_nasbench301_ranking.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
|
||||||
|
"--out-dir", "F:\\archai_experiment_reports"]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "CurrentFile",
|
"name": "CurrentFile",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from scipy.stats import kendalltau, spearmanr
|
||||||
|
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
|
||||||
|
|
||||||
|
def find_train_thresh_epochs(train_acc:List[float], train_thresh:float)->int:
|
||||||
|
for i, t in enumerate(train_acc):
|
||||||
|
if t >= train_thresh:
|
||||||
|
return i
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Nasbench301 Ranking Experiments')
|
||||||
|
parser.add_argument('--nb301-logs-dir', '-d', type=str, help='folder with nasbench301 architecture training logs')
|
||||||
|
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports', help='folder to output reports')
|
||||||
|
args, extra_args = parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
|
all_test_acc = []
|
||||||
|
all_train_acc_end = []
|
||||||
|
all_train_loss_end = []
|
||||||
|
all_val_acc_end = []
|
||||||
|
|
||||||
|
# collect all the json file names in the log dir recursively
|
||||||
|
for root, dir, files in os.walk(args.nb301_logs_dir):
|
||||||
|
for name in tqdm(files):
|
||||||
|
log_name = os.path.join(root, name)
|
||||||
|
with open(log_name, 'r') as f:
|
||||||
|
log_data = json.load(f)
|
||||||
|
test_acc = log_data['test_accuracy']
|
||||||
|
train_acc_end = log_data['learning_curves']['Train/train_accuracy'][-1]
|
||||||
|
train_loss_end = log_data['learning_curves']['Train/train_loss'][-1]
|
||||||
|
val_acc_end = log_data['learning_curves']['Train/val_accuracy'][-1]
|
||||||
|
|
||||||
|
all_test_acc.append(test_acc)
|
||||||
|
all_train_acc_end.append(train_acc_end)
|
||||||
|
all_train_loss_end.append(train_loss_end)
|
||||||
|
all_val_acc_end.append(val_acc_end)
|
||||||
|
|
||||||
|
# negate the training loss for ranking purposes
|
||||||
|
all_train_loss_neg_end = [-x for x in all_train_loss_end]
|
||||||
|
|
||||||
|
train_acc_spe, _ = spearmanr(all_test_acc, all_train_acc_end)
|
||||||
|
train_loss_spe, _ = spearmanr(all_test_acc, all_train_loss_neg_end)
|
||||||
|
val_loss_spe, _ = spearmanr(all_test_acc, all_val_acc_end)
|
||||||
|
|
||||||
|
print(f'Spearman training accuracy end: {train_acc_spe}')
|
||||||
|
print(f'Spearman training loss end: {train_loss_spe}')
|
||||||
|
print(f'Spearman val loss end: {val_loss_spe}')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -4,6 +4,16 @@ exp_folder: 'F:\\archai_experiment_reports'
|
||||||
darts_cifar10:
|
darts_cifar10:
|
||||||
freezetrain:
|
freezetrain:
|
||||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly'
|
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly'
|
||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly'
|
||||||
|
ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly'
|
||||||
|
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly'
|
||||||
|
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c4_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c4_ftonly'
|
||||||
|
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly'
|
||||||
|
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly'
|
||||||
|
|
||||||
|
>>>>>>> eaaf94d5 (Added analysis script for ranking in nb301.)
|
||||||
|
|
||||||
shortreg:
|
shortreg:
|
||||||
dt_reg_b96_e5: 'dt_reg_b96_e5'
|
dt_reg_b96_e5: 'dt_reg_b96_e5'
|
||||||
|
|
Загрузка…
Ссылка в новой задаче