diff --git a/archai/algos/divnas/divnas_cell_builder.py b/archai/algos/divnas/divnas_cell_builder.py index bae11a6d..08359662 100644 --- a/archai/algos/divnas/divnas_cell_builder.py +++ b/archai/algos/divnas/divnas_cell_builder.py @@ -9,8 +9,8 @@ class DivnasCellBuilder(CellBuilder): @overrides def register_ops(self) -> None: Op.register_op('div_op', - lambda op_desc, alphas, affine: - DivOp(op_desc, alphas, affine)) + lambda op_desc, arch_params, affine: + DivOp(op_desc, arch_params, affine)) @overrides def build(self, model_desc:ModelDesc, search_iter:int)->None: diff --git a/archai/algos/divnas/divnas_finalizers.py b/archai/algos/divnas/divnas_finalizers.py index 2c5d8443..9b4c92a6 100644 --- a/archai/algos/divnas/divnas_finalizers.py +++ b/archai/algos/divnas/divnas_finalizers.py @@ -27,7 +27,7 @@ class DivnasFinalizers(Finalizers): # get config and train data loader # TODO: confirm this is correct in case you get silent bugs conf = get_conf() - conf_loader = conf['nas']['search']['loader'] + conf_loader = conf['nas']['search']['loader'] train_dl, val_dl, test_dl = get_data(conf_loader) # wrap all cells in the model @@ -36,12 +36,12 @@ class DivnasFinalizers(Finalizers): divnas_cell = Divnas_Cell(cell) self._divnas_cells[id(cell)] = divnas_cell - # go through all edges in the DAG and if they are of divop + # go through all edges in the DAG and if they are of divop # type then set them to collect activations sigma = conf['nas']['search']['divnas']['sigma'] for _, dcell in enumerate(self._divnas_cells.values()): dcell.collect_activations(DivOp, sigma) - + # now we need to run one evaluation epoch to collect activations # we do it on cpu otherwise we might run into memory issues # later we can redo the whole logic in pytorch itself @@ -53,7 +53,7 @@ class DivnasFinalizers(Finalizers): for _ in range(1): for _, (x, _) in enumerate(train_dl): _, _ = model(x), None - # now you can go through and update the + # now you can go through and update the # node covariances in every cell for dcell in self._divnas_cells.values(): dcell.update_covs() @@ -83,7 +83,7 @@ class DivnasFinalizers(Finalizers): nodes = node_descs, s0_op=cell.s0_op.finalize()[0], s1_op=cell.s1_op.finalize()[0], - alphas_from = cell.desc.alphas_from, + template_cell = cell.desc.template_cell, max_final_edges=cell.desc.max_final_edges, node_ch_out=cell.desc.node_ch_out, post_op=cell.post_op.finalize()[0] @@ -101,17 +101,17 @@ class DivnasFinalizers(Finalizers): assert cov.shape[0] == cov.shape[1] # the number of primitive operators has to be greater - # than equal to the maximum number of final edges + # than equal to the maximum number of final edges # allowed assert cov.shape[0] >= max_final_edges - + # get total number of ops incoming to this node num_ops = sum([edge._op.num_valid_div_ops for edge in node]) # and collect some bookkeeping indices edge_num_and_op_ind = [] for j, edge in enumerate(node): - if type(edge._op) == DivOp: + if type(edge._op) == DivOp: for k in range(edge._op.num_valid_div_ops): edge_num_and_op_ind.append((j, k)) @@ -127,7 +127,7 @@ class DivnasFinalizers(Finalizers): op_desc = node[edge_ind]._op.get_valid_op_desc(op_ind) new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids) selected_edges.append(new_edge) - + # for edge in selected_edges: # self.finalize_edge(edge) diff --git a/archai/algos/divnas/divop.py b/archai/algos/divnas/divop.py index 5318bb20..f13024f9 100644 --- a/archai/algos/divnas/divop.py +++ b/archai/algos/divnas/divop.py @@ -13,6 +13,7 @@ from overrides import overrides from archai.nas.model_desc import OpDesc from archai.nas.operations import Op from archai.common.common import get_conf +from archai.nas.arch_params import ArchParams # TODO: reduction cell might have output reduced by 2^1=2X due to # stride 2 through input nodes however FactorizedReduce does only @@ -55,7 +56,7 @@ class DivOp(Op): else: self._valid_to_orig.append(i) - def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter], + def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams], affine:bool): super().__init__() @@ -67,10 +68,10 @@ class DivOp(Op): finalizer = conf['nas']['search']['finalizer'] if trainer == 'noalpha' and finalizer == 'default': - raise NotImplementedError + raise NotImplementedError('noalpha trainer is not implemented for the default finalizer') if trainer != 'noalpha': - self._set_alphas(alphas) + self._setup_arch_params(arch_params) else: self._alphas = None @@ -78,7 +79,7 @@ class DivOp(Op): for primitive in DivOp.PRIMITIVES: op = Op.create( OpDesc(primitive, op_desc.params, in_len=1, trainables=None), - affine=affine, alphas=alphas) + affine=affine, arch_params=None) self._ops.append(op) # various state variables for diversity @@ -125,43 +126,24 @@ class DivOp(Op): return result @overrides - def alphas(self) -> Iterable[nn.Parameter]: - if self._alphas: - for alpha in self._alphas: - yield alpha - - @overrides - def weights(self) -> Iterable[nn.Parameter]: - for op in self._ops: - for w in op.parameters(): - yield w - - @overrides - def ops(self)->Iterator['Op']: # type: ignore - return iter(self._ops) - - def get_valid_op_desc(self, index:int)->OpDesc: - ''' index: index in the valid index list ''' - assert index <= self.num_valid_div_ops - orig_index = self._valid_to_orig[index] - desc, _ = self._ops[orig_index].finalize() - return desc - + def ops(self)->Iterator['Op']: + return iter(self._ops) # type: ignore @overrides def can_drop_path(self) -> bool: return False - def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None: - # must call before adding other ops - assert len(list(self.parameters())) == 0 - self._alphas = list(alphas) - if not len(self._alphas): + def _setup_arch_params(self, arch_params:Optional[ArchParams])->None: + # do we have shared arch params? + if arch_params is None: + # create our own arch params new_p = nn.Parameter( # TODO: use better init than uniform random? - 1.0e-3*torch.randn(len(DivOp.PRIMITIVES)), requires_grad=True) - # NOTE: This is a way to register parameters with PyTorch. - # One creates a dummy variable with the parameters and then - # asks back for the parameters in the object from Pytorch - # which automagically registers the just created parameters. - self._reg_alphas = new_p - self._alphas = [p for p in self.parameters()] + 1.0e-3*torch.randn(len(self.PRIMITIVES)), requires_grad=True) + self.create_arch_params([('alphas', new_p)]) + else: + assert arch_params.has_kind('alphas') + self.set_arch_params(arch_params) + + # we store alphas in list so Pytorch don't register them + self._alphas = list(self.arch_params().param_by_kind('alphas')) + assert len(self._alphas)==1