Added ratio of common architectures for ranking.

This commit is contained in:
Debadeepta Dey 2021-02-04 21:56:54 -08:00 коммит произвёл Gustavo Rosa
Родитель 8511d25699
Коммит c6396d3ec6
5 изменённых файлов: 220 добавлений и 97 удалений

Просмотреть файл

@ -27,3 +27,19 @@ python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\User
python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\ft_fb256_ftlr0.1_fte5_ct256_ftt0.6 --out-dir C:\\Users\\dedey\\archai_experiment_reports python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\ft_fb256_ftlr0.1_fte5_ct256_ftt0.6 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\ft_fb256_ftlr0.1_fte10_ct256_ftt0.6 --out-dir C:\\Users\\dedey\\archai_experiment_reports python scripts/reports/analysis_freeze_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\ft_fb256_ftlr0.1_fte10_ct256_ftt0.6 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e01 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e02 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e04 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e06 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e08 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b1024_e10 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e01 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e02 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e04 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e06 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e08 --out-dir C:\\Users\\dedey\\archai_experiment_reports
python scripts/reports/analysis_regular_natsbench_space.py --results-dir C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\nb_reg_b512_e10 --out-dir C:\\Users\\dedey\\archai_experiment_reports

Просмотреть файл

@ -31,66 +31,11 @@ from collections import namedtuple
from archai.common import utils from archai.common import utils
from archai.common.ordereddict_logger import OrderedDictLogger from archai.common.ordereddict_logger import OrderedDictLogger
from analysis_utils import epoch_nodes, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report from analysis_utils import epoch_nodes, parse_a_job, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
import re import re
def find_valid_log(subdir:str)->str:
# originally log should be in base folder of eval or search
logs_filepath_og = os.path.join(str(subdir), 'log.yaml')
if os.path.isfile(logs_filepath_og):
return logs_filepath_og
else:
# look in the 'dist' folder for any yaml file
dist_folder = os.path.join(str(subdir), 'dist')
for f in os.listdir(dist_folder):
if f.endswith(".yaml"):
return os.path.join(dist_folder, f)
def parse_a_job(job_dir:str)->Dict:
if job_dir.is_dir():
storage = {}
for subdir in job_dir.iterdir():
if not subdir.is_dir():
continue
# currently we expect that each job was ExperimentRunner job which should have
# _search or _eval folders
if subdir.stem.endswith('_search'):
sub_job = 'search'
elif subdir.stem.endswith('_eval'):
sub_job = 'eval'
else:
raise RuntimeError(f'Sub directory "{subdir}" in job "{job_dir}" must '
'end with either _search or _eval which '
'should be the case if ExperimentRunner was used.')
logs_filepath = find_valid_log(subdir)
# if no valid logfile found, ignore this job as it probably
# didn't finish or errored out or is yet to run
if not logs_filepath:
continue
config_used_filepath = os.path.join(subdir, 'config_used.yaml')
if os.path.isfile(logs_filepath):
fix_yaml(logs_filepath)
key = job_dir.name + subdir.name + ':' + sub_job
# parse log
with open(logs_filepath, 'r') as f:
data = yaml.load(f, Loader=yaml.Loader)
# parse config used
with open(config_used_filepath, 'r') as f:
confs = yaml.load(f, Loader=yaml.Loader)
storage[key] = (data, confs)
return storage
def myfunc(x):
return x*x
def main(): def main():
parser = argparse.ArgumentParser(description='Report creator') parser = argparse.ArgumentParser(description='Report creator')
parser.add_argument('--results-dir', '-d', type=str, parser.add_argument('--results-dir', '-d', type=str,
@ -176,6 +121,8 @@ def main():
logs.pop(key) logs.pop(key)
all_arch_ids = []
all_reg_evals = [] all_reg_evals = []
all_naswotrain_evals = [] all_naswotrain_evals = []
@ -242,9 +189,7 @@ def main():
all_freeze_time_last.append(freeze_duration + cond_duration) all_freeze_time_last.append(freeze_duration + cond_duration)
all_cond_time_last.append(cond_duration) all_cond_time_last.append(cond_duration)
all_partial_time_last.append(freeze_duration) all_partial_time_last.append(freeze_duration)
# naswotrain # naswotrain
# -------------- # --------------
naswotrain_top1 = logs[key]['naswotrain_evaluate']['eval_arch']['eval_train']['naswithouttraining'] naswotrain_top1 = logs[key]['naswotrain_evaluate']['eval_arch']['eval_train']['naswithouttraining']
@ -254,6 +199,10 @@ def main():
# -------------------- # --------------------
reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1'] reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1']
all_reg_evals.append(reg_eval_top1) all_reg_evals.append(reg_eval_top1)
# record the arch id
# --------------------
all_arch_ids.append(confs[key]['nas']['eval']['natsbench']['arch_index'])
except KeyError as err: except KeyError as err:
print(f'KeyError {err} not in {key}!') print(f'KeyError {err} not in {key}!')
@ -275,6 +224,7 @@ def main():
assert len(all_reg_evals) == len(all_freeze_flops_last) assert len(all_reg_evals) == len(all_freeze_flops_last)
assert len(all_reg_evals) == len(all_cond_flops_last) assert len(all_reg_evals) == len(all_cond_flops_last)
assert len(all_reg_evals) == len(all_freeze_time_last) assert len(all_reg_evals) == len(all_freeze_time_last)
assert len(all_reg_evals) == len(all_arch_ids)
# Freeze training results at last epoch # Freeze training results at last epoch
freeze_tau, freeze_p_value = kendalltau(all_reg_evals, all_freeze_evals_last) freeze_tau, freeze_p_value = kendalltau(all_reg_evals, all_freeze_evals_last)
@ -396,7 +346,8 @@ def main():
spe_freeze_top_percents = [] spe_freeze_top_percents = []
spe_naswot_top_percents = [] spe_naswot_top_percents = []
top_percents = [] top_percents = []
for top_percent in range(2, 101, 2): top_percent_range = range(2, 101, 2)
for top_percent in top_percent_range:
top_percents.append(top_percent) top_percents.append(top_percent)
num_to_keep = int(ma.floor(len(reg_freezelast_naswot_evals) * top_percent * 0.01)) num_to_keep = int(ma.floor(len(reg_freezelast_naswot_evals) * top_percent * 0.01))
top_percent_evals = reg_freezelast_naswot_evals[:num_to_keep] top_percent_evals = reg_freezelast_naswot_evals[:num_to_keep]
@ -433,6 +384,39 @@ def main():
savename = os.path.join(out_dir, f'freeze_train_duration_top_archs.png') savename = os.path.join(out_dir, f'freeze_train_duration_top_archs.png')
plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight') plt.savefig(savename, dpi=plt.gcf().dpi, bbox_inches='tight')
# how much overlap in top x% of architectures between method and groundtruth
# ----------------------------------------------------------------------------
arch_id_reg_evals = [(arch_id, reg_eval) for arch_id, reg_eval in zip(all_arch_ids, all_reg_evals)]
arch_id_freezetrain_evals = [(arch_id, freeze_eval) for arch_id, freeze_eval in zip(all_arch_ids, all_freeze_evals_last)]
arch_id_naswot_evals = [(arch_id, naswot_eval) for arch_id, naswot_eval in zip(all_arch_ids, all_naswotrain_evals)]
arch_id_reg_evals.sort(key=lambda x: x[1], reverse=True)
arch_id_freezetrain_evals.sort(key=lambda x: x[1], reverse=True)
arch_id_naswot_evals.sort(key=lambda x: x[1], reverse=True)
assert len(arch_id_reg_evals) == len(arch_id_freezetrain_evals)
assert len(arch_id_reg_evals) == len(arch_id_naswot_evals)
top_percents = []
freezetrain_ratio_common = []
naswot_ratio_common = []
for top_percent in top_percent_range:
top_percents.append(top_percent)
num_to_keep = int(ma.floor(len(arch_id_reg_evals) * top_percent * 0.01))
top_percent_arch_id_reg_evals = arch_id_reg_evals[:num_to_keep]
top_percent_arch_id_freezetrain_evals = arch_id_freezetrain_evals[:num_to_keep]
top_percent_arch_id_naswot_evals = arch_id_naswot_evals[:num_to_keep]
# take the set of arch_ids in each method and find overlap with top archs
set_reg = set([x[0] for x in top_percent_arch_id_reg_evals])
set_ft = set([x[0] for x in top_percent_arch_id_freezetrain_evals])
ft_num_common = len(set_reg.intersection(set_ft))
freezetrain_ratio_common.append(ft_num_common/num_to_keep)
set_naswot = set([x[0] for x in top_percent_arch_id_naswot_evals])
naswot_num_common = len(set_reg.intersection(set_naswot))
naswot_ratio_common.append(naswot_num_common/num_to_keep)
# save raw data for other aggregate plots over experiments # save raw data for other aggregate plots over experiments
raw_data_dict = {} raw_data_dict = {}
raw_data_dict['top_percents'] = top_percents raw_data_dict['top_percents'] = top_percents
@ -440,6 +424,9 @@ def main():
raw_data_dict['spe_naswot'] = spe_naswot_top_percents raw_data_dict['spe_naswot'] = spe_naswot_top_percents
raw_data_dict['freeze_times_avg'] = top_percent_freeze_times_avg raw_data_dict['freeze_times_avg'] = top_percent_freeze_times_avg
raw_data_dict['freeze_times_std'] = top_percent_freeze_times_std raw_data_dict['freeze_times_std'] = top_percent_freeze_times_std
raw_data_dict['freeze_ratio_common'] = freezetrain_ratio_common
raw_data_dict['naswot_ratio_common'] = naswot_ratio_common
savename = os.path.join(out_dir, 'raw_data.yaml') savename = os.path.join(out_dir, 'raw_data.yaml')
with open(savename, 'w') as f: with open(savename, 'w') as f:

Просмотреть файл

@ -30,34 +30,10 @@ from multiprocessing import Pool
from archai.common import utils from archai.common import utils
from archai.common.ordereddict_logger import OrderedDictLogger from archai.common.ordereddict_logger import OrderedDictLogger
from analysis_utils import epoch_nodes, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report from analysis_utils import epoch_nodes, parse_a_job, fix_yaml, remove_seed_part, group_multi_runs, collect_epoch_nodes, EpochStats, FoldStats, stat2str, get_epoch_stats, get_summary_text, get_details_text, plot_epochs, write_report
import re import re
def parse_a_job(job_dir:str)->OrderedDict:
if job_dir.is_dir():
for subdir in job_dir.iterdir():
if not subdir.is_dir():
continue
# currently we expect that each job was ExperimentRunner job which should have
# _search or _eval folders
if subdir.stem.endswith('_search'):
sub_job = 'search'
elif subdir.stem.endswith('_eval'):
sub_job = 'eval'
else:
raise RuntimeError(f'Sub directory "{subdir}" in job "{job_dir}" must '
'end with either _search or _eval which '
'should be the case if ExperimentRunner was used.')
logs_filepath = os.path.join(str(subdir), 'log.yaml')
if os.path.isfile(logs_filepath):
fix_yaml(logs_filepath)
with open(logs_filepath, 'r') as f:
key = job_dir.name + ':' + sub_job
data = yaml.load(f, Loader=yaml.Loader)
return (key, data)
def main(): def main():
parser = argparse.ArgumentParser(description='Report creator') parser = argparse.ArgumentParser(description='Report creator')
@ -82,21 +58,56 @@ def main():
# get list of all structured logs for each job # get list of all structured logs for each job
logs = {} logs = {}
confs = {}
job_dirs = list(results_dir.iterdir()) job_dirs = list(results_dir.iterdir())
# parallel parssing of yaml logs # # test single job parsing for debugging
# # WARNING: very slow, just use for debugging
# for job_dir in job_dirs:
# a = parse_a_job(job_dir)
# parallel parsing of yaml logs
with Pool(18) as p: with Pool(18) as p:
a = p.map(parse_a_job, job_dirs) a = p.map(parse_a_job, job_dirs)
for key, data in a: for storage in a:
logs[key] = data for key, val in storage.items():
logs[key] = val[0]
confs[key] = val[1]
# examples of accessing logs # examples of accessing logs
# best_test = logs[key]['eval_arch']['eval_train']['best_test']['top1'] # best_test = logs[key]['eval_arch']['eval_train']['best_test']['top1']
# best_train = logs[key]['eval_arch']['eval_train']['best_train']['top1'] # best_train = logs[key]['eval_arch']['eval_train']['best_train']['top1']
# remove all search jobs
for key in list(logs.keys()):
if 'search' in key:
logs.pop(key)
# remove all arch_ids which did not finish
# remove all arch_ids which did not finish
for key in list(logs.keys()):
to_delete = False
# it might have died early
if 'eval_arch' not in list(logs[key].keys()):
to_delete = True
if 'regular_evaluate' not in list(logs[key].keys()):
to_delete = True
if to_delete:
print(f'arch id {key} did not finish. removing from calculations.')
logs.pop(key)
continue
if 'eval_train' not in list(logs[key]['eval_arch'].keys()):
print(f'arch id {key} did not finish. removing from calculations.')
logs.pop(key)
continue
all_arch_ids = []
all_reg_evals = [] all_reg_evals = []
all_short_reg_evals = [] all_short_reg_evals = []
all_short_reg_time = [] all_short_reg_time = []
@ -120,7 +131,11 @@ def main():
# -------------------- # --------------------
reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1'] reg_eval_top1 = logs[key]['regular_evaluate']['regtrainingtop1']
all_reg_evals.append(reg_eval_top1) all_reg_evals.append(reg_eval_top1)
# record the arch id
# --------------------
all_arch_ids.append(confs[key]['nas']['eval']['natsbench']['arch_index'])
except KeyError as err: except KeyError as err:
print(f'KeyError {err} not in {key}!') print(f'KeyError {err} not in {key}!')
@ -160,7 +175,8 @@ def main():
spe_shortreg_top_percents = [] spe_shortreg_top_percents = []
top_percents = [] top_percents = []
for top_percent in range(2, 101, 2): top_percent_range = range(2, 101, 2)
for top_percent in top_percent_range:
top_percents.append(top_percent) top_percents.append(top_percent)
num_to_keep = int(ma.floor(len(reg_shortreg_evals) * top_percent * 0.01)) num_to_keep = int(ma.floor(len(reg_shortreg_evals) * top_percent * 0.01))
top_percent_evals = reg_shortreg_evals[:num_to_keep] top_percent_evals = reg_shortreg_evals[:num_to_keep]
@ -201,12 +217,37 @@ def main():
with open(results_savename, 'a') as f: with open(results_savename, 'a') as f:
f.write(f'Avg. Shortened Training Runtime: {avg_shortreg_runtime:.03f}, stderr {stderr_shortreg_runtime:.03f} \n') f.write(f'Avg. Shortened Training Runtime: {avg_shortreg_runtime:.03f}, stderr {stderr_shortreg_runtime:.03f} \n')
# how much overlap in top x% of architectures between method and groundtruth
# ----------------------------------------------------------------------------
arch_id_reg_evals = [(arch_id, reg_eval) for arch_id, reg_eval in zip(all_arch_ids, all_reg_evals)]
arch_id_shortreg_evals = [(arch_id, shortreg_eval) for arch_id, shortreg_eval in zip(all_arch_ids, all_short_reg_evals)]
arch_id_reg_evals.sort(key=lambda x: x[1], reverse=True)
arch_id_shortreg_evals.sort(key=lambda x: x[1], reverse=True)
assert len(arch_id_reg_evals) == len(arch_id_shortreg_evals)
top_percents = []
shortreg_ratio_common = []
for top_percent in top_percent_range:
top_percents.append(top_percent)
num_to_keep = int(ma.floor(len(arch_id_reg_evals) * top_percent * 0.01))
top_percent_arch_id_reg_evals = arch_id_reg_evals[:num_to_keep]
top_percent_arch_id_shortreg_evals = arch_id_shortreg_evals[:num_to_keep]
# take the set of arch_ids in each method and find overlap with top archs
set_reg = set([x[0] for x in top_percent_arch_id_reg_evals])
set_ft = set([x[0] for x in top_percent_arch_id_shortreg_evals])
ft_num_common = len(set_reg.intersection(set_ft))
shortreg_ratio_common.append(ft_num_common/num_to_keep)
# save raw data for other aggregate plots over experiments # save raw data for other aggregate plots over experiments
raw_data_dict = {} raw_data_dict = {}
raw_data_dict['top_percents'] = top_percents raw_data_dict['top_percents'] = top_percents
raw_data_dict['spe_shortreg'] = spe_shortreg_top_percents raw_data_dict['spe_shortreg'] = spe_shortreg_top_percents
raw_data_dict['shortreg_times_avg'] = top_percent_shortreg_times_avg raw_data_dict['shortreg_times_avg'] = top_percent_shortreg_times_avg
raw_data_dict['shortreg_times_std'] = top_percent_shortreg_times_std raw_data_dict['shortreg_times_std'] = top_percent_shortreg_times_std
raw_data_dict['shortreg_ratio_common'] = shortreg_ratio_common
savename = os.path.join(out_dir, 'raw_data.yaml') savename = os.path.join(out_dir, 'raw_data.yaml')
with open(savename, 'w') as f: with open(savename, 'w') as f:

Просмотреть файл

@ -300,3 +300,56 @@ def write_report(template_filename:str, **kwargs)->None:
with open(outfilepath, 'w', encoding='utf-8') as f: with open(outfilepath, 'w', encoding='utf-8') as f:
f.write(report) f.write(report)
print(f'report written to: {outfilepath}') print(f'report written to: {outfilepath}')
def find_valid_log(subdir:str)->str:
# originally log should be in base folder of eval or search
logs_filepath_og = os.path.join(str(subdir), 'log.yaml')
if os.path.isfile(logs_filepath_og):
return logs_filepath_og
else:
# look in the 'dist' folder for any yaml file
dist_folder = os.path.join(str(subdir), 'dist')
for f in os.listdir(dist_folder):
if f.endswith(".yaml"):
return os.path.join(dist_folder, f)
def parse_a_job(job_dir:str)->Dict:
if job_dir.is_dir():
storage = {}
for subdir in job_dir.iterdir():
if not subdir.is_dir():
continue
# currently we expect that each job was ExperimentRunner job which should have
# _search or _eval folders
if subdir.stem.endswith('_search'):
sub_job = 'search'
elif subdir.stem.endswith('_eval'):
sub_job = 'eval'
else:
raise RuntimeError(f'Sub directory "{subdir}" in job "{job_dir}" must '
'end with either _search or _eval which '
'should be the case if ExperimentRunner was used.')
logs_filepath = find_valid_log(subdir)
# if no valid logfile found, ignore this job as it probably
# didn't finish or errored out or is yet to run
if not logs_filepath:
continue
config_used_filepath = os.path.join(subdir, 'config_used.yaml')
if os.path.isfile(logs_filepath):
fix_yaml(logs_filepath)
key = job_dir.name + subdir.name + ':' + sub_job
# parse log
with open(logs_filepath, 'r') as f:
data = yaml.load(f, Loader=yaml.Loader)
# parse config used
with open(config_used_filepath, 'r') as f:
confs = yaml.load(f, Loader=yaml.Loader)
storage[key] = (data, confs)
return storage

Просмотреть файл

@ -73,7 +73,19 @@ def main():
'ft_fb512_ftlr0.1_fte5_ct256_ftt0.6'] 'ft_fb512_ftlr0.1_fte5_ct256_ftt0.6']
shortreg_exp_list = ['nb_reg_b1024_e10', 'nb_reg_b1024_e20', 'nb_reg_b1024_e30'] shortreg_exp_list = ['nb_reg_b1024_e01', \
'nb_reg_b1024_e02', \
'nb_reg_b1024_e04', \
'nb_reg_b1024_e06', \
'nb_reg_b1024_e08', \
'nb_reg_b1024_e10', \
'nb_reg_b512_e01', \
'nb_reg_b512_e02', \
'nb_reg_b512_e04', \
'nb_reg_b512_e06', \
'nb_reg_b512_e08', \
'nb_reg_b512_e10' ]
# parse raw data from all processed experiments # parse raw data from all processed experiments
data = parse_raw_data(exp_folder, exp_list) data = parse_raw_data(exp_folder, exp_list)
@ -84,6 +96,7 @@ def main():
colors = [cmap(i) for i in np.linspace(0, 1, len(exp_list)*2)] colors = [cmap(i) for i in np.linspace(0, 1, len(exp_list)*2)]
linestyles = ['solid', 'dashdot', 'dotted', 'dashed'] linestyles = ['solid', 'dashdot', 'dotted', 'dashed']
markers = ['.', 'v', '^', '<', '>', '1', 's', 'p', '*', '+', 'x', 'X', 'D', 'd'] markers = ['.', 'v', '^', '<', '>', '1', 's', 'p', '*', '+', 'x', 'X', 'D', 'd']
mathy_markers = ["$a$", "$b$", "$c$", "$d$", "$e$", "$f$", "$g$", "$h$", "$i$", "$j$", "$k$", "$l$", "$m$", "$n$", "$o$", "$p$", "$q$", "$r$", "$s$", "$t$", "$u$", "$v$", "$w$", "$x$", "$y$", "$z$"]
cc = cycler(color=colors) * cycler(linestyle=linestyles) * cycler(marker=markers) cc = cycler(color=colors) * cycler(linestyle=linestyles) * cycler(marker=markers)
@ -112,8 +125,10 @@ def main():
break break
# plot shortreg data # plot shortreg data
counter = 0
for i, key in enumerate(shortreg_data.keys()): for i, key in enumerate(shortreg_data.keys()):
plt.plot(shortreg_data[key]['top_percents'], shortreg_data[key]['spe_shortreg'], marker = '8', mfc='green', ms=10) plt.plot(shortreg_data[key]['top_percents'], shortreg_data[key]['spe_shortreg'], marker = mathy_markers[counter], mfc='green', ms=10)
counter += 1
legend_labels.append(key) legend_labels.append(key)
# annotate the shortreg data points with time information # annotate the shortreg data points with time information
@ -166,23 +181,35 @@ def main():
tp_info[tp] = this_tp_info tp_info[tp] = this_tp_info
# now plot each top percent # now plot each top percent
cc = cycler(marker=markers) # markers the same size as number of freezetrain experiments
fig, axs = plt.subplots(5, 10) fig, axs = plt.subplots(5, 10)
handles = None handles = None
labels = None labels = None
for tp_key, ax in zip(tp_info.keys(), axs.flat): for tp_key, ax in zip(tp_info.keys(), axs.flat):
counter = 0
counter_reg = 0
for exp in tp_info[tp_key].keys(): for exp in tp_info[tp_key].keys():
duration = tp_info[tp_key][exp][0] duration = tp_info[tp_key][exp][0]
spe = tp_info[tp_key][exp][1] spe = tp_info[tp_key][exp][1]
#ax.set_prop_cycle(cc)
ax.scatter(duration, spe, label=exp) if 'ft_fb' in exp:
marker = markers[counter]
counter += 1
elif 'nb_reg' in exp:
marker = mathy_markers[counter_reg]
counter_reg += 1
else:
raise NotImplementedError
ax.scatter(duration, spe, label=exp, marker=marker)
ax.set_title(str(tp_key)) ax.set_title(str(tp_key))
#ax.set(xlabel='Duration (s)', ylabel='SPE') #ax.set(xlabel='Duration (s)', ylabel='SPE')
ax.set_ylim([0, 1]) ax.set_ylim([0, 1])
handles, labels = axs.flat[-1].get_legend_handles_labels() handles, labels = axs.flat[-1].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center') fig.legend(handles, labels, loc='center right')
plt.show() plt.show()
@ -191,7 +218,6 @@ def main():
# ------------------------------------------------------------ # ------------------------------------------------------------
plt.clf() plt.clf()
time_legend_labels = [] time_legend_labels = []
for key in data.keys(): for key in data.keys():
plt.errorbar(data[key]['top_percents'], data[key]['freeze_times_avg'], yerr=np.array(data[key]['freeze_times_std'])/2, marker='*', mfc='red', ms=5) plt.errorbar(data[key]['top_percents'], data[key]['freeze_times_avg'], yerr=np.array(data[key]['freeze_times_std'])/2, marker='*', mfc='red', ms=5)
time_legend_labels.append(key + '_freezetrain') time_legend_labels.append(key + '_freezetrain')