general arch params implementation, support for multiple optimizers

This commit is contained in:
Shital Shah 2020-05-18 00:36:32 -07:00
Родитель cf0ce350fe
Коммит 57a10a2dac
36 изменённых файлов: 594 добавлений и 412 удалений

8
.vscode/launch.json поставляемый
Просмотреть файл

@ -53,6 +53,14 @@
"console": "integratedTerminal",
"args": ["--algos", "darts"]
},
{
"name": "DiDarts-E2E-Toy",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--algos", "didarts"]
},
{
"name": "Darts-Food101-Toy",
"type": "python",

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Mapping, Optional, Union
from typing import Iterator, Mapping, Optional, Union
import copy
import torch
@ -10,6 +10,14 @@ from archai.common.config import Config
from archai.common import utils, ml_utils
from archai.nas.model import Model
from archai.common.common import logger
from archai.common.utils import zip_eq
def _get_loss(model:Model, lossfn, x, y):
logits, *_ = model(x) # might also return aux tower logits
return lossfn(logits, y)
def _get_alphas(model:Model)->Iterator[nn.Parameter]:
return model.all_owned().param_by_kind('alphas')
class BilevelOptimizer:
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
@ -25,8 +33,12 @@ class BilevelOptimizer:
# to compute grads for alphas without disturbing
# original weights
self._vmodel = copy.deepcopy(model).to(device)
self._alphas = list(_get_alphas(self._model))
self._valphas = list(_get_alphas(self._vmodel))
# this is the optimizer to optimize alphas parameter
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, self._alphas)
def state_dict(self)->dict:
return {
@ -47,16 +59,11 @@ class BilevelOptimizer:
def _vmodel_params(self):
return self._vmodel.parameters()
@staticmethod
def _get_loss(model, lossfn, x, y):
logits, *_ = model(x) # might also return aux tower logits
return lossfn(logits, y)
def _update_vmodel(self, x, y, lr: float, w_optim: Optimizer) -> None:
""" Update vmodel with w' (main model has w) """
# TODO: should this loss be stored for later use?
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
loss = _get_loss(self._model, self._lossfn, x, y)
gradients = autograd.grad(loss, self._model_params())
"""update weights in vmodel so we leave main model undisturbed
@ -76,7 +83,7 @@ class BilevelOptimizer:
vw.copy_(w - lr * (m + g + self._w_weight_decay*w))
# synchronize alphas
for a, va in zip(self._model.alphas(), self._vmodel.alphas()):
for a, va in zip_eq(self._alphas, self._valphas):
va.copy_(a)
def step(self, x_train: Tensor, y_train: Tensor, x_valid: Tensor, y_valid: Tensor,
@ -114,11 +121,11 @@ class BilevelOptimizer:
# compute loss on validation set for model with w'
# wrt alphas. The autograd.grad is used instead of backward()
# to avoid having to loop through params
vloss = BilevelOptimizer._get_loss(
self._vmodel, self._lossfn, x_valid, y_valid)
vloss = _get_loss(self._vmodel, self._lossfn, x_valid, y_valid)
v_alphas = tuple(self._vmodel.alphas())
v_alphas = tuple(self._valphas)
v_weights = tuple(self._vmodel_params())
# TODO: if v_weights = all params then below does double counting of alpahs
v_grads = autograd.grad(vloss, v_alphas + v_weights)
# grad(L(w', a), a), part of Eq. 6
@ -133,7 +140,7 @@ class BilevelOptimizer:
# update final gradient = dalpha - xi*hessian
# TODO: currently alphas lr is same as w lr
with torch.no_grad():
for alpha, da, h in zip(self._model.alphas(), dalpha, hessian):
for alpha, da, h in zip(self._alphas, dalpha, hessian):
alpha.grad = da - lr*h
# now that model has both w and alpha grads,
# we can run w_optim.step() to update the param values
@ -164,9 +171,9 @@ class BilevelOptimizer:
# Now that we have model with w+, we need to compute grads wrt alphas
# This loss needs to be on train set, not validation set
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
loss = _get_loss(self._model, self._lossfn, x, y)
dalpha_plus = autograd.grad(
loss, self._model.alphas()) # dalpha{L_trn(w+)}
loss, self._alphas) # dalpha{L_trn(w+)}
# get model with w- and then compute grads wrt alphas
# w- = w - eps*dw`
@ -176,8 +183,8 @@ class BilevelOptimizer:
p -= 2. * epsilon * v
# similarly get dalpha_minus
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
dalpha_minus = autograd.grad(loss, self._model.alphas())
loss = _get_loss(self._model, self._lossfn, x, y)
dalpha_minus = autograd.grad(loss, self._alphas)
# reset back params to original values by adding dw
with torch.no_grad():

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Mapping, Optional, Union
from typing import Mapping, Optional, Union, Iterator
import copy
import torch
@ -10,7 +10,7 @@ from archai.common.config import Config
from archai.common import utils, ml_utils
from archai.nas.model import Model
from archai.common.common import logger
from archai.common.utils import zip_eq
def _flatten_concate(xs):
"""
@ -21,6 +21,13 @@ def _flatten_concate(xs):
"""
return torch.cat([x.view(-1) for x in xs])
def _get_alphas(model:Model)->Iterator[nn.Parameter]:
return model.all_owned().param_by_kind('alphas')
def _get_loss(model:Model, lossfn, x, y):
logits, *_ = model(x) # might also return aux tower logits
return lossfn(logits, y)
class BilevelOptimizer:
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
model: Model, lossfn: _Loss) -> None:
@ -29,8 +36,10 @@ class BilevelOptimizer:
self._lossfn = lossfn
self._model = model # main model with respect to w and alpha
self._alphas = list(_get_alphas(self._model))
# this is the optimizer to optimize alphas parameter
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, self._alphas)
def state_dict(self)->dict:
return {
@ -40,14 +49,9 @@ class BilevelOptimizer:
def load_state_dict(self, state_dict)->None:
self._alpha_optim.load_state_dict(state_dict['alpha_optim'])
@staticmethod
def _get_loss(model, lossfn, x, y):
logits, *_ = model(x) # might also return aux tower logits
return lossfn(logits, y)
def _unrolled_model(self, x, y, lr: float, main_optim: Optimizer)->Model:
# TODO: should this loss be stored for later use?
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
loss = _get_loss(self._model, self._lossfn, x, y)
params = _flatten_concate(self._model.parameters()).detach()
try:
@ -113,9 +117,9 @@ class BilevelOptimizer:
# compute loss on validation set for model with w'
# wrt alphas. The autograd.grad is used instead of backward()
# to avoid having to loop through params
vloss = BilevelOptimizer._get_loss(unrolled_model, self._lossfn, x_valid, y_valid)
vloss = _get_loss(unrolled_model, self._lossfn, x_valid, y_valid)
vloss.backward()
dalpha = [v.grad for v in unrolled_model.alphas()]
dalpha = [v.grad for v in _get_alphas(unrolled_model)]
dparams = [v.grad.data for v in unrolled_model.parameters()]
hessian = self._hessian_vector_product(dparams, x_train, y_train)
@ -125,7 +129,7 @@ class BilevelOptimizer:
# update final gradient = dalpha - xi*hessian
# TODO: currently alphas lr is same as w lr
with torch.no_grad():
for alpha, da, h in zip(self._model.alphas(), dalpha, hessian):
for alpha, da, h in zip_eq(self._alphas, dalpha, hessian):
alpha.grad = da - lr*h
# now that model has both w and alpha grads,
# we can run main_optim.step() to update the param values
@ -151,32 +155,32 @@ class BilevelOptimizer:
# w+ = w + epsilon * grad(w')
with torch.no_grad():
for p, v in zip(self._model.parameters(), dw):
for p, v in zip_eq(self._model.parameters(), dw):
p += epsilon * v
# Now that we have model with w+, we need to compute grads wrt alphas
# This loss needs to be on train set, not validation set
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
loss = _get_loss(self._model, self._lossfn, x, y)
dalpha_plus = autograd.grad(
loss, self._model.alphas()) # dalpha{L_trn(w+)}
loss, self._alphas) # dalpha{L_trn(w+)}
# get model with w- and then compute grads wrt alphas
# w- = w - eps*dw`
with torch.no_grad():
for p, v in zip(self._model.parameters(), dw):
for p, v in zip_eq(self._model.parameters(), dw):
# we had already added dw above so sutracting twice gives w-
p -= 2. * epsilon * v
# similarly get dalpha_minus
loss = BilevelOptimizer._get_loss(self._model, self._lossfn, x, y)
dalpha_minus = autograd.grad(loss, self._model.alphas())
loss = _get_loss(self._model, self._lossfn, x, y)
dalpha_minus = autograd.grad(loss, self._alphas)
# reset back params to original values by adding dw
with torch.no_grad():
for p, v in zip(self._model.parameters(), dw):
for p, v in zip_eq(self._model.parameters(), dw):
p += epsilon * v
# apply eq 8, final difference to compute hessian
h = [(p - m) / (2. * epsilon)
for p, m in zip(dalpha_plus, dalpha_minus)]
for p, m in zip_eq(dalpha_plus, dalpha_minus)]
return h

Просмотреть файл

@ -9,8 +9,8 @@ class DartsCellBuilder(CellBuilder):
@overrides
def register_ops(self) -> None:
Op.register_op('mixed_op',
lambda op_desc, alphas, affine:
MixedOp(op_desc, alphas, affine))
lambda op_desc, arch_params, affine:
MixedOp(op_desc, arch_params, affine))
@overrides
def build(self, model_desc:ModelDesc, search_iter:int)->None:

Просмотреть файл

@ -8,6 +8,7 @@ from overrides import overrides
from archai.nas.model_desc import OpDesc
from archai.nas.operations import Op
from archai.nas.arch_params import ArchParams
# TODO: reduction cell might have output reduced by 2^1=2X due to
# stride 2 through input nodes however FactorizedReduce does only
@ -29,41 +30,32 @@ class MixedOp(Op):
'none' # this must be at the end so top1 doesn't chose it
]
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
affine:bool):
super().__init__()
# assume last PRIMITIVE is 'none'
assert MixedOp.PRIMITIVES[-1] == 'none'
self._set_alphas(alphas)
self._ops = nn.ModuleList()
for primitive in MixedOp.PRIMITIVES:
op = Op.create(
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
affine=affine, alphas=alphas)
affine=affine, arch_params=None)
self._ops.append(op)
# we do this at the end so that we can capture all arch params registered by
# any previous child modules
self._setup_arch_params(arch_params)
@overrides
def forward(self, x):
asm = F.softmax(self._alphas[0], dim=0)
return sum(w * op(x) for w, op in zip(asm, self._ops))
@overrides
def alphas(self) -> Iterable[nn.Parameter]:
for alpha in self._alphas:
yield alpha
@overrides
def weights(self) -> Iterable[nn.Parameter]:
for op in self._ops:
for w in op.parameters():
yield w
@overrides
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
# select except 'none' op
with torch.no_grad():
# select except 'none' op
val, i = torch.topk(self._alphas[0][:-1], 1)
desc, _ = self._ops[i].finalize()
return desc, float(val.item())
@ -76,16 +68,17 @@ class MixedOp(Op):
def ops(self)->Iterator['Op']: # type: ignore
return iter(self._ops)
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
# must call before adding other ops
assert len(list(self.parameters())) == 0
self._alphas = list(alphas)
if not len(self._alphas):
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
# do we have shared arch params?
if arch_params is None:
# create our own arch params
new_p = nn.Parameter( # TODO: use better init than uniform random?
1.0e-3*torch.randn(len(MixedOp.PRIMITIVES)), requires_grad=True)
# NOTE: This is a way to register parameters with PyTorch.
# One creates a dummy variable with the parameters and then
# asks back for the parameters in the object from Pytorch
# which automagically registers the just created parameters.
self._reg_alphas = new_p
self._alphas = [p for p in self.parameters()]
self.create_arch_params([('alphas', new_p)])
else:
assert arch_params.has_kind('alphas')
self.set_arch_params(arch_params)
# we store alphas in list so Pytorch don't register them
self._alphas = list(self.arch_params().param_by_kind('alphas'))
assert len(self._alphas)==1

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Mapping, Optional, Union
from typing import Mapping, Optional, Union, Tuple
import copy
import torch
@ -16,7 +16,7 @@ from archai.common import utils, ml_utils
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from archai.common.common import logger
from archai.algos.darts.bilevel_optimizer import BilevelOptimizer
from archai.common.multi_optim import MultiOptim, OptimSched
class DidartsArchTrainer(ArchTrainer):
"""Train network using different optimizers for alphas and other parameters"""
@ -25,65 +25,29 @@ class DidartsArchTrainer(ArchTrainer):
checkpoint:Optional[CheckPoint]) -> None:
super().__init__(conf_train, model, checkpoint)
self._conf_w_optim = conf_train['optimizer']
self._conf_w_lossfn = conf_train['lossfn']
self._conf_alpha_optim = conf_train['alpha_optimizer']
self._conf_alpha_sched = conf_train['alpha_lr_schedule']
@overrides
def pre_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader])->None:
super().pre_fit(train_dl, val_dl)
def create_multi_optim(self, train_len:int)->MultiOptim:
# optimizers, schedulers needs to be recreated for each fit call
# as they have state
assert val_dl is not None
w_momentum = self._conf_w_optim['momentum']
w_decay = self._conf_w_optim['decay']
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
# as they have state specific to each run
optim = self.create_optimizer(self.conf_optim, self.model.nonarch_params(recurse=True))
# create scheduler for optim before applying amp
sched, sched_on_epoch = self.create_scheduler(self.conf_sched, optim, train_len)
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
w_decay, self.model, lossfn,
self.get_device(), self.batch_chunks)
alpha_optim = self.create_optimizer(self._conf_alpha_optim,
self.model.all_owned().param_by_kind(None))
alpha_sched, alpha_sched_on_epoch = self.create_scheduler(self._conf_alpha_sched, alpha_optim, train_len)
@overrides
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
# delete state we created in pre_fit
del self._bilevel_optim
return super().post_fit(train_dl, val_dl)
multi_optim = MultiOptim()
multi_optim.append(OptimSched(optim, sched, sched_on_epoch))
multi_optim.append(OptimSched(alpha_optim, alpha_sched, alpha_sched_on_epoch))
@overrides
def pre_epoch(self, train_dl: DataLoader, val_dl: Optional[DataLoader])->None:
super().pre_epoch(train_dl, val_dl)
logger.info({'multi_optim_len': len(multi_optim)})
# prep val set to train alphas
self._valid_iter = iter(val_dl) # type: ignore
return multi_optim
@overrides
def post_epoch(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
del self._valid_iter # clean up
super().post_epoch(train_dl, val_dl)
@overrides
def pre_step(self, x: Tensor, y: Tensor) -> None:
super().pre_step(x, y)
# reset val loader if we exausted it
try:
x_val, y_val = next(self._valid_iter)
except StopIteration:
# reinit iterator
self._valid_iter = iter(self._val_dl)
x_val, y_val = next(self._valid_iter)
# update alphas
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())
@overrides
def update_checkpoint(self, check_point:CheckPoint)->None:
super().update_checkpoint(check_point)
check_point['bilevel_optim'] = self._bilevel_optim.state_dict()
@overrides
def restore_checkpoint(self)->None:
super().restore_checkpoint()
self._bilevel_optim.load_state_dict(self.check_point['bilevel_optim'])

Просмотреть файл

@ -19,16 +19,11 @@ from archai.common.common import logger
class GsArchTrainer(ArchTrainer):
def __init__(self, conf_train: Config, model: Model,
checkpoint:Optional[CheckPoint]) -> None:
super().__init__(conf_train, model, checkpoint)
self._conf_w_optim = conf_train['optimizer']
# self._conf_w_lossfn = conf_train['lossfn']
@overrides
def create_optimizer(self) -> Optimizer:
# in this case we don't need to differentiate between alphas and weights
def create_optimizer(self, conf_optim:Config, params) -> Optimizer:
# in this case we don't need to differentiate between arch_params and weights
# as the same optimizer will update both
param_groups = [{'params': self.model.weights()}, {'params': self.model.alphas()}]
return ml_utils.create_optimizer(self._conf_w_optim, param_groups)
arch_params = list(self.model.all_owned().param_by_kind('alphas'))
nonarch_params = list(self.model.nonarch_params(recurse=True))
param_groups = [{'params': nonarch_params}, {'params': arch_params}]
return ml_utils.create_optimizer(conf_optim, param_groups)

Просмотреть файл

@ -9,8 +9,8 @@ class GsCellBuilder(CellBuilder):
@overrides
def register_ops(self) -> None:
Op.register_op('gs_op',
lambda op_desc, alphas, affine:
GsOp(op_desc, alphas, affine))
lambda op_desc, arch_params, affine:
GsOp(op_desc, arch_params, affine))
@overrides
def build(self, model_desc:ModelDesc, search_iter:int)->None:

Просмотреть файл

@ -8,6 +8,7 @@ from overrides import overrides
from archai.nas.model_desc import OpDesc
from archai.nas.operations import Op
from archai.nas.arch_params import ArchParams
# TODO: reduction cell might have output reduced by 2^1=2X due to
# stride 2 through input nodes however FactorizedReduce does only
@ -29,7 +30,7 @@ class GsOp(Op):
'none' # this must be at the end so top1 doesn't chose it
]
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
affine:bool):
super().__init__()
@ -38,20 +39,23 @@ class GsOp(Op):
self._gs_num_sample = op_desc.params['gs_num_sample']
self._set_alphas(alphas)
self._ops = nn.ModuleList()
for primitive in GsOp.PRIMITIVES:
op = Op.create(
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
affine=affine, alphas=alphas)
affine=affine, arch_params=None)
self._ops.append(op)
# we do this at the end so that we can capture all arch params registered by
# any previous child modules
self._setup_arch_params(arch_params)
@overrides
def forward(self, x):
# soft sample from the categorical distribution
# via gumbel softmax distribution
# TODO: should we be normalizing the ensemble?
#sampled = torch.zeros(self._alphas[0].size(), requires_grad=True)
#sampled = torch.zeros(alphas.size(), requires_grad=True)
sample_storage = []
for _ in range(self._gs_num_sample):
sampled = F.gumbel_softmax(self._alphas[0], tau=1, hard=False, eps=1e-10, dim=-1)
@ -60,18 +64,6 @@ class GsOp(Op):
samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0)
return sum(w * op(x) for w, op in zip(samples_summed, self._ops))
@overrides
def alphas(self) -> Iterable[nn.Parameter]:
for alpha in self._alphas:
yield alpha
@overrides
def weights(self) -> Iterable[nn.Parameter]:
for op in self._ops:
for w in op.parameters():
yield w
@overrides
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
# finalization where each edge gets to keep as many
@ -110,7 +102,6 @@ class GsOp(Op):
return final_op_desc, None
@overrides
def can_drop_path(self) -> bool:
return False
@ -119,16 +110,17 @@ class GsOp(Op):
def ops(self)->Iterator['Op']: # type: ignore
return iter(self._ops)
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
# must call before adding other ops
assert len(list(self.parameters())) == 0
self._alphas = list(alphas)
if not len(self._alphas):
# TODO: Better initialization than random?
new_p = nn.Parameter(1.0e-3*torch.randn(len(GsOp.PRIMITIVES)), requires_grad=True)
# NOTE: This is a way to register parameters with PyTorch.
# One creates a dummy variable with the parameters and then
# asks back for the parameters in the object from Pytorch
# which automagically registers the just created parameters.
self._reg_alphas = new_p
self._alphas = [p for p in self.parameters()]
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
# do we have shared arch params?
if arch_params is None:
# create our own arch params
new_p = nn.Parameter( # TODO: use better init than uniform random?
1.0e-3*torch.randn(len(GsOp.PRIMITIVES)), requires_grad=True)
self.create_arch_params([('alphas', new_p)])
else:
assert arch_params.has_kind('alphas')
self.set_arch_params(arch_params)
# we store alphas in list so Pytorch don't register them
self._alphas = list(self.arch_params().param_by_kind('alphas'))
assert len(self._alphas)==1

Просмотреть файл

@ -11,13 +11,13 @@ class PetridishCellBuilder(CellBuilder):
@overrides
def register_ops(self) -> None:
Op.register_op('petridish_normal_op',
lambda op_desc, alphas, affine:
PetridishOp(op_desc, alphas, False, affine))
lambda op_desc, arch_params, affine:
PetridishOp(op_desc, arch_params, False, affine))
Op.register_op('petridish_reduction_op',
lambda op_desc, alphas, affine:
PetridishOp(op_desc, alphas, True, affine))
lambda op_desc, arch_params, affine:
PetridishOp(op_desc, arch_params, True, affine))
Op.register_op('temp_identity_op',
lambda op_desc, alphas, affine:
lambda op_desc, arch_params, affine:
TempIdentityOp(op_desc))
@overrides

Просмотреть файл

@ -10,7 +10,8 @@ from overrides import overrides
from archai.nas.model_desc import ConvMacroParams, OpDesc
from archai.nas.operations import Identity, Op, FactorizedReduce
from archai.common.utils import zip_eq
from archai.nas.arch_params import ArchParams
class StopForward(Op):
def __init__(self):
@ -79,16 +80,13 @@ class PetridishOp(Op):
'none' # this must be at the end so top1 doesn't chose it
]
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
def __init__(self, op_desc:OpDesc, arch_params: Optional[ArchParams],
reduction:bool, affine:bool):
super().__init__()
# assume last PRIMITIVE is 'none' (this is used for finalize)
assert PetridishOp.PRIMITIVES[-1] == 'none'
# create alphas for the op
self._set_alphas(alphas, op_desc.in_len)
# create edges for the op, each edge connects input state,
# within each edge we will have all N primitives
self._edges = nn.ModuleList()
@ -107,7 +105,7 @@ class PetridishOp(Op):
for primitive in PetridishOp.PRIMITIVES:
primitive_op = Op.create(OpDesc(primitive, params=params,
in_len=1, trainables=None),
affine=affine, alphas=alphas)
affine=affine, arch_params=None)
# wrap primitive with sg
op = nn.Sequential(StopGradient(), primitive_op)
edge.append(op)
@ -120,31 +118,23 @@ class PetridishOp(Op):
# won't match. So you have to use StopGradientReductionOp on s_1 to make it match.
self._sf = StopForward()
# we do this at the end so that we can capture all arch params registered by
# any previous child modules
self._setup_arch_params(arch_params, op_desc.in_len)
@overrides
def forward(self, x:List[Tensor]):
assert not isinstance(x, torch.Tensor)
s = 0.0
# apply each input in the list to associated edge
for i, (xi, edge) in enumerate(zip(x, self._edges)):
for i, (xi, edge) in enumerate(zip_eq(x, self._edges)):
# apply input to each primitive within edge
# TODO: is avg better idea than sum here? sum can explode as
# number of primitives goes up
s = sum(a * op(xi) for a, op in zip(self._alphas[i], edge)) + s
s = sum(a * op(xi) for a, op in zip_eq(self._alphas[0][i], edge)) + s
return self._sf(s)
@overrides
def alphas(self) -> Iterable[nn.Parameter]:
for alpha in self._alphas:
yield alpha
@overrides
def weights(self) -> Iterable[nn.Parameter]:
#TODO: cache this?
for edge in self._edges:
for op in edge:
for w in op.parameters():
yield w
@overrides
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
with torch.no_grad(): # probably this is not needed
@ -152,10 +142,10 @@ class PetridishOp(Op):
# Here op should be nn.Sequence of sg followed by primitive.
# First for loop gets edge and associated alphas.
# Second for loop gets op and associated alpha.
l = ((a, i, op[1]) \
l = ((a, i, op[1]) \
for edge_alphas, i, edge in \
zip(self._alphas, range(self.desc.in_len), self._edges) \
for a, op in zip(edge_alphas, edge)) # op is nn.Sequence
zip_eq(self._alphas[0], range(self.desc.in_len), self._edges) \
for a, op in zip_eq(edge_alphas, edge)) # op is nn.Sequence
# select 3 largest ops by alpha
sel = heapq.nlargest(3, l, key=lambda t: t[0]) # TODO: add config
@ -184,12 +174,9 @@ class PetridishOp(Op):
def ops(self)->Iterator['Op']: # type: ignore
return iter(self._ops)
def _set_alphas(self, alphas: Iterable[nn.Parameter], in_len:int) -> None:
assert len(list(self.parameters()))==0 # must call before adding other ops
# If we are using shared alphas from another cell, don't create our own
self._alphas = list(alphas)
if not len(self._alphas):
def _setup_arch_params(self, arch_params:Optional[ArchParams], in_len:int)->None:
# do we have shared arch params?
if arch_params is None:
# Each nn.Parameter is tensor with alphas for entire edge.
# We will create same numbers of nn.Parameter as number of edges
n_primitives = len(PetridishOp.PRIMITIVES)
@ -199,7 +186,12 @@ class PetridishOp(Op):
requires_grad=True)
for _ in range(in_len)
))
# register parameters with module
self._reg_alphas = pl
# save PyTorch registered alphas into list for later use
self._alphas = [p for p in self.parameters()]
self.create_arch_params([('alphas', pl)])
else:
assert arch_params.has_kind('alphas')
self.set_arch_params(arch_params)
# we store alphas in list so Pytorch don't register them
self._alphas = list(self.arch_params().paramlist_by_kind('alphas'))
assert len(self._alphas)==1

Просмотреть файл

@ -23,14 +23,14 @@ class XnasArchTrainer(ArchTrainer):
checkpoint:Optional[CheckPoint]) -> None:
super().__init__(conf_train, model, checkpoint)
self._conf_w_optim = conf_train['optimizer']
self._conf_w_lossfn = conf_train['lossfn']
self._conf_alpha_optim = conf_train['alpha_optimizer']
@overrides
def create_optimizer(self) -> Optimizer:
def create_optimizer(self, conf_optim:Config, params) -> Optimizer:
# return optim that only operates on w, not alphas
return ml_utils.create_optimizer(self._conf_w_optim, self.model.weights())
return ml_utils.create_optimizer(conf_optim,
self.model.nonarch_params(recurse=True))
@overrides
def pre_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader])->None:
@ -43,7 +43,6 @@ class XnasArchTrainer(ArchTrainer):
self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn)
@overrides
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
# delete state we created in pre_fit
@ -98,7 +97,6 @@ class _XnasOptimizer:
return lossfn(logits, y)
def step(self, x_train: Tensor, y_train: Tensor, x_valid: Tensor, y_valid: Tensor) -> None:
# put model in train mode just to be safe
self._model.train()

Просмотреть файл

@ -9,8 +9,8 @@ class XnasCellBuilder(CellBuilder):
@overrides
def register_ops(self) -> None:
Op.register_op('xnas_op',
lambda op_desc, alphas, affine:
XnasOp(op_desc, alphas, affine))
lambda op_desc, arch_params, affine:
XnasOp(op_desc, arch_params, affine))
@overrides
def build(self, model_desc:ModelDesc, search_iter:int)->None:

Просмотреть файл

@ -10,6 +10,8 @@ from overrides import overrides
from archai.nas.model_desc import OpDesc
from archai.nas.operations import Op
from archai.nas.arch_params import ArchParams
from archai.common.utils import zip_eq
# TODO: reduction cell might have output reduced by 2^1=2X due to
# stride 2 through input nodes however FactorizedReduce does only
@ -31,25 +33,28 @@ class XnasOp(Op):
'none' # this must be at the end so top1 doesn't chose it
]
def __init__(self, op_desc:OpDesc, alphas: Iterable[nn.Parameter],
def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
affine:bool):
super().__init__()
# assume last PRIMITIVE is 'none'
assert XnasOp.PRIMITIVES[-1] == 'none'
self._set_alphas(alphas)
self._ops = nn.ModuleList()
for primitive in XnasOp.PRIMITIVES:
op = Op.create(
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
affine=affine, alphas=alphas)
affine=affine, arch_params=None)
self._ops.append(op)
# for getting gradients to non-leaf node
self._is_first_call = True
self._avg_grad_meter = AverageMeter()
# we do this at the end so that we can capture all arch params registered by
# any previous child modules
self._setup_arch_params(arch_params)
def get_avg_grad(self)->torch.Tensor:
return self._avg_grad_meter.avg
@ -70,7 +75,7 @@ class XnasOp(Op):
@overrides
def forward(self, x):
self._activs = [op(x) for op in self._ops]
numer = sum(w * activ for w, activ in zip(self._alphas[0], self._activs))
numer = sum(w * activ for w, activ in zip_eq(self._alphas[0], self._activs))
denom = sum(self._alphas[0])
self.pt = torch.div(numer, denom)
@ -81,22 +86,10 @@ class XnasOp(Op):
return self.pt
@overrides
def alphas(self) -> Iterable[nn.Parameter]:
for alpha in self._alphas:
yield alpha
@overrides
def weights(self) -> Iterable[nn.Parameter]:
for op in self._ops:
for w in op.parameters():
yield w
@overrides
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
# select except 'none' op
with torch.no_grad():
# select except 'none' op
val, i = torch.topk(self._alphas[0][:-1], 1)
desc, _ = self._ops[i].finalize()
return desc, float(val.item())
@ -105,18 +98,18 @@ class XnasOp(Op):
def can_drop_path(self) -> bool:
return False
# TODO: Do we even need alphas to be registered with Pytorch
# since we don't have to compute gradients on them?
def _set_alphas(self, alphas: Iterable[nn.Parameter]) -> None:
# must call before adding other ops
assert len(list(self.parameters())) == 0
self._alphas = list(alphas)
if not len(self._alphas):
# TODO: Better initialization than random?
new_p = nn.Parameter(1.0e-3*torch.randn(len(XnasOp.PRIMITIVES)), requires_grad=False)
# NOTE: This is a way to register parameters with PyTorch.
# One creates a dummy variable with the parameters and then
# asks back for the parameters in the object from Pytorch
# which automagically registers the just created parameters.
self._reg_alphas = new_p
self._alphas = [p for p in self.parameters()]
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
# do we have shared arch params?
if arch_params is None:
# create our own arch params
# TODO: dey: why requires_grad = False?
new_p = nn.Parameter( # TODO: use better init than uniform random?
1.0e-3*torch.randn(len(XnasOp.PRIMITIVES)), requires_grad=False)
self.create_arch_params([('alphas', new_p)])
else:
assert arch_params.has_kind('alphas')
self.set_arch_params(arch_params)
# we store alphas in list so Pytorch don't register them
self._alphas = list(self.arch_params().param_by_kind('alphas'))
assert len(self._alphas)==1

Просмотреть файл

@ -13,6 +13,7 @@ import psutil
from archai.common.config import Config
from archai.common import ml_utils, utils
from archai.common.ordereddict_logger import OrderedDictLogger
from archai.common.multi_optim import MultiOptim
class ApexUtils:
def __init__(self, apex_config:Config, logger:Optional[OrderedDictLogger])->None:
@ -186,15 +187,22 @@ class ApexUtils:
else:
return val
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
def _get_optim(self, multi_optim:MultiOptim)->Optimizer:
assert len(multi_optim)==1, \
'Mixed precision is only supported for one optimizer' \
f' but {len(multi_optim)} optimizers were supplied'
return multi_optim[0].optim
def backward(self, loss:torch.Tensor, multi_optim:MultiOptim)->None:
if self.is_mixed():
optim = self._get_optim(multi_optim)
with self._amp.scale_loss(loss, optim) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
->Tuple[nn.Module, Optimizer]:
def to_amp(self, model:nn.Module, multi_optim:MultiOptim, batch_size:int)\
->nn.Module:
# conver BNs to sync BNs in distributed mode
if self.is_dist() and self._sync_bn:
model = self._ddp.convert_syncbn_model(model)
@ -203,6 +211,8 @@ class ApexUtils:
model = model.to(self.device)
if self.is_mixed():
optim = self._get_optim(multi_optim)
# scale LR
if self.is_dist() and self._scale_lr:
lr = ml_utils.get_optim_lr(optim)
@ -215,17 +225,21 @@ class ApexUtils:
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
)
# put back amp'd optim
multi_optim[0].optim = optim
if self.is_dist():
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# delay_allreduce delays all communication to the end of the backward pass.
model = self._ddp.DistributedDataParallel(model, delay_allreduce=True)
return model, optim
return model
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
def clip_grad(self, clip:float, model:nn.Module, multi_optim:MultiOptim)->None:
if clip > 0.0:
if self.is_mixed():
optim = self._get_optim(multi_optim)
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
else:
nn.utils.clip_grad_norm_(model.parameters(), clip)

Просмотреть файл

@ -14,7 +14,6 @@ import yaml
from . import yaml_utils
# global config instance
_config:'Config' = None

Просмотреть файл

@ -188,7 +188,8 @@ class Metrics:
def __getstate__(self):
state = self.__dict__.copy()
del state['_apex'] # cannot serialize this
if '_apex' in state:
del state['_apex'] # cannot serialize this
return state
# no need to define __setstate__ because _apex should be set from constructor

Просмотреть файл

@ -0,0 +1,71 @@
from typing import Iterator, List, Optional
from collections import UserList
from torch import nn, Tensor
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from archai.common.utils import zip_eq
class OptimSched:
"""Holds the optimizer and scheduler"""
def __init__(self, optim:Optimizer, sched:Optional[_LRScheduler],
sched_on_epoch:Optional[bool])->None:
self.optim = optim
self.sched = sched
self.sched_on_epoch = sched_on_epoch
class MultiOptim:
def __init__(self) -> None:
self._optim_scheds:List[OptimSched] = []
def append(self, optim_sched:OptimSched)->None:
self._optim_scheds.append(optim_sched)
def zero_grad(self)->None:
for optim_sched in self._optim_scheds:
optim_sched.optim.zero_grad()
def step(self)->None:
for optim_sched in self._optim_scheds:
optim_sched.optim.step()
if optim_sched.sched and not optim_sched.sched_on_epoch:
optim_sched.sched.step(epoch=None)
def epoch(self, epoch:Optional[int]=None)->None:
for optim_sched in self._optim_scheds:
if optim_sched.sched and optim_sched.sched_on_epoch:
optim_sched.sched.step(epoch=epoch)
def get_lr(self, optim_index:int, param_index:int)->float:
return self._optim_scheds[optim_index].optim.param_groups[param_index]['lr']
def state_dict(self)->dict:
optim_states = [optim_sched.optim.state_dict() for optim_sched in self]
sched_states = [optim_sched.sched.state_dict() if optim_sched.sched else None \
for optim_sched in self]
return {'optim_states': optim_states, 'sched_states':sched_states}
def load_state_dict(self, state_dict:dict)->None:
optim_states = state_dict['optim_states']
sched_states = state_dict['sched_states']
for optim_sched, optim_state, sched_state in zip_eq(self, optim_states, sched_states):
optim_sched.optim.load_state_dict(optim_state)
if optim_sched.sched:
assert sched_state is not None
optim_sched.sched.load_state_dict(sched_state)
else:
assert sched_state is None
def __getitem__(self, index)->OptimSched:
return self._optim_scheds[index]
def __len__(self)->int:
return len(self._optim_scheds)
def __iter__(self)->Iterator[OptimSched]:
return iter(self._optim_scheds)

Просмотреть файл

@ -1,25 +1,28 @@
from typing import Callable, Tuple, Optional
from torch import nn, Tensor, torch
import torch
from torch import nn, Tensor
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from overrides import EnforceOverrides
from .metrics import Metrics
from .tester import Tester
from .config import Config
from . import utils, ml_utils
from ..common.common import logger
from ..common.checkpoint import CheckPoint
from ..common.apex_utils import ApexUtils
from archai.common.metrics import Metrics
from archai.common.tester import Tester
from archai.common.config import Config
from archai.common import utils, ml_utils
from archai.common.common import logger
from archai.common.checkpoint import CheckPoint
from archai.common.apex_utils import ApexUtils
from archai.common.multi_optim import MultiOptim, OptimSched
class Trainer(EnforceOverrides):
def __init__(self, conf_train:Config, model:nn.Module,
checkpoint:Optional[CheckPoint])->None:
# region config vars
self.conf_train = conf_train
conf_lossfn = conf_train['lossfn']
self._aux_weight = conf_train['aux_weight']
self._grad_clip = conf_train['grad_clip']
@ -27,8 +30,8 @@ class Trainer(EnforceOverrides):
self._logger_freq = conf_train['logger_freq']
self._title = conf_train['title']
self._epochs = conf_train['epochs']
self._conf_optim = conf_train['optimizer']
self._conf_sched = conf_train['lr_schedule']
self.conf_optim = conf_train['optimizer']
self.conf_sched = conf_train['lr_schedule']
self.batch_chunks = conf_train['batch_chunks']
conf_validation = conf_train['validation']
conf_apex = conf_train['apex']
@ -58,14 +61,11 @@ class Trainer(EnforceOverrides):
self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)
# optimizers, schedulers needs to be recreated for each fit call
# as they have state specific to each run
optim = self.create_optimizer()
# create scheduler for optim before applying amp
self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl))
# create optimizers and schedulers
self._multi_optim = self.create_multi_optim(len(train_dl))
# before checkpoint restore, convert to amp
self.model, self._optim = self._apex.to_amp(self.model, optim,
batch_size=train_dl.batch_size)
self.model = self._apex.to_amp(self.model, self._multi_optim,
batch_size=train_dl.batch_size)
self._lossfn = self._lossfn.to(self.get_device())
@ -101,42 +101,53 @@ class Trainer(EnforceOverrides):
logger.pushd('epochs')
for epoch in range(self._start_epoch, self._epochs):
logger.pushd(epoch)
assert self._metrics.epochs() == epoch
self._set_drop_path(epoch, self._epochs)
self.pre_epoch(train_dl, val_dl)
self._train_epoch(train_dl)
self.post_epoch(train_dl, val_dl)
logger.popd()
logger.popd()
self.post_fit(train_dl, val_dl)
# make sure we don't keep references to the graph
del self._optim
del self._sched
del self._multi_optim
logger.popd()
return self.get_metrics()
def create_optimizer(self)->Optimizer:
optim = ml_utils.create_optimizer(self._conf_optim, self.model.parameters())
logger.info({'conf_optim': self._conf_optim})
def create_multi_optim(self, train_len:int)->MultiOptim:
logger.info({'steps_per_epoch': train_len,
'conf_sched': self.conf_sched.to_dict()})
logger.info({'conf_optim': self.conf_optim.to_dict()})
# optimizers, schedulers needs to be recreated for each fit call
# as they have state specific to each run
optim = self.create_optimizer(self.conf_optim, self.model.parameters())
# create scheduler for optim before applying amp
sched, sched_on_epoch = self.create_scheduler(self.conf_sched, optim, train_len)
multi_optim = MultiOptim()
multi_optim.append(OptimSched(optim, sched, sched_on_epoch))
logger.info({'multi_optim_len': len(multi_optim)})
return multi_optim
def create_optimizer(self, conf_optim:Config, params)->Optimizer:
optim = ml_utils.create_optimizer(conf_optim, params)
return optim
def _create_scheduler(self, optim:Optimizer, steps_per_epoch:int) \
def create_scheduler(self, conf_sched:Config, optim:Optimizer, steps_per_epoch:int) \
->Tuple[Optional[_LRScheduler],bool]:
logger.info({'steps_per_epoch': steps_per_epoch,
'scheduler': self._conf_sched.to_dict()})
return ml_utils.create_lr_scheduler(self._conf_sched, self._epochs,
return ml_utils.create_lr_scheduler(conf_sched, self._epochs,
optim, steps_per_epoch)
def get_optimizer(self)->Optimizer:
return self._optim
def get_scheduler(self)->Optional[_LRScheduler]:
return self._sched
def get_optimizer(self, index=0)->Optimizer:
return self._multi_optim[index].optim
def get_scheduler(self, index=0)->Optional[_LRScheduler]:
return self._multi_optim[index].sched
def get_metrics(self)->Metrics:
return self._metrics
@ -149,7 +160,7 @@ class Trainer(EnforceOverrides):
self._metrics.post_run()
def pre_epoch(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
self._metrics.pre_epoch(lr=self._optim.param_groups[0]['lr'])
self._metrics.pre_epoch(lr=self._multi_optim.get_lr(0, 0))
def post_epoch(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
val_metrics = None
@ -157,10 +168,15 @@ class Trainer(EnforceOverrides):
if val_dl and self._tester and self._validation_freq > 0:
if self._metrics.epochs() % self._validation_freq == 0 or \
self._metrics.epochs() >= self._epochs:
# optimizers such as bi-level may use val set for its own use
# which causes reshuffling due to automatic epoch counting
# here we make sure that val_dl has same epoch as train_dl
if hasattr(val_dl.sampler, 'set_epoch'):
val_dl.sampler.set_epoch(self._metrics.epochs())
val_metrics = self._tester.test(val_dl)
# update val metrics
self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr'])
self._metrics.post_epoch(val_metrics, lr=self._multi_optim.get_lr(0, 0))
# checkpoint if enabled with given freq or if this is the last epoch
if self._checkpoint is not None and self._apex.is_master() and \
@ -190,11 +206,7 @@ class Trainer(EnforceOverrides):
assert self._metrics.epochs() == last_epoch+1
self._apex.load_state_dict(state['amp'])
self.model.load_state_dict(state['model'])
self._optim.load_state_dict(state['optim'])
if self._sched:
self._sched.load_state_dict(state['sched'])
else:
assert state['sched'] is None
self._multi_optim.load_state_dict(state['multi_optim'])
self._start_epoch = last_epoch + 1
@ -204,8 +216,7 @@ class Trainer(EnforceOverrides):
'last_epoch': self._metrics.epochs()-1,
'metrics': self._metrics.state_dict(),
'model': self.model.state_dict(),
'optim': self._optim.state_dict(),
'sched': self._sched.state_dict() if self._sched else None,
'multi_optim': self._multi_optim.state_dict(),
'amp': self._apex.state_dict()
}
self._checkpoint['trainer'] = state
@ -221,7 +232,7 @@ class Trainer(EnforceOverrides):
self.pre_step(x, y)
self._optim.zero_grad()
self._multi_optim.zero_grad()
# divide batch in to chunks if needed so it fits in GPU RAM
if self.batch_chunks > 1:
@ -243,23 +254,20 @@ class Trainer(EnforceOverrides):
loss_c = self.compute_loss(self._lossfn, yc, logits_c,
self._aux_weight, aux_logits)
self._apex.backward(loss_c, self._optim)
self._apex.backward(loss_c, self._multi_optim)
loss_sum += loss_c.item() * len(logits_c)
loss_count += len(logits_c)
logits_chunks.append(logits_c.detach().cpu())
# TODO: original darts clips alphas as well but pt.darts doesn't
self._apex.clip_grad(self._grad_clip, self.model, self._optim)
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
self._optim.step()
self._multi_optim.step()
# TODO: we possibly need to sync so all replicas are upto date
self._apex.sync_devices()
if self._sched and not self._sched_on_epoch:
self._sched.step()
self.post_step(x, y,
ml_utils.join_chunks(logits_chunks),
torch.tensor(loss_sum/loss_count),
@ -268,8 +276,7 @@ class Trainer(EnforceOverrides):
# end of step
if self._sched and self._sched_on_epoch:
self._sched.step()
self._multi_optim.epoch()
logger.popd()
def compute_loss(self, lossfn:Callable, y:Tensor, logits:Tensor,

Просмотреть файл

@ -7,6 +7,7 @@ import sys
import os
import pathlib
import random
from itertools import zip_longest
import torch
import torch.backends.cudnn as cudnn
@ -230,3 +231,10 @@ def exec_shell_command(command:str, print_command=True)->None:
print(command)
subprocess.run(command, shell=True, check=True)
def zip_eq(*iterables):
sentinel = object()
for count, combo in enumerate(zip_longest(*iterables, fillvalue=sentinel)):
if any(True for c in combo if sentinel is c):
shorter_its = ','.join([str(i) for i,c in enumerate(combo) if sentinel is c])
raise ValueError(f'Iterator {shorter_its} have length {count} which is shorter than others')
yield combo

47
archai/nas/arch_module.py Normal file
Просмотреть файл

@ -0,0 +1,47 @@
from abc import ABC
from typing import Iterable, List, Optional, Tuple, Iterator
from torch import nn
from archai.nas.arch_params import ArchParams, NNTypes
from archai.common import utils
class ArchModule(nn.Module, ABC):
def __init__(self) -> None:
super().__init__()
# these are params module should use, they may be shared or created by this module
self._arch_params = ArchParams.empty()
# these are the params created and registerd in this module
self._owned_arch_params:Optional[ArchParams] = None
def create_arch_params(self, named_params:Iterable[Tuple[str, NNTypes]])->None:
if len(self._arch_params):
raise RuntimeError('Arch parameters for this module already exist')
self._owned_arch_params = ArchParams(named_params, registrar=self)
self.set_arch_params(self._owned_arch_params)
def set_arch_params(self, arch_params:ArchParams)->None:
if len(self._arch_params):
raise RuntimeError('Arch parameters for this module already exist')
self._arch_params = arch_params
def arch_params(self, recurse=False, only_owned=False)->ArchParams:
# note that we will cache lists on first calls, this doesn't allow
# dynamic parameters but it makes this frequent calls much faster
if not recurse:
if not only_owned:
return self._arch_params
else:
return ArchParams.from_module(self, recurse=False)
else:
if not only_owned:
raise NotImplementedError('Recursively getting shared and owned arch params not implemented yet')
else:
return ArchParams.from_module(self, recurse=True)
def all_owned(self)->ArchParams:
return self.arch_params(recurse=True, only_owned=True)
def nonarch_params(self, recurse:bool)->Iterator[nn.Parameter]:
return ArchParams.nonarch_from_module(self, recurse)

75
archai/nas/arch_params.py Normal file
Просмотреть файл

@ -0,0 +1,75 @@
from collections import UserDict
from typing import Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union
import torch
from torch import nn
_param_suffix = '_arch_param' # all arch parameter names must have this suffix
NNTypes = Union[nn.Parameter, nn.ParameterDict, nn.ParameterList]
class ArchParams(UserDict):
"""This class holds set of learnable architecture parameter(s) for a given module. For example, one instance of this class would hold alphas for one instance of MixedOp. For sharing parameters, instance of this class can be passed around. Different algorithms may add learnable parameters for their need."""
def __init__(self, arch_params:Iterable[Tuple[str, NNTypes]], registrar:Optional[nn.Module]=None):
"""Create architecture parameters and register them
Arguments:
registrar {Optional[nn.Module]} -- If this parameter is beingly newly created instead of being shared by other module then owner should be specified. When owner is not None, this method will create a variable in the owning module with suffix _arch_param so that the parameter gets registered with Pytorch and becomes available in module's .parameters() calls.
"""
super().__init__()
for name, param in arch_params:
self.data[name] = param
if registrar is not None:
setattr(registrar, name + _param_suffix, param)
def __setitem__(self, name:str, param:NNTypes)->None:
raise RuntimeError(f'ArchParams is immutable hence adding/updating key {name} is not allowed.')
def __delitem__(self, name:str) -> None:
raise RuntimeError(f'ArchParams is immutable hence removing key {name} is not allowed.')
def _by_kind(self, kind:Optional[str])->Iterator[NNTypes]:
# TODO: may be optimize to avoid split() calls?
for name, param in self.items():
if kind is None or name.split('.')[-1]==kind:
yield param
def param_by_kind(self, kind:Optional[str])->Iterator[nn.Parameter]:
# TODO: enforce type checking if debugger is active?
return self._by_kind(kind) # type: ignore
def paramlist_by_kind(self, kind:Optional[str])->Iterator[nn.ParameterList]:
# TODO: enforce type checking if debugger is active?
return self._by_kind(kind) # type: ignore
def paramdict_by_kind(self, kind:Optional[str])->Iterator[nn.ParameterDict]:
# TODO: enforce type checking if debugger is active?
return self._by_kind(kind) # type: ignore
def has_kind(self, kind:str)->bool:
# TODO: may be optimize to avoid split() calls?
for name in self.keys():
if name.split('.')[-1]==kind:
return True
return False
@staticmethod
def from_module(module:nn.Module, recurse:bool=False)->'ArchParams':
suffix_len = len(_param_suffix)
# Pytorch named params have . in name for each module, we pick last part and remove _arch_params prefix
arch_params = ((name[:-suffix_len], param) \
for name, param in module.named_parameters(recurse=recurse)
if name.endswith(_param_suffix))
return ArchParams(arch_params)
@staticmethod
def nonarch_from_module(module:nn.Module, recurse:bool=False)->Iterator[nn.Parameter]:
# Pytorch named params have . in name for each module, we pick last part and remove _arch_params prefix
return (param for name, param in module.named_parameters(recurse=recurse)
if not name.endswith(_param_suffix))
@staticmethod
def empty()->'ArchParams':
return ArchParams([])

Просмотреть файл

@ -27,6 +27,10 @@ class ArchTrainer(Trainer, EnforceOverrides):
self._l1_alphas = conf_train['l1_alphas']
self._plotsdir = conf_train['plotsdir']
# if l1 regularization is needed then cache alphas
if self._l1_alphas > 0.0:
self._alphas = list(self.model.all_owned().param_by_kind('alphas'))
@overrides
def compute_loss(self, lossfn: Callable,
y: Tensor, logits: Tensor,
@ -35,7 +39,7 @@ class ArchTrainer(Trainer, EnforceOverrides):
aux_weight, aux_logits)
# add L1 alpha regularization
if self._l1_alphas > 0.0:
l_extra = sum(torch.sum(a.abs()) for a in self.model.alphas())
l_extra = sum(torch.sum(a.abs()) for a in self._alphas)
loss += self._l1_alphas * l_extra
return loss

Просмотреть файл

@ -1,36 +1,35 @@
from typing import Callable, Iterable, List, Optional, Tuple
from abc import ABC, abstractmethod
from torch import nn, tensor
from overrides import overrides, EnforceOverrides
from ..common.common import logger
from .dag_edge import DagEdge
from .model_desc import ConvMacroParams, CellDesc, OpDesc, NodeDesc
from .operations import Op
from archai.nas.dag_edge import DagEdge
from archai.nas.model_desc import ConvMacroParams, CellDesc, OpDesc, NodeDesc
from archai.nas.operations import Op
from archai.nas.arch_module import ArchModule
class Cell(nn.Module, ABC, EnforceOverrides):
class Cell(ArchModule, EnforceOverrides):
def __init__(self, desc:CellDesc,
affine:bool, droppath:bool,
alphas_cell:Optional['Cell']):
template_cell:Optional['Cell']): # template cell, if any, to use for arch params
super().__init__()
# some of these members are public as finalizer needs access
self.shared_alphas = alphas_cell is not None
self.desc = desc
self.s0_op = Op.create(desc.s0_op, affine=affine)
self.s1_op = Op.create(desc.s1_op, affine=affine)
self.dag = Cell._create_dag(desc.nodes(),
affine=affine, droppath=droppath,
alphas_cell=alphas_cell)
template_cell=template_cell)
self.post_op = Op.create(desc.post_op, affine=affine)
@staticmethod
def _create_dag(nodes_desc:List[NodeDesc],
affine:bool, droppath:bool,
alphas_cell:Optional['Cell'])->nn.ModuleList:
template_cell:Optional['Cell'])->nn.ModuleList:
dag = nn.ModuleList()
for i, node_desc in enumerate(nodes_desc):
edges:nn.ModuleList = nn.ModuleList()
@ -39,21 +38,9 @@ class Cell(nn.Module, ABC, EnforceOverrides):
for j, edge_desc in enumerate(node_desc.edges):
edges.append(DagEdge(edge_desc,
affine=affine, droppath=droppath,
alphas_edge=alphas_cell.dag[i][j] if alphas_cell else None))
template_edge=template_cell.dag[i][j] if template_cell else None))
return dag
def alphas(self)->Iterable[nn.Parameter]:
for node in self.dag:
for edge in node:
for alpha in edge.alphas():
yield alpha
def weights(self)->Iterable[nn.Parameter]:
for node in self.dag:
for edge in node:
for p in edge.weights():
yield p
def ops(self)->Iterable[Op]:
for node in self.dag:
for edge in node:

Просмотреть файл

@ -4,16 +4,17 @@ import torch
from torch import nn
from overrides import overrides
from .operations import Op, DropPath_
from .model_desc import EdgeDesc
from archai.nas.operations import Op, DropPath_
from archai.nas.model_desc import EdgeDesc
from archai.nas.arch_module import ArchModule
class DagEdge(nn.Module):
class DagEdge(ArchModule):
def __init__(self, desc:EdgeDesc, affine:bool, droppath:bool,
alphas_edge:Optional['DagEdge'])->None:
template_edge:Optional['DagEdge'])->None:
super().__init__()
# we may need to wrap op is droppath is needed
self._wrapped = self._op = Op.create(desc.op_desc, affine,
alphas_edge.alphas() if alphas_edge else [])
template_edge.op().arch_params() if template_edge is not None else None)
if droppath and self._op.can_drop_path():
assert self.training
self._wrapped = nn.Sequential(self._op, DropPath_())
@ -29,14 +30,5 @@ class DagEdge(nn.Module):
else:
return self._wrapped([inputs[i] for i in self.input_ids])
def alphas(self)->Iterable[nn.Parameter]:
for alpha in self._op.alphas():
if alpha is not None:
yield alpha
def weights(self)->Iterable[nn.Parameter]:
for w in self._op.weights():
yield w
def op(self)->Op:
return self._op

Просмотреть файл

@ -56,7 +56,7 @@ class Finalizers(EnforceOverrides):
nodes = node_descs,
s0_op=cell.s0_op.finalize()[0],
s1_op=cell.s1_op.finalize()[0],
alphas_from = cell.desc.alphas_from,
template_cell = cell.desc.template_cell,
max_final_edges=cell.desc.max_final_edges,
node_ch_out=cell.desc.node_ch_out,
post_op=cell.post_op.finalize()[0]

Просмотреть файл

@ -99,21 +99,21 @@ class MacroBuilder(EnforceOverrides):
# chennels of prev-prev whole cell, prev whole cell and current cell node
pp_ch_out, p_ch_out, node_ch_out = stem_ch_out, stem_ch_out, self.init_node_ch
# stores first cells of each time with whom alphas would be shared
# stores first cells of each time with whom arch params would be shared
first_normal, first_reduction = -1, -1
for ci in range(self.n_cells):
# find cell type and output channels for this cell
# also update if this is our first cell from which alphas will be shared
# also update if this is our first cell from which arch params will be shared
reduction = self._is_reduction(ci)
if reduction:
node_ch_out, cell_type = node_ch_out*2, CellType.Reduction
first_reduction = ci if first_reduction < 0 else first_reduction
alphas_from = first_reduction
template_cell = first_reduction
else:
cell_type = CellType.Regular
first_normal = ci if first_normal < 0 else first_normal
alphas_from = first_normal
template_cell = first_normal
s0_op, s1_op = self._get_cell_stems(
node_ch_out, p_ch_out, pp_ch_out, reduction_p)
@ -134,7 +134,7 @@ class MacroBuilder(EnforceOverrides):
cell_type=cell_type, id=ci,
nodes=nodes,
s0_op=s0_op, s1_op=s1_op,
alphas_from=alphas_from,
template_cell=template_cell,
max_final_edges=max_final_edges,
node_ch_out=node_ch_out,
post_op=self.cell_post_op

Просмотреть файл

@ -6,17 +6,17 @@ import os
import torch
from torch import nn, Tensor
from overrides import overrides
from archai.nas.arch_params import ArchParams
from archai.nas.cell import Cell
from archai.nas.operations import Op, DropPath_
from archai.nas.model_desc import ModelDesc, AuxTowerDesc, CellDesc
from archai.common.common import logger
from archai.common import utils, ml_utils
from archai.nas.arch_module import ArchModule
from .cell import Cell
from .operations import Op, DropPath_
from .model_desc import ModelDesc, AuxTowerDesc, CellDesc
from ..common.common import logger
from ..common import utils, ml_utils
class Model(nn.Module):
class Model(ArchModule):
def __init__(self, model_desc:ModelDesc, droppath:bool, affine:bool):
super().__init__()
@ -45,35 +45,26 @@ class Model(nn.Module):
def _build_cell(self, cell_desc:CellDesc,
aux_tower_desc:Optional[AuxTowerDesc],
droppath:bool, affine:bool)->None:
alphas_cell = None if cell_desc.alphas_from==cell_desc.id \
else self.cells[cell_desc.alphas_from]
template_cell = None if cell_desc.template_cell==cell_desc.id \
else self.cells[cell_desc.template_cell]
cell = Cell(cell_desc, affine=affine, droppath=droppath,
alphas_cell=alphas_cell)
template_cell=template_cell)
self.cells.append(cell)
self._aux_towers.append(AuxTower(aux_tower_desc) \
if aux_tower_desc else None)
def summary(self)->dict:
all_arch_params = list(self.all_owned()
.param_by_kind(kind=None))
return {
'cell_count': len(self.cells),
#'cell_params': [ml_utils.param_size(c) for c in self.cells]
'params': ml_utils.param_size(self),
'alphas_p': len(list(a for a in self.alphas())),
'alphas': np.sum(a.numel() for a in self.alphas()),
'arch_params_len': len(all_arch_params),
'arch_params_numel': np.sum(a.numel() for a in all_arch_params),
'ops': np.sum(len(n.edges) for c in self.desc.cell_descs() for n in c.nodes()),
}
def alphas(self)->Iterable[nn.Parameter]:
for cell in self.cells:
if not cell.shared_alphas:
for alpha in cell.alphas():
yield alpha
def weights(self)->Iterable[nn.Parameter]:
for cell in self.cells:
for w in cell.weights():
yield w
def ops(self)->Iterable[Op]:
for cell in self.cells:
for op in cell.ops():

Просмотреть файл

@ -135,7 +135,7 @@ class CellType(Enum):
class CellDesc:
def __init__(self, cell_type:CellType, id:int, nodes:List[NodeDesc],
s0_op:OpDesc, s1_op:OpDesc, alphas_from:int, max_final_edges:int,
s0_op:OpDesc, s1_op:OpDesc, template_cell:int, max_final_edges:int,
node_ch_out:int, post_op:Union[str,OpDesc])->None:
assert s0_op.params['conv'].ch_out == s1_op.params['conv'].ch_out
assert s0_op.params['conv'].ch_out == node_ch_out
@ -143,7 +143,7 @@ class CellDesc:
self.cell_type = cell_type
self.id = id
self.s0_op, self.s1_op = s0_op, s1_op
self.alphas_from = alphas_from # cell id with which we share alphas
self.template_cell = template_cell # cell id with which we share arch params
self.max_final_edges = max_final_edges
self.cell_ch_out = -1 # will be set later by reset_nodes
@ -151,7 +151,7 @@ class CellDesc:
assert self.cell_ch_out > 0
def clone(self, id:int)->'CellDesc':
c = copy.deepcopy(self) # note that alphas_from is also cloned
c = copy.deepcopy(self) # note that template_cell is also cloned
c.id = id
return c
@ -284,7 +284,7 @@ class ModelDesc:
def reset_cells(self, cell_descs:List[CellDesc],
aux_tower_descs:List[Optional[AuxTowerDesc]])->None:
assert len(cell_descs) == len(aux_tower_descs)
# every cell should have unique ID so we can tell where alphas are shared
# every cell should have unique ID so we can tell where arch params are shared
assert len(set(c.id for c in cell_descs)) == len(cell_descs)
self._cell_descs = cell_descs

Просмотреть файл

@ -8,66 +8,69 @@ from overrides import overrides, EnforceOverrides
import torch
from torch import affine_grid_generator, nn, Tensor, strided
from ..common import utils, ml_utils
from .model_desc import OpDesc, ConvMacroParams
from archai.common import utils, ml_utils
from archai.nas.model_desc import OpDesc, ConvMacroParams
from archai.nas.arch_params import ArchParams
from archai.nas.arch_module import ArchModule
# type alias
OpFactoryFn = Callable[[OpDesc, Iterable[nn.Parameter]], 'Op']
# Each op is a unary tensor operator, all take same constructor params
# TODO: swap order of arch_params and affine to match with create signature
_ops_factory:Dict[str, Callable] = {
'max_pool_3x3': lambda op_desc, alphas, affine:
'max_pool_3x3': lambda op_desc, arch_params, affine:
PoolBN('max', op_desc, affine),
'avg_pool_3x3': lambda op_desc, alphas, affine:
'avg_pool_3x3': lambda op_desc, arch_params, affine:
PoolBN('avg', op_desc, affine),
'skip_connect': lambda op_desc, alphas, affine:
'skip_connect': lambda op_desc, arch_params, affine:
SkipConnect(op_desc, affine),
'sep_conv_3x3': lambda op_desc, alphas, affine:
'sep_conv_3x3': lambda op_desc, arch_params, affine:
SepConv(op_desc, 3, 1, affine),
'sep_conv_5x5': lambda op_desc, alphas, affine:
'sep_conv_5x5': lambda op_desc, arch_params, affine:
SepConv(op_desc, 5, 2, affine),
'dil_conv_3x3': lambda op_desc, alphas, affine:
'dil_conv_3x3': lambda op_desc, arch_params, affine:
DilConv(op_desc, 3, op_desc.params['stride'], 2, 2, affine),
'dil_conv_5x5': lambda op_desc, alphas, affine:
'dil_conv_5x5': lambda op_desc, arch_params, affine:
DilConv(op_desc, 5, op_desc.params['stride'], 4, 2, affine),
'none': lambda op_desc, alphas, affine:
'none': lambda op_desc, arch_params, affine:
Zero(op_desc),
'identity': lambda op_desc, alphas, affine:
'identity': lambda op_desc, arch_params, affine:
Identity(op_desc),
'sep_conv_7x7': lambda op_desc, alphas, affine:
'sep_conv_7x7': lambda op_desc, arch_params, affine:
SepConv(op_desc, 7, 3, affine),
'conv_7x1_1x7': lambda op_desc, alphas, affine:
'conv_7x1_1x7': lambda op_desc, arch_params, affine:
FacConv(op_desc, 7, 3, affine),
'prepr_reduce': lambda op_desc, alphas, affine:
'prepr_reduce': lambda op_desc, arch_params, affine:
FactorizedReduce(op_desc, affine),
'prepr_normal': lambda op_desc, alphas, affine:
'prepr_normal': lambda op_desc, arch_params, affine:
ReLUConvBN(op_desc, 1, 1, 0, affine),
'stem_conv3x3': lambda op_desc, alphas, affine:
'stem_conv3x3': lambda op_desc, arch_params, affine:
StemConv3x3(op_desc, affine),
'stem_conv3x3_s4': lambda op_desc, alphas, affine:
'stem_conv3x3_s4': lambda op_desc, arch_params, affine:
StemConv3x3S4(op_desc, affine),
'stem_conv3x3_s4s2': lambda op_desc, alphas, affine:
'stem_conv3x3_s4s2': lambda op_desc, arch_params, affine:
StemConv3x3S4S2(op_desc, affine),
'pool_adaptive_avg2d': lambda op_desc, alphas, affine:
'pool_adaptive_avg2d': lambda op_desc, arch_params, affine:
PoolAdaptiveAvg2D(),
'pool_avg2d7x7': lambda op_desc, alphas, affine:
'pool_avg2d7x7': lambda op_desc, arch_params, affine:
AvgPool2d7x7(),
'concate_channels': lambda op_desc, alphas, affine:
'concate_channels': lambda op_desc, arch_params, affine:
ConcateChannelsOp(op_desc, affine),
'proj_channels': lambda op_desc, alphas, affine:
'proj_channels': lambda op_desc, arch_params, affine:
ProjectChannelsOp(op_desc, affine),
'linear': lambda op_desc, alphas, affine:
'linear': lambda op_desc, arch_params, affine:
LinearOp(op_desc),
'multi_op': lambda op_desc, alphas, affine:
'multi_op': lambda op_desc, arch_params, affine:
MultiOp(op_desc, affine)
}
class Op(nn.Module, ABC, EnforceOverrides):
class Op(ArchModule, ABC, EnforceOverrides):
@staticmethod
def create(op_desc:OpDesc, affine:bool,
alphas:Iterable[nn.Parameter]=[])->'Op':
op = _ops_factory[op_desc.name](op_desc, alphas, affine)
def create(op_desc:OpDesc, affine:bool, arch_params:Optional[ArchParams]=None)->'Op':
global _ops_factory
op = _ops_factory[op_desc.name](op_desc, arch_params, affine)
# TODO: annotate as Final?
op.desc = op_desc # type: ignore
# load any pre-trained weights
@ -76,6 +79,7 @@ class Op(nn.Module, ABC, EnforceOverrides):
def get_trainables(self)->Mapping:
return {'name': self.desc.name, 'sd': self.state_dict()}
def set_trainables(self, state_dict)->None:
if state_dict is not None:
assert state_dict['name'] == self.desc.name
@ -95,16 +99,6 @@ class Op(nn.Module, ABC, EnforceOverrides):
else:
_ops_factory[name] = factory_fn
# must override if op has alphas, otherwise this returns nothing!
def alphas(self)->Iterable[nn.Parameter]:
return # when supported, derived class should override it
yield
# must override if op has alphas, otherwise this will return weights + alphas!
def weights(self)->Iterable[nn.Parameter]:
for w in self.parameters():
yield w
def finalize(self)->Tuple[OpDesc, Optional[float]]:
"""for trainable op, return final op and its rank"""

Просмотреть файл

@ -82,7 +82,7 @@ nas:
aug: '' # additional augmentations to use
cutout: 16 # cutout length, use cutout augmentation when > 0
load_train: True # load train split of dataset
train_batch: 96
train_batch: 68 # 96 is too aggressive for 1080Ti, better set it to 68
train_workers: 4
test_workers: '_copy: ../train_workers' # if null then 4
load_test: True # load test split of dataset
@ -229,6 +229,8 @@ nas:
decay: 1.0e-3
betas: [0.5, 0.999]
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
alpha_lr_schedule:
type: ''
lr_schedule:
type: 'cosine'
min_lr: 0.001 # min learning rate, this will be used in eta_min param of scheduler

16
confs/algos/didarts.yaml Normal file
Просмотреть файл

@ -0,0 +1,16 @@
__include__: "darts.yaml" # just use darts defaults
nas:
search:
trainer:
alpha_optimizer:
type: 'sgd'
lr: 0.025 # init learning rate
decay: 3.0e-4
momentum: 0.9 # pytorch default is 0
nesterov: False
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
alpha_lr_schedule:
type: 'cosine'
min_lr: 0.001 # min learning rate, this will be used in eta_min param of scheduler
warmup: null

Просмотреть файл

@ -8,7 +8,7 @@ nas:
model_desc:
cell_post_op: 'proj_channels'
loader:
train_batch: 96
train_batch: 64
search:
search_iters: 4
pareto:

Просмотреть файл

@ -334,3 +334,5 @@ So we need ways to retrieve parameters by:
- Are they shared or owned
This can be achieved by naming convention for variables where such parameters will be stored. Let's define this convention as kind_arch_param. This way any parameter with name ending in _arch_param is considered as architecture parameter. Their full name in the form module1.module2.kind1_arch_param defines where they reside. The part after last "." and without _arch_param suffix defines the kind of the parameter. While Pytorch automatically avoids double listing for shared parameters, a module can have following convention to keep things clean: Module keeps arch parameters in dictionary where key is same as what their variable names would have been. This way Pytorch doesn't register them automatically. If module does own these parameters, it will create variables with same name so they get registered. Module then can provide following methods: get_owned_params, get_shared_params, is_owned_param(p). For parameter sharing, module may receive dictionary of parameters owned by someone else and given module can decide to share some or all of those.
So, we stipulate that each instance of nn.Module type have same number and type of arch params. So Cell may have one set of arch params, Op have another, Model have another and so on. Question is (1) is it possible to share only subset of one's parameters among instances? (2) how Cell1 can share its arch parameters with Cell2 and Cell3 and Cell4 can with Cell5, Cell6. I think supporting this level of infinite flexibility can potentially make things complex. So let's see how we can do subset of these functionalities. We will have MacroBuilder decide which module shares arch params with which one. This can be done with base *Desc object having member specifying identity of object it will receive parameters from. If no arch parameter is received then object shall create its own. If it did, it may take whole or portion of it and create rest of its own. One can access arch_params method to access params for that module directly and pass parameter recursive=True to get arch params of entire module hierarchy. The return value is ArchParams object.

Просмотреть файл

@ -9,6 +9,8 @@ from archai.algos.manual.manual_exp_runner import ManualExperimentRunner
from archai.algos.xnas.xnas_exp_runner import XnasExperimentRunner
from archai.algos.gumbelsoftmax.gs_exp_runner import GsExperimentRunner
from archai.algos.divnas.divnas_exp_runner import DivnasExperimentRunner
from archai.algos.didarts.didarts_exp_runner import DiDartsExperimentRunner
def main():
runner_types:Dict[str, Type[ExperimentRunner]] = {
@ -18,11 +20,12 @@ def main():
'random': RandomExperimentRunner,
'manual': ManualExperimentRunner,
'gs': GsExperimentRunner,
'divnas': DivnasExperimentRunner
'divnas': DivnasExperimentRunner,
'didarts': DiDartsExperimentRunner
}
parser = argparse.ArgumentParser(description='NAS E2E Runs')
parser.add_argument('--algos', type=str, default='darts,petridish,xnas,random,gs,divnas,manual',
parser.add_argument('--algos', type=str, default='darts,petridish,xnas,random,gs,manual,didarts,divnas',
help='NAS algos to run, seperated by comma')
parser.add_argument('--datasets', type=str, default='cifar10',
help='datasets to use, separated by comma')

Просмотреть файл

@ -93,5 +93,28 @@ def imagenet_test():
dl_train, *_ = data.get_data(conf_loader)
def exclusion_test(data_len=32, labels_len=2, val_ratio=0.5):
x = np.array(range(data_len))
labels = np.array(range(labels_len))
y = np.repeat(labels, math.ceil(float(data_len)/labels_len))[:data_len]
np.random.shuffle(y)
dataset = ListDataset(x, y)
train_sampler = DistributedStratifiedSampler(dataset,
val_ratio=val_ratio, is_val=False, shuffle=True,
max_items=-1, world_size=1, rank=0)
valid_sampler = DistributedStratifiedSampler(dataset,
val_ratio=val_ratio, is_val=True, shuffle=True,
max_items=-1, world_size=1, rank=0)
tidx = list(train_sampler)
vidx = list(valid_sampler)
assert len(tidx) == len(vidx) == 16
assert all(ti not in vidx for ti in tidx)
# print(len(tidx), tidx)
# print(len(vidx), vidx)
exclusion_test()
_dist_no_val(1, 100, val_ratio=0.1)
test_combinations()