зеркало из https://github.com/microsoft/archai.git
general arch params implementation, support for multiple optimizers
This commit is contained in:
Родитель
cf0ce350fe
Коммит
57a10a2dac
|
@ -53,6 +53,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "darts"]
|
||||
},
|
||||
{
|
||||
"name": "DiDarts-E2E-Toy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "didarts"]
|
||||
},
|
||||
{
|
||||
"name": "Darts-Food101-Toy",
|
||||
"type": "python",
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Mapping, Optional, Union
|
||||
from typing import Iterator, Mapping, Optional, Union
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
@ -10,6 +10,14 @@ from archai.common.config import Config
|
|||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model import Model
|
||||
from archai.common.common import logger
|
||||
from archai.common.utils import zip_eq
|
||||
|
||||
def _get_loss(model:Model, lossfn, x, y):
|
||||
logits, *_ = model(x) # might also return aux tower logits
|
||||
return lossfn(logits, y)
|
||||
|
||||
def _get_alphas(model:Model)->Iterator[nn.Parameter]:
|
||||
return model.all_owned().param_by_kind('alphas')
|
||||
|
||||
class BilevelOptimizer:
|
||||
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
|
||||
|
@ -25,8 +33,12 @@ class BilevelOptimizer:
|
|||
# to compute grads for alphas without disturbing
|
||||
# original weights
|
||||
self._vmodel = copy.deepcopy(model).to(device)
|
||||
|
||||
self._alphas = list(_get_alphas(self._model))
|
||||
self._valphas = list(_get_alphas(self._vmodel))
|
||||
|
||||
# this is the optimizer to optimize alphas parameter
|
||||
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())
|
||||
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, self._alphas)
|
||||
|
||||
def state_dict(self)->dict:
|
||||
return {
|
||||
|
@ -47,16 +59,11 @@ class BilevelOptimizer:
|
|||
def _vmodel_params(self):
|
||||
return self._vmodel.parameters()
|
||||
|
||||
@staticmethod
|
||||
def _get_loss(model, lossfn, x, y):
|
||||
logits, *_ = model(x) # might also return aux tower logits
|
||||
return lossfn(logits, y)
|
||||
|
||||
def _update_vmodel(self, x, y, lr: float, w_optim: Optimizer) -> None:
|
||||
""" Update vmodel with w' (main model has w) """
|
||||
|
||||
# TODO: should this loss be stored for later use?
|
||||
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
|
||||
loss = _get_loss(self._model, self._lossfn, x, y)
|
||||
gradients = autograd.grad(loss, self._model_params())
|
||||
|
||||
"""update weights in vmodel so we leave main model undisturbed
|
||||
|
@ -76,7 +83,7 @@ class BilevelOptimizer:
|
|||
vw.copy_(w - lr * (m + g + self._w_weight_decay*w))
|
||||
|
||||
# synchronize alphas
|
||||
for a, va in zip(self._model.alphas(), self._vmodel.alphas()):
|
||||
for a, va in zip_eq(self._alphas, self._valphas):
|
||||
va.copy_(a)
|
||||
|
||||
def step(self, x_train: Tensor, y_train: Tensor, x_valid: Tensor, y_valid: Tensor,
|
||||
|
@ -114,11 +121,11 @@ class BilevelOptimizer:
|
|||
# compute loss on validation set for model with w'
|
||||
# wrt alphas. The autograd.grad is used instead of backward()
|
||||
# to avoid having to loop through params
|
||||
vloss = BilevelOptimizer._get_loss(
|
||||
self._vmodel, self._lossfn, x_valid, y_valid)
|
||||
vloss = _get_loss(self._vmodel, self._lossfn, x_valid, y_valid)
|
||||
|
||||
v_alphas = tuple(self._vmodel.alphas())
|
||||
v_alphas = tuple(self._valphas)
|
||||
v_weights = tuple(self._vmodel_params())
|
||||
# TODO: if v_weights = all params then below does double counting of alpahs
|
||||
v_grads = autograd.grad(vloss, v_alphas + v_weights)
|
||||
|
||||
# grad(L(w', a), a), part of Eq. 6
|
||||
|
@ -133,7 +140,7 @@ class BilevelOptimizer:
|
|||
# update final gradient = dalpha - xi*hessian
|
||||
# TODO: currently alphas lr is same as w lr
|
||||
with torch.no_grad():
|
||||
for alpha, da, h in zip(self._model.alphas(), dalpha, hessian):
|
||||
for alpha, da, h in zip(self._alphas, dalpha, hessian):
|
||||
alpha.grad = da - lr*h
|
||||
# now that model has both w and alpha grads,
|
||||
# we can run w_optim.step() to update the param values
|
||||
|
@ -164,9 +171,9 @@ class BilevelOptimizer:
|
|||
|
||||
# Now that we have model with w+, we need to compute grads wrt alphas
|
||||
# This loss needs to be on train set, not validation set
|
||||
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
|
||||
loss = _get_loss(self._model, self._lossfn, x, y)
|
||||
dalpha_plus = autograd.grad(
|
||||
loss, self._model.alphas()) # dalpha{L_trn(w+)}
|
||||
loss, self._alphas) # dalpha{L_trn(w+)}
|
||||
|
||||
# get model with w- and then compute grads wrt alphas
|
||||
# w- = w - eps*dw`
|
||||
|
@ -176,8 +183,8 @@ class BilevelOptimizer:
|
|||
p -= 2. * epsilon * v
|
||||
|
||||
# similarly get dalpha_minus
|
||||
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
|
||||
dalpha_minus = autograd.grad(loss, self._model.alphas())
|
||||
loss = _get_loss(self._model, self._lossfn, x, y)
|
||||
dalpha_minus = autograd.grad(loss, self._alphas)
|
||||
|
||||
# reset back params to original values by adding dw
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Mapping, Optional, Union
|
||||
from typing import Mapping, Optional, Union, Iterator
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
@ -10,7 +10,7 @@ from archai.common.config import Config
|
|||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model import Model
|
||||
from archai.common.common import logger
|
||||
|
||||
from archai.common.utils import zip_eq
|
||||
|
||||
def _flatten_concate(xs):
|
||||
"""
|
||||
|
@ -21,6 +21,13 @@ def _flatten_concate(xs):
|
|||
"""
|
||||
return torch.cat([x.view(-1) for x in xs])
|
||||
|
||||
def _get_alphas(model:Model)->Iterator[nn.Parameter]:
|
||||
return model.all_owned().param_by_kind('alphas')
|
||||
|
||||
def _get_loss(model:Model, lossfn, x, y):
|
||||
logits, *_ = model(x) # might also return aux tower logits
|
||||
return lossfn(logits, y)
|
||||
|
||||
class BilevelOptimizer:
|
||||
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
|
||||
model: Model, lossfn: _Loss) -> None:
|
||||
|
@ -29,8 +36,10 @@ class BilevelOptimizer:
|
|||
self._lossfn = lossfn
|
||||
self._model = model # main model with respect to w and alpha
|
||||
|
||||
self._alphas = list(_get_alphas(self._model))
|
||||
|
||||
# this is the optimizer to optimize alphas parameter
|
||||
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())
|
||||
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, self._alphas)
|
||||
|
||||
def state_dict(self)->dict:
|
||||
return {
|
||||
|
@ -40,14 +49,9 @@ class BilevelOptimizer:
|
|||
def load_state_dict(self, state_dict)->None:
|
||||
self._alpha_optim.load_state_dict(state_dict['alpha_optim'])
|
||||
|
||||
@staticmethod
|
||||
def _get_loss(model, lossfn, x, y):
|
||||
logits, *_ = model(x) # might also return aux tower logits
|
||||
return lossfn(logits, y)
|
||||
|
||||
def _unrolled_model(self, x, y, lr: float, main_optim: Optimizer)->Model:
|
||||
# TODO: should this loss be stored for later use?
|
||||
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
|
||||
loss = _get_loss(self._model, self._lossfn, x, y)
|
||||
params = _flatten_concate(self._model.parameters()).detach()
|
||||
|
||||
try:
|
||||
|
@ -113,9 +117,9 @@ class BilevelOptimizer:
|
|||
# compute loss on validation set for model with w'
|
||||
# wrt alphas. The autograd.grad is used instead of backward()
|
||||
# to avoid having to loop through params
|
||||
vloss = BilevelOptimizer._get_loss(unrolled_model, self._lossfn, x_valid, y_valid)
|
||||
vloss = _get_loss(unrolled_model, self._lossfn, x_valid, y_valid)
|
||||
vloss.backward()
|
||||
dalpha = [v.grad for v in unrolled_model.alphas()]
|
||||
dalpha = [v.grad for v in _get_alphas(unrolled_model)]
|
||||
dparams = [v.grad.data for v in unrolled_model.parameters()]
|
||||
|
||||
hessian = self._hessian_vector_product(dparams, x_train, y_train)
|
||||
|
@ -125,7 +129,7 @@ class BilevelOptimizer:
|
|||
# update final gradient = dalpha - xi*hessian
|
||||
# TODO: currently alphas lr is same as w lr
|
||||
with torch.no_grad():
|
||||
for alpha, da, h in zip(self._model.alphas(), dalpha, hessian):
|
||||
for alpha, da, h in zip_eq(self._alphas, dalpha, hessian):
|
||||
alpha.grad = da - lr*h
|
||||
# now that model has both w and alpha grads,
|
||||
# we can run main_optim.step() to update the param values
|
||||
|
@ -151,32 +155,32 @@ class BilevelOptimizer:
|
|||
|
||||
# w+ = w + epsilon * grad(w')
|
||||
with torch.no_grad():
|
||||
for p, v in zip(self._model.parameters(), dw):
|
||||
for p, v in zip_eq(self._model.parameters(), dw):
|
||||
p += epsilon * v
|
||||
|
||||
# Now that we have model with w+, we need to compute grads wrt alphas
|
||||
# This loss needs to be on train set, not validation set
|
||||
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
|
||||
loss = _get_loss(self._model, self._lossfn, x, y)
|
||||
dalpha_plus = autograd.grad(
|
||||
loss, self._model.alphas()) # dalpha{L_trn(w+)}
|
||||
loss, self._alphas) # dalpha{L_trn(w+)}
|
||||
|
||||
# get model with w- and then compute grads wrt alphas
|
||||
# w- = w - eps*dw`
|
||||
with torch.no_grad():
|
||||
for p, v in zip(self._model.parameters(), dw):
|
||||
for p, v in zip_eq(self._model.parameters(), dw):
|
||||
# we had already added dw above so sutracting twice gives w-
|
||||
p -= 2. * epsilon * v
|
||||
|
||||
# similarly get dalpha_minus
|
||||
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
|
||||
dalpha_minus = autograd.grad(loss, self._model.alphas())
|
||||
loss = _get_loss(self._model, self._lossfn, x, y)
|
||||
dalpha_minus = autograd.grad(loss, self._alphas)
|
||||
|
||||
# reset back params to original values by adding dw
|
||||
with torch.no_grad():
|
||||
for p, v in zip(self._model.parameters(), dw):
|
||||
for p, v in zip_eq(self._model.parameters(), dw):
|
||||
p += epsilon * v
|
||||
|
||||
# apply eq 8, final difference to compute hessian
|
||||
h = [(p - m) / (2. * epsilon)
|
||||
for p, m in zip(dalpha_plus, dalpha_minus)]
|
||||
for p, m in zip_eq(dalpha_plus, dalpha_minus)]
|
||||
return h
|
||||
|
|
|
@ -9,8 +9,8 @@ class DartsCellBuilder(CellBuilder):
|
|||
@overrides
|
||||
def register_ops(self) -> None:
|
||||
Op.register_op('mixed_op',
|
||||
lambda op_desc, alphas, affine:
|
||||
MixedOp(op_desc, alphas, affine))
|
||||
lambda op_desc, arch_params, affine:
|
||||
MixedOp(op_desc, arch_params, affine))
|
||||
|
||||
@overrides
|
||||
def build(self, model_desc:ModelDesc, search_iter:int)->None:
|
||||
|
|
|
@ -8,6 +8,7 @@ from overrides import overrides
|
|||
|
||||
from archai.nas.model_desc import OpDesc
|
||||
from archai.nas.operations import Op
|
||||
from archai.nas.arch_params import ArchParams
|
||||
|
||||
# TODO: reduction cell might have output reduced by 2^1=2X due to
|
||||
# stride 2 through input nodes however FactorizedReduce does only
|
||||
|
@ -29,41 +30,32 @@ class MixedOp(Op):
|
|||
'none' # this must be at the end so top1 doesn't chose it
|
||||
]
|
||||
|
||||
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
|
||||
def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
|
||||
affine:bool):
|
||||
super().__init__()
|
||||
|
||||
# assume last PRIMITIVE is 'none'
|
||||
assert MixedOp.PRIMITIVES[-1] == 'none'
|
||||
|
||||
self._set_alphas(alphas)
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in MixedOp.PRIMITIVES:
|
||||
op = Op.create(
|
||||
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
|
||||
affine=affine, alphas=alphas)
|
||||
affine=affine, arch_params=None)
|
||||
self._ops.append(op)
|
||||
# we do this at the end so that we can capture all arch params registered by
|
||||
# any previous child modules
|
||||
self._setup_arch_params(arch_params)
|
||||
|
||||
@overrides
|
||||
def forward(self, x):
|
||||
asm = F.softmax(self._alphas[0], dim=0)
|
||||
return sum(w * op(x) for w, op in zip(asm, self._ops))
|
||||
|
||||
@overrides
|
||||
def alphas(self) -> Iterable[nn.Parameter]:
|
||||
for alpha in self._alphas:
|
||||
yield alpha
|
||||
|
||||
@overrides
|
||||
def weights(self) -> Iterable[nn.Parameter]:
|
||||
for op in self._ops:
|
||||
for w in op.parameters():
|
||||
yield w
|
||||
|
||||
@overrides
|
||||
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
|
||||
# select except 'none' op
|
||||
with torch.no_grad():
|
||||
# select except 'none' op
|
||||
val, i = torch.topk(self._alphas[0][:-1], 1)
|
||||
desc, _ = self._ops[i].finalize()
|
||||
return desc, float(val.item())
|
||||
|
@ -76,16 +68,17 @@ class MixedOp(Op):
|
|||
def ops(self)->Iterator['Op']: # type: ignore
|
||||
return iter(self._ops)
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
||||
# must call before adding other ops
|
||||
assert len(list(self.parameters())) == 0
|
||||
self._alphas = list(alphas)
|
||||
if not len(self._alphas):
|
||||
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
|
||||
# do we have shared arch params?
|
||||
if arch_params is None:
|
||||
# create our own arch params
|
||||
new_p = nn.Parameter( # TODO: use better init than uniform random?
|
||||
1.0e-3*torch.randn(len(MixedOp.PRIMITIVES)), requires_grad=True)
|
||||
# NOTE: This is a way to register parameters with PyTorch.
|
||||
# One creates a dummy variable with the parameters and then
|
||||
# asks back for the parameters in the object from Pytorch
|
||||
# which automagically registers the just created parameters.
|
||||
self._reg_alphas = new_p
|
||||
self._alphas = [p for p in self.parameters()]
|
||||
self.create_arch_params([('alphas', new_p)])
|
||||
else:
|
||||
assert arch_params.has_kind('alphas')
|
||||
self.set_arch_params(arch_params)
|
||||
|
||||
# we store alphas in list so Pytorch don't register them
|
||||
self._alphas = list(self.arch_params().param_by_kind('alphas'))
|
||||
assert len(self._alphas)==1
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Mapping, Optional, Union
|
||||
from typing import Mapping, Optional, Union, Tuple
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
@ -16,7 +16,7 @@ from archai.common import utils, ml_utils
|
|||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.common import logger
|
||||
from archai.algos.darts.bilevel_optimizer import BilevelOptimizer
|
||||
from archai.common.multi_optim import MultiOptim, OptimSched
|
||||
|
||||
class DidartsArchTrainer(ArchTrainer):
|
||||
"""Train network using different optimizers for alphas and other parameters"""
|
||||
|
@ -25,65 +25,29 @@ class DidartsArchTrainer(ArchTrainer):
|
|||
checkpoint:Optional[CheckPoint]) -> None:
|
||||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
self._conf_w_optim = conf_train['optimizer']
|
||||
self._conf_w_lossfn = conf_train['lossfn']
|
||||
self._conf_alpha_optim = conf_train['alpha_optimizer']
|
||||
self._conf_alpha_sched = conf_train['alpha_lr_schedule']
|
||||
|
||||
|
||||
@overrides
|
||||
def pre_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader])->None:
|
||||
super().pre_fit(train_dl, val_dl)
|
||||
|
||||
def create_multi_optim(self, train_len:int)->MultiOptim:
|
||||
# optimizers, schedulers needs to be recreated for each fit call
|
||||
# as they have state
|
||||
assert val_dl is not None
|
||||
w_momentum = self._conf_w_optim['momentum']
|
||||
w_decay = self._conf_w_optim['decay']
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
|
||||
# as they have state specific to each run
|
||||
optim = self.create_optimizer(self.conf_optim, self.model.nonarch_params(recurse=True))
|
||||
# create scheduler for optim before applying amp
|
||||
sched, sched_on_epoch = self.create_scheduler(self.conf_sched, optim, train_len)
|
||||
|
||||
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
|
||||
w_decay, self.model, lossfn,
|
||||
self.get_device(), self.batch_chunks)
|
||||
alpha_optim = self.create_optimizer(self._conf_alpha_optim,
|
||||
self.model.all_owned().param_by_kind(None))
|
||||
alpha_sched, alpha_sched_on_epoch = self.create_scheduler(self._conf_alpha_sched, alpha_optim, train_len)
|
||||
|
||||
@overrides
|
||||
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
|
||||
# delete state we created in pre_fit
|
||||
del self._bilevel_optim
|
||||
return super().post_fit(train_dl, val_dl)
|
||||
multi_optim = MultiOptim()
|
||||
multi_optim.append(OptimSched(optim, sched, sched_on_epoch))
|
||||
multi_optim.append(OptimSched(alpha_optim, alpha_sched, alpha_sched_on_epoch))
|
||||
|
||||
@overrides
|
||||
def pre_epoch(self, train_dl: DataLoader, val_dl: Optional[DataLoader])->None:
|
||||
super().pre_epoch(train_dl, val_dl)
|
||||
logger.info({'multi_optim_len': len(multi_optim)})
|
||||
|
||||
# prep val set to train alphas
|
||||
self._valid_iter = iter(val_dl) # type: ignore
|
||||
return multi_optim
|
||||
|
||||
@overrides
|
||||
def post_epoch(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
|
||||
del self._valid_iter # clean up
|
||||
super().post_epoch(train_dl, val_dl)
|
||||
|
||||
@overrides
|
||||
def pre_step(self, x: Tensor, y: Tensor) -> None:
|
||||
super().pre_step(x, y)
|
||||
|
||||
# reset val loader if we exausted it
|
||||
try:
|
||||
x_val, y_val = next(self._valid_iter)
|
||||
except StopIteration:
|
||||
# reinit iterator
|
||||
self._valid_iter = iter(self._val_dl)
|
||||
x_val, y_val = next(self._valid_iter)
|
||||
|
||||
# update alphas
|
||||
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())
|
||||
|
||||
@overrides
|
||||
def update_checkpoint(self, check_point:CheckPoint)->None:
|
||||
super().update_checkpoint(check_point)
|
||||
check_point['bilevel_optim'] = self._bilevel_optim.state_dict()
|
||||
|
||||
@overrides
|
||||
def restore_checkpoint(self)->None:
|
||||
super().restore_checkpoint()
|
||||
self._bilevel_optim.load_state_dict(self.check_point['bilevel_optim'])
|
||||
|
||||
|
|
|
@ -19,16 +19,11 @@ from archai.common.common import logger
|
|||
|
||||
|
||||
class GsArchTrainer(ArchTrainer):
|
||||
def __init__(self, conf_train: Config, model: Model,
|
||||
checkpoint:Optional[CheckPoint]) -> None:
|
||||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
self._conf_w_optim = conf_train['optimizer']
|
||||
# self._conf_w_lossfn = conf_train['lossfn']
|
||||
|
||||
@overrides
|
||||
def create_optimizer(self) -> Optimizer:
|
||||
# in this case we don't need to differentiate between alphas and weights
|
||||
def create_optimizer(self, conf_optim:Config, params) -> Optimizer:
|
||||
# in this case we don't need to differentiate between arch_params and weights
|
||||
# as the same optimizer will update both
|
||||
param_groups = [{'params': self.model.weights()}, {'params': self.model.alphas()}]
|
||||
return ml_utils.create_optimizer(self._conf_w_optim, param_groups)
|
||||
arch_params = list(self.model.all_owned().param_by_kind('alphas'))
|
||||
nonarch_params = list(self.model.nonarch_params(recurse=True))
|
||||
param_groups = [{'params': nonarch_params}, {'params': arch_params}]
|
||||
return ml_utils.create_optimizer(conf_optim, param_groups)
|
||||
|
|
|
@ -9,8 +9,8 @@ class GsCellBuilder(CellBuilder):
|
|||
@overrides
|
||||
def register_ops(self) -> None:
|
||||
Op.register_op('gs_op',
|
||||
lambda op_desc, alphas, affine:
|
||||
GsOp(op_desc, alphas, affine))
|
||||
lambda op_desc, arch_params, affine:
|
||||
GsOp(op_desc, arch_params, affine))
|
||||
|
||||
@overrides
|
||||
def build(self, model_desc:ModelDesc, search_iter:int)->None:
|
||||
|
|
|
@ -8,6 +8,7 @@ from overrides import overrides
|
|||
|
||||
from archai.nas.model_desc import OpDesc
|
||||
from archai.nas.operations import Op
|
||||
from archai.nas.arch_params import ArchParams
|
||||
|
||||
# TODO: reduction cell might have output reduced by 2^1=2X due to
|
||||
# stride 2 through input nodes however FactorizedReduce does only
|
||||
|
@ -29,7 +30,7 @@ class GsOp(Op):
|
|||
'none' # this must be at the end so top1 doesn't chose it
|
||||
]
|
||||
|
||||
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
|
||||
def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
|
||||
affine:bool):
|
||||
super().__init__()
|
||||
|
||||
|
@ -38,20 +39,23 @@ class GsOp(Op):
|
|||
|
||||
self._gs_num_sample = op_desc.params['gs_num_sample']
|
||||
|
||||
self._set_alphas(alphas)
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in GsOp.PRIMITIVES:
|
||||
op = Op.create(
|
||||
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
|
||||
affine=affine, alphas=alphas)
|
||||
affine=affine, arch_params=None)
|
||||
self._ops.append(op)
|
||||
# we do this at the end so that we can capture all arch params registered by
|
||||
# any previous child modules
|
||||
self._setup_arch_params(arch_params)
|
||||
|
||||
@overrides
|
||||
def forward(self, x):
|
||||
# soft sample from the categorical distribution
|
||||
# via gumbel softmax distribution
|
||||
# TODO: should we be normalizing the ensemble?
|
||||
#sampled = torch.zeros(self._alphas[0].size(), requires_grad=True)
|
||||
#sampled = torch.zeros(alphas.size(), requires_grad=True)
|
||||
|
||||
sample_storage = []
|
||||
for _ in range(self._gs_num_sample):
|
||||
sampled = F.gumbel_softmax(self._alphas[0], tau=1, hard=False, eps=1e-10, dim=-1)
|
||||
|
@ -60,18 +64,6 @@ class GsOp(Op):
|
|||
samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0)
|
||||
return sum(w * op(x) for w, op in zip(samples_summed, self._ops))
|
||||
|
||||
|
||||
@overrides
|
||||
def alphas(self) -> Iterable[nn.Parameter]:
|
||||
for alpha in self._alphas:
|
||||
yield alpha
|
||||
|
||||
@overrides
|
||||
def weights(self) -> Iterable[nn.Parameter]:
|
||||
for op in self._ops:
|
||||
for w in op.parameters():
|
||||
yield w
|
||||
|
||||
@overrides
|
||||
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
|
||||
# finalization where each edge gets to keep as many
|
||||
|
@ -110,7 +102,6 @@ class GsOp(Op):
|
|||
|
||||
return final_op_desc, None
|
||||
|
||||
|
||||
@overrides
|
||||
def can_drop_path(self) -> bool:
|
||||
return False
|
||||
|
@ -119,16 +110,17 @@ class GsOp(Op):
|
|||
def ops(self)->Iterator['Op']: # type: ignore
|
||||
return iter(self._ops)
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
||||
# must call before adding other ops
|
||||
assert len(list(self.parameters())) == 0
|
||||
self._alphas = list(alphas)
|
||||
if not len(self._alphas):
|
||||
# TODO: Better initialization than random?
|
||||
new_p = nn.Parameter(1.0e-3*torch.randn(len(GsOp.PRIMITIVES)), requires_grad=True)
|
||||
# NOTE: This is a way to register parameters with PyTorch.
|
||||
# One creates a dummy variable with the parameters and then
|
||||
# asks back for the parameters in the object from Pytorch
|
||||
# which automagically registers the just created parameters.
|
||||
self._reg_alphas = new_p
|
||||
self._alphas = [p for p in self.parameters()]
|
||||
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
|
||||
# do we have shared arch params?
|
||||
if arch_params is None:
|
||||
# create our own arch params
|
||||
new_p = nn.Parameter( # TODO: use better init than uniform random?
|
||||
1.0e-3*torch.randn(len(GsOp.PRIMITIVES)), requires_grad=True)
|
||||
self.create_arch_params([('alphas', new_p)])
|
||||
else:
|
||||
assert arch_params.has_kind('alphas')
|
||||
self.set_arch_params(arch_params)
|
||||
|
||||
# we store alphas in list so Pytorch don't register them
|
||||
self._alphas = list(self.arch_params().param_by_kind('alphas'))
|
||||
assert len(self._alphas)==1
|
||||
|
|
|
@ -11,13 +11,13 @@ class PetridishCellBuilder(CellBuilder):
|
|||
@overrides
|
||||
def register_ops(self) -> None:
|
||||
Op.register_op('petridish_normal_op',
|
||||
lambda op_desc, alphas, affine:
|
||||
PetridishOp(op_desc, alphas, False, affine))
|
||||
lambda op_desc, arch_params, affine:
|
||||
PetridishOp(op_desc, arch_params, False, affine))
|
||||
Op.register_op('petridish_reduction_op',
|
||||
lambda op_desc, alphas, affine:
|
||||
PetridishOp(op_desc, alphas, True, affine))
|
||||
lambda op_desc, arch_params, affine:
|
||||
PetridishOp(op_desc, arch_params, True, affine))
|
||||
Op.register_op('temp_identity_op',
|
||||
lambda op_desc, alphas, affine:
|
||||
lambda op_desc, arch_params, affine:
|
||||
TempIdentityOp(op_desc))
|
||||
|
||||
@overrides
|
||||
|
|
|
@ -10,7 +10,8 @@ from overrides import overrides
|
|||
|
||||
from archai.nas.model_desc import ConvMacroParams, OpDesc
|
||||
from archai.nas.operations import Identity, Op, FactorizedReduce
|
||||
|
||||
from archai.common.utils import zip_eq
|
||||
from archai.nas.arch_params import ArchParams
|
||||
|
||||
class StopForward(Op):
|
||||
def __init__(self):
|
||||
|
@ -79,16 +80,13 @@ class PetridishOp(Op):
|
|||
'none' # this must be at the end so top1 doesn't chose it
|
||||
]
|
||||
|
||||
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
|
||||
def __init__(self, op_desc:OpDesc, arch_params: Optional[ArchParams],
|
||||
reduction:bool, affine:bool):
|
||||
super().__init__()
|
||||
|
||||
# assume last PRIMITIVE is 'none' (this is used for finalize)
|
||||
assert PetridishOp.PRIMITIVES[-1] == 'none'
|
||||
|
||||
# create alphas for the op
|
||||
self._set_alphas(alphas, op_desc.in_len)
|
||||
|
||||
# create edges for the op, each edge connects input state,
|
||||
# within each edge we will have all N primitives
|
||||
self._edges = nn.ModuleList()
|
||||
|
@ -107,7 +105,7 @@ class PetridishOp(Op):
|
|||
for primitive in PetridishOp.PRIMITIVES:
|
||||
primitive_op = Op.create(OpDesc(primitive, params=params,
|
||||
in_len=1, trainables=None),
|
||||
affine=affine, alphas=alphas)
|
||||
affine=affine, arch_params=None)
|
||||
# wrap primitive with sg
|
||||
op = nn.Sequential(StopGradient(), primitive_op)
|
||||
edge.append(op)
|
||||
|
@ -120,31 +118,23 @@ class PetridishOp(Op):
|
|||
# won't match. So you have to use StopGradientReductionOp on s_1 to make it match.
|
||||
self._sf = StopForward()
|
||||
|
||||
# we do this at the end so that we can capture all arch params registered by
|
||||
# any previous child modules
|
||||
self._setup_arch_params(arch_params, op_desc.in_len)
|
||||
|
||||
@overrides
|
||||
def forward(self, x:List[Tensor]):
|
||||
assert not isinstance(x, torch.Tensor)
|
||||
|
||||
s = 0.0
|
||||
# apply each input in the list to associated edge
|
||||
for i, (xi, edge) in enumerate(zip(x, self._edges)):
|
||||
for i, (xi, edge) in enumerate(zip_eq(x, self._edges)):
|
||||
# apply input to each primitive within edge
|
||||
# TODO: is avg better idea than sum here? sum can explode as
|
||||
# number of primitives goes up
|
||||
s = sum(a * op(xi) for a, op in zip(self._alphas[i], edge)) + s
|
||||
s = sum(a * op(xi) for a, op in zip_eq(self._alphas[0][i], edge)) + s
|
||||
return self._sf(s)
|
||||
|
||||
@overrides
|
||||
def alphas(self) -> Iterable[nn.Parameter]:
|
||||
for alpha in self._alphas:
|
||||
yield alpha
|
||||
|
||||
@overrides
|
||||
def weights(self) -> Iterable[nn.Parameter]:
|
||||
#TODO: cache this?
|
||||
for edge in self._edges:
|
||||
for op in edge:
|
||||
for w in op.parameters():
|
||||
yield w
|
||||
|
||||
@overrides
|
||||
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
|
||||
with torch.no_grad(): # probably this is not needed
|
||||
|
@ -152,10 +142,10 @@ class PetridishOp(Op):
|
|||
# Here op should be nn.Sequence of sg followed by primitive.
|
||||
# First for loop gets edge and associated alphas.
|
||||
# Second for loop gets op and associated alpha.
|
||||
l = ((a, i, op[1]) \
|
||||
l = ((a, i, op[1]) \
|
||||
for edge_alphas, i, edge in \
|
||||
zip(self._alphas, range(self.desc.in_len), self._edges) \
|
||||
for a, op in zip(edge_alphas, edge)) # op is nn.Sequence
|
||||
zip_eq(self._alphas[0], range(self.desc.in_len), self._edges) \
|
||||
for a, op in zip_eq(edge_alphas, edge)) # op is nn.Sequence
|
||||
|
||||
# select 3 largest ops by alpha
|
||||
sel = heapq.nlargest(3, l, key=lambda t: t[0]) # TODO: add config
|
||||
|
@ -184,12 +174,9 @@ class PetridishOp(Op):
|
|||
def ops(self)->Iterator['Op']: # type: ignore
|
||||
return iter(self._ops)
|
||||
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter], in_len:int) -> None:
|
||||
assert len(list(self.parameters()))==0 # must call before adding other ops
|
||||
|
||||
# If we are using shared alphas from another cell, don't create our own
|
||||
self._alphas = list(alphas)
|
||||
if not len(self._alphas):
|
||||
def _setup_arch_params(self, arch_params:Optional[ArchParams], in_len:int)->None:
|
||||
# do we have shared arch params?
|
||||
if arch_params is None:
|
||||
# Each nn.Parameter is tensor with alphas for entire edge.
|
||||
# We will create same numbers of nn.Parameter as number of edges
|
||||
n_primitives = len(PetridishOp.PRIMITIVES)
|
||||
|
@ -199,7 +186,12 @@ class PetridishOp(Op):
|
|||
requires_grad=True)
|
||||
for _ in range(in_len)
|
||||
))
|
||||
# register parameters with module
|
||||
self._reg_alphas = pl
|
||||
# save PyTorch registered alphas into list for later use
|
||||
self._alphas = [p for p in self.parameters()]
|
||||
self.create_arch_params([('alphas', pl)])
|
||||
else:
|
||||
assert arch_params.has_kind('alphas')
|
||||
self.set_arch_params(arch_params)
|
||||
|
||||
# we store alphas in list so Pytorch don't register them
|
||||
self._alphas = list(self.arch_params().paramlist_by_kind('alphas'))
|
||||
assert len(self._alphas)==1
|
||||
|
||||
|
|
|
@ -23,14 +23,14 @@ class XnasArchTrainer(ArchTrainer):
|
|||
checkpoint:Optional[CheckPoint]) -> None:
|
||||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
self._conf_w_optim = conf_train['optimizer']
|
||||
self._conf_w_lossfn = conf_train['lossfn']
|
||||
self._conf_alpha_optim = conf_train['alpha_optimizer']
|
||||
|
||||
@overrides
|
||||
def create_optimizer(self) -> Optimizer:
|
||||
def create_optimizer(self, conf_optim:Config, params) -> Optimizer:
|
||||
# return optim that only operates on w, not alphas
|
||||
return ml_utils.create_optimizer(self._conf_w_optim, self.model.weights())
|
||||
return ml_utils.create_optimizer(conf_optim,
|
||||
self.model.nonarch_params(recurse=True))
|
||||
|
||||
@overrides
|
||||
def pre_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader])->None:
|
||||
|
@ -43,7 +43,6 @@ class XnasArchTrainer(ArchTrainer):
|
|||
|
||||
self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn)
|
||||
|
||||
|
||||
@overrides
|
||||
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
|
||||
# delete state we created in pre_fit
|
||||
|
@ -98,7 +97,6 @@ class _XnasOptimizer:
|
|||
return lossfn(logits, y)
|
||||
|
||||
def step(self, x_train: Tensor, y_train: Tensor, x_valid: Tensor, y_valid: Tensor) -> None:
|
||||
|
||||
# put model in train mode just to be safe
|
||||
self._model.train()
|
||||
|
||||
|
|
|
@ -9,8 +9,8 @@ class XnasCellBuilder(CellBuilder):
|
|||
@overrides
|
||||
def register_ops(self) -> None:
|
||||
Op.register_op('xnas_op',
|
||||
lambda op_desc, alphas, affine:
|
||||
XnasOp(op_desc, alphas, affine))
|
||||
lambda op_desc, arch_params, affine:
|
||||
XnasOp(op_desc, arch_params, affine))
|
||||
|
||||
@overrides
|
||||
def build(self, model_desc:ModelDesc, search_iter:int)->None:
|
||||
|
|
|
@ -10,6 +10,8 @@ from overrides import overrides
|
|||
|
||||
from archai.nas.model_desc import OpDesc
|
||||
from archai.nas.operations import Op
|
||||
from archai.nas.arch_params import ArchParams
|
||||
from archai.common.utils import zip_eq
|
||||
|
||||
# TODO: reduction cell might have output reduced by 2^1=2X due to
|
||||
# stride 2 through input nodes however FactorizedReduce does only
|
||||
|
@ -31,25 +33,28 @@ class XnasOp(Op):
|
|||
'none' # this must be at the end so top1 doesn't chose it
|
||||
]
|
||||
|
||||
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
|
||||
def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
|
||||
affine:bool):
|
||||
super().__init__()
|
||||
|
||||
# assume last PRIMITIVE is 'none'
|
||||
assert XnasOp.PRIMITIVES[-1] == 'none'
|
||||
|
||||
self._set_alphas(alphas)
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in XnasOp.PRIMITIVES:
|
||||
op = Op.create(
|
||||
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
|
||||
affine=affine, alphas=alphas)
|
||||
affine=affine, arch_params=None)
|
||||
self._ops.append(op)
|
||||
|
||||
# for getting gradients to non-leaf node
|
||||
self._is_first_call = True
|
||||
self._avg_grad_meter = AverageMeter()
|
||||
|
||||
# we do this at the end so that we can capture all arch params registered by
|
||||
# any previous child modules
|
||||
self._setup_arch_params(arch_params)
|
||||
|
||||
def get_avg_grad(self)->torch.Tensor:
|
||||
return self._avg_grad_meter.avg
|
||||
|
||||
|
@ -70,7 +75,7 @@ class XnasOp(Op):
|
|||
@overrides
|
||||
def forward(self, x):
|
||||
self._activs = [op(x) for op in self._ops]
|
||||
numer = sum(w * activ for w, activ in zip(self._alphas[0], self._activs))
|
||||
numer = sum(w * activ for w, activ in zip_eq(self._alphas[0], self._activs))
|
||||
denom = sum(self._alphas[0])
|
||||
self.pt = torch.div(numer, denom)
|
||||
|
||||
|
@ -81,22 +86,10 @@ class XnasOp(Op):
|
|||
|
||||
return self.pt
|
||||
|
||||
|
||||
@overrides
|
||||
def alphas(self) -> Iterable[nn.Parameter]:
|
||||
for alpha in self._alphas:
|
||||
yield alpha
|
||||
|
||||
@overrides
|
||||
def weights(self) -> Iterable[nn.Parameter]:
|
||||
for op in self._ops:
|
||||
for w in op.parameters():
|
||||
yield w
|
||||
|
||||
@overrides
|
||||
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
|
||||
# select except 'none' op
|
||||
with torch.no_grad():
|
||||
# select except 'none' op
|
||||
val, i = torch.topk(self._alphas[0][:-1], 1)
|
||||
desc, _ = self._ops[i].finalize()
|
||||
return desc, float(val.item())
|
||||
|
@ -105,18 +98,18 @@ class XnasOp(Op):
|
|||
def can_drop_path(self) -> bool:
|
||||
return False
|
||||
|
||||
# TODO: Do we even need alphas to be registered with Pytorch
|
||||
# since we don't have to compute gradients on them?
|
||||
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
|
||||
# must call before adding other ops
|
||||
assert len(list(self.parameters())) == 0
|
||||
self._alphas = list(alphas)
|
||||
if not len(self._alphas):
|
||||
# TODO: Better initialization than random?
|
||||
new_p = nn.Parameter(1.0e-3*torch.randn(len(XnasOp.PRIMITIVES)), requires_grad=False)
|
||||
# NOTE: This is a way to register parameters with PyTorch.
|
||||
# One creates a dummy variable with the parameters and then
|
||||
# asks back for the parameters in the object from Pytorch
|
||||
# which automagically registers the just created parameters.
|
||||
self._reg_alphas = new_p
|
||||
self._alphas = [p for p in self.parameters()]
|
||||
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
|
||||
# do we have shared arch params?
|
||||
if arch_params is None:
|
||||
# create our own arch params
|
||||
# TODO: dey: why requires_grad = False?
|
||||
new_p = nn.Parameter( # TODO: use better init than uniform random?
|
||||
1.0e-3*torch.randn(len(XnasOp.PRIMITIVES)), requires_grad=False)
|
||||
self.create_arch_params([('alphas', new_p)])
|
||||
else:
|
||||
assert arch_params.has_kind('alphas')
|
||||
self.set_arch_params(arch_params)
|
||||
|
||||
# we store alphas in list so Pytorch don't register them
|
||||
self._alphas = list(self.arch_params().param_by_kind('alphas'))
|
||||
assert len(self._alphas)==1
|
||||
|
|
|
@ -13,6 +13,7 @@ import psutil
|
|||
from archai.common.config import Config
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.ordereddict_logger import OrderedDictLogger
|
||||
from archai.common.multi_optim import MultiOptim
|
||||
|
||||
class ApexUtils:
|
||||
def __init__(self, apex_config:Config, logger:Optional[OrderedDictLogger])->None:
|
||||
|
@ -186,15 +187,22 @@ class ApexUtils:
|
|||
else:
|
||||
return val
|
||||
|
||||
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
|
||||
def _get_optim(self, multi_optim:MultiOptim)->Optimizer:
|
||||
assert len(multi_optim)==1, \
|
||||
'Mixed precision is only supported for one optimizer' \
|
||||
f' but {len(multi_optim)} optimizers were supplied'
|
||||
return multi_optim[0].optim
|
||||
|
||||
def backward(self, loss:torch.Tensor, multi_optim:MultiOptim)->None:
|
||||
if self.is_mixed():
|
||||
optim = self._get_optim(multi_optim)
|
||||
with self._amp.scale_loss(loss, optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
|
||||
->Tuple[nn.Module, Optimizer]:
|
||||
def to_amp(self, model:nn.Module, multi_optim:MultiOptim, batch_size:int)\
|
||||
->nn.Module:
|
||||
# conver BNs to sync BNs in distributed mode
|
||||
if self.is_dist() and self._sync_bn:
|
||||
model = self._ddp.convert_syncbn_model(model)
|
||||
|
@ -203,6 +211,8 @@ class ApexUtils:
|
|||
model = model.to(self.device)
|
||||
|
||||
if self.is_mixed():
|
||||
optim = self._get_optim(multi_optim)
|
||||
|
||||
# scale LR
|
||||
if self.is_dist() and self._scale_lr:
|
||||
lr = ml_utils.get_optim_lr(optim)
|
||||
|
@ -215,17 +225,21 @@ class ApexUtils:
|
|||
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
|
||||
)
|
||||
|
||||
# put back amp'd optim
|
||||
multi_optim[0].optim = optim
|
||||
|
||||
if self.is_dist():
|
||||
# By default, apex.parallel.DistributedDataParallel overlaps communication with
|
||||
# computation in the backward pass.
|
||||
# delay_allreduce delays all communication to the end of the backward pass.
|
||||
model = self._ddp.DistributedDataParallel(model, delay_allreduce=True)
|
||||
|
||||
return model, optim
|
||||
return model
|
||||
|
||||
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
|
||||
def clip_grad(self, clip:float, model:nn.Module, multi_optim:MultiOptim)->None:
|
||||
if clip > 0.0:
|
||||
if self.is_mixed():
|
||||
optim = self._get_optim(multi_optim)
|
||||
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
||||
|
|
|
@ -14,7 +14,6 @@ import yaml
|
|||
from . import yaml_utils
|
||||
|
||||
|
||||
|
||||
# global config instance
|
||||
_config:'Config' = None
|
||||
|
||||
|
|
|
@ -188,7 +188,8 @@ class Metrics:
|
|||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
del state['_apex'] # cannot serialize this
|
||||
if '_apex' in state:
|
||||
del state['_apex'] # cannot serialize this
|
||||
return state
|
||||
# no need to define __setstate__ because _apex should be set from constructor
|
||||
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Iterator, List, Optional
|
||||
from collections import UserList
|
||||
|
||||
from torch import nn, Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from archai.common.utils import zip_eq
|
||||
|
||||
|
||||
class OptimSched:
|
||||
"""Holds the optimizer and scheduler"""
|
||||
def __init__(self, optim:Optimizer, sched:Optional[_LRScheduler],
|
||||
sched_on_epoch:Optional[bool])->None:
|
||||
self.optim = optim
|
||||
self.sched = sched
|
||||
self.sched_on_epoch = sched_on_epoch
|
||||
|
||||
class MultiOptim:
|
||||
def __init__(self) -> None:
|
||||
self._optim_scheds:List[OptimSched] = []
|
||||
|
||||
def append(self, optim_sched:OptimSched)->None:
|
||||
self._optim_scheds.append(optim_sched)
|
||||
|
||||
def zero_grad(self)->None:
|
||||
for optim_sched in self._optim_scheds:
|
||||
optim_sched.optim.zero_grad()
|
||||
|
||||
def step(self)->None:
|
||||
for optim_sched in self._optim_scheds:
|
||||
optim_sched.optim.step()
|
||||
if optim_sched.sched and not optim_sched.sched_on_epoch:
|
||||
optim_sched.sched.step(epoch=None)
|
||||
|
||||
def epoch(self, epoch:Optional[int]=None)->None:
|
||||
for optim_sched in self._optim_scheds:
|
||||
if optim_sched.sched and optim_sched.sched_on_epoch:
|
||||
optim_sched.sched.step(epoch=epoch)
|
||||
|
||||
def get_lr(self, optim_index:int, param_index:int)->float:
|
||||
return self._optim_scheds[optim_index].optim.param_groups[param_index]['lr']
|
||||
|
||||
def state_dict(self)->dict:
|
||||
optim_states = [optim_sched.optim.state_dict() for optim_sched in self]
|
||||
sched_states = [optim_sched.sched.state_dict() if optim_sched.sched else None \
|
||||
for optim_sched in self]
|
||||
|
||||
return {'optim_states': optim_states, 'sched_states':sched_states}
|
||||
|
||||
def load_state_dict(self, state_dict:dict)->None:
|
||||
optim_states = state_dict['optim_states']
|
||||
sched_states = state_dict['sched_states']
|
||||
|
||||
for optim_sched, optim_state, sched_state in zip_eq(self, optim_states, sched_states):
|
||||
optim_sched.optim.load_state_dict(optim_state)
|
||||
if optim_sched.sched:
|
||||
assert sched_state is not None
|
||||
optim_sched.sched.load_state_dict(sched_state)
|
||||
else:
|
||||
assert sched_state is None
|
||||
|
||||
def __getitem__(self, index)->OptimSched:
|
||||
return self._optim_scheds[index]
|
||||
|
||||
def __len__(self)->int:
|
||||
return len(self._optim_scheds)
|
||||
|
||||
def __iter__(self)->Iterator[OptimSched]:
|
||||
return iter(self._optim_scheds)
|
||||
|
|
@ -1,25 +1,28 @@
|
|||
from typing import Callable, Tuple, Optional
|
||||
|
||||
from torch import nn, Tensor, torch
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from overrides import EnforceOverrides
|
||||
|
||||
from .metrics import Metrics
|
||||
from .tester import Tester
|
||||
from .config import Config
|
||||
from . import utils, ml_utils
|
||||
from ..common.common import logger
|
||||
from ..common.checkpoint import CheckPoint
|
||||
from ..common.apex_utils import ApexUtils
|
||||
from archai.common.metrics import Metrics
|
||||
from archai.common.tester import Tester
|
||||
from archai.common.config import Config
|
||||
from archai.common import utils, ml_utils
|
||||
from archai.common.common import logger
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.apex_utils import ApexUtils
|
||||
from archai.common.multi_optim import MultiOptim, OptimSched
|
||||
|
||||
|
||||
class Trainer(EnforceOverrides):
|
||||
def __init__(self, conf_train:Config, model:nn.Module,
|
||||
checkpoint:Optional[CheckPoint])->None:
|
||||
# region config vars
|
||||
self.conf_train = conf_train
|
||||
conf_lossfn = conf_train['lossfn']
|
||||
self._aux_weight = conf_train['aux_weight']
|
||||
self._grad_clip = conf_train['grad_clip']
|
||||
|
@ -27,8 +30,8 @@ class Trainer(EnforceOverrides):
|
|||
self._logger_freq = conf_train['logger_freq']
|
||||
self._title = conf_train['title']
|
||||
self._epochs = conf_train['epochs']
|
||||
self._conf_optim = conf_train['optimizer']
|
||||
self._conf_sched = conf_train['lr_schedule']
|
||||
self.conf_optim = conf_train['optimizer']
|
||||
self.conf_sched = conf_train['lr_schedule']
|
||||
self.batch_chunks = conf_train['batch_chunks']
|
||||
conf_validation = conf_train['validation']
|
||||
conf_apex = conf_train['apex']
|
||||
|
@ -58,14 +61,11 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)
|
||||
|
||||
# optimizers, schedulers needs to be recreated for each fit call
|
||||
# as they have state specific to each run
|
||||
optim = self.create_optimizer()
|
||||
# create scheduler for optim before applying amp
|
||||
self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl))
|
||||
# create optimizers and schedulers
|
||||
self._multi_optim = self.create_multi_optim(len(train_dl))
|
||||
# before checkpoint restore, convert to amp
|
||||
self.model, self._optim = self._apex.to_amp(self.model, optim,
|
||||
batch_size=train_dl.batch_size)
|
||||
self.model = self._apex.to_amp(self.model, self._multi_optim,
|
||||
batch_size=train_dl.batch_size)
|
||||
|
||||
self._lossfn = self._lossfn.to(self.get_device())
|
||||
|
||||
|
@ -101,42 +101,53 @@ class Trainer(EnforceOverrides):
|
|||
logger.pushd('epochs')
|
||||
for epoch in range(self._start_epoch, self._epochs):
|
||||
logger.pushd(epoch)
|
||||
assert self._metrics.epochs() == epoch
|
||||
self._set_drop_path(epoch, self._epochs)
|
||||
|
||||
self.pre_epoch(train_dl, val_dl)
|
||||
self._train_epoch(train_dl)
|
||||
self.post_epoch(train_dl, val_dl)
|
||||
|
||||
logger.popd()
|
||||
logger.popd()
|
||||
self.post_fit(train_dl, val_dl)
|
||||
|
||||
# make sure we don't keep references to the graph
|
||||
del self._optim
|
||||
del self._sched
|
||||
|
||||
del self._multi_optim
|
||||
|
||||
logger.popd()
|
||||
return self.get_metrics()
|
||||
|
||||
def create_optimizer(self)->Optimizer:
|
||||
optim = ml_utils.create_optimizer(self._conf_optim, self.model.parameters())
|
||||
logger.info({'conf_optim': self._conf_optim})
|
||||
def create_multi_optim(self, train_len:int)->MultiOptim:
|
||||
logger.info({'steps_per_epoch': train_len,
|
||||
'conf_sched': self.conf_sched.to_dict()})
|
||||
logger.info({'conf_optim': self.conf_optim.to_dict()})
|
||||
|
||||
# optimizers, schedulers needs to be recreated for each fit call
|
||||
# as they have state specific to each run
|
||||
optim = self.create_optimizer(self.conf_optim, self.model.parameters())
|
||||
# create scheduler for optim before applying amp
|
||||
sched, sched_on_epoch = self.create_scheduler(self.conf_sched, optim, train_len)
|
||||
|
||||
multi_optim = MultiOptim()
|
||||
multi_optim.append(OptimSched(optim, sched, sched_on_epoch))
|
||||
|
||||
logger.info({'multi_optim_len': len(multi_optim)})
|
||||
|
||||
return multi_optim
|
||||
|
||||
def create_optimizer(self, conf_optim:Config, params)->Optimizer:
|
||||
optim = ml_utils.create_optimizer(conf_optim, params)
|
||||
return optim
|
||||
|
||||
def _create_scheduler(self, optim:Optimizer, steps_per_epoch:int) \
|
||||
def create_scheduler(self, conf_sched:Config, optim:Optimizer, steps_per_epoch:int) \
|
||||
->Tuple[Optional[_LRScheduler],bool]:
|
||||
|
||||
logger.info({'steps_per_epoch': steps_per_epoch,
|
||||
'scheduler': self._conf_sched.to_dict()})
|
||||
|
||||
return ml_utils.create_lr_scheduler(self._conf_sched, self._epochs,
|
||||
return ml_utils.create_lr_scheduler(conf_sched, self._epochs,
|
||||
optim, steps_per_epoch)
|
||||
|
||||
def get_optimizer(self)->Optimizer:
|
||||
return self._optim
|
||||
def get_scheduler(self)->Optional[_LRScheduler]:
|
||||
return self._sched
|
||||
def get_optimizer(self, index=0)->Optimizer:
|
||||
return self._multi_optim[index].optim
|
||||
def get_scheduler(self, index=0)->Optional[_LRScheduler]:
|
||||
return self._multi_optim[index].sched
|
||||
|
||||
def get_metrics(self)->Metrics:
|
||||
return self._metrics
|
||||
|
@ -149,7 +160,7 @@ class Trainer(EnforceOverrides):
|
|||
self._metrics.post_run()
|
||||
|
||||
def pre_epoch(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
|
||||
self._metrics.pre_epoch(lr=self._optim.param_groups[0]['lr'])
|
||||
self._metrics.pre_epoch(lr=self._multi_optim.get_lr(0, 0))
|
||||
|
||||
def post_epoch(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
|
||||
val_metrics = None
|
||||
|
@ -157,10 +168,15 @@ class Trainer(EnforceOverrides):
|
|||
if val_dl and self._tester and self._validation_freq > 0:
|
||||
if self._metrics.epochs() % self._validation_freq == 0 or \
|
||||
self._metrics.epochs() >= self._epochs:
|
||||
# optimizers such as bi-level may use val set for its own use
|
||||
# which causes reshuffling due to automatic epoch counting
|
||||
# here we make sure that val_dl has same epoch as train_dl
|
||||
if hasattr(val_dl.sampler, 'set_epoch'):
|
||||
val_dl.sampler.set_epoch(self._metrics.epochs())
|
||||
val_metrics = self._tester.test(val_dl)
|
||||
|
||||
# update val metrics
|
||||
self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr'])
|
||||
self._metrics.post_epoch(val_metrics, lr=self._multi_optim.get_lr(0, 0))
|
||||
|
||||
# checkpoint if enabled with given freq or if this is the last epoch
|
||||
if self._checkpoint is not None and self._apex.is_master() and \
|
||||
|
@ -190,11 +206,7 @@ class Trainer(EnforceOverrides):
|
|||
assert self._metrics.epochs() == last_epoch+1
|
||||
self._apex.load_state_dict(state['amp'])
|
||||
self.model.load_state_dict(state['model'])
|
||||
self._optim.load_state_dict(state['optim'])
|
||||
if self._sched:
|
||||
self._sched.load_state_dict(state['sched'])
|
||||
else:
|
||||
assert state['sched'] is None
|
||||
self._multi_optim.load_state_dict(state['multi_optim'])
|
||||
|
||||
self._start_epoch = last_epoch + 1
|
||||
|
||||
|
@ -204,8 +216,7 @@ class Trainer(EnforceOverrides):
|
|||
'last_epoch': self._metrics.epochs()-1,
|
||||
'metrics': self._metrics.state_dict(),
|
||||
'model': self.model.state_dict(),
|
||||
'optim': self._optim.state_dict(),
|
||||
'sched': self._sched.state_dict() if self._sched else None,
|
||||
'multi_optim': self._multi_optim.state_dict(),
|
||||
'amp': self._apex.state_dict()
|
||||
}
|
||||
self._checkpoint['trainer'] = state
|
||||
|
@ -221,7 +232,7 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
self.pre_step(x, y)
|
||||
|
||||
self._optim.zero_grad()
|
||||
self._multi_optim.zero_grad()
|
||||
|
||||
# divide batch in to chunks if needed so it fits in GPU RAM
|
||||
if self.batch_chunks > 1:
|
||||
|
@ -243,23 +254,20 @@ class Trainer(EnforceOverrides):
|
|||
loss_c = self.compute_loss(self._lossfn, yc, logits_c,
|
||||
self._aux_weight, aux_logits)
|
||||
|
||||
self._apex.backward(loss_c, self._optim)
|
||||
self._apex.backward(loss_c, self._multi_optim)
|
||||
|
||||
loss_sum += loss_c.item() * len(logits_c)
|
||||
loss_count += len(logits_c)
|
||||
logits_chunks.append(logits_c.detach().cpu())
|
||||
|
||||
# TODO: original darts clips alphas as well but pt.darts doesn't
|
||||
self._apex.clip_grad(self._grad_clip, self.model, self._optim)
|
||||
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
|
||||
|
||||
self._optim.step()
|
||||
self._multi_optim.step()
|
||||
|
||||
# TODO: we possibly need to sync so all replicas are upto date
|
||||
self._apex.sync_devices()
|
||||
|
||||
if self._sched and not self._sched_on_epoch:
|
||||
self._sched.step()
|
||||
|
||||
self.post_step(x, y,
|
||||
ml_utils.join_chunks(logits_chunks),
|
||||
torch.tensor(loss_sum/loss_count),
|
||||
|
@ -268,8 +276,7 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
# end of step
|
||||
|
||||
if self._sched and self._sched_on_epoch:
|
||||
self._sched.step()
|
||||
self._multi_optim.epoch()
|
||||
logger.popd()
|
||||
|
||||
def compute_loss(self, lossfn:Callable, y:Tensor, logits:Tensor,
|
||||
|
|
|
@ -7,6 +7,7 @@ import sys
|
|||
import os
|
||||
import pathlib
|
||||
import random
|
||||
from itertools import zip_longest
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
@ -230,3 +231,10 @@ def exec_shell_command(command:str, print_command=True)->None:
|
|||
print(command)
|
||||
subprocess.run(command, shell=True, check=True)
|
||||
|
||||
def zip_eq(*iterables):
|
||||
sentinel = object()
|
||||
for count, combo in enumerate(zip_longest(*iterables, fillvalue=sentinel)):
|
||||
if any(True for c in combo if sentinel is c):
|
||||
shorter_its = ','.join([str(i) for i,c in enumerate(combo) if sentinel is c])
|
||||
raise ValueError(f'Iterator {shorter_its} have length {count} which is shorter than others')
|
||||
yield combo
|
|
@ -0,0 +1,47 @@
|
|||
from abc import ABC
|
||||
from typing import Iterable, List, Optional, Tuple, Iterator
|
||||
|
||||
from torch import nn
|
||||
|
||||
from archai.nas.arch_params import ArchParams, NNTypes
|
||||
from archai.common import utils
|
||||
|
||||
class ArchModule(nn.Module, ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# these are params module should use, they may be shared or created by this module
|
||||
self._arch_params = ArchParams.empty()
|
||||
# these are the params created and registerd in this module
|
||||
self._owned_arch_params:Optional[ArchParams] = None
|
||||
|
||||
def create_arch_params(self, named_params:Iterable[Tuple[str, NNTypes]])->None:
|
||||
if len(self._arch_params):
|
||||
raise RuntimeError('Arch parameters for this module already exist')
|
||||
self._owned_arch_params = ArchParams(named_params, registrar=self)
|
||||
self.set_arch_params(self._owned_arch_params)
|
||||
|
||||
def set_arch_params(self, arch_params:ArchParams)->None:
|
||||
if len(self._arch_params):
|
||||
raise RuntimeError('Arch parameters for this module already exist')
|
||||
self._arch_params = arch_params
|
||||
|
||||
def arch_params(self, recurse=False, only_owned=False)->ArchParams:
|
||||
# note that we will cache lists on first calls, this doesn't allow
|
||||
# dynamic parameters but it makes this frequent calls much faster
|
||||
if not recurse:
|
||||
if not only_owned:
|
||||
return self._arch_params
|
||||
else:
|
||||
return ArchParams.from_module(self, recurse=False)
|
||||
else:
|
||||
if not only_owned:
|
||||
raise NotImplementedError('Recursively getting shared and owned arch params not implemented yet')
|
||||
else:
|
||||
return ArchParams.from_module(self, recurse=True)
|
||||
|
||||
def all_owned(self)->ArchParams:
|
||||
return self.arch_params(recurse=True, only_owned=True)
|
||||
|
||||
def nonarch_params(self, recurse:bool)->Iterator[nn.Parameter]:
|
||||
return ArchParams.nonarch_from_module(self, recurse)
|
|
@ -0,0 +1,75 @@
|
|||
from collections import UserDict
|
||||
from typing import Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
_param_suffix = '_arch_param' # all arch parameter names must have this suffix
|
||||
|
||||
NNTypes = Union[nn.Parameter, nn.ParameterDict, nn.ParameterList]
|
||||
|
||||
class ArchParams(UserDict):
|
||||
"""This class holds set of learnable architecture parameter(s) for a given module. For example, one instance of this class would hold alphas for one instance of MixedOp. For sharing parameters, instance of this class can be passed around. Different algorithms may add learnable parameters for their need."""
|
||||
|
||||
def __init__(self, arch_params:Iterable[Tuple[str, NNTypes]], registrar:Optional[nn.Module]=None):
|
||||
"""Create architecture parameters and register them
|
||||
|
||||
Arguments:
|
||||
registrar {Optional[nn.Module]} -- If this parameter is beingly newly created instead of being shared by other module then owner should be specified. When owner is not None, this method will create a variable in the owning module with suffix _arch_param so that the parameter gets registered with Pytorch and becomes available in module's .parameters() calls.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
for name, param in arch_params:
|
||||
self.data[name] = param
|
||||
if registrar is not None:
|
||||
setattr(registrar, name + _param_suffix, param)
|
||||
|
||||
def __setitem__(self, name:str, param:NNTypes)->None:
|
||||
raise RuntimeError(f'ArchParams is immutable hence adding/updating key {name} is not allowed.')
|
||||
|
||||
def __delitem__(self, name:str) -> None:
|
||||
raise RuntimeError(f'ArchParams is immutable hence removing key {name} is not allowed.')
|
||||
|
||||
def _by_kind(self, kind:Optional[str])->Iterator[NNTypes]:
|
||||
# TODO: may be optimize to avoid split() calls?
|
||||
for name, param in self.items():
|
||||
if kind is None or name.split('.')[-1]==kind:
|
||||
yield param
|
||||
|
||||
def param_by_kind(self, kind:Optional[str])->Iterator[nn.Parameter]:
|
||||
# TODO: enforce type checking if debugger is active?
|
||||
return self._by_kind(kind) # type: ignore
|
||||
|
||||
def paramlist_by_kind(self, kind:Optional[str])->Iterator[nn.ParameterList]:
|
||||
# TODO: enforce type checking if debugger is active?
|
||||
return self._by_kind(kind) # type: ignore
|
||||
|
||||
def paramdict_by_kind(self, kind:Optional[str])->Iterator[nn.ParameterDict]:
|
||||
# TODO: enforce type checking if debugger is active?
|
||||
return self._by_kind(kind) # type: ignore
|
||||
|
||||
def has_kind(self, kind:str)->bool:
|
||||
# TODO: may be optimize to avoid split() calls?
|
||||
for name in self.keys():
|
||||
if name.split('.')[-1]==kind:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def from_module(module:nn.Module, recurse:bool=False)->'ArchParams':
|
||||
suffix_len = len(_param_suffix)
|
||||
# Pytorch named params have . in name for each module, we pick last part and remove _arch_params prefix
|
||||
arch_params = ((name[:-suffix_len], param) \
|
||||
for name, param in module.named_parameters(recurse=recurse)
|
||||
if name.endswith(_param_suffix))
|
||||
return ArchParams(arch_params)
|
||||
|
||||
@staticmethod
|
||||
def nonarch_from_module(module:nn.Module, recurse:bool=False)->Iterator[nn.Parameter]:
|
||||
# Pytorch named params have . in name for each module, we pick last part and remove _arch_params prefix
|
||||
return (param for name, param in module.named_parameters(recurse=recurse)
|
||||
if not name.endswith(_param_suffix))
|
||||
|
||||
@staticmethod
|
||||
def empty()->'ArchParams':
|
||||
return ArchParams([])
|
||||
|
|
@ -27,6 +27,10 @@ class ArchTrainer(Trainer, EnforceOverrides):
|
|||
self._l1_alphas = conf_train['l1_alphas']
|
||||
self._plotsdir = conf_train['plotsdir']
|
||||
|
||||
# if l1 regularization is needed then cache alphas
|
||||
if self._l1_alphas > 0.0:
|
||||
self._alphas = list(self.model.all_owned().param_by_kind('alphas'))
|
||||
|
||||
@overrides
|
||||
def compute_loss(self, lossfn: Callable,
|
||||
y: Tensor, logits: Tensor,
|
||||
|
@ -35,7 +39,7 @@ class ArchTrainer(Trainer, EnforceOverrides):
|
|||
aux_weight, aux_logits)
|
||||
# add L1 alpha regularization
|
||||
if self._l1_alphas > 0.0:
|
||||
l_extra = sum(torch.sum(a.abs()) for a in self.model.alphas())
|
||||
l_extra = sum(torch.sum(a.abs()) for a in self._alphas)
|
||||
loss += self._l1_alphas * l_extra
|
||||
return loss
|
||||
|
||||
|
|
|
@ -1,36 +1,35 @@
|
|||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from torch import nn, tensor
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from ..common.common import logger
|
||||
from .dag_edge import DagEdge
|
||||
from .model_desc import ConvMacroParams, CellDesc, OpDesc, NodeDesc
|
||||
from .operations import Op
|
||||
from archai.nas.dag_edge import DagEdge
|
||||
from archai.nas.model_desc import ConvMacroParams, CellDesc, OpDesc, NodeDesc
|
||||
from archai.nas.operations import Op
|
||||
from archai.nas.arch_module import ArchModule
|
||||
|
||||
class Cell(nn.Module, ABC, EnforceOverrides):
|
||||
class Cell(ArchModule, EnforceOverrides):
|
||||
def __init__(self, desc:CellDesc,
|
||||
affine:bool, droppath:bool,
|
||||
alphas_cell:Optional['Cell']):
|
||||
template_cell:Optional['Cell']): # template cell, if any, to use for arch params
|
||||
super().__init__()
|
||||
|
||||
# some of these members are public as finalizer needs access
|
||||
self.shared_alphas = alphas_cell is not None
|
||||
self.desc = desc
|
||||
self.s0_op = Op.create(desc.s0_op, affine=affine)
|
||||
self.s1_op = Op.create(desc.s1_op, affine=affine)
|
||||
|
||||
self.dag = Cell._create_dag(desc.nodes(),
|
||||
affine=affine, droppath=droppath,
|
||||
alphas_cell=alphas_cell)
|
||||
template_cell=template_cell)
|
||||
|
||||
self.post_op = Op.create(desc.post_op, affine=affine)
|
||||
|
||||
@staticmethod
|
||||
def _create_dag(nodes_desc:List[NodeDesc],
|
||||
affine:bool, droppath:bool,
|
||||
alphas_cell:Optional['Cell'])->nn.ModuleList:
|
||||
template_cell:Optional['Cell'])->nn.ModuleList:
|
||||
dag = nn.ModuleList()
|
||||
for i, node_desc in enumerate(nodes_desc):
|
||||
edges:nn.ModuleList = nn.ModuleList()
|
||||
|
@ -39,21 +38,9 @@ class Cell(nn.Module, ABC, EnforceOverrides):
|
|||
for j, edge_desc in enumerate(node_desc.edges):
|
||||
edges.append(DagEdge(edge_desc,
|
||||
affine=affine, droppath=droppath,
|
||||
alphas_edge=alphas_cell.dag[i][j] if alphas_cell else None))
|
||||
template_edge=template_cell.dag[i][j] if template_cell else None))
|
||||
return dag
|
||||
|
||||
def alphas(self)->Iterable[nn.Parameter]:
|
||||
for node in self.dag:
|
||||
for edge in node:
|
||||
for alpha in edge.alphas():
|
||||
yield alpha
|
||||
|
||||
def weights(self)->Iterable[nn.Parameter]:
|
||||
for node in self.dag:
|
||||
for edge in node:
|
||||
for p in edge.weights():
|
||||
yield p
|
||||
|
||||
def ops(self)->Iterable[Op]:
|
||||
for node in self.dag:
|
||||
for edge in node:
|
||||
|
|
|
@ -4,16 +4,17 @@ import torch
|
|||
from torch import nn
|
||||
from overrides import overrides
|
||||
|
||||
from .operations import Op, DropPath_
|
||||
from .model_desc import EdgeDesc
|
||||
from archai.nas.operations import Op, DropPath_
|
||||
from archai.nas.model_desc import EdgeDesc
|
||||
from archai.nas.arch_module import ArchModule
|
||||
|
||||
class DagEdge(nn.Module):
|
||||
class DagEdge(ArchModule):
|
||||
def __init__(self, desc:EdgeDesc, affine:bool, droppath:bool,
|
||||
alphas_edge:Optional['DagEdge'])->None:
|
||||
template_edge:Optional['DagEdge'])->None:
|
||||
super().__init__()
|
||||
# we may need to wrap op is droppath is needed
|
||||
self._wrapped = self._op = Op.create(desc.op_desc, affine,
|
||||
alphas_edge.alphas() if alphas_edge else [])
|
||||
template_edge.op().arch_params() if template_edge is not None else None)
|
||||
if droppath and self._op.can_drop_path():
|
||||
assert self.training
|
||||
self._wrapped = nn.Sequential(self._op, DropPath_())
|
||||
|
@ -29,14 +30,5 @@ class DagEdge(nn.Module):
|
|||
else:
|
||||
return self._wrapped([inputs[i] for i in self.input_ids])
|
||||
|
||||
def alphas(self)->Iterable[nn.Parameter]:
|
||||
for alpha in self._op.alphas():
|
||||
if alpha is not None:
|
||||
yield alpha
|
||||
|
||||
def weights(self)->Iterable[nn.Parameter]:
|
||||
for w in self._op.weights():
|
||||
yield w
|
||||
|
||||
def op(self)->Op:
|
||||
return self._op
|
||||
|
|
|
@ -56,7 +56,7 @@ class Finalizers(EnforceOverrides):
|
|||
nodes = node_descs,
|
||||
s0_op=cell.s0_op.finalize()[0],
|
||||
s1_op=cell.s1_op.finalize()[0],
|
||||
alphas_from = cell.desc.alphas_from,
|
||||
template_cell = cell.desc.template_cell,
|
||||
max_final_edges=cell.desc.max_final_edges,
|
||||
node_ch_out=cell.desc.node_ch_out,
|
||||
post_op=cell.post_op.finalize()[0]
|
||||
|
|
|
@ -99,21 +99,21 @@ class MacroBuilder(EnforceOverrides):
|
|||
# chennels of prev-prev whole cell, prev whole cell and current cell node
|
||||
pp_ch_out, p_ch_out, node_ch_out = stem_ch_out, stem_ch_out, self.init_node_ch
|
||||
|
||||
# stores first cells of each time with whom alphas would be shared
|
||||
# stores first cells of each time with whom arch params would be shared
|
||||
first_normal, first_reduction = -1, -1
|
||||
|
||||
for ci in range(self.n_cells):
|
||||
# find cell type and output channels for this cell
|
||||
# also update if this is our first cell from which alphas will be shared
|
||||
# also update if this is our first cell from which arch params will be shared
|
||||
reduction = self._is_reduction(ci)
|
||||
if reduction:
|
||||
node_ch_out, cell_type = node_ch_out*2, CellType.Reduction
|
||||
first_reduction = ci if first_reduction < 0 else first_reduction
|
||||
alphas_from = first_reduction
|
||||
template_cell = first_reduction
|
||||
else:
|
||||
cell_type = CellType.Regular
|
||||
first_normal = ci if first_normal < 0 else first_normal
|
||||
alphas_from = first_normal
|
||||
template_cell = first_normal
|
||||
|
||||
s0_op, s1_op = self._get_cell_stems(
|
||||
node_ch_out, p_ch_out, pp_ch_out, reduction_p)
|
||||
|
@ -134,7 +134,7 @@ class MacroBuilder(EnforceOverrides):
|
|||
cell_type=cell_type, id=ci,
|
||||
nodes=nodes,
|
||||
s0_op=s0_op, s1_op=s1_op,
|
||||
alphas_from=alphas_from,
|
||||
template_cell=template_cell,
|
||||
max_final_edges=max_final_edges,
|
||||
node_ch_out=node_ch_out,
|
||||
post_op=self.cell_post_op
|
||||
|
|
|
@ -6,17 +6,17 @@ import os
|
|||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.nas.arch_params import ArchParams
|
||||
from archai.nas.cell import Cell
|
||||
from archai.nas.operations import Op, DropPath_
|
||||
from archai.nas.model_desc import ModelDesc, AuxTowerDesc, CellDesc
|
||||
from archai.common.common import logger
|
||||
from archai.common import utils, ml_utils
|
||||
from archai.nas.arch_module import ArchModule
|
||||
|
||||
from .cell import Cell
|
||||
from .operations import Op, DropPath_
|
||||
from .model_desc import ModelDesc, AuxTowerDesc, CellDesc
|
||||
from ..common.common import logger
|
||||
from ..common import utils, ml_utils
|
||||
|
||||
class Model(nn.Module):
|
||||
class Model(ArchModule):
|
||||
def __init__(self, model_desc:ModelDesc, droppath:bool, affine:bool):
|
||||
super().__init__()
|
||||
|
||||
|
@ -45,35 +45,26 @@ class Model(nn.Module):
|
|||
def _build_cell(self, cell_desc:CellDesc,
|
||||
aux_tower_desc:Optional[AuxTowerDesc],
|
||||
droppath:bool, affine:bool)->None:
|
||||
alphas_cell = None if cell_desc.alphas_from==cell_desc.id \
|
||||
else self.cells[cell_desc.alphas_from]
|
||||
template_cell = None if cell_desc.template_cell==cell_desc.id \
|
||||
else self.cells[cell_desc.template_cell]
|
||||
cell = Cell(cell_desc, affine=affine, droppath=droppath,
|
||||
alphas_cell=alphas_cell)
|
||||
template_cell=template_cell)
|
||||
self.cells.append(cell)
|
||||
self._aux_towers.append(AuxTower(aux_tower_desc) \
|
||||
if aux_tower_desc else None)
|
||||
|
||||
def summary(self)->dict:
|
||||
all_arch_params = list(self.all_owned()
|
||||
.param_by_kind(kind=None))
|
||||
return {
|
||||
'cell_count': len(self.cells),
|
||||
#'cell_params': [ml_utils.param_size(c) for c in self.cells]
|
||||
'params': ml_utils.param_size(self),
|
||||
'alphas_p': len(list(a for a in self.alphas())),
|
||||
'alphas': np.sum(a.numel() for a in self.alphas()),
|
||||
'arch_params_len': len(all_arch_params),
|
||||
'arch_params_numel': np.sum(a.numel() for a in all_arch_params),
|
||||
'ops': np.sum(len(n.edges) for c in self.desc.cell_descs() for n in c.nodes()),
|
||||
}
|
||||
|
||||
def alphas(self)->Iterable[nn.Parameter]:
|
||||
for cell in self.cells:
|
||||
if not cell.shared_alphas:
|
||||
for alpha in cell.alphas():
|
||||
yield alpha
|
||||
|
||||
def weights(self)->Iterable[nn.Parameter]:
|
||||
for cell in self.cells:
|
||||
for w in cell.weights():
|
||||
yield w
|
||||
|
||||
def ops(self)->Iterable[Op]:
|
||||
for cell in self.cells:
|
||||
for op in cell.ops():
|
||||
|
|
|
@ -135,7 +135,7 @@ class CellType(Enum):
|
|||
|
||||
class CellDesc:
|
||||
def __init__(self, cell_type:CellType, id:int, nodes:List[NodeDesc],
|
||||
s0_op:OpDesc, s1_op:OpDesc, alphas_from:int, max_final_edges:int,
|
||||
s0_op:OpDesc, s1_op:OpDesc, template_cell:int, max_final_edges:int,
|
||||
node_ch_out:int, post_op:Union[str,OpDesc])->None:
|
||||
assert s0_op.params['conv'].ch_out == s1_op.params['conv'].ch_out
|
||||
assert s0_op.params['conv'].ch_out == node_ch_out
|
||||
|
@ -143,7 +143,7 @@ class CellDesc:
|
|||
self.cell_type = cell_type
|
||||
self.id = id
|
||||
self.s0_op, self.s1_op = s0_op, s1_op
|
||||
self.alphas_from = alphas_from # cell id with which we share alphas
|
||||
self.template_cell = template_cell # cell id with which we share arch params
|
||||
self.max_final_edges = max_final_edges
|
||||
|
||||
self.cell_ch_out = -1 # will be set later by reset_nodes
|
||||
|
@ -151,7 +151,7 @@ class CellDesc:
|
|||
assert self.cell_ch_out > 0
|
||||
|
||||
def clone(self, id:int)->'CellDesc':
|
||||
c = copy.deepcopy(self) # note that alphas_from is also cloned
|
||||
c = copy.deepcopy(self) # note that template_cell is also cloned
|
||||
c.id = id
|
||||
return c
|
||||
|
||||
|
@ -284,7 +284,7 @@ class ModelDesc:
|
|||
def reset_cells(self, cell_descs:List[CellDesc],
|
||||
aux_tower_descs:List[Optional[AuxTowerDesc]])->None:
|
||||
assert len(cell_descs) == len(aux_tower_descs)
|
||||
# every cell should have unique ID so we can tell where alphas are shared
|
||||
# every cell should have unique ID so we can tell where arch params are shared
|
||||
assert len(set(c.id for c in cell_descs)) == len(cell_descs)
|
||||
|
||||
self._cell_descs = cell_descs
|
||||
|
|
|
@ -8,66 +8,69 @@ from overrides import overrides, EnforceOverrides
|
|||
import torch
|
||||
from torch import affine_grid_generator, nn, Tensor, strided
|
||||
|
||||
|
||||
from ..common import utils, ml_utils
|
||||
from .model_desc import OpDesc, ConvMacroParams
|
||||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model_desc import OpDesc, ConvMacroParams
|
||||
from archai.nas.arch_params import ArchParams
|
||||
from archai.nas.arch_module import ArchModule
|
||||
|
||||
# type alias
|
||||
OpFactoryFn = Callable[[OpDesc, Iterable[nn.Parameter]], 'Op']
|
||||
|
||||
# Each op is a unary tensor operator, all take same constructor params
|
||||
# TODO: swap order of arch_params and affine to match with create signature
|
||||
_ops_factory:Dict[str, Callable] = {
|
||||
'max_pool_3x3': lambda op_desc, alphas, affine:
|
||||
'max_pool_3x3': lambda op_desc, arch_params, affine:
|
||||
PoolBN('max', op_desc, affine),
|
||||
'avg_pool_3x3': lambda op_desc, alphas, affine:
|
||||
'avg_pool_3x3': lambda op_desc, arch_params, affine:
|
||||
PoolBN('avg', op_desc, affine),
|
||||
'skip_connect': lambda op_desc, alphas, affine:
|
||||
'skip_connect': lambda op_desc, arch_params, affine:
|
||||
SkipConnect(op_desc, affine),
|
||||
'sep_conv_3x3': lambda op_desc, alphas, affine:
|
||||
'sep_conv_3x3': lambda op_desc, arch_params, affine:
|
||||
SepConv(op_desc, 3, 1, affine),
|
||||
'sep_conv_5x5': lambda op_desc, alphas, affine:
|
||||
'sep_conv_5x5': lambda op_desc, arch_params, affine:
|
||||
SepConv(op_desc, 5, 2, affine),
|
||||
'dil_conv_3x3': lambda op_desc, alphas, affine:
|
||||
'dil_conv_3x3': lambda op_desc, arch_params, affine:
|
||||
DilConv(op_desc, 3, op_desc.params['stride'], 2, 2, affine),
|
||||
'dil_conv_5x5': lambda op_desc, alphas, affine:
|
||||
'dil_conv_5x5': lambda op_desc, arch_params, affine:
|
||||
DilConv(op_desc, 5, op_desc.params['stride'], 4, 2, affine),
|
||||
'none': lambda op_desc, alphas, affine:
|
||||
'none': lambda op_desc, arch_params, affine:
|
||||
Zero(op_desc),
|
||||
'identity': lambda op_desc, alphas, affine:
|
||||
'identity': lambda op_desc, arch_params, affine:
|
||||
Identity(op_desc),
|
||||
'sep_conv_7x7': lambda op_desc, alphas, affine:
|
||||
'sep_conv_7x7': lambda op_desc, arch_params, affine:
|
||||
SepConv(op_desc, 7, 3, affine),
|
||||
'conv_7x1_1x7': lambda op_desc, alphas, affine:
|
||||
'conv_7x1_1x7': lambda op_desc, arch_params, affine:
|
||||
FacConv(op_desc, 7, 3, affine),
|
||||
'prepr_reduce': lambda op_desc, alphas, affine:
|
||||
'prepr_reduce': lambda op_desc, arch_params, affine:
|
||||
FactorizedReduce(op_desc, affine),
|
||||
'prepr_normal': lambda op_desc, alphas, affine:
|
||||
'prepr_normal': lambda op_desc, arch_params, affine:
|
||||
ReLUConvBN(op_desc, 1, 1, 0, affine),
|
||||
'stem_conv3x3': lambda op_desc, alphas, affine:
|
||||
'stem_conv3x3': lambda op_desc, arch_params, affine:
|
||||
StemConv3x3(op_desc, affine),
|
||||
'stem_conv3x3_s4': lambda op_desc, alphas, affine:
|
||||
'stem_conv3x3_s4': lambda op_desc, arch_params, affine:
|
||||
StemConv3x3S4(op_desc, affine),
|
||||
'stem_conv3x3_s4s2': lambda op_desc, alphas, affine:
|
||||
'stem_conv3x3_s4s2': lambda op_desc, arch_params, affine:
|
||||
StemConv3x3S4S2(op_desc, affine),
|
||||
'pool_adaptive_avg2d': lambda op_desc, alphas, affine:
|
||||
'pool_adaptive_avg2d': lambda op_desc, arch_params, affine:
|
||||
PoolAdaptiveAvg2D(),
|
||||
'pool_avg2d7x7': lambda op_desc, alphas, affine:
|
||||
'pool_avg2d7x7': lambda op_desc, arch_params, affine:
|
||||
AvgPool2d7x7(),
|
||||
'concate_channels': lambda op_desc, alphas, affine:
|
||||
'concate_channels': lambda op_desc, arch_params, affine:
|
||||
ConcateChannelsOp(op_desc, affine),
|
||||
'proj_channels': lambda op_desc, alphas, affine:
|
||||
'proj_channels': lambda op_desc, arch_params, affine:
|
||||
ProjectChannelsOp(op_desc, affine),
|
||||
'linear': lambda op_desc, alphas, affine:
|
||||
'linear': lambda op_desc, arch_params, affine:
|
||||
LinearOp(op_desc),
|
||||
'multi_op': lambda op_desc, alphas, affine:
|
||||
'multi_op': lambda op_desc, arch_params, affine:
|
||||
MultiOp(op_desc, affine)
|
||||
}
|
||||
|
||||
class Op(nn.Module, ABC, EnforceOverrides):
|
||||
class Op(ArchModule, ABC, EnforceOverrides):
|
||||
@staticmethod
|
||||
def create(op_desc:OpDesc, affine:bool,
|
||||
alphas:Iterable[nn.Parameter]=[])->'Op':
|
||||
op = _ops_factory[op_desc.name](op_desc, alphas, affine)
|
||||
def create(op_desc:OpDesc, affine:bool, arch_params:Optional[ArchParams]=None)->'Op':
|
||||
global _ops_factory
|
||||
op = _ops_factory[op_desc.name](op_desc, arch_params, affine)
|
||||
|
||||
# TODO: annotate as Final?
|
||||
op.desc = op_desc # type: ignore
|
||||
# load any pre-trained weights
|
||||
|
@ -76,6 +79,7 @@ class Op(nn.Module, ABC, EnforceOverrides):
|
|||
|
||||
def get_trainables(self)->Mapping:
|
||||
return {'name': self.desc.name, 'sd': self.state_dict()}
|
||||
|
||||
def set_trainables(self, state_dict)->None:
|
||||
if state_dict is not None:
|
||||
assert state_dict['name'] == self.desc.name
|
||||
|
@ -95,16 +99,6 @@ class Op(nn.Module, ABC, EnforceOverrides):
|
|||
else:
|
||||
_ops_factory[name] = factory_fn
|
||||
|
||||
# must override if op has alphas, otherwise this returns nothing!
|
||||
def alphas(self)->Iterable[nn.Parameter]:
|
||||
return # when supported, derived class should override it
|
||||
yield
|
||||
|
||||
# must override if op has alphas, otherwise this will return weights + alphas!
|
||||
def weights(self)->Iterable[nn.Parameter]:
|
||||
for w in self.parameters():
|
||||
yield w
|
||||
|
||||
def finalize(self)->Tuple[OpDesc, Optional[float]]:
|
||||
"""for trainable op, return final op and its rank"""
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ nas:
|
|||
aug: '' # additional augmentations to use
|
||||
cutout: 16 # cutout length, use cutout augmentation when > 0
|
||||
load_train: True # load train split of dataset
|
||||
train_batch: 96
|
||||
train_batch: 68 # 96 is too aggressive for 1080Ti, better set it to 68
|
||||
train_workers: 4
|
||||
test_workers: '_copy: ../train_workers' # if null then 4
|
||||
load_test: True # load test split of dataset
|
||||
|
@ -229,6 +229,8 @@ nas:
|
|||
decay: 1.0e-3
|
||||
betas: [0.5, 0.999]
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
alpha_lr_schedule:
|
||||
type: ''
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.001 # min learning rate, this will be used in eta_min param of scheduler
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
__include__: "darts.yaml" # just use darts defaults
|
||||
|
||||
nas:
|
||||
search:
|
||||
trainer:
|
||||
alpha_optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.025 # init learning rate
|
||||
decay: 3.0e-4
|
||||
momentum: 0.9 # pytorch default is 0
|
||||
nesterov: False
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
alpha_lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.001 # min learning rate, this will be used in eta_min param of scheduler
|
||||
warmup: null
|
|
@ -8,7 +8,7 @@ nas:
|
|||
model_desc:
|
||||
cell_post_op: 'proj_channels'
|
||||
loader:
|
||||
train_batch: 96
|
||||
train_batch: 64
|
||||
search:
|
||||
search_iters: 4
|
||||
pareto:
|
||||
|
|
|
@ -334,3 +334,5 @@ So we need ways to retrieve parameters by:
|
|||
- Are they shared or owned
|
||||
|
||||
This can be achieved by naming convention for variables where such parameters will be stored. Let's define this convention as kind_arch_param. This way any parameter with name ending in _arch_param is considered as architecture parameter. Their full name in the form module1.module2.kind1_arch_param defines where they reside. The part after last "." and without _arch_param suffix defines the kind of the parameter. While Pytorch automatically avoids double listing for shared parameters, a module can have following convention to keep things clean: Module keeps arch parameters in dictionary where key is same as what their variable names would have been. This way Pytorch doesn't register them automatically. If module does own these parameters, it will create variables with same name so they get registered. Module then can provide following methods: get_owned_params, get_shared_params, is_owned_param(p). For parameter sharing, module may receive dictionary of parameters owned by someone else and given module can decide to share some or all of those.
|
||||
|
||||
So, we stipulate that each instance of nn.Module type have same number and type of arch params. So Cell may have one set of arch params, Op have another, Model have another and so on. Question is (1) is it possible to share only subset of one's parameters among instances? (2) how Cell1 can share its arch parameters with Cell2 and Cell3 and Cell4 can with Cell5, Cell6. I think supporting this level of infinite flexibility can potentially make things complex. So let's see how we can do subset of these functionalities. We will have MacroBuilder decide which module shares arch params with which one. This can be done with base *Desc object having member specifying identity of object it will receive parameters from. If no arch parameter is received then object shall create its own. If it did, it may take whole or portion of it and create rest of its own. One can access arch_params method to access params for that module directly and pass parameter recursive=True to get arch params of entire module hierarchy. The return value is ArchParams object.
|
|
@ -9,6 +9,8 @@ from archai.algos.manual.manual_exp_runner import ManualExperimentRunner
|
|||
from archai.algos.xnas.xnas_exp_runner import XnasExperimentRunner
|
||||
from archai.algos.gumbelsoftmax.gs_exp_runner import GsExperimentRunner
|
||||
from archai.algos.divnas.divnas_exp_runner import DivnasExperimentRunner
|
||||
from archai.algos.didarts.didarts_exp_runner import DiDartsExperimentRunner
|
||||
|
||||
|
||||
def main():
|
||||
runner_types:Dict[str, Type[ExperimentRunner]] = {
|
||||
|
@ -18,11 +20,12 @@ def main():
|
|||
'random': RandomExperimentRunner,
|
||||
'manual': ManualExperimentRunner,
|
||||
'gs': GsExperimentRunner,
|
||||
'divnas': DivnasExperimentRunner
|
||||
'divnas': DivnasExperimentRunner,
|
||||
'didarts': DiDartsExperimentRunner
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS E2E Runs')
|
||||
parser.add_argument('--algos', type=str, default='darts,petridish,xnas,random,gs,divnas,manual',
|
||||
parser.add_argument('--algos', type=str, default='darts,petridish,xnas,random,gs,manual,didarts,divnas',
|
||||
help='NAS algos to run, seperated by comma')
|
||||
parser.add_argument('--datasets', type=str, default='cifar10',
|
||||
help='datasets to use, separated by comma')
|
||||
|
|
|
@ -93,5 +93,28 @@ def imagenet_test():
|
|||
dl_train, *_ = data.get_data(conf_loader)
|
||||
|
||||
|
||||
def exclusion_test(data_len=32, labels_len=2, val_ratio=0.5):
|
||||
x = np.array(range(data_len))
|
||||
labels = np.array(range(labels_len))
|
||||
y = np.repeat(labels, math.ceil(float(data_len)/labels_len))[:data_len]
|
||||
np.random.shuffle(y)
|
||||
dataset = ListDataset(x, y)
|
||||
|
||||
train_sampler = DistributedStratifiedSampler(dataset,
|
||||
val_ratio=val_ratio, is_val=False, shuffle=True,
|
||||
max_items=-1, world_size=1, rank=0)
|
||||
|
||||
valid_sampler = DistributedStratifiedSampler(dataset,
|
||||
val_ratio=val_ratio, is_val=True, shuffle=True,
|
||||
max_items=-1, world_size=1, rank=0)
|
||||
tidx = list(train_sampler)
|
||||
vidx = list(valid_sampler)
|
||||
|
||||
assert len(tidx) == len(vidx) == 16
|
||||
assert all(ti not in vidx for ti in tidx)
|
||||
# print(len(tidx), tidx)
|
||||
# print(len(vidx), vidx)
|
||||
|
||||
exclusion_test()
|
||||
_dist_no_val(1, 100, val_ratio=0.1)
|
||||
test_combinations()
|
Загрузка…
Ссылка в новой задаче