зеркало из https://github.com/microsoft/archai.git
sync divnas with new arch params design
This commit is contained in:
Родитель
57a10a2dac
Коммит
9f8a4b3cfb
|
@ -9,8 +9,8 @@ class DivnasCellBuilder(CellBuilder):
|
||||||
@overrides
|
@overrides
|
||||||
def register_ops(self) -> None:
|
def register_ops(self) -> None:
|
||||||
Op.register_op('div_op',
|
Op.register_op('div_op',
|
||||||
lambda op_desc, alphas, affine:
|
lambda op_desc, arch_params, affine:
|
||||||
DivOp(op_desc, alphas, affine))
|
DivOp(op_desc, arch_params, affine))
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def build(self, model_desc:ModelDesc, search_iter:int)->None:
|
def build(self, model_desc:ModelDesc, search_iter:int)->None:
|
||||||
|
|
|
@ -27,7 +27,7 @@ class DivnasFinalizers(Finalizers):
|
||||||
# get config and train data loader
|
# get config and train data loader
|
||||||
# TODO: confirm this is correct in case you get silent bugs
|
# TODO: confirm this is correct in case you get silent bugs
|
||||||
conf = get_conf()
|
conf = get_conf()
|
||||||
conf_loader = conf['nas']['search']['loader']
|
conf_loader = conf['nas']['search']['loader']
|
||||||
train_dl, val_dl, test_dl = get_data(conf_loader)
|
train_dl, val_dl, test_dl = get_data(conf_loader)
|
||||||
|
|
||||||
# wrap all cells in the model
|
# wrap all cells in the model
|
||||||
|
@ -36,12 +36,12 @@ class DivnasFinalizers(Finalizers):
|
||||||
divnas_cell = Divnas_Cell(cell)
|
divnas_cell = Divnas_Cell(cell)
|
||||||
self._divnas_cells[id(cell)] = divnas_cell
|
self._divnas_cells[id(cell)] = divnas_cell
|
||||||
|
|
||||||
# go through all edges in the DAG and if they are of divop
|
# go through all edges in the DAG and if they are of divop
|
||||||
# type then set them to collect activations
|
# type then set them to collect activations
|
||||||
sigma = conf['nas']['search']['divnas']['sigma']
|
sigma = conf['nas']['search']['divnas']['sigma']
|
||||||
for _, dcell in enumerate(self._divnas_cells.values()):
|
for _, dcell in enumerate(self._divnas_cells.values()):
|
||||||
dcell.collect_activations(DivOp, sigma)
|
dcell.collect_activations(DivOp, sigma)
|
||||||
|
|
||||||
# now we need to run one evaluation epoch to collect activations
|
# now we need to run one evaluation epoch to collect activations
|
||||||
# we do it on cpu otherwise we might run into memory issues
|
# we do it on cpu otherwise we might run into memory issues
|
||||||
# later we can redo the whole logic in pytorch itself
|
# later we can redo the whole logic in pytorch itself
|
||||||
|
@ -53,7 +53,7 @@ class DivnasFinalizers(Finalizers):
|
||||||
for _ in range(1):
|
for _ in range(1):
|
||||||
for _, (x, _) in enumerate(train_dl):
|
for _, (x, _) in enumerate(train_dl):
|
||||||
_, _ = model(x), None
|
_, _ = model(x), None
|
||||||
# now you can go through and update the
|
# now you can go through and update the
|
||||||
# node covariances in every cell
|
# node covariances in every cell
|
||||||
for dcell in self._divnas_cells.values():
|
for dcell in self._divnas_cells.values():
|
||||||
dcell.update_covs()
|
dcell.update_covs()
|
||||||
|
@ -83,7 +83,7 @@ class DivnasFinalizers(Finalizers):
|
||||||
nodes = node_descs,
|
nodes = node_descs,
|
||||||
s0_op=cell.s0_op.finalize()[0],
|
s0_op=cell.s0_op.finalize()[0],
|
||||||
s1_op=cell.s1_op.finalize()[0],
|
s1_op=cell.s1_op.finalize()[0],
|
||||||
alphas_from = cell.desc.alphas_from,
|
template_cell = cell.desc.template_cell,
|
||||||
max_final_edges=cell.desc.max_final_edges,
|
max_final_edges=cell.desc.max_final_edges,
|
||||||
node_ch_out=cell.desc.node_ch_out,
|
node_ch_out=cell.desc.node_ch_out,
|
||||||
post_op=cell.post_op.finalize()[0]
|
post_op=cell.post_op.finalize()[0]
|
||||||
|
@ -101,17 +101,17 @@ class DivnasFinalizers(Finalizers):
|
||||||
assert cov.shape[0] == cov.shape[1]
|
assert cov.shape[0] == cov.shape[1]
|
||||||
|
|
||||||
# the number of primitive operators has to be greater
|
# the number of primitive operators has to be greater
|
||||||
# than equal to the maximum number of final edges
|
# than equal to the maximum number of final edges
|
||||||
# allowed
|
# allowed
|
||||||
assert cov.shape[0] >= max_final_edges
|
assert cov.shape[0] >= max_final_edges
|
||||||
|
|
||||||
# get total number of ops incoming to this node
|
# get total number of ops incoming to this node
|
||||||
num_ops = sum([edge._op.num_valid_div_ops for edge in node])
|
num_ops = sum([edge._op.num_valid_div_ops for edge in node])
|
||||||
|
|
||||||
# and collect some bookkeeping indices
|
# and collect some bookkeeping indices
|
||||||
edge_num_and_op_ind = []
|
edge_num_and_op_ind = []
|
||||||
for j, edge in enumerate(node):
|
for j, edge in enumerate(node):
|
||||||
if type(edge._op) == DivOp:
|
if type(edge._op) == DivOp:
|
||||||
for k in range(edge._op.num_valid_div_ops):
|
for k in range(edge._op.num_valid_div_ops):
|
||||||
edge_num_and_op_ind.append((j, k))
|
edge_num_and_op_ind.append((j, k))
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ class DivnasFinalizers(Finalizers):
|
||||||
op_desc = node[edge_ind]._op.get_valid_op_desc(op_ind)
|
op_desc = node[edge_ind]._op.get_valid_op_desc(op_ind)
|
||||||
new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids)
|
new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids)
|
||||||
selected_edges.append(new_edge)
|
selected_edges.append(new_edge)
|
||||||
|
|
||||||
# for edge in selected_edges:
|
# for edge in selected_edges:
|
||||||
# self.finalize_edge(edge)
|
# self.finalize_edge(edge)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from overrides import overrides
|
||||||
from archai.nas.model_desc import OpDesc
|
from archai.nas.model_desc import OpDesc
|
||||||
from archai.nas.operations import Op
|
from archai.nas.operations import Op
|
||||||
from archai.common.common import get_conf
|
from archai.common.common import get_conf
|
||||||
|
from archai.nas.arch_params import ArchParams
|
||||||
|
|
||||||
# TODO: reduction cell might have output reduced by 2^1=2X due to
|
# TODO: reduction cell might have output reduced by 2^1=2X due to
|
||||||
# stride 2 through input nodes however FactorizedReduce does only
|
# stride 2 through input nodes however FactorizedReduce does only
|
||||||
|
@ -55,7 +56,7 @@ class DivOp(Op):
|
||||||
else:
|
else:
|
||||||
self._valid_to_orig.append(i)
|
self._valid_to_orig.append(i)
|
||||||
|
|
||||||
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
|
def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
|
||||||
affine:bool):
|
affine:bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -67,10 +68,10 @@ class DivOp(Op):
|
||||||
finalizer = conf['nas']['search']['finalizer']
|
finalizer = conf['nas']['search']['finalizer']
|
||||||
|
|
||||||
if trainer == 'noalpha' and finalizer == 'default':
|
if trainer == 'noalpha' and finalizer == 'default':
|
||||||
raise NotImplementedError
|
raise NotImplementedError('noalpha trainer is not implemented for the default finalizer')
|
||||||
|
|
||||||
if trainer != 'noalpha':
|
if trainer != 'noalpha':
|
||||||
self._set_alphas(alphas)
|
self._setup_arch_params(arch_params)
|
||||||
else:
|
else:
|
||||||
self._alphas = None
|
self._alphas = None
|
||||||
|
|
||||||
|
@ -78,7 +79,7 @@ class DivOp(Op):
|
||||||
for primitive in DivOp.PRIMITIVES:
|
for primitive in DivOp.PRIMITIVES:
|
||||||
op = Op.create(
|
op = Op.create(
|
||||||
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
|
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
|
||||||
affine=affine, alphas=alphas)
|
affine=affine, arch_params=None)
|
||||||
self._ops.append(op)
|
self._ops.append(op)
|
||||||
|
|
||||||
# various state variables for diversity
|
# various state variables for diversity
|
||||||
|
@ -125,43 +126,24 @@ class DivOp(Op):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def alphas(self) -> Iterable[nn.Parameter]:
|
def ops(self)->Iterator['Op']:
|
||||||
if self._alphas:
|
return iter(self._ops) # type: ignore
|
||||||
for alpha in self._alphas:
|
|
||||||
yield alpha
|
|
||||||
|
|
||||||
@overrides
|
|
||||||
def weights(self) -> Iterable[nn.Parameter]:
|
|
||||||
for op in self._ops:
|
|
||||||
for w in op.parameters():
|
|
||||||
yield w
|
|
||||||
|
|
||||||
@overrides
|
|
||||||
def ops(self)->Iterator['Op']: # type: ignore
|
|
||||||
return iter(self._ops)
|
|
||||||
|
|
||||||
def get_valid_op_desc(self, index:int)->OpDesc:
|
|
||||||
''' index: index in the valid index list '''
|
|
||||||
assert index <= self.num_valid_div_ops
|
|
||||||
orig_index = self._valid_to_orig[index]
|
|
||||||
desc, _ = self._ops[orig_index].finalize()
|
|
||||||
return desc
|
|
||||||
|
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def can_drop_path(self) -> bool:
|
def can_drop_path(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
|
||||||
# must call before adding other ops
|
# do we have shared arch params?
|
||||||
assert len(list(self.parameters())) == 0
|
if arch_params is None:
|
||||||
self._alphas = list(alphas)
|
# create our own arch params
|
||||||
if not len(self._alphas):
|
|
||||||
new_p = nn.Parameter( # TODO: use better init than uniform random?
|
new_p = nn.Parameter( # TODO: use better init than uniform random?
|
||||||
1.0e-3*torch.randn(len(DivOp.PRIMITIVES)), requires_grad=True)
|
1.0e-3*torch.randn(len(self.PRIMITIVES)), requires_grad=True)
|
||||||
# NOTE: This is a way to register parameters with PyTorch.
|
self.create_arch_params([('alphas', new_p)])
|
||||||
# One creates a dummy variable with the parameters and then
|
else:
|
||||||
# asks back for the parameters in the object from Pytorch
|
assert arch_params.has_kind('alphas')
|
||||||
# which automagically registers the just created parameters.
|
self.set_arch_params(arch_params)
|
||||||
self._reg_alphas = new_p
|
|
||||||
self._alphas = [p for p in self.parameters()]
|
# we store alphas in list so Pytorch don't register them
|
||||||
|
self._alphas = list(self.arch_params().param_by_kind('alphas'))
|
||||||
|
assert len(self._alphas)==1
|
||||||
|
|
Загрузка…
Ссылка в новой задаче