Refactored an analysis script.

This commit is contained in:
Debadeepta Dey 2021-03-22 09:56:31 -07:00 коммит произвёл Gustavo Rosa
Родитель f598ba3a35
Коммит 325d44b3d4
2 изменённых файлов: 3 добавлений и 23 удалений

5
.vscode/launch.json поставляемый
Просмотреть файл

@ -607,13 +607,12 @@
"--out-dir", "D:\\archai_experiment_reports"]
},
{
"name": "Analysis Natsbench ZeroCost Synthetic vs Cifar10",
"name": "Analysis Natsbench Nonstandard Generate Benchmark",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/reports/analysis_natsbench_zerocost_synthetic_vs_cifar10.py",
"program": "${cwd}/scripts/reports/analysis_natsbench_nonstandard_generate_benchmark.py",
"console": "integratedTerminal",
"args": ["--results-dir", "D:\\archaiphilly\\phillytools\\nb_reg_b256_e200_sc10",
"--zerocost_measures_file", "C:\\Users\\dedey\\Documents\\zero-cost-nas\\precomputed_results\\nasbench201_cifar10_subset_all_results.yaml",
"--out-dir", "D:\\archai_experiment_reports"]
},
{

Просмотреть файл

@ -42,9 +42,6 @@ def main():
parser.add_argument('--results-dir', '-d', type=str,
default=r'~/logdir/proxynas_test_0001',
help='folder with experiment results from pt')
parser.add_argument('--zerocost_measures_file', type=str, default=r'~/nasbench201_cifar10_subset_all_results.yaml',
help='yaml file with zerocost measures for the same set of architectures. This is \
generated by calling scripts in zero-cost-nas repo.')
parser.add_argument('--out-dir', '-o', type=str, default=r'~/logdir/reports',
help='folder to output reports')
args, extra_args = parser.parse_known_args()
@ -106,18 +103,6 @@ def main():
logs.pop(key)
# load the zero-cost measures for the same architectures in natsbench
# computed on any dataset since we are interested in synflow which is data agnostic
with open(args.zerocost_measures_file, 'r') as f:
zero_cost_data = yaml.load(f, Loader=yaml.Loader)
# create a dict with arch_id: synflow score as entries
arch_id_synflow = {}
for entry in zero_cost_data:
arch_id = entry['i']
synflow = entry['logmeasures']['synflow']
arch_id_synflow[arch_id] = synflow
# create a dict with arch_id: regular eval score as entries
# and save since synthetic cifar10 is not part of the benchmark
arch_id_reg_eval = {}
@ -136,12 +121,8 @@ def main():
all_synflow = []
for arch_id in arch_id_reg_eval.keys():
all_reg_evals.append(arch_id_reg_eval[arch_id])
all_synflow.append(arch_id_synflow[arch_id])
assert(len(all_reg_evals) == len(all_synflow))
synflow_spe, _ = spearmanr(all_reg_evals, all_synflow)
print(f'num valid architectures used for analysis {len(logs)}')
print(f'synflow spearman on synthetic cifar10 is {synflow_spe}')
# plot histogram of regular evaluation scores
fig = px.histogram(all_reg_evals, labels={'x': 'Test Accuracy', 'y': 'Counts'})