sync divnas with new arch params design

This commit is contained in:
Shital Shah 2020-05-18 01:07:50 -07:00
Родитель 57a10a2dac
Коммит 9f8a4b3cfb
3 изменённых файлов: 31 добавлений и 49 удалений

Просмотреть файл

@ -9,8 +9,8 @@ class DivnasCellBuilder(CellBuilder):
@overrides @overrides
def register_ops(self) -> None: def register_ops(self) -> None:
Op.register_op('div_op', Op.register_op('div_op',
lambda op_desc, alphas, affine: lambda op_desc, arch_params, affine:
DivOp(op_desc, alphas, affine)) DivOp(op_desc, arch_params, affine))
@overrides @overrides
def build(self, model_desc:ModelDesc, search_iter:int)->None: def build(self, model_desc:ModelDesc, search_iter:int)->None:

Просмотреть файл

@ -27,7 +27,7 @@ class DivnasFinalizers(Finalizers):
# get config and train data loader # get config and train data loader
# TODO: confirm this is correct in case you get silent bugs # TODO: confirm this is correct in case you get silent bugs
conf = get_conf() conf = get_conf()
conf_loader = conf['nas']['search']['loader'] conf_loader = conf['nas']['search']['loader']
train_dl, val_dl, test_dl = get_data(conf_loader) train_dl, val_dl, test_dl = get_data(conf_loader)
# wrap all cells in the model # wrap all cells in the model
@ -36,12 +36,12 @@ class DivnasFinalizers(Finalizers):
divnas_cell = Divnas_Cell(cell) divnas_cell = Divnas_Cell(cell)
self._divnas_cells[id(cell)] = divnas_cell self._divnas_cells[id(cell)] = divnas_cell
# go through all edges in the DAG and if they are of divop # go through all edges in the DAG and if they are of divop
# type then set them to collect activations # type then set them to collect activations
sigma = conf['nas']['search']['divnas']['sigma'] sigma = conf['nas']['search']['divnas']['sigma']
for _, dcell in enumerate(self._divnas_cells.values()): for _, dcell in enumerate(self._divnas_cells.values()):
dcell.collect_activations(DivOp, sigma) dcell.collect_activations(DivOp, sigma)
# now we need to run one evaluation epoch to collect activations # now we need to run one evaluation epoch to collect activations
# we do it on cpu otherwise we might run into memory issues # we do it on cpu otherwise we might run into memory issues
# later we can redo the whole logic in pytorch itself # later we can redo the whole logic in pytorch itself
@ -53,7 +53,7 @@ class DivnasFinalizers(Finalizers):
for _ in range(1): for _ in range(1):
for _, (x, _) in enumerate(train_dl): for _, (x, _) in enumerate(train_dl):
_, _ = model(x), None _, _ = model(x), None
# now you can go through and update the # now you can go through and update the
# node covariances in every cell # node covariances in every cell
for dcell in self._divnas_cells.values(): for dcell in self._divnas_cells.values():
dcell.update_covs() dcell.update_covs()
@ -83,7 +83,7 @@ class DivnasFinalizers(Finalizers):
nodes = node_descs, nodes = node_descs,
s0_op=cell.s0_op.finalize()[0], s0_op=cell.s0_op.finalize()[0],
s1_op=cell.s1_op.finalize()[0], s1_op=cell.s1_op.finalize()[0],
alphas_from = cell.desc.alphas_from, template_cell = cell.desc.template_cell,
max_final_edges=cell.desc.max_final_edges, max_final_edges=cell.desc.max_final_edges,
node_ch_out=cell.desc.node_ch_out, node_ch_out=cell.desc.node_ch_out,
post_op=cell.post_op.finalize()[0] post_op=cell.post_op.finalize()[0]
@ -101,17 +101,17 @@ class DivnasFinalizers(Finalizers):
assert cov.shape[0] == cov.shape[1] assert cov.shape[0] == cov.shape[1]
# the number of primitive operators has to be greater # the number of primitive operators has to be greater
# than equal to the maximum number of final edges # than equal to the maximum number of final edges
# allowed # allowed
assert cov.shape[0] >= max_final_edges assert cov.shape[0] >= max_final_edges
# get total number of ops incoming to this node # get total number of ops incoming to this node
num_ops = sum([edge._op.num_valid_div_ops for edge in node]) num_ops = sum([edge._op.num_valid_div_ops for edge in node])
# and collect some bookkeeping indices # and collect some bookkeeping indices
edge_num_and_op_ind = [] edge_num_and_op_ind = []
for j, edge in enumerate(node): for j, edge in enumerate(node):
if type(edge._op) == DivOp: if type(edge._op) == DivOp:
for k in range(edge._op.num_valid_div_ops): for k in range(edge._op.num_valid_div_ops):
edge_num_and_op_ind.append((j, k)) edge_num_and_op_ind.append((j, k))
@ -127,7 +127,7 @@ class DivnasFinalizers(Finalizers):
op_desc = node[edge_ind]._op.get_valid_op_desc(op_ind) op_desc = node[edge_ind]._op.get_valid_op_desc(op_ind)
new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids) new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids)
selected_edges.append(new_edge) selected_edges.append(new_edge)
# for edge in selected_edges: # for edge in selected_edges:
# self.finalize_edge(edge) # self.finalize_edge(edge)

Просмотреть файл

@ -13,6 +13,7 @@ from overrides import overrides
from archai.nas.model_desc import OpDesc from archai.nas.model_desc import OpDesc
from archai.nas.operations import Op from archai.nas.operations import Op
from archai.common.common import get_conf from archai.common.common import get_conf
from archai.nas.arch_params import ArchParams
# TODO: reduction cell might have output reduced by 2^1=2X due to # TODO: reduction cell might have output reduced by 2^1=2X due to
# stride 2 through input nodes however FactorizedReduce does only # stride 2 through input nodes however FactorizedReduce does only
@ -55,7 +56,7 @@ class DivOp(Op):
else: else:
self._valid_to_orig.append(i) self._valid_to_orig.append(i)
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter], def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
affine:bool): affine:bool):
super().__init__() super().__init__()
@ -67,10 +68,10 @@ class DivOp(Op):
finalizer = conf['nas']['search']['finalizer'] finalizer = conf['nas']['search']['finalizer']
if trainer == 'noalpha' and finalizer == 'default': if trainer == 'noalpha' and finalizer == 'default':
raise NotImplementedError raise NotImplementedError('noalpha trainer is not implemented for the default finalizer')
if trainer != 'noalpha': if trainer != 'noalpha':
self._set_alphas(alphas) self._setup_arch_params(arch_params)
else: else:
self._alphas = None self._alphas = None
@ -78,7 +79,7 @@ class DivOp(Op):
for primitive in DivOp.PRIMITIVES: for primitive in DivOp.PRIMITIVES:
op = Op.create( op = Op.create(
OpDesc(primitive, op_desc.params, in_len=1, trainables=None), OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
affine=affine, alphas=alphas) affine=affine, arch_params=None)
self._ops.append(op) self._ops.append(op)
# various state variables for diversity # various state variables for diversity
@ -125,43 +126,24 @@ class DivOp(Op):
return result return result
@overrides @overrides
def alphas(self) -> Iterable[nn.Parameter]: def ops(self)->Iterator['Op']:
if self._alphas: return iter(self._ops) # type: ignore
for alpha in self._alphas:
yield alpha
@overrides
def weights(self) -> Iterable[nn.Parameter]:
for op in self._ops:
for w in op.parameters():
yield w
@overrides
def ops(self)->Iterator['Op']: # type: ignore
return iter(self._ops)
def get_valid_op_desc(self, index:int)->OpDesc:
''' index: index in the valid index list '''
assert index <= self.num_valid_div_ops
orig_index = self._valid_to_orig[index]
desc, _ = self._ops[orig_index].finalize()
return desc
@overrides @overrides
def can_drop_path(self) -> bool: def can_drop_path(self) -> bool:
return False return False
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None: def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
# must call before adding other ops # do we have shared arch params?
assert len(list(self.parameters())) == 0 if arch_params is None:
self._alphas = list(alphas) # create our own arch params
if not len(self._alphas):
new_p = nn.Parameter( # TODO: use better init than uniform random? new_p = nn.Parameter( # TODO: use better init than uniform random?
1.0e-3*torch.randn(len(DivOp.PRIMITIVES)), requires_grad=True) 1.0e-3*torch.randn(len(self.PRIMITIVES)), requires_grad=True)
# NOTE: This is a way to register parameters with PyTorch. self.create_arch_params([('alphas', new_p)])
# One creates a dummy variable with the parameters and then else:
# asks back for the parameters in the object from Pytorch assert arch_params.has_kind('alphas')
# which automagically registers the just created parameters. self.set_arch_params(arch_params)
self._reg_alphas = new_p
self._alphas = [p for p in self.parameters()] # we store alphas in list so Pytorch don't register them
self._alphas = list(self.arch_params().param_by_kind('alphas'))
assert len(self._alphas)==1