зеркало из https://github.com/microsoft/archai.git
Fixed analysis script for random search in FastArchRank.
This commit is contained in:
Родитель
098a8d7891
Коммит
5e440cc95c
|
@ -1,5 +1,3 @@
|
|||
|
||||
|
||||
from overrides import overrides
|
||||
from typing import Optional, Type, Tuple
|
||||
|
||||
|
|
|
@ -107,7 +107,6 @@ class RandomNatsbenchTssFarSearcher(Searcher):
|
|||
# get the full evaluation result from natsbench
|
||||
info = api.get_more_info(archid, dataset_name, hp=200, is_random=False)
|
||||
this_arch_top1_test = info['test-accuracy']
|
||||
|
||||
best_tests.append((archid, this_arch_top1_test))
|
||||
|
||||
# dump important things to log
|
||||
|
|
|
@ -77,17 +77,16 @@ def main():
|
|||
logs[key] = val[0]
|
||||
confs[key] = val[1]
|
||||
|
||||
|
||||
raw_data = {}
|
||||
|
||||
for key in logs.keys():
|
||||
|
||||
# Get total duration of the run
|
||||
# which is the sum of all conditional and freeze
|
||||
# trainings over all architectures
|
||||
duration = 0.0
|
||||
|
||||
best_tests = None
|
||||
best_trains = None
|
||||
|
||||
for skey in logs[key].keys():
|
||||
if 'conditional' in skey or 'freeze' in skey:
|
||||
|
@ -101,12 +100,14 @@ def main():
|
|||
|
||||
# Get the last best_trains_tests
|
||||
if 'best_trains_tests' in skey:
|
||||
best_trains = logs[key][skey]['best_trains']
|
||||
best_tests = logs[key][skey]['best_tests']
|
||||
assert len(best_trains) == len(best_tests)
|
||||
|
||||
best_tests.sort(key= lambda x: x[1])
|
||||
best_test_encountered = best_tests[-1][1]
|
||||
# find the test error of the best train
|
||||
best_test = best_tests[-1][1]
|
||||
|
||||
raw_data[key] = (duration, best_test_encountered)
|
||||
raw_data[key] = (duration, best_test)
|
||||
|
||||
run_durations = [raw_data[key][0] for key in raw_data.keys()]
|
||||
max_accs = [raw_data[key][1] for key in raw_data.keys()]
|
||||
|
|
Загрузка…
Ссылка в новой задаче