Fixed analysis script for random search in FastArchRank.

This commit is contained in:
Debadeepta Dey 2021-04-14 19:43:28 -07:00 коммит произвёл Gustavo Rosa
Родитель 098a8d7891
Коммит 5e440cc95c
3 изменённых файлов: 6 добавлений и 8 удалений

Просмотреть файл

@ -1,5 +1,3 @@
from overrides import overrides
from typing import Optional, Type, Tuple

Просмотреть файл

@ -107,7 +107,6 @@ class RandomNatsbenchTssFarSearcher(Searcher):
# get the full evaluation result from natsbench
info = api.get_more_info(archid, dataset_name, hp=200, is_random=False)
this_arch_top1_test = info['test-accuracy']
best_tests.append((archid, this_arch_top1_test))
# dump important things to log

Просмотреть файл

@ -77,17 +77,16 @@ def main():
logs[key] = val[0]
confs[key] = val[1]
raw_data = {}
for key in logs.keys():
# Get total duration of the run
# which is the sum of all conditional and freeze
# trainings over all architectures
duration = 0.0
best_tests = None
best_trains = None
for skey in logs[key].keys():
if 'conditional' in skey or 'freeze' in skey:
@ -101,12 +100,14 @@ def main():
# Get the last best_trains_tests
if 'best_trains_tests' in skey:
best_trains = logs[key][skey]['best_trains']
best_tests = logs[key][skey]['best_tests']
assert len(best_trains) == len(best_tests)
best_tests.sort(key= lambda x: x[1])
best_test_encountered = best_tests[-1][1]
# find the test error of the best train
best_test = best_tests[-1][1]
raw_data[key] = (duration, best_test_encountered)
raw_data[key] = (duration, best_test)
run_durations = [raw_data[key][0] for key in raw_data.keys()]
max_accs = [raw_data[key][1] for key in raw_data.keys()]