Still debugging local search on DARTS.

This commit is contained in:
Debadeepta Dey 2021-11-30 08:45:29 -08:00 коммит произвёл Gustavo Rosa
Родитель 476bb01429
Коммит dcd3cb8dd0
3 изменённых файлов: 29 добавлений и 14 удалений

Просмотреть файл

@ -1,5 +1,6 @@
import os
import random
import math as ma
from overrides.overrides import overrides
from typing import List, Tuple, Optional, Dict
from archai.nas.discrete_search_space import DiscreteSearchSpace
@ -78,6 +79,19 @@ class LocalSearchDartsReg(LocalSearch):
if self.local_minima:
best_minimum = max(self.local_minima, key=lambda x:x[1])
return best_minimum
else:
# if no local minima encountered, return the best
# encountered so far
logger.info('No local minima encountered! Returning best of visited.')
max_train_top1 = -ma.inf
argmax_arch = None
for _, arch in self.eval_cache.items():
this_train_top1 = arch.metadata['train_top1']
if this_train_top1 > max_train_top1:
max_train_top1 = this_train_top1
argmax_arch = arch
return argmax_arch, max_train_top1
@overrides
@ -106,9 +120,9 @@ class LocalSearchDartsReg(LocalSearch):
train_top1 = trainer_metrics.best_train_top1()
arch.metadata['train_top1'] = train_top1
# DEBUG: simulate architecture evaluation
train_top1 = random.random()
arch.metadata['train_top1'] = train_top1
# # DEBUG: simulate architecture evaluation
# train_top1 = random.random()
# arch.metadata['train_top1'] = train_top1
# cache it
self.eval_cache[arch_flat_rep] = arch

Просмотреть файл

@ -160,20 +160,21 @@ class DiscreteSearchSpaceDARTS(DiscreteSearchSpace):
op_nbrs = op_nbrs_regular + op_nbrs_reduction
assert len(op_nbrs) == 96
# now create the edge neighbors where the
# only difference is in one of the input edges
# there should be 24 of them
edge_nbrs_regular = self._get_edge_neighbors(central_regular_cell, central_desc)
edge_nbrs_reduction = self._get_edge_neighbors(central_reduction_cell, central_desc)
# # now create the edge neighbors where the
# # only difference is in one of the input edges
# # there should be 24 of them
# edge_nbrs_regular = self._get_edge_neighbors(central_regular_cell, central_desc)
# edge_nbrs_reduction = self._get_edge_neighbors(central_reduction_cell, central_desc)
assert len(edge_nbrs_regular) == 12
assert len(edge_nbrs_reduction) == 12
# assert len(edge_nbrs_regular) == 12
# assert len(edge_nbrs_reduction) == 12
edge_nbrs = edge_nbrs_regular + edge_nbrs_reduction
assert len(edge_nbrs) == 24
# edge_nbrs = edge_nbrs_regular + edge_nbrs_reduction
# assert len(edge_nbrs) == 24
# Now convert all the model descs to actual Models
all_nbrs = op_nbrs + edge_nbrs
# all_nbrs = op_nbrs + edge_nbrs
all_nbrs = op_nbrs
all_models = [Model(nbr_desc, self.drop_path_prob, affine=True) for nbr_desc in all_nbrs]
all_arch_meta = []

Просмотреть файл

@ -9,7 +9,7 @@ nas:
loader:
val_ratio: 0.0 # don't need val for local search
trainer:
epochs: 2
epochs: 1
eval:
model_desc: