зеркало из https://github.com/microsoft/archai.git
Still debugging local search on DARTS.
This commit is contained in:
Родитель
476bb01429
Коммит
dcd3cb8dd0
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import random
|
||||
import math as ma
|
||||
from overrides.overrides import overrides
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from archai.nas.discrete_search_space import DiscreteSearchSpace
|
||||
|
@ -78,6 +79,19 @@ class LocalSearchDartsReg(LocalSearch):
|
|||
if self.local_minima:
|
||||
best_minimum = max(self.local_minima, key=lambda x:x[1])
|
||||
return best_minimum
|
||||
else:
|
||||
# if no local minima encountered, return the best
|
||||
# encountered so far
|
||||
logger.info('No local minima encountered! Returning best of visited.')
|
||||
max_train_top1 = -ma.inf
|
||||
argmax_arch = None
|
||||
for _, arch in self.eval_cache.items():
|
||||
this_train_top1 = arch.metadata['train_top1']
|
||||
if this_train_top1 > max_train_top1:
|
||||
max_train_top1 = this_train_top1
|
||||
argmax_arch = arch
|
||||
return argmax_arch, max_train_top1
|
||||
|
||||
|
||||
|
||||
@overrides
|
||||
|
@ -106,9 +120,9 @@ class LocalSearchDartsReg(LocalSearch):
|
|||
train_top1 = trainer_metrics.best_train_top1()
|
||||
arch.metadata['train_top1'] = train_top1
|
||||
|
||||
# DEBUG: simulate architecture evaluation
|
||||
train_top1 = random.random()
|
||||
arch.metadata['train_top1'] = train_top1
|
||||
# # DEBUG: simulate architecture evaluation
|
||||
# train_top1 = random.random()
|
||||
# arch.metadata['train_top1'] = train_top1
|
||||
|
||||
# cache it
|
||||
self.eval_cache[arch_flat_rep] = arch
|
||||
|
|
|
@ -160,20 +160,21 @@ class DiscreteSearchSpaceDARTS(DiscreteSearchSpace):
|
|||
op_nbrs = op_nbrs_regular + op_nbrs_reduction
|
||||
assert len(op_nbrs) == 96
|
||||
|
||||
# now create the edge neighbors where the
|
||||
# only difference is in one of the input edges
|
||||
# there should be 24 of them
|
||||
edge_nbrs_regular = self._get_edge_neighbors(central_regular_cell, central_desc)
|
||||
edge_nbrs_reduction = self._get_edge_neighbors(central_reduction_cell, central_desc)
|
||||
# # now create the edge neighbors where the
|
||||
# # only difference is in one of the input edges
|
||||
# # there should be 24 of them
|
||||
# edge_nbrs_regular = self._get_edge_neighbors(central_regular_cell, central_desc)
|
||||
# edge_nbrs_reduction = self._get_edge_neighbors(central_reduction_cell, central_desc)
|
||||
|
||||
assert len(edge_nbrs_regular) == 12
|
||||
assert len(edge_nbrs_reduction) == 12
|
||||
# assert len(edge_nbrs_regular) == 12
|
||||
# assert len(edge_nbrs_reduction) == 12
|
||||
|
||||
edge_nbrs = edge_nbrs_regular + edge_nbrs_reduction
|
||||
assert len(edge_nbrs) == 24
|
||||
# edge_nbrs = edge_nbrs_regular + edge_nbrs_reduction
|
||||
# assert len(edge_nbrs) == 24
|
||||
|
||||
# Now convert all the model descs to actual Models
|
||||
all_nbrs = op_nbrs + edge_nbrs
|
||||
# all_nbrs = op_nbrs + edge_nbrs
|
||||
all_nbrs = op_nbrs
|
||||
all_models = [Model(nbr_desc, self.drop_path_prob, affine=True) for nbr_desc in all_nbrs]
|
||||
|
||||
all_arch_meta = []
|
||||
|
|
|
@ -9,7 +9,7 @@ nas:
|
|||
loader:
|
||||
val_ratio: 0.0 # don't need val for local search
|
||||
trainer:
|
||||
epochs: 2
|
||||
epochs: 1
|
||||
|
||||
eval:
|
||||
model_desc:
|
||||
|
|
Загрузка…
Ссылка в новой задаче