зеркало из https://github.com/microsoft/archai.git
Added ratio of common architectures for ranking.
This commit is contained in:
Родитель
8511d25699
Коммит
c6396d3ec6
|
@ -27,3 +27,19 @@ python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\User
|
||||||
|
|
||||||
python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\ft_fb256_ftlr0.1_fte5_ct256_ftt0.6 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\ft_fb256_ftlr0.1_fte5_ct256_ftt0.6 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\ft_fb256_ftlr0.1_fte10_ct256_ftt0.6 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\ft_fb256_ftlr0.1_fte10_ct256_ftt0.6 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
|
||||||
|
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e01 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e02 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e04 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e06 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e08 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e10 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e01 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e02 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e04 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e06 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e08 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e10 --out-dir C:\\Users\\dedey\\archai_experiment_reports
|
||||||
|
|
||||||
|
|
|
@ -31,66 +31,11 @@ from collections import namedtuple
|
||||||
|
|
||||||
from archai.common import utils
|
from archai.common import utils
|
||||||
from archai.common.ordereddict_logger import OrderedDictLogger
|
from archai.common.ordereddict_logger import OrderedDictLogger
|
||||||
from analysis_utils import epoch_nodes, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
|
from analysis_utils import epoch_nodes, parse_a_job, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def find_valid_log(subdir:str)->str:
|
|
||||||
# originally log should be in base folder of eval or search
|
|
||||||
logs_filepath_og = os.path.join(str(subdir), 'log.yaml')
|
|
||||||
if os.path.isfile(logs_filepath_og):
|
|
||||||
return logs_filepath_og
|
|
||||||
else:
|
|
||||||
# look in the 'dist' folder for any yaml file
|
|
||||||
dist_folder = os.path.join(str(subdir), 'dist')
|
|
||||||
for f in os.listdir(dist_folder):
|
|
||||||
if f.endswith(".yaml"):
|
|
||||||
return os.path.join(dist_folder, f)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_a_job(job_dir:str)->Dict:
|
|
||||||
if job_dir.is_dir():
|
|
||||||
|
|
||||||
storage = {}
|
|
||||||
for subdir in job_dir.iterdir():
|
|
||||||
if not subdir.is_dir():
|
|
||||||
continue
|
|
||||||
# currently we expect that each job was ExperimentRunner job which should have
|
|
||||||
# _search or _eval folders
|
|
||||||
if subdir.stem.endswith('_search'):
|
|
||||||
sub_job = 'search'
|
|
||||||
elif subdir.stem.endswith('_eval'):
|
|
||||||
sub_job = 'eval'
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f'Sub directory "{subdir}" in job "{job_dir}" must '
|
|
||||||
'end with either _search or _eval which '
|
|
||||||
'should be the case if ExperimentRunner was used.')
|
|
||||||
|
|
||||||
logs_filepath = find_valid_log(subdir)
|
|
||||||
# if no valid logfile found, ignore this job as it probably
|
|
||||||
# didn't finish or errored out or is yet to run
|
|
||||||
if not logs_filepath:
|
|
||||||
continue
|
|
||||||
|
|
||||||
config_used_filepath = os.path.join(subdir, 'config_used.yaml')
|
|
||||||
|
|
||||||
if os.path.isfile(logs_filepath):
|
|
||||||
fix_yaml(logs_filepath)
|
|
||||||
key = job_dir.name + subdir.name + ':' + sub_job
|
|
||||||
# parse log
|
|
||||||
with open(logs_filepath, 'r') as f:
|
|
||||||
data = yaml.load(f, Loader=yaml.Loader)
|
|
||||||
# parse config used
|
|
||||||
with open(config_used_filepath, 'r') as f:
|
|
||||||
confs = yaml.load(f, Loader=yaml.Loader)
|
|
||||||
storage[key] = (data, confs)
|
|
||||||
|
|
||||||
return storage
|
|
||||||
|
|
||||||
def myfunc(x):
|
|
||||||
return x*x
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Report creator')
|
parser = argparse.ArgumentParser(description='Report creator')
|
||||||
parser.add_argument('--results-dir', '-d', type=str,
|
parser.add_argument('--results-dir', '-d', type=str,
|
||||||
|
@ -176,6 +121,8 @@ def main():
|
||||||
logs.pop(key)
|
logs.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
all_arch_ids = []
|
||||||
|
|
||||||
all_reg_evals = []
|
all_reg_evals = []
|
||||||
|
|
||||||
all_naswotrain_evals = []
|
all_naswotrain_evals = []
|
||||||
|
@ -243,8 +190,6 @@ def main():
|
||||||
all_cond_time_last.append(cond_duration)
|
all_cond_time_last.append(cond_duration)
|
||||||
all_partial_time_last.append(freeze_duration)
|
all_partial_time_last.append(freeze_duration)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# naswotrain
|
# naswotrain
|
||||||
# --------------
|
# --------------
|
||||||
naswotrain_top1 = logs[key]['naswotrain_evaluate']['eval_arch']['eval_train']['naswithouttraining']
|
naswotrain_top1 = logs[key]['naswotrain_evaluate']['eval_arch']['eval_train']['naswithouttraining']
|
||||||
|
@ -255,6 +200,10 @@ def main():
|
||||||
reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1']
|
reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1']
|
||||||
all_reg_evals.append(reg_eval_top1)
|
all_reg_evals.append(reg_eval_top1)
|
||||||
|
|
||||||
|
# record the arch id
|
||||||
|
# --------------------
|
||||||
|
all_arch_ids.append(confs[key]['nas']['eval']['natsbench']['arch_index'])
|
||||||
|
|
||||||
except KeyError as err:
|
except KeyError as err:
|
||||||
print(f'KeyError {err} not in {key}!')
|
print(f'KeyError {err} not in {key}!')
|
||||||
|
|
||||||
|
@ -275,6 +224,7 @@ def main():
|
||||||
assert len(all_reg_evals) == len(all_freeze_flops_last)
|
assert len(all_reg_evals) == len(all_freeze_flops_last)
|
||||||
assert len(all_reg_evals) == len(all_cond_flops_last)
|
assert len(all_reg_evals) == len(all_cond_flops_last)
|
||||||
assert len(all_reg_evals) == len(all_freeze_time_last)
|
assert len(all_reg_evals) == len(all_freeze_time_last)
|
||||||
|
assert len(all_reg_evals) == len(all_arch_ids)
|
||||||
|
|
||||||
# Freeze training results at last epoch
|
# Freeze training results at last epoch
|
||||||
freeze_tau, freeze_p_value = kendalltau(all_reg_evals, all_freeze_evals_last)
|
freeze_tau, freeze_p_value = kendalltau(all_reg_evals, all_freeze_evals_last)
|
||||||
|
@ -396,7 +346,8 @@ def main():
|
||||||
spe_freeze_top_percents = []
|
spe_freeze_top_percents = []
|
||||||
spe_naswot_top_percents = []
|
spe_naswot_top_percents = []
|
||||||
top_percents = []
|
top_percents = []
|
||||||
for top_percent in range(2, 101, 2):
|
top_percent_range = range(2, 101, 2)
|
||||||
|
for top_percent in top_percent_range:
|
||||||
top_percents.append(top_percent)
|
top_percents.append(top_percent)
|
||||||
num_to_keep = int(ma.floor(len(reg_freezelast_naswot_evals) * top_percent * 0.01))
|
num_to_keep = int(ma.floor(len(reg_freezelast_naswot_evals) * top_percent * 0.01))
|
||||||
top_percent_evals = reg_freezelast_naswot_evals[:num_to_keep]
|
top_percent_evals = reg_freezelast_naswot_evals[:num_to_keep]
|
||||||
|
@ -433,6 +384,39 @@ def main():
|
||||||
savename = os.path.join(out_dir, f'freeze_train_duration_top_archs.png')
|
savename = os.path.join(out_dir, f'freeze_train_duration_top_archs.png')
|
||||||
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
|
||||||
|
|
||||||
|
# how much overlap in top x% of architectures between method and groundtruth
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
arch_id_reg_evals = [(arch_id, reg_eval) for arch_id, reg_eval in zip(all_arch_ids, all_reg_evals)]
|
||||||
|
arch_id_freezetrain_evals = [(arch_id, freeze_eval) for arch_id, freeze_eval in zip(all_arch_ids, all_freeze_evals_last)]
|
||||||
|
arch_id_naswot_evals = [(arch_id, naswot_eval) for arch_id, naswot_eval in zip(all_arch_ids, all_naswotrain_evals)]
|
||||||
|
|
||||||
|
arch_id_reg_evals.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
arch_id_freezetrain_evals.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
arch_id_naswot_evals.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
assert len(arch_id_reg_evals) == len(arch_id_freezetrain_evals)
|
||||||
|
assert len(arch_id_reg_evals) == len(arch_id_naswot_evals)
|
||||||
|
|
||||||
|
top_percents = []
|
||||||
|
freezetrain_ratio_common = []
|
||||||
|
naswot_ratio_common = []
|
||||||
|
for top_percent in top_percent_range:
|
||||||
|
top_percents.append(top_percent)
|
||||||
|
num_to_keep = int(ma.floor(len(arch_id_reg_evals) * top_percent * 0.01))
|
||||||
|
top_percent_arch_id_reg_evals = arch_id_reg_evals[:num_to_keep]
|
||||||
|
top_percent_arch_id_freezetrain_evals = arch_id_freezetrain_evals[:num_to_keep]
|
||||||
|
top_percent_arch_id_naswot_evals = arch_id_naswot_evals[:num_to_keep]
|
||||||
|
|
||||||
|
# take the set of arch_ids in each method and find overlap with top archs
|
||||||
|
set_reg = set([x[0] for x in top_percent_arch_id_reg_evals])
|
||||||
|
set_ft = set([x[0] for x in top_percent_arch_id_freezetrain_evals])
|
||||||
|
ft_num_common = len(set_reg.intersection(set_ft))
|
||||||
|
freezetrain_ratio_common.append(ft_num_common/num_to_keep)
|
||||||
|
|
||||||
|
set_naswot = set([x[0] for x in top_percent_arch_id_naswot_evals])
|
||||||
|
naswot_num_common = len(set_reg.intersection(set_naswot))
|
||||||
|
naswot_ratio_common.append(naswot_num_common/num_to_keep)
|
||||||
|
|
||||||
# save raw data for other aggregate plots over experiments
|
# save raw data for other aggregate plots over experiments
|
||||||
raw_data_dict = {}
|
raw_data_dict = {}
|
||||||
raw_data_dict['top_percents'] = top_percents
|
raw_data_dict['top_percents'] = top_percents
|
||||||
|
@ -440,6 +424,9 @@ def main():
|
||||||
raw_data_dict['spe_naswot'] = spe_naswot_top_percents
|
raw_data_dict['spe_naswot'] = spe_naswot_top_percents
|
||||||
raw_data_dict['freeze_times_avg'] = top_percent_freeze_times_avg
|
raw_data_dict['freeze_times_avg'] = top_percent_freeze_times_avg
|
||||||
raw_data_dict['freeze_times_std'] = top_percent_freeze_times_std
|
raw_data_dict['freeze_times_std'] = top_percent_freeze_times_std
|
||||||
|
raw_data_dict['freeze_ratio_common'] = freezetrain_ratio_common
|
||||||
|
raw_data_dict['naswot_ratio_common'] = naswot_ratio_common
|
||||||
|
|
||||||
|
|
||||||
savename = os.path.join(out_dir, 'raw_data.yaml')
|
savename = os.path.join(out_dir, 'raw_data.yaml')
|
||||||
with open(savename, 'w') as f:
|
with open(savename, 'w') as f:
|
||||||
|
|
|
@ -30,34 +30,10 @@ from multiprocessing import Pool
|
||||||
|
|
||||||
from archai.common import utils
|
from archai.common import utils
|
||||||
from archai.common.ordereddict_logger import OrderedDictLogger
|
from archai.common.ordereddict_logger import OrderedDictLogger
|
||||||
from analysis_utils import epoch_nodes, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
|
from analysis_utils import epoch_nodes, parse_a_job, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
def parse_a_job(job_dir:str)->OrderedDict:
|
|
||||||
if job_dir.is_dir():
|
|
||||||
for subdir in job_dir.iterdir():
|
|
||||||
if not subdir.is_dir():
|
|
||||||
continue
|
|
||||||
# currently we expect that each job was ExperimentRunner job which should have
|
|
||||||
# _search or _eval folders
|
|
||||||
if subdir.stem.endswith('_search'):
|
|
||||||
sub_job = 'search'
|
|
||||||
elif subdir.stem.endswith('_eval'):
|
|
||||||
sub_job = 'eval'
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f'Sub directory "{subdir}" in job "{job_dir}" must '
|
|
||||||
'end with either _search or _eval which '
|
|
||||||
'should be the case if ExperimentRunner was used.')
|
|
||||||
|
|
||||||
logs_filepath = os.path.join(str(subdir), 'log.yaml')
|
|
||||||
if os.path.isfile(logs_filepath):
|
|
||||||
fix_yaml(logs_filepath)
|
|
||||||
with open(logs_filepath, 'r') as f:
|
|
||||||
key = job_dir.name + ':' + sub_job
|
|
||||||
data = yaml.load(f, Loader=yaml.Loader)
|
|
||||||
return (key, data)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Report creator')
|
parser = argparse.ArgumentParser(description='Report creator')
|
||||||
|
@ -82,21 +58,56 @@ def main():
|
||||||
|
|
||||||
# get list of all structured logs for each job
|
# get list of all structured logs for each job
|
||||||
logs = {}
|
logs = {}
|
||||||
|
confs = {}
|
||||||
job_dirs = list(results_dir.iterdir())
|
job_dirs = list(results_dir.iterdir())
|
||||||
|
|
||||||
# parallel parssing of yaml logs
|
# # test single job parsing for debugging
|
||||||
|
# # WARNING: very slow, just use for debugging
|
||||||
|
# for job_dir in job_dirs:
|
||||||
|
# a = parse_a_job(job_dir)
|
||||||
|
|
||||||
|
# parallel parsing of yaml logs
|
||||||
with Pool(18) as p:
|
with Pool(18) as p:
|
||||||
a = p.map(parse_a_job, job_dirs)
|
a = p.map(parse_a_job, job_dirs)
|
||||||
|
|
||||||
for key, data in a:
|
for storage in a:
|
||||||
logs[key] = data
|
for key, val in storage.items():
|
||||||
|
logs[key] = val[0]
|
||||||
|
confs[key] = val[1]
|
||||||
|
|
||||||
# examples of accessing logs
|
# examples of accessing logs
|
||||||
# best_test = logs[key]['eval_arch']['eval_train']['best_test']['top1']
|
# best_test = logs[key]['eval_arch']['eval_train']['best_test']['top1']
|
||||||
# best_train = logs[key]['eval_arch']['eval_train']['best_train']['top1']
|
# best_train = logs[key]['eval_arch']['eval_train']['best_train']['top1']
|
||||||
|
|
||||||
|
# remove all search jobs
|
||||||
|
for key in list(logs.keys()):
|
||||||
|
if 'search' in key:
|
||||||
|
logs.pop(key)
|
||||||
|
|
||||||
|
# remove all arch_ids which did not finish
|
||||||
|
# remove all arch_ids which did not finish
|
||||||
|
for key in list(logs.keys()):
|
||||||
|
to_delete = False
|
||||||
|
|
||||||
|
# it might have died early
|
||||||
|
if 'eval_arch' not in list(logs[key].keys()):
|
||||||
|
to_delete = True
|
||||||
|
|
||||||
|
if 'regular_evaluate' not in list(logs[key].keys()):
|
||||||
|
to_delete = True
|
||||||
|
|
||||||
|
if to_delete:
|
||||||
|
print(f'arch id {key} did not finish. removing from calculations.')
|
||||||
|
logs.pop(key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 'eval_train' not in list(logs[key]['eval_arch'].keys()):
|
||||||
|
print(f'arch id {key} did not finish. removing from calculations.')
|
||||||
|
logs.pop(key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
all_arch_ids = []
|
||||||
all_reg_evals = []
|
all_reg_evals = []
|
||||||
all_short_reg_evals = []
|
all_short_reg_evals = []
|
||||||
all_short_reg_time = []
|
all_short_reg_time = []
|
||||||
|
@ -121,6 +132,10 @@ def main():
|
||||||
reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1']
|
reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1']
|
||||||
all_reg_evals.append(reg_eval_top1)
|
all_reg_evals.append(reg_eval_top1)
|
||||||
|
|
||||||
|
# record the arch id
|
||||||
|
# --------------------
|
||||||
|
all_arch_ids.append(confs[key]['nas']['eval']['natsbench']['arch_index'])
|
||||||
|
|
||||||
except KeyError as err:
|
except KeyError as err:
|
||||||
print(f'KeyError {err} not in {key}!')
|
print(f'KeyError {err} not in {key}!')
|
||||||
|
|
||||||
|
@ -160,7 +175,8 @@ def main():
|
||||||
|
|
||||||
spe_shortreg_top_percents = []
|
spe_shortreg_top_percents = []
|
||||||
top_percents = []
|
top_percents = []
|
||||||
for top_percent in range(2, 101, 2):
|
top_percent_range = range(2, 101, 2)
|
||||||
|
for top_percent in top_percent_range:
|
||||||
top_percents.append(top_percent)
|
top_percents.append(top_percent)
|
||||||
num_to_keep = int(ma.floor(len(reg_shortreg_evals) * top_percent * 0.01))
|
num_to_keep = int(ma.floor(len(reg_shortreg_evals) * top_percent * 0.01))
|
||||||
top_percent_evals = reg_shortreg_evals[:num_to_keep]
|
top_percent_evals = reg_shortreg_evals[:num_to_keep]
|
||||||
|
@ -201,12 +217,37 @@ def main():
|
||||||
with open(results_savename, 'a') as f:
|
with open(results_savename, 'a') as f:
|
||||||
f.write(f'Avg. Shortened Training Runtime: {avg_shortreg_runtime:.03f}, stderr {stderr_shortreg_runtime:.03f} \n')
|
f.write(f'Avg. Shortened Training Runtime: {avg_shortreg_runtime:.03f}, stderr {stderr_shortreg_runtime:.03f} \n')
|
||||||
|
|
||||||
|
# how much overlap in top x% of architectures between method and groundtruth
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
arch_id_reg_evals = [(arch_id, reg_eval) for arch_id, reg_eval in zip(all_arch_ids, all_reg_evals)]
|
||||||
|
arch_id_shortreg_evals = [(arch_id, shortreg_eval) for arch_id, shortreg_eval in zip(all_arch_ids, all_short_reg_evals)]
|
||||||
|
|
||||||
|
arch_id_reg_evals.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
arch_id_shortreg_evals.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
assert len(arch_id_reg_evals) == len(arch_id_shortreg_evals)
|
||||||
|
|
||||||
|
top_percents = []
|
||||||
|
shortreg_ratio_common = []
|
||||||
|
for top_percent in top_percent_range:
|
||||||
|
top_percents.append(top_percent)
|
||||||
|
num_to_keep = int(ma.floor(len(arch_id_reg_evals) * top_percent * 0.01))
|
||||||
|
top_percent_arch_id_reg_evals = arch_id_reg_evals[:num_to_keep]
|
||||||
|
top_percent_arch_id_shortreg_evals = arch_id_shortreg_evals[:num_to_keep]
|
||||||
|
|
||||||
|
# take the set of arch_ids in each method and find overlap with top archs
|
||||||
|
set_reg = set([x[0] for x in top_percent_arch_id_reg_evals])
|
||||||
|
set_ft = set([x[0] for x in top_percent_arch_id_shortreg_evals])
|
||||||
|
ft_num_common = len(set_reg.intersection(set_ft))
|
||||||
|
shortreg_ratio_common.append(ft_num_common/num_to_keep)
|
||||||
|
|
||||||
# save raw data for other aggregate plots over experiments
|
# save raw data for other aggregate plots over experiments
|
||||||
raw_data_dict = {}
|
raw_data_dict = {}
|
||||||
raw_data_dict['top_percents'] = top_percents
|
raw_data_dict['top_percents'] = top_percents
|
||||||
raw_data_dict['spe_shortreg'] = spe_shortreg_top_percents
|
raw_data_dict['spe_shortreg'] = spe_shortreg_top_percents
|
||||||
raw_data_dict['shortreg_times_avg'] = top_percent_shortreg_times_avg
|
raw_data_dict['shortreg_times_avg'] = top_percent_shortreg_times_avg
|
||||||
raw_data_dict['shortreg_times_std'] = top_percent_shortreg_times_std
|
raw_data_dict['shortreg_times_std'] = top_percent_shortreg_times_std
|
||||||
|
raw_data_dict['shortreg_ratio_common'] = shortreg_ratio_common
|
||||||
|
|
||||||
savename = os.path.join(out_dir, 'raw_data.yaml')
|
savename = os.path.join(out_dir, 'raw_data.yaml')
|
||||||
with open(savename, 'w') as f:
|
with open(savename, 'w') as f:
|
||||||
|
|
|
@ -300,3 +300,56 @@ def write_report(template_filename:str, **kwargs)->None:
|
||||||
with open(outfilepath, 'w', encoding='utf-8') as f:
|
with open(outfilepath, 'w', encoding='utf-8') as f:
|
||||||
f.write(report)
|
f.write(report)
|
||||||
print(f'report written to: {outfilepath}')
|
print(f'report written to: {outfilepath}')
|
||||||
|
|
||||||
|
|
||||||
|
def find_valid_log(subdir:str)->str:
|
||||||
|
# originally log should be in base folder of eval or search
|
||||||
|
logs_filepath_og = os.path.join(str(subdir), 'log.yaml')
|
||||||
|
if os.path.isfile(logs_filepath_og):
|
||||||
|
return logs_filepath_og
|
||||||
|
else:
|
||||||
|
# look in the 'dist' folder for any yaml file
|
||||||
|
dist_folder = os.path.join(str(subdir), 'dist')
|
||||||
|
for f in os.listdir(dist_folder):
|
||||||
|
if f.endswith(".yaml"):
|
||||||
|
return os.path.join(dist_folder, f)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_a_job(job_dir:str)->Dict:
|
||||||
|
if job_dir.is_dir():
|
||||||
|
|
||||||
|
storage = {}
|
||||||
|
for subdir in job_dir.iterdir():
|
||||||
|
if not subdir.is_dir():
|
||||||
|
continue
|
||||||
|
# currently we expect that each job was ExperimentRunner job which should have
|
||||||
|
# _search or _eval folders
|
||||||
|
if subdir.stem.endswith('_search'):
|
||||||
|
sub_job = 'search'
|
||||||
|
elif subdir.stem.endswith('_eval'):
|
||||||
|
sub_job = 'eval'
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Sub directory "{subdir}" in job "{job_dir}" must '
|
||||||
|
'end with either _search or _eval which '
|
||||||
|
'should be the case if ExperimentRunner was used.')
|
||||||
|
|
||||||
|
logs_filepath = find_valid_log(subdir)
|
||||||
|
# if no valid logfile found, ignore this job as it probably
|
||||||
|
# didn't finish or errored out or is yet to run
|
||||||
|
if not logs_filepath:
|
||||||
|
continue
|
||||||
|
|
||||||
|
config_used_filepath = os.path.join(subdir, 'config_used.yaml')
|
||||||
|
|
||||||
|
if os.path.isfile(logs_filepath):
|
||||||
|
fix_yaml(logs_filepath)
|
||||||
|
key = job_dir.name + subdir.name + ':' + sub_job
|
||||||
|
# parse log
|
||||||
|
with open(logs_filepath, 'r') as f:
|
||||||
|
data = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
# parse config used
|
||||||
|
with open(config_used_filepath, 'r') as f:
|
||||||
|
confs = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
storage[key] = (data, confs)
|
||||||
|
|
||||||
|
return storage
|
|
@ -73,7 +73,19 @@ def main():
|
||||||
'ft_fb512_ftlr0.1_fte5_ct256_ftt0.6']
|
'ft_fb512_ftlr0.1_fte5_ct256_ftt0.6']
|
||||||
|
|
||||||
|
|
||||||
shortreg_exp_list = ['nb_reg_b1024_e10', 'nb_reg_b1024_e20', 'nb_reg_b1024_e30']
|
shortreg_exp_list = ['nb_reg_b1024_e01', \
|
||||||
|
'nb_reg_b1024_e02', \
|
||||||
|
'nb_reg_b1024_e04', \
|
||||||
|
'nb_reg_b1024_e06', \
|
||||||
|
'nb_reg_b1024_e08', \
|
||||||
|
'nb_reg_b1024_e10', \
|
||||||
|
'nb_reg_b512_e01', \
|
||||||
|
'nb_reg_b512_e02', \
|
||||||
|
'nb_reg_b512_e04', \
|
||||||
|
'nb_reg_b512_e06', \
|
||||||
|
'nb_reg_b512_e08', \
|
||||||
|
'nb_reg_b512_e10' ]
|
||||||
|
|
||||||
|
|
||||||
# parse raw data from all processed experiments
|
# parse raw data from all processed experiments
|
||||||
data = parse_raw_data(exp_folder, exp_list)
|
data = parse_raw_data(exp_folder, exp_list)
|
||||||
|
@ -84,6 +96,7 @@ def main():
|
||||||
colors = [cmap(i) for i in np.linspace(0, 1, len(exp_list)*2)]
|
colors = [cmap(i) for i in np.linspace(0, 1, len(exp_list)*2)]
|
||||||
linestyles = ['solid', 'dashdot', 'dotted', 'dashed']
|
linestyles = ['solid', 'dashdot', 'dotted', 'dashed']
|
||||||
markers = ['.', 'v', '^', '<', '>', '1', 's', 'p', '*', '+', 'x', 'X', 'D', 'd']
|
markers = ['.', 'v', '^', '<', '>', '1', 's', 'p', '*', '+', 'x', 'X', 'D', 'd']
|
||||||
|
mathy_markers = ["$a$", "$b$", "$c$", "$d$", "$e$", "$f$", "$g$", "$h$", "$i$", "$j$", "$k$", "$l$", "$m$", "$n$", "$o$", "$p$", "$q$", "$r$", "$s$", "$t$", "$u$", "$v$", "$w$", "$x$", "$y$", "$z$"]
|
||||||
|
|
||||||
cc = cycler(color=colors) * cycler(linestyle=linestyles) * cycler(marker=markers)
|
cc = cycler(color=colors) * cycler(linestyle=linestyles) * cycler(marker=markers)
|
||||||
|
|
||||||
|
@ -112,8 +125,10 @@ def main():
|
||||||
break
|
break
|
||||||
|
|
||||||
# plot shortreg data
|
# plot shortreg data
|
||||||
|
counter = 0
|
||||||
for i, key in enumerate(shortreg_data.keys()):
|
for i, key in enumerate(shortreg_data.keys()):
|
||||||
plt.plot(shortreg_data[key]['top_percents'], shortreg_data[key]['spe_shortreg'], marker = '8', mfc='green', ms=10)
|
plt.plot(shortreg_data[key]['top_percents'], shortreg_data[key]['spe_shortreg'], marker = mathy_markers[counter], mfc='green', ms=10)
|
||||||
|
counter += 1
|
||||||
legend_labels.append(key)
|
legend_labels.append(key)
|
||||||
|
|
||||||
# annotate the shortreg data points with time information
|
# annotate the shortreg data points with time information
|
||||||
|
@ -166,23 +181,35 @@ def main():
|
||||||
tp_info[tp] = this_tp_info
|
tp_info[tp] = this_tp_info
|
||||||
|
|
||||||
# now plot each top percent
|
# now plot each top percent
|
||||||
cc = cycler(marker=markers)
|
# markers the same size as number of freezetrain experiments
|
||||||
|
|
||||||
fig, axs = plt.subplots(5, 10)
|
fig, axs = plt.subplots(5, 10)
|
||||||
handles = None
|
handles = None
|
||||||
labels = None
|
labels = None
|
||||||
for tp_key, ax in zip(tp_info.keys(), axs.flat):
|
for tp_key, ax in zip(tp_info.keys(), axs.flat):
|
||||||
|
counter = 0
|
||||||
|
counter_reg = 0
|
||||||
for exp in tp_info[tp_key].keys():
|
for exp in tp_info[tp_key].keys():
|
||||||
duration = tp_info[tp_key][exp][0]
|
duration = tp_info[tp_key][exp][0]
|
||||||
spe = tp_info[tp_key][exp][1]
|
spe = tp_info[tp_key][exp][1]
|
||||||
#ax.set_prop_cycle(cc)
|
|
||||||
ax.scatter(duration, spe, label=exp)
|
if 'ft_fb' in exp:
|
||||||
|
marker = markers[counter]
|
||||||
|
counter += 1
|
||||||
|
elif 'nb_reg' in exp:
|
||||||
|
marker = mathy_markers[counter_reg]
|
||||||
|
counter_reg += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
ax.scatter(duration, spe, label=exp, marker=marker)
|
||||||
ax.set_title(str(tp_key))
|
ax.set_title(str(tp_key))
|
||||||
#ax.set(xlabel='Duration (s)', ylabel='SPE')
|
#ax.set(xlabel='Duration (s)', ylabel='SPE')
|
||||||
ax.set_ylim([0, 1])
|
ax.set_ylim([0, 1])
|
||||||
|
|
||||||
|
|
||||||
handles, labels = axs.flat[-1].get_legend_handles_labels()
|
handles, labels = axs.flat[-1].get_legend_handles_labels()
|
||||||
fig.legend(handles, labels, loc='upper center')
|
fig.legend(handles, labels, loc='center right')
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
@ -191,7 +218,6 @@ def main():
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
plt.clf()
|
plt.clf()
|
||||||
time_legend_labels = []
|
time_legend_labels = []
|
||||||
|
|
||||||
for key in data.keys():
|
for key in data.keys():
|
||||||
plt.errorbar(data[key]['top_percents'], data[key]['freeze_times_avg'], yerr=np.array(data[key]['freeze_times_std'])/2, marker='*', mfc='red', ms=5)
|
plt.errorbar(data[key]['top_percents'], data[key]['freeze_times_avg'], yerr=np.array(data[key]['freeze_times_std'])/2, marker='*', mfc='red', ms=5)
|
||||||
time_legend_labels.append(key + '_freezetrain')
|
time_legend_labels.append(key + '_freezetrain')
|
||||||
|
|
Загрузка…
Ссылка в новой задаче