зеркало из https://github.com/microsoft/nni.git
Promote Retiarii to NAS (step 1) - move files (#5020)
This commit is contained in:
Родитель
481aa29299
Коммит
867871b244
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator
|
||||
from .trainer import CdartsTrainer
|
|
@ -1,143 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
|
||||
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
|
||||
from nni.algorithms.nas.pytorch.darts import DartsMutator # pylint: disable=wrong-import-order
|
||||
from nni.nas.pytorch.mutables import LayerChoice # pylint: disable=wrong-import-order
|
||||
from nni.nas.pytorch.mutator import Mutator # pylint: disable=wrong-import-order
|
||||
|
||||
|
||||
class RegularizedDartsMutator(DartsMutator):
|
||||
"""
|
||||
This is :class:`~nni.algorithms.nas.pytorch.darts.DartsMutator` basically, with two differences.
|
||||
|
||||
1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in
|
||||
forward pass and thus consumes no memory.
|
||||
|
||||
2. Regularization on choices, to prevent the mutator from overfitting on some choices.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Warnings
|
||||
--------
|
||||
Renamed :func:`~reset_with_loss` to return regularization loss on reset.
|
||||
"""
|
||||
raise ValueError("You should probably call `reset_with_loss`.")
|
||||
|
||||
def cut_choices(self, cut_num=2):
|
||||
"""
|
||||
Cut the choices with the smallest weights.
|
||||
``cut_num`` should be the accumulative number of cutting, e.g., if first time cutting
|
||||
is 2, the second time should be 4 to cut another two.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cut_num : int
|
||||
Number of choices to cut, so far.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
Though the parameters are set to :math:`-\infty` to be bypassed, they will still receive gradient of 0,
|
||||
which introduced ``nan`` problem when calling ``optimizer.step()``. To solve this issue, a simple way is to
|
||||
reset nan to :math:`-\infty` each time after the parameters are updated.
|
||||
"""
|
||||
# `cut_choices` is implemented but not used in current implementation of CdartsTrainer
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
_, idx = torch.topk(-self.choices[mutable.key], cut_num)
|
||||
with torch.no_grad():
|
||||
for i in idx:
|
||||
self.choices[mutable.key][i] = -float("inf")
|
||||
|
||||
def reset_with_loss(self):
|
||||
"""
|
||||
Resample and return loss. If loss is 0, to avoid device issue, it will return ``None``.
|
||||
|
||||
Currently loss penalty are proportional to the L1-norm of parameters corresponding
|
||||
to modules if their type name contains certain substrings. These substrings include: ``poolwithoutbn``,
|
||||
``identity``, ``dilconv``.
|
||||
"""
|
||||
self._cache, reg_loss = self.sample_search()
|
||||
return reg_loss
|
||||
|
||||
def sample_search(self):
|
||||
result = super().sample_search()
|
||||
loss = []
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
def need_reg(choice):
|
||||
return any(t in str(type(choice)).lower() for t in ["poolwithoutbn", "identity", "dilconv"])
|
||||
|
||||
for i, choice in enumerate(mutable.choices):
|
||||
if need_reg(choice):
|
||||
norm = torch.abs(self.choices[mutable.key][i])
|
||||
if norm < 1E10:
|
||||
loss.append(norm)
|
||||
if not loss:
|
||||
return result, None
|
||||
return result, sum(loss)
|
||||
|
||||
def export(self, logger=None):
|
||||
"""
|
||||
Export an architecture with logger. Genotype will be printed with logger.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A mapping from mutable keys to decisions.
|
||||
"""
|
||||
result = self.sample_final()
|
||||
if hasattr(self.model, "plot_genotype") and logger is not None:
|
||||
genotypes = self.model.plot_genotype(result, logger)
|
||||
return result, genotypes
|
||||
|
||||
|
||||
class RegularizedMutatorParallel(DistributedDataParallel):
|
||||
"""
|
||||
Parallelize :class:`~RegularizedDartsMutator`.
|
||||
|
||||
This makes :func:`~RegularizedDartsMutator.reset_with_loss` method parallelized,
|
||||
also allowing :func:`~RegularizedDartsMutator.cut_choices` and :func:`~RegularizedDartsMutator.export`
|
||||
to be easily accessible.
|
||||
"""
|
||||
def reset_with_loss(self):
|
||||
"""
|
||||
Parallelized :func:`~RegularizedDartsMutator.reset_with_loss`.
|
||||
"""
|
||||
result = self.module.reset_with_loss()
|
||||
self.callback_queued = False
|
||||
return result
|
||||
|
||||
def cut_choices(self, *args, **kwargs):
|
||||
"""
|
||||
Parallelized :func:`~RegularizedDartsMutator.cut_choices`.
|
||||
"""
|
||||
self.module.cut_choices(*args, **kwargs)
|
||||
|
||||
def export(self, logger):
|
||||
"""
|
||||
Parallelized :func:`~RegularizedDartsMutator.export`.
|
||||
"""
|
||||
return self.module.export(logger)
|
||||
|
||||
|
||||
class DartsDiscreteMutator(Mutator):
|
||||
"""
|
||||
A mutator that applies the final sampling result of a parent mutator on another model to train.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
The model to apply the mutator.
|
||||
parent_mutator : nni.nas.pytorch.mutator.Mutator
|
||||
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
|
||||
"""
|
||||
def __init__(self, model, parent_mutator):
|
||||
super().__init__(model)
|
||||
self.__dict__["parent_mutator"] = parent_mutator # avoid parameters to be included
|
||||
|
||||
def sample_search(self):
|
||||
return self.parent_mutator.sample_final()
|
|
@ -1,275 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import apex # pylint: disable=import-error
|
||||
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
|
||||
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator # pylint: disable=wrong-import-order
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup # pylint: disable=wrong-import-order
|
||||
|
||||
from .utils import CyclicIterator, TorchTensorEncoder, accuracy, reduce_metrics
|
||||
|
||||
PHASE_SMALL = "small"
|
||||
PHASE_LARGE = "large"
|
||||
|
||||
|
||||
class InteractiveKLLoss(nn.Module):
|
||||
def __init__(self, temperature):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
# self.kl_loss = nn.KLDivLoss(reduction = 'batchmean')
|
||||
self.kl_loss = nn.KLDivLoss()
|
||||
|
||||
def forward(self, student, teacher):
|
||||
return self.kl_loss(F.log_softmax(student / self.temperature, dim=1),
|
||||
F.softmax(teacher / self.temperature, dim=1))
|
||||
|
||||
|
||||
class CdartsTrainer(object):
|
||||
"""
|
||||
CDARTS trainer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_small : nn.Module
|
||||
PyTorch model to be trained. This is the search network of CDARTS.
|
||||
model_large : nn.Module
|
||||
PyTorch model to be trained. This is the evaluation network of CDARTS.
|
||||
criterion : callable
|
||||
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
|
||||
loaders : list of torch.utils.data.DataLoader
|
||||
List of train data and valid data loaders, for training weights and architecture weights respectively.
|
||||
samplers : list of torch.utils.data.Sampler
|
||||
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
|
||||
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
|
||||
logger : logging.Logger
|
||||
The logger for logging. Will use nni logger by default (if logger is ``None``).
|
||||
regular_coeff : float
|
||||
The coefficient of regular loss.
|
||||
regular_ratio : float
|
||||
The ratio of regular loss.
|
||||
warmup_epochs : int
|
||||
The epochs to warmup the search network
|
||||
fix_head : bool
|
||||
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
|
||||
epochs : int
|
||||
Number of epochs planned for training.
|
||||
steps_per_epoch : int
|
||||
Steps of one epoch.
|
||||
loss_alpha : float
|
||||
The loss coefficient.
|
||||
loss_T : float
|
||||
The loss coefficient.
|
||||
distributed : bool
|
||||
``True`` if using distributed training, else non-distributed training.
|
||||
log_frequency : int
|
||||
Step count per logging.
|
||||
grad_clip : float
|
||||
Gradient clipping for weights.
|
||||
interactive_type : string
|
||||
``kl`` or ``smoothl1``.
|
||||
output_path : string
|
||||
Log storage path.
|
||||
w_lr : float
|
||||
Learning rate of the search network parameters.
|
||||
w_momentum : float
|
||||
Momentum of the search and the evaluation network.
|
||||
w_weight_decay : float
|
||||
The weight decay the search and the evaluation network parameters.
|
||||
alpha_lr : float
|
||||
Learning rate of the architecture parameters.
|
||||
alpha_weight_decay : float
|
||||
The weight decay the architecture parameters.
|
||||
nasnet_lr : float
|
||||
Learning rate of the evaluation network parameters.
|
||||
local_rank : int
|
||||
The number of thread.
|
||||
share_module : bool
|
||||
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
|
||||
"""
|
||||
def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None,
|
||||
regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True,
|
||||
epochs=32, steps_per_epoch=None, loss_alpha=2, loss_T=2, distributed=True,
|
||||
log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs',
|
||||
w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4,
|
||||
nasnet_lr=0.2, local_rank=0, share_module=True):
|
||||
if logger is None:
|
||||
logger = logging.getLogger(__name__)
|
||||
train_loader, valid_loader = loaders
|
||||
train_sampler, valid_sampler = samplers
|
||||
self.train_loader = CyclicIterator(train_loader, train_sampler, distributed)
|
||||
self.valid_loader = CyclicIterator(valid_loader, valid_sampler, distributed)
|
||||
|
||||
self.regular_coeff = regular_coeff
|
||||
self.regular_ratio = regular_ratio
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.fix_head = fix_head
|
||||
self.epochs = epochs
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
if self.steps_per_epoch is None:
|
||||
self.steps_per_epoch = min(len(self.train_loader), len(self.valid_loader))
|
||||
self.loss_alpha = loss_alpha
|
||||
self.grad_clip = grad_clip
|
||||
if interactive_type == "kl":
|
||||
self.interactive_loss = InteractiveKLLoss(loss_T)
|
||||
elif interactive_type == "smoothl1":
|
||||
self.interactive_loss = nn.SmoothL1Loss()
|
||||
self.loss_T = loss_T
|
||||
self.distributed = distributed
|
||||
self.log_frequency = log_frequency
|
||||
self.main_proc = not distributed or local_rank == 0
|
||||
|
||||
self.logger = logger
|
||||
self.checkpoint_dir = output_path
|
||||
if self.main_proc:
|
||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||
if distributed:
|
||||
torch.distributed.barrier()
|
||||
|
||||
self.model_small = model_small
|
||||
self.model_large = model_large
|
||||
if self.fix_head:
|
||||
for param in self.model_small.aux_head.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.model_large.aux_head.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.mutator_small = RegularizedDartsMutator(self.model_small).cuda()
|
||||
self.mutator_large = DartsDiscreteMutator(self.model_large, self.mutator_small).cuda()
|
||||
self.criterion = criterion
|
||||
|
||||
self.optimizer_small = torch.optim.SGD(self.model_small.parameters(), w_lr,
|
||||
momentum=w_momentum, weight_decay=w_weight_decay)
|
||||
self.optimizer_large = torch.optim.SGD(self.model_large.parameters(), nasnet_lr,
|
||||
momentum=w_momentum, weight_decay=w_weight_decay)
|
||||
self.optimizer_alpha = torch.optim.Adam(self.mutator_small.parameters(), alpha_lr,
|
||||
betas=(0.5, 0.999), weight_decay=alpha_weight_decay)
|
||||
|
||||
if distributed:
|
||||
apex.parallel.convert_syncbn_model(self.model_small)
|
||||
apex.parallel.convert_syncbn_model(self.model_large)
|
||||
self.model_small = DistributedDataParallel(self.model_small, delay_allreduce=True)
|
||||
self.model_large = DistributedDataParallel(self.model_large, delay_allreduce=True)
|
||||
self.mutator_small = RegularizedMutatorParallel(self.mutator_small, delay_allreduce=True)
|
||||
if share_module:
|
||||
self.model_small.callback_queued = True
|
||||
self.model_large.callback_queued = True
|
||||
# mutator large never gets optimized, so do not need parallelized
|
||||
|
||||
def _warmup(self, phase, epoch):
|
||||
assert phase in [PHASE_SMALL, PHASE_LARGE]
|
||||
if phase == PHASE_SMALL:
|
||||
model, optimizer = self.model_small, self.optimizer_small
|
||||
elif phase == PHASE_LARGE:
|
||||
model, optimizer = self.model_large, self.optimizer_large
|
||||
model.train()
|
||||
meters = AverageMeterGroup()
|
||||
for step in range(self.steps_per_epoch):
|
||||
x, y = next(self.train_loader)
|
||||
x, y = x.cuda(), y.cuda()
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits_main, _ = model(x)
|
||||
loss = self.criterion(logits_main, y)
|
||||
loss.backward()
|
||||
|
||||
self._clip_grad_norm(model)
|
||||
optimizer.step()
|
||||
prec1, prec5 = accuracy(logits_main, y, topk=(1, 5))
|
||||
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
|
||||
metrics = reduce_metrics(metrics, self.distributed)
|
||||
meters.update(metrics)
|
||||
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
|
||||
self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s) %s", epoch + 1, self.epochs,
|
||||
step + 1, self.steps_per_epoch, phase, meters)
|
||||
|
||||
def _clip_grad_norm(self, model):
|
||||
if isinstance(model, DistributedDataParallel):
|
||||
nn.utils.clip_grad_norm_(model.module.parameters(), self.grad_clip)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
|
||||
|
||||
def _reset_nan(self, parameters):
|
||||
with torch.no_grad():
|
||||
for param in parameters:
|
||||
for i, p in enumerate(param):
|
||||
if p != p: # equivalent to `isnan(p)`
|
||||
param[i] = float("-inf")
|
||||
|
||||
def _joint_train(self, epoch):
|
||||
self.model_large.train()
|
||||
self.model_small.train()
|
||||
meters = AverageMeterGroup()
|
||||
for step in range(self.steps_per_epoch):
|
||||
trn_x, trn_y = next(self.train_loader)
|
||||
val_x, val_y = next(self.valid_loader)
|
||||
trn_x, trn_y = trn_x.cuda(), trn_y.cuda()
|
||||
val_x, val_y = val_x.cuda(), val_y.cuda()
|
||||
|
||||
# step 1. optimize architecture
|
||||
self.optimizer_alpha.zero_grad()
|
||||
self.optimizer_large.zero_grad()
|
||||
reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
|
||||
(self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
|
||||
loss_regular = self.mutator_small.reset_with_loss()
|
||||
if loss_regular:
|
||||
loss_regular *= reg_decay
|
||||
logits_search, emsemble_logits_search = self.model_small(val_x)
|
||||
logits_main, emsemble_logits_main = self.model_large(val_x)
|
||||
loss_cls = (self.criterion(logits_search, val_y) + self.criterion(logits_main, val_y)) / self.loss_alpha
|
||||
loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * (self.loss_T ** 2) * self.loss_alpha
|
||||
loss = loss_cls + loss_interactive + loss_regular
|
||||
loss.backward()
|
||||
self._clip_grad_norm(self.model_large)
|
||||
self.optimizer_large.step()
|
||||
self.optimizer_alpha.step()
|
||||
# NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices`
|
||||
|
||||
# step 2. optimize op weights
|
||||
self.optimizer_small.zero_grad()
|
||||
with torch.no_grad():
|
||||
# resample architecture since parameters have been changed
|
||||
self.mutator_small.reset_with_loss()
|
||||
logits_search_train, _ = self.model_small(trn_x)
|
||||
loss_weight = self.criterion(logits_search_train, trn_y)
|
||||
loss_weight.backward()
|
||||
self._clip_grad_norm(self.model_small)
|
||||
self.optimizer_small.step()
|
||||
|
||||
metrics = {"loss_cls": loss_cls, "loss_interactive": loss_interactive,
|
||||
"loss_regular": loss_regular, "loss_weight": loss_weight}
|
||||
metrics = reduce_metrics(metrics, self.distributed)
|
||||
meters.update(metrics)
|
||||
|
||||
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
|
||||
self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint) %s", epoch + 1, self.epochs,
|
||||
step + 1, self.steps_per_epoch, meters)
|
||||
|
||||
def train(self):
|
||||
for epoch in range(self.epochs):
|
||||
if epoch < self.warmup_epochs:
|
||||
with torch.no_grad(): # otherwise grads will be retained on the architecture params
|
||||
self.mutator_small.reset_with_loss()
|
||||
self._warmup(PHASE_SMALL, epoch)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.mutator_large.reset()
|
||||
self._warmup(PHASE_LARGE, epoch)
|
||||
self._joint_train(epoch)
|
||||
|
||||
self.export(os.path.join(self.checkpoint_dir, "epoch_{:02d}.json".format(epoch)),
|
||||
os.path.join(self.checkpoint_dir, "epoch_{:02d}.genotypes".format(epoch)))
|
||||
|
||||
def export(self, file, genotype_file):
|
||||
if self.main_proc:
|
||||
mutator_export, genotypes = self.mutator_small.export(self.logger)
|
||||
with open(file, "w") as f:
|
||||
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
|
||||
with open(genotype_file, "w") as f:
|
||||
f.write(str(genotypes))
|
|
@ -1,76 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class CyclicIterator:
|
||||
def __init__(self, loader, sampler, distributed):
|
||||
self.loader = loader
|
||||
self.sampler = sampler
|
||||
self.epoch = 0
|
||||
self.distributed = distributed
|
||||
self._next_epoch()
|
||||
|
||||
def _next_epoch(self):
|
||||
if self.distributed:
|
||||
self.sampler.set_epoch(self.epoch)
|
||||
self.iterator = iter(self.loader)
|
||||
self.epoch += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
self._next_epoch()
|
||||
return next(self.iterator)
|
||||
|
||||
|
||||
class TorchTensorEncoder(json.JSONEncoder):
|
||||
def default(self, o): # pylint: disable=method-hidden
|
||||
if isinstance(o, torch.Tensor):
|
||||
return o.tolist()
|
||||
return super().default(o)
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
""" Computes the precision@k for the specified values of k """
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
# one-hot case
|
||||
if target.ndimension() > 1:
|
||||
target = target.max(1)[1]
|
||||
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(1.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= float(os.environ["WORLD_SIZE"])
|
||||
return rt
|
||||
|
||||
|
||||
def reduce_metrics(metrics, distributed=False):
|
||||
if distributed:
|
||||
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
|
||||
return {k: v.item() for k, v in metrics.items()}
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import get_and_apply_next_architecture
|
|
@ -1,221 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
import nni
|
||||
from nni.runtime.env_vars import trial_env_vars
|
||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
|
||||
from nni.nas.pytorch.mutator import Mutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
|
||||
LAYER_CHOICE = "layer_choice"
|
||||
INPUT_CHOICE = "input_choice"
|
||||
|
||||
|
||||
def get_and_apply_next_architecture(model):
|
||||
"""
|
||||
Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful,
|
||||
similar to ``get_next_parameter`` for HPO.
|
||||
|
||||
It will generate search space based on ``model``.
|
||||
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
|
||||
generating search space for the experiment.
|
||||
If not, there are still two mode, one is nni experiment mode where users
|
||||
use ``nnictl`` to start an experiment. The other is standalone mode
|
||||
where users directly run the trial command, this mode chooses the first
|
||||
one(s) for each LayerChoice and InputChoice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
|
||||
"""
|
||||
ClassicMutator(model)
|
||||
|
||||
|
||||
class ClassicMutator(Mutator):
|
||||
"""
|
||||
This mutator is to apply the architecture chosen from tuner.
|
||||
It implements the forward function of LayerChoice and InputChoice,
|
||||
to only activate the chosen ones.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super(ClassicMutator, self).__init__(model)
|
||||
self._chosen_arch = {}
|
||||
self._search_space = self._generate_search_space()
|
||||
if NNI_GEN_SEARCH_SPACE in os.environ:
|
||||
# dry run for only generating search space
|
||||
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
|
||||
sys.exit(0)
|
||||
|
||||
if trial_env_vars.NNI_PLATFORM is None:
|
||||
logger.warning("This is in standalone mode, the chosen are the first one(s).")
|
||||
self._chosen_arch = self._standalone_generate_chosen()
|
||||
else:
|
||||
# get chosen arch from tuner
|
||||
self._chosen_arch = nni.get_next_parameter()
|
||||
if self._chosen_arch is None:
|
||||
if trial_env_vars.NNI_PLATFORM == "unittest":
|
||||
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
|
||||
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
|
||||
self._chosen_arch = self._standalone_generate_chosen()
|
||||
else:
|
||||
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
|
||||
self.reset()
|
||||
|
||||
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
|
||||
"""
|
||||
Convert layer choice to tensor representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : Mutable
|
||||
idx : int
|
||||
Number `idx` of list will be selected.
|
||||
value : str
|
||||
The verbose representation of the selected value.
|
||||
search_space_item : list
|
||||
The list for corresponding search space.
|
||||
"""
|
||||
# doesn't support multihot for layer choice yet
|
||||
onehot_list = [False] * len(mutable)
|
||||
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
|
||||
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
|
||||
onehot_list[idx] = True
|
||||
return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable
|
||||
|
||||
def _sample_input_choice(self, mutable, idx, value, search_space_item):
|
||||
"""
|
||||
Convert input choice to tensor representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : Mutable
|
||||
idx : int
|
||||
Number `idx` of list will be selected.
|
||||
value : str
|
||||
The verbose representation of the selected value.
|
||||
search_space_item : list
|
||||
The list for corresponding search space.
|
||||
"""
|
||||
candidate_repr = search_space_item["candidates"]
|
||||
multihot_list = [False] * mutable.n_candidates
|
||||
for i, v in zip(idx, value):
|
||||
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
|
||||
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
|
||||
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
|
||||
multihot_list[i] = True
|
||||
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
|
||||
|
||||
def sample_search(self):
|
||||
"""
|
||||
See :meth:`sample_final`.
|
||||
"""
|
||||
return self.sample_final()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Convert the chosen arch and apply it on model.
|
||||
"""
|
||||
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
|
||||
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
|
||||
self._chosen_arch.keys())
|
||||
result = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, (LayerChoice, InputChoice)):
|
||||
assert mutable.key in self._chosen_arch, \
|
||||
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
|
||||
data = self._chosen_arch[mutable.key]
|
||||
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
|
||||
"'{}' is not a valid choice.".format(data)
|
||||
if isinstance(mutable, LayerChoice):
|
||||
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
|
||||
self._search_space[mutable.key]["_value"])
|
||||
elif isinstance(mutable, InputChoice):
|
||||
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
|
||||
self._search_space[mutable.key]["_value"])
|
||||
elif isinstance(mutable, MutableScope):
|
||||
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
|
||||
else:
|
||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
|
||||
return result
|
||||
|
||||
def _standalone_generate_chosen(self):
|
||||
"""
|
||||
Generate the chosen architecture for standalone mode,
|
||||
i.e., choose the first one(s) for LayerChoice and InputChoice.
|
||||
::
|
||||
{ key_name: {"_value": "conv1",
|
||||
"_idx": 0} }
|
||||
{ key_name: {"_value": ["in1"],
|
||||
"_idx": [0]} }
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the chosen architecture
|
||||
"""
|
||||
chosen_arch = {}
|
||||
for key, val in self._search_space.items():
|
||||
if val["_type"] == LAYER_CHOICE:
|
||||
choices = val["_value"]
|
||||
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
|
||||
elif val["_type"] == INPUT_CHOICE:
|
||||
choices = val["_value"]["candidates"]
|
||||
n_chosen = val["_value"]["n_chosen"]
|
||||
if n_chosen is None:
|
||||
n_chosen = len(choices)
|
||||
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
|
||||
else:
|
||||
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
|
||||
return chosen_arch
|
||||
|
||||
def _generate_search_space(self):
|
||||
"""
|
||||
Generate search space from mutables.
|
||||
Here is the search space format:
|
||||
::
|
||||
{ key_name: {"_type": "layer_choice",
|
||||
"_value": ["conv1", "conv2"]} }
|
||||
{ key_name: {"_type": "input_choice",
|
||||
"_value": {"candidates": ["in1", "in2"],
|
||||
"n_chosen": 1}} }
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the generated search space
|
||||
"""
|
||||
search_space = {}
|
||||
for mutable in self.mutables:
|
||||
# for now we only generate flattened search space
|
||||
if isinstance(mutable, LayerChoice):
|
||||
key = mutable.key
|
||||
val = mutable.names
|
||||
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
|
||||
elif isinstance(mutable, InputChoice):
|
||||
key = mutable.key
|
||||
search_space[key] = {"_type": INPUT_CHOICE,
|
||||
"_value": {"candidates": mutable.choose_from,
|
||||
"n_chosen": mutable.n_chosen}}
|
||||
elif isinstance(mutable, MutableScope):
|
||||
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
|
||||
else:
|
||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
|
||||
return search_space
|
||||
|
||||
def _dump_search_space(self, file_path):
|
||||
with open(file_path, "w") as ss_file:
|
||||
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .trainer import CreamSupernetTrainer
|
|
@ -1,403 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from nni.nas.pytorch.trainer import Trainer
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup
|
||||
|
||||
from .utils import accuracy, reduce_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreamSupernetTrainer(Trainer):
|
||||
"""
|
||||
This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with mutables.
|
||||
loss : callable
|
||||
Called with logits and targets. Returns a loss tensor.
|
||||
val_loss : callable
|
||||
Called with logits and targets for validation only. Returns a loss tensor.
|
||||
optimizer : Optimizer
|
||||
Optimizer that optimizes the model.
|
||||
num_epochs : int
|
||||
Number of epochs of training.
|
||||
train_loader : iterablez
|
||||
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
|
||||
valid_loader : iterablez
|
||||
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
|
||||
mutator : Mutator
|
||||
A mutator object that has been initialized with the model.
|
||||
batch_size : int
|
||||
Batch size.
|
||||
log_frequency : int
|
||||
Number of mini-batches to log metrics.
|
||||
meta_sta_epoch : int
|
||||
start epoch of using meta matching network to pick teacher architecture
|
||||
update_iter : int
|
||||
interval of updating meta matching networks
|
||||
slices : int
|
||||
batch size of mini training data in the process of training meta matching network
|
||||
pool_size : int
|
||||
board size
|
||||
pick_method : basestring
|
||||
how to pick teacher network
|
||||
choice_num : int
|
||||
number of operations in supernet
|
||||
sta_num : int
|
||||
layer number of each stage in supernet (5 stage in supernet)
|
||||
acc_gap : int
|
||||
maximum accuracy improvement to omit the limitation of flops
|
||||
flops_dict : Dict
|
||||
dictionary of each layer's operations in supernet
|
||||
flops_fixed : int
|
||||
flops of fixed part in supernet
|
||||
local_rank : int
|
||||
index of current rank
|
||||
callbacks : list of Callback
|
||||
Callbacks to plug into the trainer. See Callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self, model, loss, val_loss,
|
||||
optimizer, num_epochs, train_loader, valid_loader,
|
||||
mutator=None, batch_size=64, log_frequency=None,
|
||||
meta_sta_epoch=20, update_iter=200, slices=2,
|
||||
pool_size=10, pick_method='meta', choice_num=6,
|
||||
sta_num=(4, 4, 4, 4, 4), acc_gap=5,
|
||||
flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None):
|
||||
assert torch.cuda.is_available()
|
||||
super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None,
|
||||
optimizer, num_epochs, None, None,
|
||||
batch_size, None, None, log_frequency, callbacks)
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.val_loss = val_loss
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
self.log_frequency = log_frequency
|
||||
self.batch_size = batch_size
|
||||
self.optimizer = optimizer
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.num_epochs = num_epochs
|
||||
self.meta_sta_epoch = meta_sta_epoch
|
||||
self.update_iter = update_iter
|
||||
self.slices = slices
|
||||
self.pick_method = pick_method
|
||||
self.pool_size = pool_size
|
||||
self.local_rank = local_rank
|
||||
self.choice_num = choice_num
|
||||
self.sta_num = sta_num
|
||||
self.acc_gap = acc_gap
|
||||
self.flops_dict = flops_dict
|
||||
self.flops_fixed = flops_fixed
|
||||
|
||||
self.current_student_arch = None
|
||||
self.current_teacher_arch = None
|
||||
self.main_proc = (local_rank == 0)
|
||||
self.current_epoch = 0
|
||||
|
||||
self.prioritized_board = []
|
||||
|
||||
# size of prioritized board
|
||||
def _board_size(self):
|
||||
return len(self.prioritized_board)
|
||||
|
||||
# select teacher architecture according to the logit difference
|
||||
def _select_teacher(self):
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
|
||||
if self.pick_method == 'top1':
|
||||
meta_value, teacher_cand = 0.5, sorted(
|
||||
self.prioritized_board, reverse=True)[0][3]
|
||||
elif self.pick_method == 'meta':
|
||||
meta_value, cand_idx, teacher_cand = -1000000000, -1, None
|
||||
for now_idx, item in enumerate(self.prioritized_board):
|
||||
inputx = item[4]
|
||||
output = torch.nn.functional.softmax(self.model(inputx), dim=1)
|
||||
weight = self.model.module.forward_meta(output - item[5])
|
||||
if weight > meta_value:
|
||||
meta_value = weight
|
||||
cand_idx = now_idx
|
||||
teacher_cand = self.prioritized_board[cand_idx][3]
|
||||
assert teacher_cand is not None
|
||||
meta_value = torch.nn.functional.sigmoid(-weight)
|
||||
else:
|
||||
raise ValueError('Method Not supported')
|
||||
|
||||
return meta_value, teacher_cand
|
||||
|
||||
# check whether to update prioritized board
|
||||
def _isUpdateBoard(self, prec1, flops):
|
||||
if self.current_epoch <= self.meta_sta_epoch:
|
||||
return False
|
||||
|
||||
if len(self.prioritized_board) < self.pool_size:
|
||||
return True
|
||||
|
||||
if prec1 > self.prioritized_board[-1][1] + self.acc_gap:
|
||||
return True
|
||||
|
||||
if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# update prioritized board
|
||||
def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops):
|
||||
if self._isUpdateBoard(prec1, flops):
|
||||
val_prec1 = prec1
|
||||
training_data = deepcopy(inputs[:self.slices].detach())
|
||||
if len(self.prioritized_board) == 0:
|
||||
features = deepcopy(outputs[:self.slices].detach())
|
||||
else:
|
||||
features = deepcopy(
|
||||
teacher_output[:self.slices].detach())
|
||||
self.prioritized_board.append(
|
||||
(val_prec1,
|
||||
prec1,
|
||||
flops,
|
||||
self.current_student_arch,
|
||||
training_data,
|
||||
torch.nn.functional.softmax(
|
||||
features,
|
||||
dim=1)))
|
||||
self.prioritized_board = sorted(
|
||||
self.prioritized_board, reverse=True)
|
||||
|
||||
if len(self.prioritized_board) > self.pool_size:
|
||||
del self.prioritized_board[-1]
|
||||
|
||||
# only update student network weights
|
||||
def _update_student_weights_only(self, grad_1):
|
||||
for weight, grad_item in zip(
|
||||
self.model.module.rand_parameters(self.current_student_arch), grad_1):
|
||||
weight.grad = grad_item
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.module.rand_parameters(self.current_student_arch), 1)
|
||||
self.optimizer.step()
|
||||
for weight, grad_item in zip(
|
||||
self.model.module.rand_parameters(self.current_student_arch), grad_1):
|
||||
del weight.grad
|
||||
|
||||
# only update meta networks weights
|
||||
def _update_meta_weights_only(self, teacher_cand, grad_teacher):
|
||||
for weight, grad_item in zip(self.model.module.rand_parameters(
|
||||
teacher_cand, self.pick_method == 'meta'), grad_teacher):
|
||||
weight.grad = grad_item
|
||||
|
||||
# clip gradients
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.module.rand_parameters(
|
||||
self.current_student_arch, self.pick_method == 'meta'), 1)
|
||||
|
||||
self.optimizer.step()
|
||||
for weight, grad_item in zip(self.model.module.rand_parameters(
|
||||
teacher_cand, self.pick_method == 'meta'), grad_teacher):
|
||||
del weight.grad
|
||||
|
||||
# simulate sgd updating
|
||||
def _simulate_sgd_update(self, w, g, optimizer):
|
||||
return g * optimizer.param_groups[-1]['lr'] + w
|
||||
|
||||
# split training images into several slices
|
||||
def _get_minibatch_input(self, input): # pylint: disable=redefined-builtin
|
||||
slice = self.slices # pylint: disable=redefined-builtin
|
||||
x = deepcopy(input[:slice].clone().detach())
|
||||
return x
|
||||
|
||||
# calculate 1st gradient of student architectures
|
||||
def _calculate_1st_gradient(self, kd_loss):
|
||||
self.optimizer.zero_grad()
|
||||
grad = torch.autograd.grad(
|
||||
kd_loss,
|
||||
self.model.module.rand_parameters(self.current_student_arch),
|
||||
create_graph=True)
|
||||
return grad
|
||||
|
||||
# calculate 2nd gradient of meta networks
|
||||
def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight):
|
||||
self.optimizer.zero_grad()
|
||||
grad_student_val = torch.autograd.grad(
|
||||
validation_loss,
|
||||
self.model.module.rand_parameters(self.current_student_arch),
|
||||
retain_graph=True)
|
||||
|
||||
grad_teacher = torch.autograd.grad(
|
||||
students_weight[0],
|
||||
self.model.module.rand_parameters(
|
||||
teacher_cand,
|
||||
self.pick_method == 'meta'),
|
||||
grad_outputs=grad_student_val)
|
||||
return grad_teacher
|
||||
|
||||
# forward training data
|
||||
def _forward_training(self, x, meta_value):
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
output = self.model(x)
|
||||
|
||||
with torch.no_grad():
|
||||
self._replace_mutator_cand(self.current_teacher_arch)
|
||||
teacher_output = self.model(x)
|
||||
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
|
||||
|
||||
kd_loss = meta_value * \
|
||||
self._cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
return kd_loss
|
||||
|
||||
# calculate soft target loss
|
||||
def _cross_entropy_loss_with_soft_target(self, pred, soft_target):
|
||||
logsoftmax = torch.nn.LogSoftmax()
|
||||
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
|
||||
|
||||
# forward validation data
|
||||
def _forward_validation(self, input, target): # pylint: disable=redefined-builtin
|
||||
slice = self.slices # pylint: disable=redefined-builtin
|
||||
x = input[slice:slice * 2].clone()
|
||||
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
output_2 = self.model(x)
|
||||
|
||||
validation_loss = self.loss(output_2, target[slice:slice * 2])
|
||||
return validation_loss
|
||||
|
||||
def _isUpdateMeta(self, batch_idx):
|
||||
isUpdate = True
|
||||
isUpdate &= (self.current_epoch > self.meta_sta_epoch)
|
||||
isUpdate &= (batch_idx > 0)
|
||||
isUpdate &= (batch_idx % self.update_iter == 0)
|
||||
isUpdate &= (self._board_size() > 0)
|
||||
return isUpdate
|
||||
|
||||
def _replace_mutator_cand(self, cand):
|
||||
self.mutator._cache = cand
|
||||
|
||||
# update meta matching networks
|
||||
def _run_update(self, input, target, batch_idx): # pylint: disable=redefined-builtin
|
||||
if self._isUpdateMeta(batch_idx):
|
||||
x = self._get_minibatch_input(input)
|
||||
|
||||
meta_value, teacher_cand = self._select_teacher()
|
||||
|
||||
kd_loss = self._forward_training(x, meta_value)
|
||||
|
||||
# calculate 1st gradient
|
||||
grad_1st = self._calculate_1st_gradient(kd_loss)
|
||||
|
||||
# simulate updated student weights
|
||||
students_weight = [
|
||||
self._simulate_sgd_update(
|
||||
p, grad_item, self.optimizer) for p, grad_item in zip(
|
||||
self.model.module.rand_parameters(self.current_student_arch), grad_1st)]
|
||||
|
||||
# update student weights
|
||||
self._update_student_weights_only(grad_1st)
|
||||
|
||||
validation_loss = self._forward_validation(input, target)
|
||||
|
||||
# calculate 2nd gradient
|
||||
grad_teacher = self._calculate_2nd_gradient(validation_loss, teacher_cand, students_weight)
|
||||
|
||||
# update meta matching networks
|
||||
self._update_meta_weights_only(teacher_cand, grad_teacher)
|
||||
|
||||
# delete internal variants
|
||||
del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight
|
||||
|
||||
def _get_cand_flops(self, cand):
|
||||
flops = 0
|
||||
for block_id, block in enumerate(cand):
|
||||
if block == 'LayerChoice1' or block_id == 'LayerChoice23':
|
||||
continue
|
||||
for idx, choice in enumerate(cand[block]):
|
||||
flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
|
||||
return flops + self.flops_fixed
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
meters = AverageMeterGroup()
|
||||
self.steps_per_epoch = len(self.train_loader)
|
||||
for step, (input_data, target) in enumerate(self.train_loader):
|
||||
self.mutator.reset()
|
||||
self.current_student_arch = self.mutator._cache
|
||||
|
||||
input_data, target = input_data.cuda(), target.cuda()
|
||||
|
||||
# calculate flops of current architecture
|
||||
cand_flops = self._get_cand_flops(self.mutator._cache)
|
||||
|
||||
# update meta matching network
|
||||
self._run_update(input_data, target, step)
|
||||
|
||||
if self._board_size() > 0:
|
||||
# select teacher architecture
|
||||
meta_value, teacher_cand = self._select_teacher()
|
||||
self.current_teacher_arch = teacher_cand
|
||||
|
||||
# forward supernet
|
||||
if self._board_size() == 0 or epoch <= self.meta_sta_epoch:
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
output = self.model(input_data)
|
||||
|
||||
loss = self.loss(output, target)
|
||||
kd_loss, teacher_output, teacher_cand = None, None, None
|
||||
else:
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
output = self.model(input_data)
|
||||
|
||||
gt_loss = self.loss(output, target)
|
||||
|
||||
with torch.no_grad():
|
||||
self._replace_mutator_cand(self.current_teacher_arch)
|
||||
teacher_output = self.model(input_data).detach()
|
||||
|
||||
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
|
||||
kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
|
||||
loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2
|
||||
|
||||
# update network
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# update metrics
|
||||
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
||||
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
|
||||
metrics = reduce_metrics(metrics)
|
||||
meters.update(metrics)
|
||||
|
||||
# update prioritized board
|
||||
self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops)
|
||||
|
||||
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
|
||||
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs,
|
||||
step + 1, len(self.train_loader), meters)
|
||||
|
||||
if self.main_proc and self.num_epochs == epoch + 1:
|
||||
for idx, i in enumerate(self.prioritized_board):
|
||||
logger.info("No.%s %s", idx, i[:4])
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
self.model.eval()
|
||||
meters = AverageMeterGroup()
|
||||
with torch.no_grad():
|
||||
for step, (x, y) in enumerate(self.valid_loader):
|
||||
self.mutator.reset()
|
||||
logits = self.model(x)
|
||||
loss = self.val_loss(logits, y)
|
||||
prec1, prec5 = accuracy(logits, y, topk=(1, 5))
|
||||
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
|
||||
metrics = reduce_metrics(metrics)
|
||||
meters.update(metrics)
|
||||
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.valid_loader), meters)
|
|
@ -1,37 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
""" Computes the precision@k for the specified values of k """
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
# one-hot case
|
||||
if target.ndimension() > 1:
|
||||
target = target.max(1)[1]
|
||||
|
||||
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(1.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def reduce_metrics(metrics):
|
||||
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
|
||||
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= float(os.environ["WORLD_SIZE"])
|
||||
return rt
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import DartsMutator
|
||||
from .trainer import DartsTrainer
|
|
@ -1,85 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nni.nas.pytorch.mutator import Mutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DartsMutator(Mutator):
|
||||
"""
|
||||
Connects the model in a DARTS (differentiable) way.
|
||||
|
||||
An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no
|
||||
op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
|
||||
(not chosen).
|
||||
|
||||
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
|
||||
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
|
||||
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
|
||||
|
||||
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
|
||||
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
|
||||
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
choices: ParameterDict
|
||||
dict that maps keys of LayerChoices to weighted-connection float tensors.
|
||||
"""
|
||||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
self.choices = nn.ParameterDict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
|
||||
|
||||
def device(self):
|
||||
for v in self.choices.values():
|
||||
return v.device
|
||||
|
||||
def sample_search(self):
|
||||
result = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
|
||||
elif isinstance(mutable, InputChoice):
|
||||
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
|
||||
return result
|
||||
|
||||
def sample_final(self):
|
||||
result = dict()
|
||||
edges_max = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
|
||||
edges_max[mutable.key] = max_val
|
||||
result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, InputChoice):
|
||||
if mutable.n_chosen is not None:
|
||||
weights = []
|
||||
for src_key in mutable.choose_from:
|
||||
if src_key not in edges_max:
|
||||
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
|
||||
weights.append(edges_max.get(src_key, 0.))
|
||||
weights = torch.tensor(weights) # pylint: disable=not-callable
|
||||
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
|
||||
selected_multihot = []
|
||||
for i, src_key in enumerate(mutable.choose_from):
|
||||
if i not in topk_edge_indices and src_key in result:
|
||||
# If an edge is never selected, there is no need to calculate any op on this edge.
|
||||
# This is to eliminate redundant calculation.
|
||||
result[src_key] = torch.zeros_like(result[src_key])
|
||||
selected_multihot.append(i in topk_edge_indices)
|
||||
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
|
||||
else:
|
||||
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
|
||||
return result
|
|
@ -1,214 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from nni.nas.pytorch.trainer import Trainer
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup
|
||||
|
||||
from .mutator import DartsMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DartsTrainer(Trainer):
|
||||
"""
|
||||
DARTS trainer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
PyTorch model to be trained.
|
||||
loss : callable
|
||||
Receives logits and ground truth label, return a loss tensor.
|
||||
metrics : callable
|
||||
Receives logits and ground truth label, return a dict of metrics.
|
||||
optimizer : Optimizer
|
||||
The optimizer used for optimizing the model.
|
||||
num_epochs : int
|
||||
Number of epochs planned for training.
|
||||
dataset_train : Dataset
|
||||
Dataset for training. Will be split for training weights and architecture weights.
|
||||
dataset_valid : Dataset
|
||||
Dataset for testing.
|
||||
mutator : DartsMutator
|
||||
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
|
||||
batch_size : int
|
||||
Batch size.
|
||||
workers : int
|
||||
Workers for data loading.
|
||||
device : torch.device
|
||||
``torch.device("cpu")`` or ``torch.device("cuda")``.
|
||||
log_frequency : int
|
||||
Step count per logging.
|
||||
callbacks : list of Callback
|
||||
list of callbacks to trigger at events.
|
||||
arc_learning_rate : float
|
||||
Learning rate of architecture parameters.
|
||||
unrolled : float
|
||||
``True`` if using second order optimization, else first order optimization.
|
||||
"""
|
||||
def __init__(self, model, loss, metrics,
|
||||
optimizer, num_epochs, dataset_train, dataset_valid,
|
||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
|
||||
callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
|
||||
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
|
||||
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
|
||||
batch_size, workers, device, log_frequency, callbacks)
|
||||
|
||||
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
|
||||
weight_decay=1.0E-3)
|
||||
self.unrolled = unrolled
|
||||
|
||||
n_train = len(self.dataset_train)
|
||||
split = n_train // 2
|
||||
indices = list(range(n_train))
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
|
||||
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
|
||||
batch_size=batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=workers)
|
||||
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
|
||||
batch_size=batch_size,
|
||||
sampler=valid_sampler,
|
||||
num_workers=workers)
|
||||
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
|
||||
batch_size=batch_size,
|
||||
num_workers=workers)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.model.train()
|
||||
self.mutator.train()
|
||||
meters = AverageMeterGroup()
|
||||
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
|
||||
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
|
||||
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
|
||||
|
||||
# phase 1. architecture step
|
||||
self.ctrl_optim.zero_grad()
|
||||
if self.unrolled:
|
||||
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
|
||||
else:
|
||||
self._backward(val_X, val_y)
|
||||
self.ctrl_optim.step()
|
||||
|
||||
# phase 2: child network step
|
||||
self.optimizer.zero_grad()
|
||||
logits, loss = self._logits_and_loss(trn_X, trn_y)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
|
||||
self.optimizer.step()
|
||||
|
||||
metrics = self.metrics(logits, trn_y)
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.train_loader), meters)
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
self.model.eval()
|
||||
self.mutator.eval()
|
||||
meters = AverageMeterGroup()
|
||||
with torch.no_grad():
|
||||
self.mutator.reset()
|
||||
for step, (X, y) in enumerate(self.test_loader):
|
||||
X, y = X.to(self.device), y.to(self.device)
|
||||
logits = self.model(X)
|
||||
metrics = self.metrics(logits, y)
|
||||
meters.update(metrics)
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.test_loader), meters)
|
||||
|
||||
def _logits_and_loss(self, X, y):
|
||||
self.mutator.reset()
|
||||
logits = self.model(X)
|
||||
loss = self.loss(logits, y)
|
||||
self._write_graph_status()
|
||||
return logits, loss
|
||||
|
||||
def _backward(self, val_X, val_y):
|
||||
"""
|
||||
Simple backward with gradient descent
|
||||
"""
|
||||
_, loss = self._logits_and_loss(val_X, val_y)
|
||||
loss.backward()
|
||||
|
||||
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
|
||||
"""
|
||||
Compute unrolled loss and backward its gradients
|
||||
"""
|
||||
backup_params = copy.deepcopy(tuple(self.model.parameters()))
|
||||
|
||||
# do virtual step on training data
|
||||
lr = self.optimizer.param_groups[0]["lr"]
|
||||
momentum = self.optimizer.param_groups[0]["momentum"]
|
||||
weight_decay = self.optimizer.param_groups[0]["weight_decay"]
|
||||
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
|
||||
|
||||
# calculate unrolled loss on validation data
|
||||
# keep gradients for model here for compute hessian
|
||||
_, loss = self._logits_and_loss(val_X, val_y)
|
||||
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
|
||||
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
|
||||
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
|
||||
|
||||
# compute hessian and final gradients
|
||||
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
|
||||
with torch.no_grad():
|
||||
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
|
||||
# gradient = dalpha - lr * hessian
|
||||
param.grad = d - lr * h
|
||||
|
||||
# restore weights
|
||||
self._restore_weights(backup_params)
|
||||
|
||||
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
|
||||
"""
|
||||
Compute unrolled weights w`
|
||||
"""
|
||||
# don't need zero_grad, using autograd to calculate gradients
|
||||
_, loss = self._logits_and_loss(X, y)
|
||||
gradients = torch.autograd.grad(loss, self.model.parameters())
|
||||
with torch.no_grad():
|
||||
for w, g in zip(self.model.parameters(), gradients):
|
||||
m = self.optimizer.state[w].get("momentum_buffer", 0.)
|
||||
w = w - lr * (momentum * m + g + weight_decay * w)
|
||||
|
||||
def _restore_weights(self, backup_params):
|
||||
with torch.no_grad():
|
||||
for param, backup in zip(self.model.parameters(), backup_params):
|
||||
param.copy_(backup)
|
||||
|
||||
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
|
||||
"""
|
||||
dw = dw` { L_val(w`, alpha) }
|
||||
w+ = w + eps * dw
|
||||
w- = w - eps * dw
|
||||
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
|
||||
eps = 0.01 / ||dw||
|
||||
"""
|
||||
self._restore_weights(backup_params)
|
||||
norm = torch.cat([w.view(-1) for w in dw]).norm()
|
||||
eps = 0.01 / norm
|
||||
if norm < 1E-8:
|
||||
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
|
||||
|
||||
dalphas = []
|
||||
for e in [eps, -2. * eps]:
|
||||
# w+ = w + eps*dw`, w- = w - eps*dw`
|
||||
with torch.no_grad():
|
||||
for p, d in zip(self.model.parameters(), dw):
|
||||
p += e * d
|
||||
|
||||
_, loss = self._logits_and_loss(trn_X, trn_y)
|
||||
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
|
||||
|
||||
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
|
||||
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
|
||||
return hessian
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import EnasMutator
|
||||
from .trainer import EnasTrainer
|
|
@ -1,197 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nni.nas.pytorch.mutator import Mutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
|
||||
|
||||
|
||||
class StackedLSTMCell(nn.Module):
|
||||
def __init__(self, layers, size, bias):
|
||||
super().__init__()
|
||||
self.lstm_num_layers = layers
|
||||
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
|
||||
for _ in range(self.lstm_num_layers)])
|
||||
|
||||
def forward(self, inputs, hidden):
|
||||
prev_h, prev_c = hidden
|
||||
next_h, next_c = [], []
|
||||
for i, m in enumerate(self.lstm_modules):
|
||||
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
|
||||
next_c.append(curr_c)
|
||||
next_h.append(curr_h)
|
||||
# current implementation only supports batch size equals 1,
|
||||
# but the algorithm does not necessarily have this limitation
|
||||
inputs = curr_h[-1].view(1, -1)
|
||||
return next_h, next_c
|
||||
|
||||
|
||||
class EnasMutator(Mutator):
|
||||
"""
|
||||
A mutator that mutates the graph with RL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
PyTorch model.
|
||||
lstm_size : int
|
||||
Controller LSTM hidden units.
|
||||
lstm_num_layers : int
|
||||
Number of layers for stacked LSTM.
|
||||
tanh_constant : float
|
||||
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
|
||||
cell_exit_extra_step : bool
|
||||
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
|
||||
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
|
||||
skip_target : float
|
||||
Target probability that skipconnect will appear.
|
||||
temperature : float
|
||||
Temperature constant that divides the logits.
|
||||
branch_bias : float
|
||||
Manual bias applied to make some operations more likely to be chosen.
|
||||
Currently this is implemented with a hardcoded match rule that aligns with original repo.
|
||||
If a mutable has a ``reduce`` in its key, all its op choices
|
||||
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
|
||||
receive a bias of ``-self.branch_bias``.
|
||||
entropy_reduction : str
|
||||
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
|
||||
"""
|
||||
|
||||
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
|
||||
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
|
||||
super().__init__(model)
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_num_layers = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
self.cell_exit_extra_step = cell_exit_extra_step
|
||||
self.skip_target = skip_target
|
||||
self.branch_bias = branch_bias
|
||||
|
||||
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
|
||||
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
|
||||
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
|
||||
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
|
||||
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
|
||||
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
|
||||
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
|
||||
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
|
||||
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
|
||||
self.bias_dict = nn.ParameterDict()
|
||||
|
||||
self.max_layer_choice = 0
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
if self.max_layer_choice == 0:
|
||||
self.max_layer_choice = len(mutable)
|
||||
assert self.max_layer_choice == len(mutable), \
|
||||
"ENAS mutator requires all layer choice have the same number of candidates."
|
||||
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
|
||||
if "reduce" in mutable.key:
|
||||
def is_conv(choice):
|
||||
return "conv" in str(type(choice)).lower()
|
||||
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
|
||||
for choice in mutable])
|
||||
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
|
||||
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
|
||||
|
||||
def sample_search(self):
|
||||
self._initialize()
|
||||
self._sample(self.mutables)
|
||||
return self._choices
|
||||
|
||||
def sample_final(self):
|
||||
return self.sample_search()
|
||||
|
||||
def _sample(self, tree):
|
||||
mutable = tree.mutable
|
||||
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
|
||||
self._choices[mutable.key] = self._sample_layer_choice(mutable)
|
||||
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
|
||||
self._choices[mutable.key] = self._sample_input_choice(mutable)
|
||||
for child in tree.children:
|
||||
self._sample(child)
|
||||
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
|
||||
if self.cell_exit_extra_step:
|
||||
self._lstm_next_step()
|
||||
self._mark_anchor(mutable.key)
|
||||
|
||||
def _initialize(self):
|
||||
self._choices = dict()
|
||||
self._anchors_hid = dict()
|
||||
self._inputs = self.g_emb.data
|
||||
self._c = [torch.zeros((1, self.lstm_size),
|
||||
dtype=self._inputs.dtype,
|
||||
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
|
||||
self._h = [torch.zeros((1, self.lstm_size),
|
||||
dtype=self._inputs.dtype,
|
||||
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
|
||||
self.sample_log_prob = 0
|
||||
self.sample_entropy = 0
|
||||
self.sample_skip_penalty = 0
|
||||
|
||||
def _lstm_next_step(self):
|
||||
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
|
||||
|
||||
def _mark_anchor(self, key):
|
||||
self._anchors_hid[key] = self._h[-1]
|
||||
|
||||
def _sample_layer_choice(self, mutable):
|
||||
self._lstm_next_step()
|
||||
logit = self.soft(self._h[-1])
|
||||
if self.temperature is not None:
|
||||
logit /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
logit = self.tanh_constant * torch.tanh(logit)
|
||||
if mutable.key in self.bias_dict:
|
||||
logit += self.bias_dict[mutable.key]
|
||||
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
|
||||
log_prob = self.cross_entropy_loss(logit, branch_id)
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
self._inputs = self.embedding(branch_id)
|
||||
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
|
||||
|
||||
def _sample_input_choice(self, mutable):
|
||||
query, anchors = [], []
|
||||
for label in mutable.choose_from:
|
||||
if label not in self._anchors_hid:
|
||||
self._lstm_next_step()
|
||||
self._mark_anchor(label) # empty loop, fill not found
|
||||
query.append(self.attn_anchor(self._anchors_hid[label]))
|
||||
anchors.append(self._anchors_hid[label])
|
||||
query = torch.cat(query, 0)
|
||||
query = torch.tanh(query + self.attn_query(self._h[-1]))
|
||||
query = self.v_attn(query)
|
||||
if self.temperature is not None:
|
||||
query /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
query = self.tanh_constant * torch.tanh(query)
|
||||
|
||||
if mutable.n_chosen is None:
|
||||
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
|
||||
|
||||
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
|
||||
skip_prob = torch.sigmoid(logit)
|
||||
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
|
||||
self.sample_skip_penalty += kl
|
||||
log_prob = self.cross_entropy_loss(logit, skip)
|
||||
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
|
||||
else:
|
||||
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
|
||||
logit = query.view(1, -1)
|
||||
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
|
||||
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
|
||||
log_prob = self.cross_entropy_loss(logit, index)
|
||||
self._inputs = anchors[index.item()]
|
||||
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
return skip.bool()
|
|
@ -1,209 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from itertools import cycle
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from nni.nas.pytorch.trainer import Trainer
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup, to_device
|
||||
from .mutator import EnasMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnasTrainer(Trainer):
|
||||
"""
|
||||
ENAS trainer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
PyTorch model to be trained.
|
||||
loss : callable
|
||||
Receives logits and ground truth label, return a loss tensor.
|
||||
metrics : callable
|
||||
Receives logits and ground truth label, return a dict of metrics.
|
||||
reward_function : callable
|
||||
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
|
||||
optimizer : Optimizer
|
||||
The optimizer used for optimizing the model.
|
||||
num_epochs : int
|
||||
Number of epochs planned for training.
|
||||
dataset_train : Dataset
|
||||
Dataset for training. Will be split for training weights and architecture weights.
|
||||
dataset_valid : Dataset
|
||||
Dataset for testing.
|
||||
mutator : EnasMutator
|
||||
Use when customizing your own mutator or a mutator with customized parameters.
|
||||
batch_size : int
|
||||
Batch size.
|
||||
workers : int
|
||||
Workers for data loading.
|
||||
device : torch.device
|
||||
``torch.device("cpu")`` or ``torch.device("cuda")``.
|
||||
log_frequency : int
|
||||
Step count per logging.
|
||||
callbacks : list of Callback
|
||||
list of callbacks to trigger at events.
|
||||
entropy_weight : float
|
||||
Weight of sample entropy loss.
|
||||
skip_weight : float
|
||||
Weight of skip penalty loss.
|
||||
baseline_decay : float
|
||||
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
|
||||
child_steps : int
|
||||
How many mini-batches for model training per epoch.
|
||||
mutator_lr : float
|
||||
Learning rate for RL controller.
|
||||
mutator_steps_aggregate : int
|
||||
Number of steps that will be aggregated into one mini-batch for RL controller.
|
||||
mutator_steps : int
|
||||
Number of mini-batches for each epoch of RL controller learning.
|
||||
aux_weight : float
|
||||
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
|
||||
test_arc_per_epoch : int
|
||||
How many architectures are chosen for direct test after each epoch.
|
||||
"""
|
||||
def __init__(self, model, loss, metrics, reward_function,
|
||||
optimizer, num_epochs, dataset_train, dataset_valid,
|
||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
|
||||
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500,
|
||||
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4,
|
||||
test_arc_per_epoch=1):
|
||||
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
|
||||
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
|
||||
batch_size, workers, device, log_frequency, callbacks)
|
||||
self.reward_function = reward_function
|
||||
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
|
||||
self.batch_size = batch_size
|
||||
self.workers = workers
|
||||
|
||||
self.entropy_weight = entropy_weight
|
||||
self.skip_weight = skip_weight
|
||||
self.baseline_decay = baseline_decay
|
||||
self.baseline = 0.
|
||||
self.mutator_steps_aggregate = mutator_steps_aggregate
|
||||
self.mutator_steps = mutator_steps
|
||||
self.child_steps = child_steps
|
||||
self.aux_weight = aux_weight
|
||||
self.test_arc_per_epoch = test_arc_per_epoch
|
||||
|
||||
self.init_dataloader()
|
||||
|
||||
def init_dataloader(self):
|
||||
n_train = len(self.dataset_train)
|
||||
split = n_train // 10
|
||||
indices = list(range(n_train))
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
|
||||
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
|
||||
batch_size=self.batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=self.workers)
|
||||
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
|
||||
batch_size=self.batch_size,
|
||||
sampler=valid_sampler,
|
||||
num_workers=self.workers)
|
||||
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.workers)
|
||||
self.train_loader = cycle(self.train_loader)
|
||||
self.valid_loader = cycle(self.valid_loader)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
# Sample model and train
|
||||
self.model.train()
|
||||
self.mutator.eval()
|
||||
meters = AverageMeterGroup()
|
||||
for step in range(1, self.child_steps + 1):
|
||||
x, y = next(self.train_loader)
|
||||
x, y = to_device(x, self.device), to_device(y, self.device)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with torch.no_grad():
|
||||
self.mutator.reset()
|
||||
self._write_graph_status()
|
||||
logits = self.model(x)
|
||||
|
||||
if isinstance(logits, tuple):
|
||||
logits, aux_logits = logits
|
||||
aux_loss = self.loss(aux_logits, y)
|
||||
else:
|
||||
aux_loss = 0.
|
||||
metrics = self.metrics(logits, y)
|
||||
loss = self.loss(logits, y)
|
||||
loss = loss + self.aux_weight * aux_loss
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
|
||||
self.optimizer.step()
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
|
||||
self.num_epochs, step, self.child_steps, meters)
|
||||
|
||||
# Train sampler (mutator)
|
||||
self.model.eval()
|
||||
self.mutator.train()
|
||||
meters = AverageMeterGroup()
|
||||
for mutator_step in range(1, self.mutator_steps + 1):
|
||||
self.mutator_optim.zero_grad()
|
||||
for step in range(1, self.mutator_steps_aggregate + 1):
|
||||
x, y = next(self.valid_loader)
|
||||
x, y = to_device(x, self.device), to_device(y, self.device)
|
||||
|
||||
self.mutator.reset()
|
||||
with torch.no_grad():
|
||||
logits = self.model(x)
|
||||
self._write_graph_status()
|
||||
metrics = self.metrics(logits, y)
|
||||
reward = self.reward_function(logits, y)
|
||||
if self.entropy_weight:
|
||||
reward += self.entropy_weight * self.mutator.sample_entropy.item()
|
||||
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
|
||||
loss = self.mutator.sample_log_prob * (reward - self.baseline)
|
||||
if self.skip_weight:
|
||||
loss += self.skip_weight * self.mutator.sample_skip_penalty
|
||||
metrics["reward"] = reward
|
||||
metrics["loss"] = loss.item()
|
||||
metrics["ent"] = self.mutator.sample_entropy.item()
|
||||
metrics["log_prob"] = self.mutator.sample_log_prob.item()
|
||||
metrics["baseline"] = self.baseline
|
||||
metrics["skip"] = self.mutator.sample_skip_penalty
|
||||
|
||||
loss /= self.mutator_steps_aggregate
|
||||
loss.backward()
|
||||
meters.update(metrics)
|
||||
|
||||
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
|
||||
if self.log_frequency is not None and cur_step % self.log_frequency == 0:
|
||||
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs,
|
||||
mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
|
||||
meters)
|
||||
|
||||
nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
|
||||
self.mutator_optim.step()
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
with torch.no_grad():
|
||||
for arc_id in range(self.test_arc_per_epoch):
|
||||
meters = AverageMeterGroup()
|
||||
for x, y in self.test_loader:
|
||||
x, y = to_device(x, self.device), to_device(y, self.device)
|
||||
self.mutator.reset()
|
||||
logits = self.model(x)
|
||||
if isinstance(logits, tuple):
|
||||
logits, _ = logits
|
||||
metrics = self.metrics(logits, y)
|
||||
loss = self.loss(logits, y)
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
|
||||
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
|
||||
epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch,
|
||||
meters.summary())
|
|
@ -1,14 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .mutator import FBNetMutator # noqa: F401
|
||||
from .trainer import FBNetTrainer # noqa: F401
|
||||
from .utils import ( # noqa: F401
|
||||
LookUpTable,
|
||||
NASConfig,
|
||||
RegularizerLoss,
|
||||
model_init,
|
||||
supernet_sample,
|
||||
)
|
|
@ -1,268 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from nni.nas.pytorch.base_mutator import BaseMutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
"""
|
||||
This class is to instantiate and manage info of one LayerChoice.
|
||||
It includes architecture weights and member functions for the weights.
|
||||
"""
|
||||
|
||||
def __init__(self, mutable, latency):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
A LayerChoice in user model
|
||||
latency : List
|
||||
performance cost for each op in mutable
|
||||
"""
|
||||
super(MixedOp, self).__init__()
|
||||
self.latency = latency
|
||||
n_choices = len(mutable)
|
||||
self.path_alpha = nn.Parameter(
|
||||
torch.FloatTensor([1.0 / n_choices for i in range(n_choices)])
|
||||
)
|
||||
self.path_alpha.requires_grad = False
|
||||
self.temperature = 1.0
|
||||
|
||||
def get_path_alpha(self):
|
||||
"""Return the architecture parameter."""
|
||||
return self.path_alpha
|
||||
|
||||
def get_weighted_latency(self):
|
||||
"""Return the weighted perf_cost of current mutable."""
|
||||
soft_masks = self.probs_over_ops()
|
||||
weighted_latency = sum(m * l for m, l in zip(soft_masks, self.latency))
|
||||
return weighted_latency
|
||||
|
||||
def set_temperature(self, temperature):
|
||||
"""
|
||||
Set the annealed temperature for gumbel softmax.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
temperature : float
|
||||
The annealed temperature for gumbel softmax
|
||||
"""
|
||||
self.temperature = temperature
|
||||
|
||||
def to_requires_grad(self):
|
||||
"""Enable gradient calculation."""
|
||||
self.path_alpha.requires_grad = True
|
||||
|
||||
def to_disable_grad(self):
|
||||
"""Disable gradient calculation."""
|
||||
self.path_alpha.requires_grad = False
|
||||
|
||||
def probs_over_ops(self):
|
||||
"""Apply gumbel softmax to generate probability distribution."""
|
||||
return F.gumbel_softmax(self.path_alpha, self.temperature)
|
||||
|
||||
def forward(self, mutable, x):
|
||||
"""
|
||||
Define forward of LayerChoice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
this layer's mutable
|
||||
x : tensor
|
||||
inputs of this layer, only support one input
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: tensor
|
||||
output of this layer
|
||||
"""
|
||||
candidate_ops = list(mutable)
|
||||
soft_masks = self.probs_over_ops()
|
||||
output = sum(m * op(x) for m, op in zip(soft_masks, candidate_ops))
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def chosen_index(self):
|
||||
"""
|
||||
choose the op with max prob
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index of the chosen one
|
||||
"""
|
||||
alphas = self.path_alpha.data.detach().cpu().numpy()
|
||||
index = int(np.argmax(alphas))
|
||||
return index
|
||||
|
||||
|
||||
class FBNetMutator(BaseMutator):
|
||||
"""
|
||||
This mutator initializes and operates all the LayerChoices of the supernet.
|
||||
It is for the related trainer to control the training flow of LayerChoices,
|
||||
coordinating with whole training process.
|
||||
"""
|
||||
|
||||
def __init__(self, model, lookup_table):
|
||||
"""
|
||||
Init a MixedOp instance for each mutable i.e., LayerChoice.
|
||||
And register the instantiated MixedOp in corresponding LayerChoice.
|
||||
If does not register it in LayerChoice, DataParallel does'nt work then,
|
||||
for architecture weights are not included in the DataParallel model.
|
||||
When MixedOPs are registered, we use ```requires_grad``` to control
|
||||
whether calculate gradients of architecture weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
The model that users want to tune,
|
||||
it includes search space defined with nni nas apis
|
||||
lookup_table : class
|
||||
lookup table object to manage model space information,
|
||||
including candidate ops for each stage as the model space,
|
||||
input channels/output channels/stride/fm_size as the layer config,
|
||||
and the performance information for perf_cost accumulation.
|
||||
|
||||
"""
|
||||
super(FBNetMutator, self).__init__(model)
|
||||
self.mutable_list = []
|
||||
|
||||
# Collect the op names of the candidate ops within each mutable
|
||||
ops_names_mutable = dict()
|
||||
left = 0
|
||||
right = 1
|
||||
for stage_name in lookup_table.layer_num:
|
||||
right = lookup_table.layer_num[stage_name]
|
||||
stage_ops = lookup_table.lut_ops[stage_name]
|
||||
ops_names = [op_name for op_name in stage_ops]
|
||||
|
||||
for i in range(left, left + right):
|
||||
ops_names_mutable[i] = ops_names
|
||||
left += right
|
||||
|
||||
# Create the mixed op
|
||||
for i, mutable in enumerate(self.undedup_mutables):
|
||||
ops_names = ops_names_mutable[i]
|
||||
latency_mutable = lookup_table.lut_perf[i]
|
||||
latency = [latency_mutable[op_name] for op_name in ops_names]
|
||||
self.mutable_list.append(mutable)
|
||||
mutable.registered_module = MixedOp(mutable, latency)
|
||||
|
||||
def on_forward_layer_choice(self, mutable, *args, **kwargs):
|
||||
"""
|
||||
Callback of layer choice forward. This function defines the forward
|
||||
logic of the input mutable. So mutable is only interface, its real
|
||||
implementation is defined in mutator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable: LayerChoice
|
||||
forward logic of this input mutable
|
||||
args: list of torch.Tensor
|
||||
inputs of this mutable
|
||||
kwargs: dict
|
||||
inputs of this mutable
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
output of this mutable, i.e., LayerChoice
|
||||
int
|
||||
index of the chosen op
|
||||
"""
|
||||
# FIXME: return mask, to be consistent with other algorithms
|
||||
idx = mutable.registered_module.chosen_index
|
||||
return mutable.registered_module(mutable, *args, **kwargs), idx
|
||||
|
||||
def num_arch_params(self):
|
||||
"""
|
||||
The number of mutables, i.e., LayerChoice
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
the number of LayerChoice in user model
|
||||
"""
|
||||
return len(self.mutable_list)
|
||||
|
||||
def get_architecture_parameters(self):
|
||||
"""
|
||||
Get all the architecture parameters.
|
||||
|
||||
yield
|
||||
-----
|
||||
PyTorch Parameter
|
||||
Return path_alpha of the traversed mutable
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
yield mutable.registered_module.get_path_alpha()
|
||||
|
||||
def get_weighted_latency(self):
|
||||
"""
|
||||
Get the latency weighted by gumbel softmax coefficients.
|
||||
|
||||
yield
|
||||
-----
|
||||
Tuple
|
||||
Return the weighted_latency of the traversed mutable
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
yield mutable.registered_module.get_weighted_latency()
|
||||
|
||||
def set_temperature(self, temperature):
|
||||
"""
|
||||
Set the annealed temperature of the op for gumbel softmax.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
temperature : float
|
||||
The annealed temperature for gumbel softmax
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.set_temperature(temperature)
|
||||
|
||||
def arch_requires_grad(self):
|
||||
"""
|
||||
Make architecture weights require gradient
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.to_requires_grad()
|
||||
|
||||
def arch_disable_grad(self):
|
||||
"""
|
||||
Disable gradient of architecture weights, i.e., does not
|
||||
calculate gradient for them.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.to_disable_grad()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Generate the final chosen architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the choice of each mutable, i.e., LayerChoice
|
||||
"""
|
||||
result = dict()
|
||||
for mutable in self.undedup_mutables:
|
||||
assert isinstance(mutable, LayerChoice)
|
||||
index = mutable.registered_module.chosen_index
|
||||
# pylint: disable=not-callable
|
||||
result[mutable.key] = (
|
||||
F.one_hot(torch.tensor(index), num_classes=len(mutable))
|
||||
.view(-1)
|
||||
.bool(),
|
||||
)
|
||||
return result
|
|
@ -1,413 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.autograd import Variable
|
||||
from nni.nas.pytorch.base_trainer import BaseTrainer
|
||||
from nni.nas.pytorch.trainer import TorchTensorEncoder
|
||||
from nni.nas.pytorch.utils import AverageMeter
|
||||
from .mutator import FBNetMutator
|
||||
from .utils import RegularizerLoss, accuracy
|
||||
|
||||
|
||||
class FBNetTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model_optim,
|
||||
criterion,
|
||||
device,
|
||||
device_ids,
|
||||
lookup_table,
|
||||
train_loader,
|
||||
valid_loader,
|
||||
n_epochs=120,
|
||||
load_ckpt=False,
|
||||
arch_path=None,
|
||||
logger=None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
the user model, which has mutables
|
||||
model_optim : pytorch optimizer
|
||||
the user defined optimizer
|
||||
criterion : pytorch loss
|
||||
the main task loss, nn.CrossEntropyLoss() is for classification
|
||||
device : pytorch device
|
||||
the devices to train/search the model
|
||||
device_ids : list of int
|
||||
the indexes of devices used for training
|
||||
lookup_table : class
|
||||
lookup table object for fbnet training
|
||||
train_loader : pytorch data loader
|
||||
data loader for the training set
|
||||
valid_loader : pytorch data loader
|
||||
data loader for the validation set
|
||||
n_epochs : int
|
||||
number of epochs to train/search
|
||||
load_ckpt : bool
|
||||
whether load checkpoint
|
||||
arch_path : str
|
||||
the path to store chosen architecture
|
||||
logger : logger
|
||||
the logger
|
||||
"""
|
||||
self.model = model
|
||||
self.model_optim = model_optim
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
self.device = device
|
||||
self.dev_num = len(device_ids)
|
||||
self.n_epochs = n_epochs
|
||||
self.lookup_table = lookup_table
|
||||
self.config = lookup_table.config
|
||||
self.start_epoch = self.config.start_epoch
|
||||
self.temp = self.config.init_temperature
|
||||
self.exp_anneal_rate = self.config.exp_anneal_rate
|
||||
self.mode = self.config.mode
|
||||
|
||||
self.load_ckpt = load_ckpt
|
||||
self.arch_path = arch_path
|
||||
self.logger = logger
|
||||
|
||||
# scheduler of learning rate
|
||||
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
model_optim, T_max=n_epochs, last_epoch=-1
|
||||
)
|
||||
|
||||
# init mutator
|
||||
self.mutator = FBNetMutator(model, lookup_table)
|
||||
self.mutator.set_temperature(self.temp)
|
||||
|
||||
# DataParallel should be put behind the init of mutator
|
||||
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
|
||||
self.model.to(device)
|
||||
|
||||
# build architecture optimizer
|
||||
self.arch_optimizer = torch.optim.AdamW(
|
||||
self.mutator.get_architecture_parameters(),
|
||||
self.config.nas_lr,
|
||||
weight_decay=self.config.nas_weight_decay,
|
||||
)
|
||||
self.reg_loss = RegularizerLoss(config=self.config)
|
||||
|
||||
self.criterion = criterion
|
||||
self.epoch = 0
|
||||
|
||||
def _layer_choice_sample(self):
|
||||
"""
|
||||
Sample the index of network within layer choice
|
||||
"""
|
||||
stages = [stage_name for stage_name in self.lookup_table.layer_num]
|
||||
stage_lnum = [self.lookup_table.layer_num[stage] for stage in stages]
|
||||
|
||||
# get the choice idx in each layer
|
||||
choice_ids = list()
|
||||
layer_id = 0
|
||||
for param in self.mutator.get_architecture_parameters():
|
||||
param_np = param.cpu().detach().numpy()
|
||||
op_idx = np.argmax(param_np)
|
||||
choice_ids.append(op_idx)
|
||||
self.logger.info(
|
||||
"layer {}: {}, index: {}".format(layer_id, param_np, op_idx)
|
||||
)
|
||||
layer_id += 1
|
||||
|
||||
# get the arch_sample
|
||||
choice_names = list()
|
||||
layer_id = 0
|
||||
for i, stage_name in enumerate(stages):
|
||||
ops_names = [op for op in self.lookup_table.lut_ops[stage_name]]
|
||||
for _ in range(stage_lnum[i]):
|
||||
searched_op = ops_names[choice_ids[layer_id]]
|
||||
choice_names.append(searched_op)
|
||||
layer_id += 1
|
||||
|
||||
self.logger.info(choice_names)
|
||||
return choice_names
|
||||
|
||||
def _get_perf_cost(self, requires_grad=True):
|
||||
"""
|
||||
Get the accumulated performance cost.
|
||||
"""
|
||||
perf_cost = Variable(
|
||||
torch.zeros(1), requires_grad=requires_grad
|
||||
).to(self.device, non_blocking=True)
|
||||
|
||||
for latency in self.mutator.get_weighted_latency():
|
||||
perf_cost = perf_cost + latency
|
||||
|
||||
return perf_cost
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Do validation. During validation, LayerChoices use the mixed-op.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float, float, float
|
||||
average loss, average top1 accuracy, average top5 accuracy
|
||||
"""
|
||||
self.valid_loader.batch_sampler.drop_last = False
|
||||
batch_time = AverageMeter("batch_time")
|
||||
losses = AverageMeter("losses")
|
||||
top1 = AverageMeter("top1")
|
||||
top5 = AverageMeter("top5")
|
||||
|
||||
# test on validation set under eval mode
|
||||
self.model.eval()
|
||||
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for i, (images, labels) in enumerate(self.valid_loader):
|
||||
images = images.to(self.device, non_blocking=True)
|
||||
labels = labels.to(self.device, non_blocking=True)
|
||||
|
||||
output = self.model(images)
|
||||
|
||||
loss = self.criterion(output, labels)
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss, images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0 or i + 1 == len(self.valid_loader):
|
||||
test_log = (
|
||||
"Valid" + ": [{0}/{1}]\t"
|
||||
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
|
||||
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
|
||||
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})".format(
|
||||
i,
|
||||
len(self.valid_loader) - 1,
|
||||
batch_time=batch_time,
|
||||
loss=losses,
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
)
|
||||
)
|
||||
self.logger.info(test_log)
|
||||
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
||||
def _train_epoch(self, epoch, optimizer, arch_train=False):
|
||||
"""
|
||||
Train one epoch.
|
||||
"""
|
||||
batch_time = AverageMeter("batch_time")
|
||||
data_time = AverageMeter("data_time")
|
||||
losses = AverageMeter("losses")
|
||||
top1 = AverageMeter("top1")
|
||||
top5 = AverageMeter("top5")
|
||||
|
||||
# switch to train mode
|
||||
self.model.train()
|
||||
|
||||
data_loader = self.valid_loader if arch_train else self.train_loader
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
data_time.update(time.time() - end)
|
||||
images = images.to(self.device, non_blocking=True)
|
||||
labels = labels.to(self.device, non_blocking=True)
|
||||
|
||||
output = self.model(images)
|
||||
loss = self.criterion(output, labels)
|
||||
|
||||
# hardware-aware loss
|
||||
perf_cost = self._get_perf_cost(requires_grad=True)
|
||||
regu_loss = self.reg_loss(perf_cost)
|
||||
if self.mode.startswith("mul"):
|
||||
loss = loss * regu_loss
|
||||
elif self.mode.startswith("add"):
|
||||
loss = loss + regu_loss
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0].item(), images.size(0))
|
||||
top5.update(acc5[0].item(), images.size(0))
|
||||
# compute gradient and do SGD step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0:
|
||||
batch_log = (
|
||||
"Warmup Train [{0}][{1}]\t"
|
||||
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
||||
"Loss {losses.val:.4f} ({losses.avg:.4f})\t"
|
||||
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
|
||||
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\t".format(
|
||||
epoch + 1,
|
||||
i,
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
losses=losses,
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
)
|
||||
)
|
||||
self.logger.info(batch_log)
|
||||
|
||||
def _warm_up(self):
|
||||
"""
|
||||
Warm up the model, while the architecture weights are not trained.
|
||||
"""
|
||||
for epoch in range(self.epoch, self.start_epoch):
|
||||
self.logger.info("\n--------Warmup epoch: %d--------\n", epoch + 1)
|
||||
self._train_epoch(epoch, self.model_optim)
|
||||
# adjust learning rate
|
||||
self.scheduler.step()
|
||||
|
||||
# validation
|
||||
val_loss, val_top1, val_top5 = self._validate()
|
||||
val_log = (
|
||||
"Warmup Valid [{0}/{1}]\t"
|
||||
"loss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}".format(
|
||||
epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5
|
||||
)
|
||||
)
|
||||
self.logger.info(val_log)
|
||||
|
||||
if epoch % 10 == 0:
|
||||
filename = os.path.join(
|
||||
self.config.model_dir, "checkpoint_%s.pth" % epoch
|
||||
)
|
||||
self.save_checkpoint(epoch, filename)
|
||||
|
||||
def _train(self):
|
||||
"""
|
||||
Train the model, it trains model weights and architecute weights.
|
||||
Architecture weights are trained according to the schedule.
|
||||
Before updating architecture weights, ```requires_grad``` is enabled.
|
||||
Then, it is disabled after the updating, in order not to update
|
||||
architecture weights when training model weights.
|
||||
"""
|
||||
arch_param_num = self.mutator.num_arch_params()
|
||||
self.logger.info("#arch_params: {}".format(arch_param_num))
|
||||
self.epoch = max(self.start_epoch, self.epoch)
|
||||
|
||||
ckpt_path = self.config.model_dir
|
||||
choice_names = None
|
||||
top1_best = 0.0
|
||||
|
||||
for epoch in range(self.epoch, self.n_epochs):
|
||||
self.logger.info("\n--------Train epoch: %d--------\n", epoch + 1)
|
||||
# update the weight parameters
|
||||
self._train_epoch(epoch, self.model_optim)
|
||||
# adjust learning rate
|
||||
self.scheduler.step()
|
||||
|
||||
self.logger.info("Update architecture parameters")
|
||||
# update the architecture parameters
|
||||
self.mutator.arch_requires_grad()
|
||||
self._train_epoch(epoch, self.arch_optimizer, True)
|
||||
self.mutator.arch_disable_grad()
|
||||
# temperature annealing
|
||||
self.temp = self.temp * self.exp_anneal_rate
|
||||
self.mutator.set_temperature(self.temp)
|
||||
# sample the architecture of sub-network
|
||||
choice_names = self._layer_choice_sample()
|
||||
|
||||
# validate
|
||||
val_loss, val_top1, val_top5 = self._validate()
|
||||
val_log = (
|
||||
"Valid [{0}]\t"
|
||||
"loss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}".format(
|
||||
epoch + 1, val_loss, val_top1, val_top5
|
||||
)
|
||||
)
|
||||
self.logger.info(val_log)
|
||||
|
||||
if epoch % 10 == 0:
|
||||
filename = os.path.join(ckpt_path, "checkpoint_%s.pth" % epoch)
|
||||
self.save_checkpoint(epoch, filename, choice_names)
|
||||
|
||||
val_top1 = val_top1.cpu().as_numpy()
|
||||
if val_top1 > top1_best:
|
||||
filename = os.path.join(ckpt_path, "checkpoint_best.pth")
|
||||
self.save_checkpoint(epoch, filename, choice_names)
|
||||
top1_best = val_top1
|
||||
|
||||
def save_checkpoint(self, epoch, filename, choice_names=None):
|
||||
"""
|
||||
Save checkpoint of the whole model.
|
||||
Saving model weights and architecture weights as ```filename```,
|
||||
and saving currently chosen architecture in ```arch_path```.
|
||||
"""
|
||||
state = {
|
||||
"model": self.model.state_dict(),
|
||||
"optim": self.model_optim.state_dict(),
|
||||
"epoch": epoch,
|
||||
"arch_sample": choice_names,
|
||||
}
|
||||
torch.save(state, filename)
|
||||
self.logger.info("Save checkpoint to {0:}".format(filename))
|
||||
|
||||
if self.arch_path:
|
||||
self.export(self.arch_path)
|
||||
|
||||
def load_checkpoint(self, filename):
|
||||
"""
|
||||
Load the checkpoint from ```ckpt_path```.
|
||||
"""
|
||||
ckpt = torch.load(filename)
|
||||
self.epoch = ckpt["epoch"]
|
||||
self.model.load_state_dict(ckpt["model"])
|
||||
self.model_optim.load_state_dict(ckpt["optim"])
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Train the whole model.
|
||||
"""
|
||||
if self.load_ckpt:
|
||||
ckpt_path = self.config.model_dir
|
||||
filename = os.path.join(ckpt_path, "checkpoint_best.pth")
|
||||
if os.path.exists(filename):
|
||||
self.load_checkpoint(filename)
|
||||
|
||||
if self.epoch < self.start_epoch:
|
||||
self._warm_up()
|
||||
self._train()
|
||||
|
||||
def export(self, file_name):
|
||||
"""
|
||||
Export the chosen architecture into a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_name : str
|
||||
the file that stores exported chosen architecture
|
||||
"""
|
||||
exported_arch = self.mutator.sample_final()
|
||||
with open(file_name, "w") as f:
|
||||
json.dump(
|
||||
exported_arch,
|
||||
f,
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
cls=TorchTensorEncoder,
|
||||
)
|
||||
|
||||
def validate(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def checkpoint(self):
|
||||
raise NotImplementedError
|
|
@ -1,433 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import ast
|
||||
import os
|
||||
import timeit
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.compression.pytorch.utils import count_flops_params
|
||||
|
||||
LUT_FILE = "lut.npy"
|
||||
LUT_JSON_FILE = "lut.txt"
|
||||
LUT_PATH = "lut"
|
||||
|
||||
DATA_TYPE = "float"
|
||||
|
||||
class NASConfig:
|
||||
def __init__(
|
||||
self,
|
||||
perf_metric="flops",
|
||||
lut_load=False,
|
||||
lut_load_format="json",
|
||||
model_dir=None,
|
||||
nas_lr=0.01,
|
||||
nas_weight_decay=5e-4,
|
||||
mode="mul",
|
||||
alpha=0.25,
|
||||
beta=0.6,
|
||||
start_epoch=50,
|
||||
init_temperature=5.0,
|
||||
exp_anneal_rate=np.exp(-0.045),
|
||||
search_space=None,
|
||||
):
|
||||
# LUT of performance metric
|
||||
# flops means the multiplies, latency means the time cost on platform
|
||||
self.perf_metric = perf_metric
|
||||
assert perf_metric in [
|
||||
"flops",
|
||||
"latency",
|
||||
], "perf_metric should be ['flops', 'latency']"
|
||||
# wether load or create lut file
|
||||
self.lut_load = lut_load
|
||||
|
||||
assert lut_load_format in [
|
||||
"json",
|
||||
"numpy",
|
||||
], "lut_load_format should be ['json', 'numpy']"
|
||||
self.lut_load_format = lut_load_format
|
||||
|
||||
# necessary dirs
|
||||
self.lut_en = model_dir is not None
|
||||
if self.lut_en:
|
||||
self.model_dir = model_dir
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
self.lut_path = os.path.join(model_dir, LUT_PATH)
|
||||
os.makedirs(self.lut_path, exist_ok=True)
|
||||
# NAS learning setting
|
||||
self.nas_lr = nas_lr
|
||||
self.nas_weight_decay = nas_weight_decay
|
||||
# hardware-aware loss setting
|
||||
self.mode = mode
|
||||
assert mode in ["mul", "add"], "mode should be ['mul', 'add']"
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
# NAS training setting
|
||||
self.start_epoch = start_epoch
|
||||
self.init_temperature = init_temperature
|
||||
self.exp_anneal_rate = exp_anneal_rate
|
||||
# definition of search blocks and space
|
||||
self.search_space = search_space
|
||||
|
||||
|
||||
class RegularizerLoss(nn.Module):
|
||||
"""Auxilliary loss for hardware-aware NAS."""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : class
|
||||
to manage the configuration for NAS training, and search space etc.
|
||||
"""
|
||||
super(RegularizerLoss, self).__init__()
|
||||
self.mode = config.mode
|
||||
self.alpha = config.alpha
|
||||
self.beta = config.beta
|
||||
|
||||
def forward(self, perf_cost, batch_size=1):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
perf_cost : tensor
|
||||
the accumulated performance cost
|
||||
batch_size : int
|
||||
batch size for normalization
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: tensor
|
||||
the hardware-aware constraint loss
|
||||
"""
|
||||
if self.mode == "mul":
|
||||
log_loss = torch.log(perf_cost / batch_size) ** self.beta
|
||||
return self.alpha * log_loss
|
||||
elif self.mode == "add":
|
||||
linear_loss = (perf_cost / batch_size) ** self.beta
|
||||
return self.alpha * linear_loss
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""
|
||||
Computes the precision@k for the specified values of k
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : pytorch tensor
|
||||
output, e.g., predicted value
|
||||
target : pytorch tensor
|
||||
label
|
||||
topk : tuple
|
||||
specify top1 and top5
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
accuracy of top1 and top5
|
||||
"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def supernet_sample(model, state_dict, sampled_arch=[], lookup_table=None):
|
||||
"""
|
||||
Initialize the searched sub-model from supernet.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
the created subnet
|
||||
state_dict : checkpoint
|
||||
the checkpoint of supernet, including the pre-trained params
|
||||
sampled_arch : list of str
|
||||
the searched layer names of the subnet
|
||||
lookup_table : class
|
||||
to manage the candidate ops, layer information and layer performance
|
||||
"""
|
||||
replace = list()
|
||||
stages = [stage for stage in lookup_table.layer_num]
|
||||
stage_lnum = [lookup_table.layer_num[stage] for stage in stages]
|
||||
|
||||
if sampled_arch:
|
||||
layer_id = 0
|
||||
for i, stage in enumerate(stages):
|
||||
ops_names = [op_name for op_name in lookup_table.lut_ops[stage]]
|
||||
for _ in range(stage_lnum[i]):
|
||||
searched_op = sampled_arch[layer_id]
|
||||
op_i = ops_names.index(searched_op)
|
||||
replace.append(
|
||||
[
|
||||
"blocks.{}.".format(layer_id),
|
||||
"blocks.{}.op.".format(layer_id),
|
||||
"blocks.{}.{}.".format(layer_id, op_i),
|
||||
]
|
||||
)
|
||||
layer_id += 1
|
||||
model_init(model, state_dict, replace=replace)
|
||||
|
||||
|
||||
def model_init(model, state_dict, replace=[]):
|
||||
"""Initialize the model from state_dict."""
|
||||
prefix = "module."
|
||||
param_dict = dict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith(prefix):
|
||||
k = k[7:]
|
||||
param_dict[k] = v
|
||||
|
||||
for k, (name, m) in enumerate(model.named_modules()):
|
||||
if replace:
|
||||
for layer_replace in replace:
|
||||
assert len(layer_replace) == 3, "The elements should be three."
|
||||
pre_scope, key, replace_key = layer_replace
|
||||
if pre_scope in name:
|
||||
name = name.replace(key, replace_key)
|
||||
|
||||
# Copy the state_dict to current model
|
||||
if (name + ".weight" in param_dict) or (
|
||||
name + ".running_mean" in param_dict
|
||||
):
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
shape = m.running_mean.shape
|
||||
if shape == param_dict[name + ".running_mean"].shape:
|
||||
if m.weight is not None:
|
||||
m.weight.data = param_dict[name + ".weight"]
|
||||
m.bias.data = param_dict[name + ".bias"]
|
||||
m.running_mean = param_dict[name + ".running_mean"]
|
||||
m.running_var = param_dict[name + ".running_var"]
|
||||
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
shape = m.weight.data.shape
|
||||
if shape == param_dict[name + ".weight"].shape:
|
||||
m.weight.data = param_dict[name + ".weight"]
|
||||
if m.bias is not None:
|
||||
m.bias.data = param_dict[name + ".bias"]
|
||||
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
m.weight.data = param_dict[name + ".weight"]
|
||||
if m.bias is not None:
|
||||
m.bias.data = param_dict[name + ".bias"]
|
||||
|
||||
|
||||
class LookUpTable:
|
||||
"""Build look-up table for NAS."""
|
||||
|
||||
def __init__(self, config, primitives):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : class
|
||||
to manage the configuration for NAS training, and search space etc.
|
||||
"""
|
||||
self.config = config
|
||||
# definition of search blocks and space
|
||||
self.search_space = config.search_space
|
||||
# layers for NAS
|
||||
self.cnt_layers = len(self.search_space["input_shape"])
|
||||
# constructors for each operation
|
||||
self.lut_ops = {
|
||||
stage_name: {
|
||||
op_name: primitives[op_name]
|
||||
for op_name in self.search_space["stages"][stage_name]["ops"]
|
||||
}
|
||||
for stage_name in self.search_space["stages"]
|
||||
}
|
||||
self.layer_num = {
|
||||
stage_name: self.search_space["stages"][stage_name]["layer_num"]
|
||||
for stage_name in self.search_space["stages"]
|
||||
}
|
||||
|
||||
# arguments for the ops constructors, input_shapes just for convinience
|
||||
self.layer_configs, self.layer_in_shapes = self._layer_configs()
|
||||
|
||||
# lookup_table
|
||||
self.perf_metric = config.perf_metric
|
||||
|
||||
if config.lut_en:
|
||||
self.lut_perf = None
|
||||
self.lut_file = os.path.join(config.lut_path, LUT_FILE)
|
||||
self.lut_json_file = LUT_JSON_FILE
|
||||
if config.lut_load:
|
||||
if config.lut_load_format == "numpy":
|
||||
# Load data from numpy file
|
||||
self._load_from_file()
|
||||
else:
|
||||
# Load data from json file
|
||||
self._load_from_json_file()
|
||||
else:
|
||||
self._create_perfs()
|
||||
|
||||
def _layer_configs(self):
|
||||
"""Generate basic params for different layers."""
|
||||
# layer_configs are : c_in, c_out, stride, fm_size
|
||||
layer_configs = [
|
||||
[
|
||||
self.search_space["input_shape"][layer_id][0],
|
||||
self.search_space["channel_size"][layer_id],
|
||||
self.search_space["strides"][layer_id],
|
||||
self.search_space["fm_size"][layer_id],
|
||||
]
|
||||
for layer_id in range(self.cnt_layers)
|
||||
]
|
||||
|
||||
# layer_in_shapes are (C_in, input_w, input_h)
|
||||
layer_in_shapes = self.search_space["input_shape"]
|
||||
|
||||
return layer_configs, layer_in_shapes
|
||||
|
||||
def _create_perfs(self, cnt_of_runs=200):
|
||||
"""Create performance cost for each op."""
|
||||
if self.perf_metric == "latency":
|
||||
self.lut_perf = self._calculate_latency(cnt_of_runs)
|
||||
elif self.perf_metric == "flops":
|
||||
self.lut_perf = self._calculate_flops()
|
||||
|
||||
self._write_lut_to_file()
|
||||
|
||||
def _calculate_flops(self, eps=0.001):
|
||||
"""FLOPs cost."""
|
||||
flops_lut = [{} for i in range(self.cnt_layers)]
|
||||
layer_id = 0
|
||||
|
||||
for stage_name in self.lut_ops:
|
||||
stage_ops = self.lut_ops[stage_name]
|
||||
ops_num = self.layer_num[stage_name]
|
||||
|
||||
for _ in range(ops_num):
|
||||
for op_name in stage_ops:
|
||||
layer_config = self.layer_configs[layer_id]
|
||||
key_params = {"fm_size": layer_config[3]}
|
||||
op = stage_ops[op_name](*layer_config[0:3], **key_params)
|
||||
|
||||
# measured in Flops
|
||||
in_shape = self.layer_in_shapes[layer_id]
|
||||
x = (1, in_shape[0], in_shape[1], in_shape[2])
|
||||
flops, _, _ = count_flops_params(op, x, verbose=False)
|
||||
flops = eps if flops == 0.0 else flops
|
||||
flops_lut[layer_id][op_name] = float(flops)
|
||||
layer_id += 1
|
||||
|
||||
return flops_lut
|
||||
|
||||
def _calculate_latency(self, cnt_of_runs):
|
||||
"""Latency cost."""
|
||||
LATENCY_BATCH_SIZE = 1
|
||||
latency_lut = [{} for i in range(self.cnt_layers)]
|
||||
layer_id = 0
|
||||
|
||||
for stage_name in self.lut_ops:
|
||||
stage_ops = self.lut_ops[stage_name]
|
||||
ops_num = self.layer_num[stage_name]
|
||||
|
||||
for _ in range(ops_num):
|
||||
for op_name in stage_ops:
|
||||
layer_config = self.layer_configs[layer_id]
|
||||
key_params = {"fm_size": layer_config[3]}
|
||||
op = stage_ops[op_name](*layer_config[0:3], **key_params)
|
||||
input_data = torch.randn(
|
||||
(LATENCY_BATCH_SIZE, *self.layer_in_shapes[layer_id])
|
||||
)
|
||||
globals()["op"], globals()["input_data"] = op, input_data
|
||||
total_time = timeit.timeit(
|
||||
"output = op(input_data)",
|
||||
setup="gc.enable()",
|
||||
globals=globals(),
|
||||
number=cnt_of_runs,
|
||||
)
|
||||
# measured in micro-second
|
||||
latency_lut[layer_id][op_name] = (
|
||||
total_time / cnt_of_runs / LATENCY_BATCH_SIZE * 1e6
|
||||
)
|
||||
layer_id += 1
|
||||
|
||||
return latency_lut
|
||||
|
||||
def _write_lut_to_file(self):
|
||||
"""Save lut as numpy file."""
|
||||
np.save(self.lut_file, self.lut_perf)
|
||||
|
||||
def _load_from_file(self):
|
||||
"""Load numpy file."""
|
||||
self.lut_perf = np.load(self.lut_file, allow_pickle=True)
|
||||
|
||||
def _load_from_json_file(self):
|
||||
"""Load json file."""
|
||||
|
||||
"""
|
||||
lut_json_file ('lut.txt') format:
|
||||
{'op_name': operator_name,
|
||||
'op_data_shape': (input_w, input_h, C_in, C_out, stride),
|
||||
'op_dtype': data_type,
|
||||
'op_latency': latency}
|
||||
{...}
|
||||
{...}
|
||||
"""
|
||||
latency_file = open(self.lut_json_file, "r")
|
||||
ops_latency = latency_file.readlines()
|
||||
|
||||
"""ops_lut: {'op_name': {'op_data_shape': {'op_dtype': latency}}}"""
|
||||
ops_lut = {}
|
||||
|
||||
for op_latency in ops_latency:
|
||||
assert isinstance(op_latency, str) or isinstance(op_latency, dict)
|
||||
|
||||
if isinstance(op_latency, str):
|
||||
record = ast.literal_eval(op_latency)
|
||||
elif isinstance(op_latency, dict):
|
||||
record = op_latency
|
||||
|
||||
op_name = record["op_name"]
|
||||
"""op_data_shape: (input_w, input_h, C_in, C_out, stride)"""
|
||||
op_data_shape = record["op_data_shape"]
|
||||
op_dtype = record["op_dtype"]
|
||||
op_latency = record["op_latency"]
|
||||
|
||||
if op_name not in ops_lut:
|
||||
ops_lut[op_name] = {}
|
||||
|
||||
if op_data_shape not in ops_lut[op_name]:
|
||||
ops_lut[op_name][op_data_shape] = {}
|
||||
|
||||
ops_lut[op_name][op_data_shape][op_dtype] = op_latency
|
||||
|
||||
self.lut_perf = [{} for i in range(self.cnt_layers)]
|
||||
layer_id = 0
|
||||
|
||||
for stage_name in self.lut_ops:
|
||||
stage_ops = self.lut_ops[stage_name]
|
||||
ops_num = self.layer_num[stage_name]
|
||||
|
||||
for _ in range(ops_num):
|
||||
for op_name in stage_ops:
|
||||
layer_config = self.layer_configs[layer_id]
|
||||
layer_in_shape = self.layer_in_shapes[layer_id]
|
||||
|
||||
input_w = layer_in_shape[1]
|
||||
input_h = layer_in_shape[2]
|
||||
c_in = layer_config[0]
|
||||
c_out = layer_config[1]
|
||||
stride = layer_config[2]
|
||||
op_data_shape = (input_w, input_h, c_in, c_out, stride)
|
||||
|
||||
if op_name in ops_lut and op_data_shape in ops_lut[op_name]:
|
||||
self.lut_perf[layer_id][op_name] = \
|
||||
ops_lut[op_name][op_data_shape][DATA_TYPE]
|
||||
|
||||
layer_id += 1
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .trainer import PdartsTrainer
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from nni.algorithms.nas.pytorch.darts import DartsMutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice
|
||||
|
||||
|
||||
class PdartsMutator(DartsMutator):
|
||||
"""
|
||||
It works with PdartsTrainer to calculate ops weights,
|
||||
and drop weights in different PDARTS epochs.
|
||||
"""
|
||||
|
||||
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
|
||||
self.pdarts_epoch_index = pdarts_epoch_index
|
||||
self.pdarts_num_to_drop = pdarts_num_to_drop
|
||||
if switches is None:
|
||||
self.switches = {}
|
||||
else:
|
||||
self.switches = switches
|
||||
|
||||
super(PdartsMutator, self).__init__(model)
|
||||
|
||||
# this loop go through mutables with different keys,
|
||||
# it's mainly to update length of choices.
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
|
||||
switches = self.switches.get(mutable.key, [True for j in range(len(mutable))])
|
||||
choices = self.choices[mutable.key]
|
||||
|
||||
operations_count = np.sum(switches)
|
||||
# +1 and -1 are caused by zero operation in darts network
|
||||
# the zero operation is not in choices list in network, but its weight are in,
|
||||
# so it needs one more weights and switch for zero.
|
||||
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
|
||||
self.switches[mutable.key] = switches
|
||||
|
||||
# update LayerChoice instances in model,
|
||||
# it's physically remove dropped choices operations.
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, LayerChoice):
|
||||
switches = self.switches.get(module.key)
|
||||
choices = self.choices[module.key]
|
||||
if len(module) > len(choices):
|
||||
# from last to first, so that it won't effect previous indexes after removed one.
|
||||
for index in range(len(switches)-1, -1, -1):
|
||||
if switches[index] == False:
|
||||
del module[index]
|
||||
assert len(module) <= len(choices), "Failed to remove dropped choices."
|
||||
|
||||
def export(self):
|
||||
# Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
|
||||
results = super().sample_final()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
# As some operations are dropped physically,
|
||||
# so it needs to fill back false to track dropped operations.
|
||||
trained_result = results[mutable.key]
|
||||
trained_index = 0
|
||||
switches = self.switches[mutable.key]
|
||||
result = torch.Tensor(switches).bool()
|
||||
for index in range(len(result)):
|
||||
if result[index]:
|
||||
result[index] = trained_result[trained_index]
|
||||
trained_index += 1
|
||||
results[mutable.key] = result
|
||||
return results
|
||||
|
||||
def drop_paths(self):
|
||||
"""
|
||||
This method is called when a PDARTS epoch is finished.
|
||||
It prepares switches for next epoch.
|
||||
candidate operations with False switch will be doppped in next epoch.
|
||||
"""
|
||||
all_switches = copy.deepcopy(self.switches)
|
||||
for key in all_switches:
|
||||
switches = all_switches[key]
|
||||
idxs = []
|
||||
for j in range(len(switches)):
|
||||
if switches[j]:
|
||||
idxs.append(j)
|
||||
sorted_weights = self.choices[key].data.cpu().numpy()[:-1]
|
||||
drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]]
|
||||
for idx in drop:
|
||||
switches[idxs[idx]] = False
|
||||
return all_switches
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from nni.nas.pytorch.callbacks import LRSchedulerCallback
|
||||
from nni.algorithms.nas.pytorch.darts import DartsTrainer
|
||||
from nni.nas.pytorch.trainer import BaseTrainer, TorchTensorEncoder
|
||||
|
||||
from .mutator import PdartsMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdartsTrainer(BaseTrainer):
|
||||
"""
|
||||
This trainer implements the PDARTS algorithm.
|
||||
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
|
||||
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
|
||||
pdarts_num_layers means how many layers more than first epoch.
|
||||
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
|
||||
So that the grew network can in similar size.
|
||||
"""
|
||||
|
||||
def __init__(self, model_creator, init_layers, metrics,
|
||||
num_epochs, dataset_train, dataset_valid,
|
||||
pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 1],
|
||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, unrolled=False):
|
||||
super(PdartsTrainer, self).__init__()
|
||||
self.model_creator = model_creator
|
||||
self.init_layers = init_layers
|
||||
self.pdarts_num_layers = pdarts_num_layers
|
||||
self.pdarts_num_to_drop = pdarts_num_to_drop
|
||||
self.pdarts_epoch = len(pdarts_num_to_drop)
|
||||
self.darts_parameters = {
|
||||
"metrics": metrics,
|
||||
"num_epochs": num_epochs,
|
||||
"dataset_train": dataset_train,
|
||||
"dataset_valid": dataset_valid,
|
||||
"batch_size": batch_size,
|
||||
"workers": workers,
|
||||
"device": device,
|
||||
"log_frequency": log_frequency,
|
||||
"unrolled": unrolled
|
||||
}
|
||||
self.callbacks = callbacks if callbacks is not None else []
|
||||
|
||||
def train(self):
|
||||
|
||||
switches = None
|
||||
for epoch in range(self.pdarts_epoch):
|
||||
|
||||
layers = self.init_layers+self.pdarts_num_layers[epoch]
|
||||
model, criterion, optim, lr_scheduler = self.model_creator(layers)
|
||||
self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.build(model, self.mutator, self)
|
||||
callback.on_epoch_begin(epoch)
|
||||
|
||||
darts_callbacks = []
|
||||
if lr_scheduler is not None:
|
||||
darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
|
||||
|
||||
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
|
||||
callbacks=darts_callbacks, **self.darts_parameters)
|
||||
logger.info("start pdarts training epoch %s...", epoch)
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
switches = self.mutator.drop_paths()
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_end(epoch)
|
||||
|
||||
def validate(self):
|
||||
self.trainer.validate()
|
||||
|
||||
def export(self, file):
|
||||
mutator_export = self.mutator.export()
|
||||
with open(file, "w") as f:
|
||||
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
|
||||
|
||||
def checkpoint(self):
|
||||
raise NotImplementedError("Not implemented yet")
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import ProxylessNasMutator
|
||||
from .trainer import ProxylessNasTrainer
|
|
@ -1,478 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from nni.nas.pytorch.base_mutator import BaseMutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice
|
||||
from .utils import detach_variable
|
||||
|
||||
class ArchGradientFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, binary_gates, run_func, backward_func):
|
||||
ctx.run_func = run_func
|
||||
ctx.backward_func = backward_func
|
||||
|
||||
detached_x = detach_variable(x)
|
||||
with torch.enable_grad():
|
||||
output = run_func(detached_x)
|
||||
ctx.save_for_backward(detached_x, output)
|
||||
return output.data
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
detached_x, output = ctx.saved_tensors
|
||||
|
||||
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
|
||||
# compute gradients w.r.t. binary_gates
|
||||
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
|
||||
|
||||
return grad_x[0], binary_grads, None, None
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
"""
|
||||
This class is to instantiate and manage info of one LayerChoice.
|
||||
It includes architecture weights, binary weights, and member functions
|
||||
operating the weights.
|
||||
|
||||
forward_mode:
|
||||
forward/backward mode for LayerChoice: None, two, full, and full_v2.
|
||||
For training architecture weights, we use full_v2 by default, and for training
|
||||
model weights, we use None.
|
||||
"""
|
||||
forward_mode = None
|
||||
def __init__(self, mutable):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
A LayerChoice in user model
|
||||
"""
|
||||
super(MixedOp, self).__init__()
|
||||
self.ap_path_alpha = nn.Parameter(torch.Tensor(len(mutable)))
|
||||
self.ap_path_wb = nn.Parameter(torch.Tensor(len(mutable)))
|
||||
self.ap_path_alpha.requires_grad = False
|
||||
self.ap_path_wb.requires_grad = False
|
||||
self.active_index = [0]
|
||||
self.inactive_index = None
|
||||
self.log_prob = None
|
||||
self.current_prob_over_ops = None
|
||||
self.n_choices = len(mutable)
|
||||
|
||||
def get_ap_path_alpha(self):
|
||||
return self.ap_path_alpha
|
||||
|
||||
def to_requires_grad(self):
|
||||
self.ap_path_alpha.requires_grad = True
|
||||
self.ap_path_wb.requires_grad = True
|
||||
|
||||
def to_disable_grad(self):
|
||||
self.ap_path_alpha.requires_grad = False
|
||||
self.ap_path_wb.requires_grad = False
|
||||
|
||||
def forward(self, mutable, x):
|
||||
"""
|
||||
Define forward of LayerChoice. For 'full_v2', backward is also defined.
|
||||
The 'two' mode is explained in section 3.2.1 in the paper.
|
||||
The 'full_v2' mode is explained in Appendix D in the paper.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
this layer's mutable
|
||||
x : tensor
|
||||
inputs of this layer, only support one input
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: tensor
|
||||
output of this layer
|
||||
"""
|
||||
if MixedOp.forward_mode == 'full' or MixedOp.forward_mode == 'two':
|
||||
output = 0
|
||||
for _i in self.active_index:
|
||||
oi = self.candidate_ops[_i](x)
|
||||
output = output + self.ap_path_wb[_i] * oi
|
||||
for _i in self.inactive_index:
|
||||
oi = self.candidate_ops[_i](x)
|
||||
output = output + self.ap_path_wb[_i] * oi.detach()
|
||||
elif MixedOp.forward_mode == 'full_v2':
|
||||
def run_function(key, candidate_ops, active_id):
|
||||
def forward(_x):
|
||||
return candidate_ops[active_id](_x)
|
||||
return forward
|
||||
|
||||
def backward_function(key, candidate_ops, active_id, binary_gates):
|
||||
def backward(_x, _output, grad_output):
|
||||
binary_grads = torch.zeros_like(binary_gates.data)
|
||||
with torch.no_grad():
|
||||
for k in range(len(candidate_ops)):
|
||||
if k != active_id:
|
||||
out_k = candidate_ops[k](_x.data)
|
||||
else:
|
||||
out_k = _output.data
|
||||
grad_k = torch.sum(out_k * grad_output)
|
||||
binary_grads[k] = grad_k
|
||||
return binary_grads
|
||||
return backward
|
||||
output = ArchGradientFunction.apply(
|
||||
x, self.ap_path_wb, run_function(mutable.key, list(mutable), self.active_index[0]),
|
||||
backward_function(mutable.key, list(mutable), self.active_index[0], self.ap_path_wb))
|
||||
else:
|
||||
output = self.active_op(mutable)(x)
|
||||
return output
|
||||
|
||||
@property
|
||||
def probs_over_ops(self):
|
||||
"""
|
||||
Apply softmax on alpha to generate probability distribution
|
||||
|
||||
Returns
|
||||
-------
|
||||
pytorch tensor
|
||||
probability distribution
|
||||
"""
|
||||
probs = F.softmax(self.ap_path_alpha, dim=0) # softmax to probability
|
||||
return probs
|
||||
|
||||
@property
|
||||
def chosen_index(self):
|
||||
"""
|
||||
choose the op with max prob
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index of the chosen one
|
||||
numpy.float32
|
||||
prob of the chosen one
|
||||
"""
|
||||
probs = self.probs_over_ops.data.cpu().numpy()
|
||||
index = int(np.argmax(probs))
|
||||
return index, probs[index]
|
||||
|
||||
def active_op(self, mutable):
|
||||
"""
|
||||
assume only one path is active
|
||||
|
||||
Returns
|
||||
-------
|
||||
PyTorch module
|
||||
the chosen operation
|
||||
"""
|
||||
return mutable[self.active_index[0]]
|
||||
|
||||
@property
|
||||
def active_op_index(self):
|
||||
"""
|
||||
return active op's index, the active op is sampled
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index of the active op
|
||||
"""
|
||||
return self.active_index[0]
|
||||
|
||||
def set_chosen_op_active(self):
|
||||
"""
|
||||
set chosen index, active and inactive indexes
|
||||
"""
|
||||
chosen_idx, _ = self.chosen_index
|
||||
self.active_index = [chosen_idx]
|
||||
self.inactive_index = [_i for _i in range(0, chosen_idx)] + \
|
||||
[_i for _i in range(chosen_idx + 1, self.n_choices)]
|
||||
|
||||
def binarize(self, mutable):
|
||||
"""
|
||||
Sample based on alpha, and set binary weights accordingly.
|
||||
ap_path_wb is set in this function, which is called binarize.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
this layer's mutable
|
||||
"""
|
||||
self.log_prob = None
|
||||
# reset binary gates
|
||||
self.ap_path_wb.data.zero_()
|
||||
probs = self.probs_over_ops
|
||||
if MixedOp.forward_mode == 'two':
|
||||
# sample two ops according to probs
|
||||
sample_op = torch.multinomial(probs.data, 2, replacement=False)
|
||||
probs_slice = F.softmax(torch.stack([
|
||||
self.ap_path_alpha[idx] for idx in sample_op
|
||||
]), dim=0)
|
||||
self.current_prob_over_ops = torch.zeros_like(probs)
|
||||
for i, idx in enumerate(sample_op):
|
||||
self.current_prob_over_ops[idx] = probs_slice[i]
|
||||
# choose one to be active and the other to be inactive according to probs_slice
|
||||
c = torch.multinomial(probs_slice.data, 1)[0] # 0 or 1
|
||||
active_op = sample_op[c].item()
|
||||
inactive_op = sample_op[1-c].item()
|
||||
self.active_index = [active_op]
|
||||
self.inactive_index = [inactive_op]
|
||||
# set binary gate
|
||||
self.ap_path_wb.data[active_op] = 1.0
|
||||
else:
|
||||
sample = torch.multinomial(probs, 1)[0].item()
|
||||
self.active_index = [sample]
|
||||
self.inactive_index = [_i for _i in range(0, sample)] + \
|
||||
[_i for _i in range(sample + 1, len(mutable))]
|
||||
self.log_prob = torch.log(probs[sample])
|
||||
self.current_prob_over_ops = probs
|
||||
self.ap_path_wb.data[sample] = 1.0
|
||||
# avoid over-regularization
|
||||
for choice in mutable:
|
||||
for _, param in choice.named_parameters():
|
||||
param.grad = None
|
||||
|
||||
@staticmethod
|
||||
def delta_ij(i, j):
|
||||
if i == j:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def set_arch_param_grad(self, mutable):
|
||||
"""
|
||||
Calculate alpha gradient for this LayerChoice.
|
||||
It is calculated using gradient of binary gate, probs of ops.
|
||||
"""
|
||||
binary_grads = self.ap_path_wb.grad.data
|
||||
if self.active_op(mutable).is_zero_layer():
|
||||
self.ap_path_alpha.grad = None
|
||||
return
|
||||
if self.ap_path_alpha.grad is None:
|
||||
self.ap_path_alpha.grad = torch.zeros_like(self.ap_path_alpha.data)
|
||||
if MixedOp.forward_mode == 'two':
|
||||
involved_idx = self.active_index + self.inactive_index
|
||||
probs_slice = F.softmax(torch.stack([
|
||||
self.ap_path_alpha[idx] for idx in involved_idx
|
||||
]), dim=0).data
|
||||
for i in range(2):
|
||||
for j in range(2):
|
||||
origin_i = involved_idx[i]
|
||||
origin_j = involved_idx[j]
|
||||
self.ap_path_alpha.grad.data[origin_i] += \
|
||||
binary_grads[origin_j] * probs_slice[j] * (MixedOp.delta_ij(i, j) - probs_slice[i])
|
||||
for _i, idx in enumerate(self.active_index):
|
||||
self.active_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
|
||||
for _i, idx in enumerate(self.inactive_index):
|
||||
self.inactive_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
|
||||
else:
|
||||
probs = self.probs_over_ops.data
|
||||
for i in range(self.n_choices):
|
||||
for j in range(self.n_choices):
|
||||
self.ap_path_alpha.grad.data[i] += binary_grads[j] * probs[j] * (MixedOp.delta_ij(i, j) - probs[i])
|
||||
return
|
||||
|
||||
def rescale_updated_arch_param(self):
|
||||
"""
|
||||
rescale architecture weights for the 'two' mode.
|
||||
"""
|
||||
if not isinstance(self.active_index[0], tuple):
|
||||
assert self.active_op.is_zero_layer()
|
||||
return
|
||||
involved_idx = [idx for idx, _ in (self.active_index + self.inactive_index)]
|
||||
old_alphas = [alpha for _, alpha in (self.active_index + self.inactive_index)]
|
||||
new_alphas = [self.ap_path_alpha.data[idx] for idx in involved_idx]
|
||||
|
||||
offset = math.log(
|
||||
sum([math.exp(alpha) for alpha in new_alphas]) / sum([math.exp(alpha) for alpha in old_alphas])
|
||||
)
|
||||
|
||||
for idx in involved_idx:
|
||||
self.ap_path_alpha.data[idx] -= offset
|
||||
|
||||
|
||||
class ProxylessNasMutator(BaseMutator):
|
||||
"""
|
||||
This mutator initializes and operates all the LayerChoices of the input model.
|
||||
It is for the corresponding trainer to control the training process of LayerChoices,
|
||||
coordinating with whole training process.
|
||||
"""
|
||||
def __init__(self, model):
|
||||
"""
|
||||
Init a MixedOp instance for each mutable i.e., LayerChoice.
|
||||
And register the instantiated MixedOp in corresponding LayerChoice.
|
||||
If does not register it in LayerChoice, DataParallel does not work then,
|
||||
because architecture weights are not included in the DataParallel model.
|
||||
When MixedOPs are registered, we use ```requires_grad``` to control
|
||||
whether calculate gradients of architecture weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
The model that users want to tune, it includes search space defined with nni nas apis
|
||||
"""
|
||||
super(ProxylessNasMutator, self).__init__(model)
|
||||
self._unused_modules = None
|
||||
self.mutable_list = []
|
||||
for mutable in self.undedup_mutables:
|
||||
self.mutable_list.append(mutable)
|
||||
mutable.registered_module = MixedOp(mutable)
|
||||
|
||||
def on_forward_layer_choice(self, mutable, *args, **kwargs):
|
||||
"""
|
||||
Callback of layer choice forward. This function defines the forward
|
||||
logic of the input mutable. So mutable is only interface, its real
|
||||
implementation is defined in mutator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable: LayerChoice
|
||||
forward logic of this input mutable
|
||||
args: list of torch.Tensor
|
||||
inputs of this mutable
|
||||
kwargs: dict
|
||||
inputs of this mutable
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
output of this mutable, i.e., LayerChoice
|
||||
int
|
||||
index of the chosen op
|
||||
"""
|
||||
# FIXME: return mask, to be consistent with other algorithms
|
||||
idx = mutable.registered_module.active_op_index
|
||||
return mutable.registered_module(mutable, *args, **kwargs), idx
|
||||
|
||||
def reset_binary_gates(self):
|
||||
"""
|
||||
For each LayerChoice, binarize binary weights
|
||||
based on alpha to only activate one op.
|
||||
It traverses all the mutables in the model to do this.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.binarize(mutable)
|
||||
|
||||
def set_chosen_op_active(self):
|
||||
"""
|
||||
For each LayerChoice, set the op with highest alpha as the chosen op.
|
||||
Usually used for validation.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.set_chosen_op_active()
|
||||
|
||||
def num_arch_params(self):
|
||||
"""
|
||||
The number of mutables, i.e., LayerChoice
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
the number of LayerChoice in user model
|
||||
"""
|
||||
return len(self.mutable_list)
|
||||
|
||||
def set_arch_param_grad(self):
|
||||
"""
|
||||
For each LayerChoice, calculate gradients for architecture weights, i.e., alpha
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.set_arch_param_grad(mutable)
|
||||
|
||||
def get_architecture_parameters(self):
|
||||
"""
|
||||
Get all the architecture parameters.
|
||||
|
||||
yield
|
||||
-----
|
||||
PyTorch Parameter
|
||||
Return ap_path_alpha of the traversed mutable
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
yield mutable.registered_module.get_ap_path_alpha()
|
||||
|
||||
def change_forward_mode(self, mode):
|
||||
"""
|
||||
Update forward mode of MixedOps, as training architecture weights and
|
||||
model weights use different forward modes.
|
||||
"""
|
||||
MixedOp.forward_mode = mode
|
||||
|
||||
def get_forward_mode(self):
|
||||
"""
|
||||
Get forward mode of MixedOp
|
||||
|
||||
Returns
|
||||
-------
|
||||
string
|
||||
the current forward mode of MixedOp
|
||||
"""
|
||||
return MixedOp.forward_mode
|
||||
|
||||
def rescale_updated_arch_param(self):
|
||||
"""
|
||||
Rescale architecture weights in 'two' mode.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.rescale_updated_arch_param()
|
||||
|
||||
def unused_modules_off(self):
|
||||
"""
|
||||
Remove unused modules for each mutables.
|
||||
The removed modules are kept in ```self._unused_modules``` for resume later.
|
||||
"""
|
||||
self._unused_modules = []
|
||||
for mutable in self.undedup_mutables:
|
||||
mixed_op = mutable.registered_module
|
||||
unused = {}
|
||||
if self.get_forward_mode() in ['full', 'two', 'full_v2']:
|
||||
involved_index = mixed_op.active_index + mixed_op.inactive_index
|
||||
else:
|
||||
involved_index = mixed_op.active_index
|
||||
for i in range(mixed_op.n_choices):
|
||||
if i not in involved_index:
|
||||
unused[i] = mutable[i]
|
||||
mutable[i] = None
|
||||
self._unused_modules.append(unused)
|
||||
|
||||
def unused_modules_back(self):
|
||||
"""
|
||||
Resume the removed modules back.
|
||||
"""
|
||||
if self._unused_modules is None:
|
||||
return
|
||||
for m, unused in zip(self.mutable_list, self._unused_modules):
|
||||
for i in unused:
|
||||
m[i] = unused[i]
|
||||
self._unused_modules = None
|
||||
|
||||
def arch_requires_grad(self):
|
||||
"""
|
||||
Make architecture weights require gradient
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.to_requires_grad()
|
||||
|
||||
def arch_disable_grad(self):
|
||||
"""
|
||||
Disable gradient of architecture weights, i.e., does not
|
||||
calcuate gradient for them.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.to_disable_grad()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Generate the final chosen architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the choice of each mutable, i.e., LayerChoice
|
||||
"""
|
||||
result = dict()
|
||||
for mutable in self.undedup_mutables:
|
||||
assert isinstance(mutable, LayerChoice)
|
||||
index, _ = mutable.registered_module.chosen_index
|
||||
# pylint: disable=not-callable
|
||||
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool()
|
||||
return result
|
|
@ -1,500 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from nni.nas.pytorch.base_trainer import BaseTrainer
|
||||
from nni.nas.pytorch.trainer import TorchTensorEncoder
|
||||
from nni.nas.pytorch.utils import AverageMeter
|
||||
from .mutator import ProxylessNasMutator
|
||||
from .utils import cross_entropy_with_label_smoothing, accuracy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ProxylessNasTrainer(BaseTrainer):
|
||||
def __init__(self, model, model_optim, device,
|
||||
train_loader, valid_loader, label_smoothing=0.1,
|
||||
n_epochs=120, init_lr=0.025, binary_mode='full_v2',
|
||||
arch_init_type='normal', arch_init_ratio=1e-3,
|
||||
arch_optim_lr=1e-3, arch_weight_decay=0,
|
||||
grad_update_arch_param_every=5, grad_update_steps=1,
|
||||
warmup=True, warmup_epochs=25,
|
||||
arch_valid_frequency=1,
|
||||
load_ckpt=False, ckpt_path=None, arch_path=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
the user model, which has mutables
|
||||
model_optim : pytorch optimizer
|
||||
the user defined optimizer
|
||||
device : pytorch device
|
||||
the devices to train/search the model
|
||||
train_loader : pytorch data loader
|
||||
data loader for the training set
|
||||
valid_loader : pytorch data loader
|
||||
data loader for the validation set
|
||||
label_smoothing : float
|
||||
for label smoothing
|
||||
n_epochs : int
|
||||
number of epochs to train/search
|
||||
init_lr : float
|
||||
init learning rate for training the model
|
||||
binary_mode : str
|
||||
the forward/backward mode for the binary weights in mutator
|
||||
arch_init_type : str
|
||||
the way to init architecture parameters
|
||||
arch_init_ratio : float
|
||||
the ratio to init architecture parameters
|
||||
arch_optim_lr : float
|
||||
learning rate of the architecture parameters optimizer
|
||||
arch_weight_decay : float
|
||||
weight decay of the architecture parameters optimizer
|
||||
grad_update_arch_param_every : int
|
||||
update architecture weights every this number of minibatches
|
||||
grad_update_steps : int
|
||||
during each update of architecture weights, the number of steps to train
|
||||
warmup : bool
|
||||
whether to do warmup
|
||||
warmup_epochs : int
|
||||
the number of epochs to do during warmup
|
||||
arch_valid_frequency : int
|
||||
frequency of printing validation result
|
||||
load_ckpt : bool
|
||||
whether load checkpoint
|
||||
ckpt_path : str
|
||||
checkpoint path, if load_ckpt is True, ckpt_path cannot be None
|
||||
arch_path : str
|
||||
the path to store chosen architecture
|
||||
"""
|
||||
self.model = model
|
||||
self.model_optim = model_optim
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
self.device = device
|
||||
self.n_epochs = n_epochs
|
||||
self.init_lr = init_lr
|
||||
self.warmup = warmup
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.arch_valid_frequency = arch_valid_frequency
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
self.train_batch_size = train_loader.batch_sampler.batch_size
|
||||
self.valid_batch_size = valid_loader.batch_sampler.batch_size
|
||||
# update architecture parameters every this number of minibatches
|
||||
self.grad_update_arch_param_every = grad_update_arch_param_every
|
||||
# the number of steps per architecture parameter update
|
||||
self.grad_update_steps = grad_update_steps
|
||||
self.binary_mode = binary_mode
|
||||
|
||||
self.load_ckpt = load_ckpt
|
||||
self.ckpt_path = ckpt_path
|
||||
self.arch_path = arch_path
|
||||
|
||||
# init mutator
|
||||
self.mutator = ProxylessNasMutator(model)
|
||||
|
||||
# DataParallel should be put behind the init of mutator
|
||||
self.model = torch.nn.DataParallel(self.model)
|
||||
self.model.to(self.device)
|
||||
|
||||
# iter of valid dataset for training architecture weights
|
||||
self._valid_iter = None
|
||||
# init architecture weights
|
||||
self._init_arch_params(arch_init_type, arch_init_ratio)
|
||||
# build architecture optimizer
|
||||
self.arch_optimizer = torch.optim.Adam(self.mutator.get_architecture_parameters(),
|
||||
arch_optim_lr,
|
||||
weight_decay=arch_weight_decay,
|
||||
betas=(0, 0.999),
|
||||
eps=1e-8)
|
||||
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
self.warmup_curr_epoch = 0
|
||||
self.train_curr_epoch = 0
|
||||
|
||||
def _init_arch_params(self, init_type='normal', init_ratio=1e-3):
|
||||
"""
|
||||
Initialize architecture weights
|
||||
"""
|
||||
for param in self.mutator.get_architecture_parameters():
|
||||
if init_type == 'normal':
|
||||
param.data.normal_(0, init_ratio)
|
||||
elif init_type == 'uniform':
|
||||
param.data.uniform_(-init_ratio, init_ratio)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Do validation. During validation, LayerChoices use the chosen active op.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float, float, float
|
||||
average loss, average top1 accuracy, average top5 accuracy
|
||||
"""
|
||||
self.valid_loader.batch_sampler.batch_size = self.valid_batch_size
|
||||
self.valid_loader.batch_sampler.drop_last = False
|
||||
|
||||
self.mutator.set_chosen_op_active()
|
||||
# remove unused modules to save memory
|
||||
self.mutator.unused_modules_off()
|
||||
# test on validation set under train mode
|
||||
self.model.train()
|
||||
batch_time = AverageMeter('batch_time')
|
||||
losses = AverageMeter('losses')
|
||||
top1 = AverageMeter('top1')
|
||||
top5 = AverageMeter('top5')
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for i, (images, labels) in enumerate(self.valid_loader):
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
output = self.model(images)
|
||||
loss = self.criterion(output, labels)
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss, images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0 or i + 1 == len(self.valid_loader):
|
||||
test_log = 'Valid' + ': [{0}/{1}]\t'\
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
|
||||
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
|
||||
format(i, len(self.valid_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
|
||||
# return top5:
|
||||
test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(top5=top5)
|
||||
logger.info(test_log)
|
||||
self.mutator.unused_modules_back()
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
||||
def _warm_up(self):
|
||||
"""
|
||||
Warm up the model, during warm up, architecture weights are not trained.
|
||||
"""
|
||||
lr_max = 0.05
|
||||
data_loader = self.train_loader
|
||||
nBatch = len(data_loader)
|
||||
T_total = self.warmup_epochs * nBatch # total num of batches
|
||||
|
||||
for epoch in range(self.warmup_curr_epoch, self.warmup_epochs):
|
||||
logger.info('\n--------Warmup epoch: %d--------\n', epoch + 1)
|
||||
batch_time = AverageMeter('batch_time')
|
||||
data_time = AverageMeter('data_time')
|
||||
losses = AverageMeter('losses')
|
||||
top1 = AverageMeter('top1')
|
||||
top5 = AverageMeter('top5')
|
||||
# switch to train mode
|
||||
self.model.train()
|
||||
|
||||
end = time.time()
|
||||
logger.info('warm_up epoch: %d', epoch)
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
data_time.update(time.time() - end)
|
||||
# lr
|
||||
T_cur = epoch * nBatch + i
|
||||
warmup_lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total))
|
||||
for param_group in self.model_optim.param_groups:
|
||||
param_group['lr'] = warmup_lr
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
# compute output
|
||||
self.mutator.reset_binary_gates() # random sample binary gates
|
||||
self.mutator.unused_modules_off() # remove unused module for speedup
|
||||
output = self.model(images)
|
||||
if self.label_smoothing > 0:
|
||||
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
|
||||
else:
|
||||
loss = self.criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss, images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
# compute gradient and do SGD step
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
self.model_optim.step()
|
||||
# unused modules back
|
||||
self.mutator.unused_modules_back()
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0 or i + 1 == nBatch:
|
||||
batch_log = 'Warmup Train [{0}][{1}/{2}]\t' \
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
|
||||
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
|
||||
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
|
||||
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
|
||||
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
|
||||
losses=losses, top1=top1, top5=top5, lr=warmup_lr)
|
||||
logger.info(batch_log)
|
||||
val_loss, val_top1, val_top5 = self._validate()
|
||||
val_log = 'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}\t' \
|
||||
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}M'. \
|
||||
format(epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5, top1=top1, top5=top5)
|
||||
logger.info(val_log)
|
||||
self.save_checkpoint()
|
||||
self.warmup_curr_epoch += 1
|
||||
|
||||
def _get_update_schedule(self, nBatch):
|
||||
"""
|
||||
Generate schedule for training architecture weights. Key means after which minibatch
|
||||
to update architecture weights, value means how many steps for the update.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nBatch : int
|
||||
the total number of minibatches in one epoch
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the schedule for updating architecture weights
|
||||
"""
|
||||
schedule = {}
|
||||
for i in range(nBatch):
|
||||
if (i + 1) % self.grad_update_arch_param_every == 0:
|
||||
schedule[i] = self.grad_update_steps
|
||||
return schedule
|
||||
|
||||
def _calc_learning_rate(self, epoch, batch=0, nBatch=None):
|
||||
"""
|
||||
Update learning rate.
|
||||
"""
|
||||
T_total = self.n_epochs * nBatch
|
||||
T_cur = epoch * nBatch + batch
|
||||
lr = 0.5 * self.init_lr * (1 + math.cos(math.pi * T_cur / T_total))
|
||||
return lr
|
||||
|
||||
def _adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
|
||||
"""
|
||||
Adjust learning of a given optimizer and return the new learning rate
|
||||
|
||||
Parameters
|
||||
----------
|
||||
optimizer : pytorch optimizer
|
||||
the used optimizer
|
||||
epoch : int
|
||||
the current epoch number
|
||||
batch : int
|
||||
the current minibatch
|
||||
nBatch : int
|
||||
the total number of minibatches in one epoch
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
the adjusted learning rate
|
||||
"""
|
||||
new_lr = self._calc_learning_rate(epoch, batch, nBatch)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
return new_lr
|
||||
|
||||
def _train(self):
|
||||
"""
|
||||
Train the model, it trains model weights and architecute weights.
|
||||
Architecture weights are trained according to the schedule.
|
||||
Before updating architecture weights, ```requires_grad``` is enabled.
|
||||
Then, it is disabled after the updating, in order not to update
|
||||
architecture weights when training model weights.
|
||||
"""
|
||||
nBatch = len(self.train_loader)
|
||||
arch_param_num = self.mutator.num_arch_params()
|
||||
binary_gates_num = self.mutator.num_arch_params()
|
||||
logger.info('#arch_params: %d\t#binary_gates: %d', arch_param_num, binary_gates_num)
|
||||
|
||||
update_schedule = self._get_update_schedule(nBatch)
|
||||
|
||||
for epoch in range(self.train_curr_epoch, self.n_epochs):
|
||||
logger.info('\n--------Train epoch: %d--------\n', epoch + 1)
|
||||
batch_time = AverageMeter('batch_time')
|
||||
data_time = AverageMeter('data_time')
|
||||
losses = AverageMeter('losses')
|
||||
top1 = AverageMeter('top1')
|
||||
top5 = AverageMeter('top5')
|
||||
# switch to train mode
|
||||
self.model.train()
|
||||
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(self.train_loader):
|
||||
data_time.update(time.time() - end)
|
||||
lr = self._adjust_learning_rate(self.model_optim, epoch, batch=i, nBatch=nBatch)
|
||||
# train weight parameters
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
self.mutator.reset_binary_gates()
|
||||
self.mutator.unused_modules_off()
|
||||
output = self.model(images)
|
||||
if self.label_smoothing > 0:
|
||||
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
|
||||
else:
|
||||
loss = self.criterion(output, labels)
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss, images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
self.model_optim.step()
|
||||
self.mutator.unused_modules_back()
|
||||
if epoch > 0:
|
||||
for _ in range(update_schedule.get(i, 0)):
|
||||
start_time = time.time()
|
||||
# GradientArchSearchConfig
|
||||
self.mutator.arch_requires_grad()
|
||||
arch_loss, exp_value = self._gradient_step()
|
||||
self.mutator.arch_disable_grad()
|
||||
used_time = time.time() - start_time
|
||||
log_str = 'Architecture [%d-%d]\t Time %.4f\t Loss %.4f\t null %s' % \
|
||||
(epoch + 1, i, used_time, arch_loss, exp_value)
|
||||
logger.info(log_str)
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
# training log
|
||||
if i % 10 == 0 or i + 1 == nBatch:
|
||||
batch_log = 'Train [{0}][{1}/{2}]\t' \
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
||||
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' \
|
||||
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
|
||||
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
|
||||
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
|
||||
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
|
||||
losses=losses, top1=top1, top5=top5, lr=lr)
|
||||
logger.info(batch_log)
|
||||
# validate
|
||||
if (epoch + 1) % self.arch_valid_frequency == 0:
|
||||
val_loss, val_top1, val_top5 = self._validate()
|
||||
val_log = 'Valid [{0}]\tloss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}\t' \
|
||||
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
|
||||
format(epoch + 1, val_loss, val_top1, val_top5, top1=top1, top5=top5)
|
||||
logger.info(val_log)
|
||||
self.save_checkpoint()
|
||||
self.train_curr_epoch += 1
|
||||
|
||||
def _valid_next_batch(self):
|
||||
"""
|
||||
Get next one minibatch from validation set
|
||||
|
||||
Returns
|
||||
-------
|
||||
(tensor, tensor)
|
||||
the tuple of images and labels
|
||||
"""
|
||||
if self._valid_iter is None:
|
||||
self._valid_iter = iter(self.valid_loader)
|
||||
try:
|
||||
data = next(self._valid_iter)
|
||||
except StopIteration:
|
||||
self._valid_iter = iter(self.valid_loader)
|
||||
data = next(self._valid_iter)
|
||||
return data
|
||||
|
||||
def _gradient_step(self):
|
||||
"""
|
||||
This gradient step is for updating architecture weights.
|
||||
Mutator is intensively used in this function to operate on
|
||||
architecture weights.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float, None
|
||||
loss of the model, None
|
||||
"""
|
||||
# use the same batch size as train batch size for architecture weights
|
||||
self.valid_loader.batch_sampler.batch_size = self.train_batch_size
|
||||
self.valid_loader.batch_sampler.drop_last = True
|
||||
self.model.train()
|
||||
self.mutator.change_forward_mode(self.binary_mode)
|
||||
time1 = time.time() # time
|
||||
# sample a batch of data from validation set
|
||||
images, labels = self._valid_next_batch()
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
time2 = time.time() # time
|
||||
self.mutator.reset_binary_gates()
|
||||
self.mutator.unused_modules_off()
|
||||
output = self.model(images)
|
||||
time3 = time.time()
|
||||
ce_loss = self.criterion(output, labels)
|
||||
expected_value = None
|
||||
loss = ce_loss
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
self.mutator.set_arch_param_grad()
|
||||
self.arch_optimizer.step()
|
||||
if self.mutator.get_forward_mode() == 'two':
|
||||
self.mutator.rescale_updated_arch_param()
|
||||
self.mutator.unused_modules_back()
|
||||
self.mutator.change_forward_mode(None)
|
||||
time4 = time.time()
|
||||
logger.info('(%.4f, %.4f, %.4f)', time2 - time1, time3 - time2, time4 - time3)
|
||||
return loss.data.item(), expected_value.item() if expected_value is not None else None
|
||||
|
||||
def save_checkpoint(self):
|
||||
"""
|
||||
Save checkpoint of the whole model. Saving model weights and architecture weights in
|
||||
```ckpt_path```, and saving currently chosen architecture in ```arch_path```.
|
||||
"""
|
||||
if self.ckpt_path:
|
||||
state = {
|
||||
'warmup_curr_epoch': self.warmup_curr_epoch,
|
||||
'train_curr_epoch': self.train_curr_epoch,
|
||||
'model': self.model.state_dict(),
|
||||
'optim': self.model_optim.state_dict(),
|
||||
'arch_optim': self.arch_optimizer.state_dict()
|
||||
}
|
||||
torch.save(state, self.ckpt_path)
|
||||
if self.arch_path:
|
||||
self.export(self.arch_path)
|
||||
|
||||
def load_checkpoint(self):
|
||||
"""
|
||||
Load the checkpoint from ```ckpt_path```.
|
||||
"""
|
||||
assert self.ckpt_path is not None, "If load_ckpt is not None, ckpt_path should not be None"
|
||||
ckpt = torch.load(self.ckpt_path)
|
||||
self.warmup_curr_epoch = ckpt['warmup_curr_epoch']
|
||||
self.train_curr_epoch = ckpt['train_curr_epoch']
|
||||
self.model.load_state_dict(ckpt['model'])
|
||||
self.model_optim.load_state_dict(ckpt['optim'])
|
||||
self.arch_optimizer.load_state_dict(ckpt['arch_optim'])
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Train the whole model.
|
||||
"""
|
||||
if self.load_ckpt:
|
||||
self.load_checkpoint()
|
||||
if self.warmup:
|
||||
self._warm_up()
|
||||
self._train()
|
||||
|
||||
def export(self, file_name):
|
||||
"""
|
||||
Export the chosen architecture into a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_name : str
|
||||
the file that stores exported chosen architecture
|
||||
"""
|
||||
exported_arch = self.mutator.sample_final()
|
||||
with open(file_name, 'w') as f:
|
||||
json.dump(exported_arch, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
|
||||
|
||||
def validate(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def checkpoint(self):
|
||||
raise NotImplementedError
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def detach_variable(inputs):
|
||||
"""
|
||||
Detach variables
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : pytorch tensors
|
||||
pytorch tensors
|
||||
"""
|
||||
if isinstance(inputs, tuple):
|
||||
return tuple([detach_variable(x) for x in inputs])
|
||||
else:
|
||||
x = inputs.detach()
|
||||
x.requires_grad = inputs.requires_grad
|
||||
return x
|
||||
|
||||
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
pred : pytorch tensor
|
||||
predicted value
|
||||
target : pytorch tensor
|
||||
label
|
||||
label_smoothing : float
|
||||
the degree of label smoothing
|
||||
|
||||
Returns
|
||||
-------
|
||||
pytorch tensor
|
||||
cross entropy
|
||||
"""
|
||||
logsoftmax = nn.LogSoftmax()
|
||||
n_classes = pred.size(1)
|
||||
# convert to one-hot
|
||||
target = torch.unsqueeze(target, 1)
|
||||
soft_target = torch.zeros_like(pred)
|
||||
soft_target.scatter_(1, target, 1)
|
||||
# label smoothing
|
||||
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
|
||||
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""
|
||||
Computes the precision@k for the specified values of k
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : pytorch tensor
|
||||
output, e.g., predicted value
|
||||
target : pytorch tensor
|
||||
label
|
||||
topk : tuple
|
||||
specify top1 and top5
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
accuracy of top1 and top5
|
||||
"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import RandomMutator
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nni.nas.pytorch.mutator import Mutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
|
||||
|
||||
|
||||
class RandomMutator(Mutator):
|
||||
"""
|
||||
Random mutator that samples a random candidate in the search space each time ``reset()``.
|
||||
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior.
|
||||
"""
|
||||
|
||||
def sample_search(self):
|
||||
"""
|
||||
Sample a random candidate.
|
||||
"""
|
||||
result = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
gen_index = torch.randint(high=len(mutable), size=(1, ))
|
||||
result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool()
|
||||
elif isinstance(mutable, InputChoice):
|
||||
if mutable.n_chosen is None:
|
||||
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
|
||||
else:
|
||||
perm = torch.randperm(mutable.n_candidates)
|
||||
mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
|
||||
result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
|
||||
return result
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Same as :meth:`sample_search`.
|
||||
"""
|
||||
return self.sample_search()
|
|
@ -1,6 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .evolution import SPOSEvolution
|
||||
from .mutator import SPOSSupernetTrainingMutator
|
||||
from .trainer import SPOSSupernetTrainer
|
|
@ -1,223 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from nni.tuner import Tuner
|
||||
from nni.algorithms.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SPOSEvolution(Tuner):
|
||||
"""
|
||||
SPOS evolution tuner.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_epochs : int
|
||||
Maximum number of epochs to run.
|
||||
num_select : int
|
||||
Number of survival candidates of each epoch.
|
||||
num_population : int
|
||||
Number of candidates at the start of each epoch. If candidates generated by
|
||||
crossover and mutation are not enough, the rest will be filled with random
|
||||
candidates.
|
||||
m_prob : float
|
||||
The probability of mutation.
|
||||
num_crossover : int
|
||||
Number of candidates generated by crossover in each epoch.
|
||||
num_mutation : int
|
||||
Number of candidates generated by mutation in each epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
|
||||
num_crossover=25, num_mutation=25):
|
||||
assert num_population >= num_select
|
||||
self.max_epochs = max_epochs
|
||||
self.num_select = num_select
|
||||
self.num_population = num_population
|
||||
self.m_prob = m_prob
|
||||
self.num_crossover = num_crossover
|
||||
self.num_mutation = num_mutation
|
||||
self.epoch = 0
|
||||
self.candidates = []
|
||||
self.search_space = None
|
||||
self.random_state = np.random.RandomState(0)
|
||||
|
||||
# async status
|
||||
self._to_evaluate_queue = deque()
|
||||
self._sending_parameter_queue = deque()
|
||||
self._pending_result_ids = set()
|
||||
self._reward_dict = dict()
|
||||
self._id2candidate = dict()
|
||||
self._st_callback = None
|
||||
|
||||
def update_search_space(self, search_space):
|
||||
"""
|
||||
Handle the initialization/update event of search space.
|
||||
"""
|
||||
self._search_space = search_space
|
||||
self._next_round()
|
||||
|
||||
def _next_round(self):
|
||||
_logger.info("Epoch %d, generating...", self.epoch)
|
||||
if self.epoch == 0:
|
||||
self._get_random_population()
|
||||
self.export_results(self.candidates)
|
||||
else:
|
||||
best_candidates = self._select_top_candidates()
|
||||
self.export_results(best_candidates)
|
||||
if self.epoch >= self.max_epochs:
|
||||
return
|
||||
self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
|
||||
self._get_random_population()
|
||||
self.epoch += 1
|
||||
|
||||
def _random_candidate(self):
|
||||
chosen_arch = dict()
|
||||
for key, val in self._search_space.items():
|
||||
if val["_type"] == LAYER_CHOICE:
|
||||
choices = val["_value"]
|
||||
index = self.random_state.randint(len(choices))
|
||||
chosen_arch[key] = {"_value": choices[index], "_idx": index}
|
||||
elif val["_type"] == INPUT_CHOICE:
|
||||
raise NotImplementedError("Input choice is not implemented yet.")
|
||||
return chosen_arch
|
||||
|
||||
def _add_to_evaluate_queue(self, cand):
|
||||
_logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
|
||||
self._reward_dict[self._hashcode(cand)] = 0.
|
||||
self._to_evaluate_queue.append(cand)
|
||||
|
||||
def _get_random_population(self):
|
||||
while len(self.candidates) < self.num_population:
|
||||
cand = self._random_candidate()
|
||||
if self._is_legal(cand):
|
||||
_logger.info("Random candidate generated.")
|
||||
self._add_to_evaluate_queue(cand)
|
||||
self.candidates.append(cand)
|
||||
|
||||
def _get_crossover(self, best):
|
||||
result = []
|
||||
for _ in range(10 * self.num_crossover):
|
||||
cand_p1 = best[self.random_state.randint(len(best))]
|
||||
cand_p2 = best[self.random_state.randint(len(best))]
|
||||
assert cand_p1.keys() == cand_p2.keys()
|
||||
cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
|
||||
for k in cand_p1.keys()}
|
||||
if self._is_legal(cand):
|
||||
result.append(cand)
|
||||
self._add_to_evaluate_queue(cand)
|
||||
if len(result) >= self.num_crossover:
|
||||
break
|
||||
_logger.info("Found %d architectures with crossover.", len(result))
|
||||
return result
|
||||
|
||||
def _get_mutation(self, best):
|
||||
result = []
|
||||
for _ in range(10 * self.num_mutation):
|
||||
cand = best[self.random_state.randint(len(best))].copy()
|
||||
mutation_sample = np.random.random_sample(len(cand))
|
||||
for s, k in zip(mutation_sample, cand):
|
||||
if s < self.m_prob:
|
||||
choices = self._search_space[k]["_value"]
|
||||
index = self.random_state.randint(len(choices))
|
||||
cand[k] = {"_value": choices[index], "_idx": index}
|
||||
if self._is_legal(cand):
|
||||
result.append(cand)
|
||||
self._add_to_evaluate_queue(cand)
|
||||
if len(result) >= self.num_mutation:
|
||||
break
|
||||
_logger.info("Found %d architectures with mutation.", len(result))
|
||||
return result
|
||||
|
||||
def _get_architecture_repr(self, cand):
|
||||
return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
|
||||
self._hashcode(cand))
|
||||
|
||||
def _is_legal(self, cand):
|
||||
if self._hashcode(cand) in self._reward_dict:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _select_top_candidates(self):
|
||||
reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
|
||||
_logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
|
||||
result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
|
||||
_logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _hashcode(d):
|
||||
return json.dumps(d, sort_keys=True)
|
||||
|
||||
def _bind_and_send_parameters(self):
|
||||
"""
|
||||
There are two types of resources: parameter ids and candidates. This function is called at
|
||||
necessary times to bind these resources to send new trials with st_callback.
|
||||
"""
|
||||
result = []
|
||||
while self._sending_parameter_queue and self._to_evaluate_queue:
|
||||
parameter_id = self._sending_parameter_queue.popleft()
|
||||
parameters = self._to_evaluate_queue.popleft()
|
||||
self._id2candidate[parameter_id] = parameters
|
||||
result.append(parameters)
|
||||
self._pending_result_ids.add(parameter_id)
|
||||
self._st_callback(parameter_id, parameters)
|
||||
_logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
|
||||
return result
|
||||
|
||||
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
|
||||
"""
|
||||
Callback function necessary to implement a tuner. This will put more parameter ids into the
|
||||
parameter id queue.
|
||||
"""
|
||||
if "st_callback" in kwargs and self._st_callback is None:
|
||||
self._st_callback = kwargs["st_callback"]
|
||||
for parameter_id in parameter_id_list:
|
||||
self._sending_parameter_queue.append(parameter_id)
|
||||
self._bind_and_send_parameters()
|
||||
return [] # always not use this. might induce problem of over-sending
|
||||
|
||||
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
|
||||
"""
|
||||
Callback function. Receive a trial result.
|
||||
"""
|
||||
_logger.info("Candidate %d, reported reward %f", parameter_id, value)
|
||||
self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
|
||||
|
||||
def trial_end(self, parameter_id, success, **kwargs):
|
||||
"""
|
||||
Callback function when a trial is ended and resource is released.
|
||||
"""
|
||||
self._pending_result_ids.remove(parameter_id)
|
||||
if not self._pending_result_ids and not self._to_evaluate_queue:
|
||||
# a new epoch now
|
||||
self._next_round()
|
||||
assert self._st_callback is not None
|
||||
self._bind_and_send_parameters()
|
||||
|
||||
def export_results(self, result):
|
||||
"""
|
||||
Export a number of candidates to `checkpoints` dir.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
Chosen architectures to be exported.
|
||||
"""
|
||||
os.makedirs("checkpoints", exist_ok=True)
|
||||
for i, cand in enumerate(result):
|
||||
converted = dict()
|
||||
for cand_key, cand_val in cand.items():
|
||||
onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
|
||||
converted[cand_key] = onehot
|
||||
with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
|
||||
json.dump(converted, fp)
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from nni.algorithms.nas.pytorch.random import RandomMutator
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SPOSSupernetTrainingMutator(RandomMutator):
|
||||
"""
|
||||
A random mutator with flops limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
PyTorch model.
|
||||
flops_func : callable
|
||||
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
|
||||
is None, functions related to flops will be deactivated.
|
||||
flops_lb : number
|
||||
Lower bound of flops.
|
||||
flops_ub : number
|
||||
Upper bound of flops.
|
||||
flops_bin_num : number
|
||||
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
|
||||
uniform, but the sampling will be slower.
|
||||
flops_sample_timeout : int
|
||||
Maximum number of attempts to sample before giving up and use a random candidate.
|
||||
"""
|
||||
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
|
||||
flops_bin_num=7, flops_sample_timeout=500):
|
||||
|
||||
super().__init__(model)
|
||||
self._flops_func = flops_func
|
||||
if self._flops_func is not None:
|
||||
self._flops_bin_num = flops_bin_num
|
||||
self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)]
|
||||
self._flops_sample_timeout = flops_sample_timeout
|
||||
|
||||
def sample_search(self):
|
||||
"""
|
||||
Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly
|
||||
relative to flops.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
"""
|
||||
if self._flops_func is not None:
|
||||
for times in range(self._flops_sample_timeout):
|
||||
idx = np.random.randint(self._flops_bin_num)
|
||||
cand = super().sample_search()
|
||||
if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]:
|
||||
_logger.debug("Sampled candidate flops %f in %d times.", cand, times)
|
||||
return cand
|
||||
_logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout)
|
||||
return super().sample_search()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Implement only to suffice the interface of Mutator.
|
||||
"""
|
||||
return self.sample_search()
|
|
@ -1,95 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from nni.nas.pytorch.trainer import Trainer
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup
|
||||
|
||||
from .mutator import SPOSSupernetTrainingMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SPOSSupernetTrainer(Trainer):
|
||||
"""
|
||||
This trainer trains a supernet that can be used for evolution search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with mutables.
|
||||
mutator : nni.nas.pytorch.mutator.Mutator
|
||||
A mutator object that has been initialized with the model.
|
||||
loss : callable
|
||||
Called with logits and targets. Returns a loss tensor.
|
||||
metrics : callable
|
||||
Returns a dict that maps metrics keys to metrics data.
|
||||
optimizer : Optimizer
|
||||
Optimizer that optimizes the model.
|
||||
num_epochs : int
|
||||
Number of epochs of training.
|
||||
train_loader : iterable
|
||||
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
|
||||
dataset_valid : iterable
|
||||
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
|
||||
batch_size : int
|
||||
Batch size.
|
||||
workers: int
|
||||
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
|
||||
device : torch.device
|
||||
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
|
||||
automatic detects GPU and selects GPU first.
|
||||
log_frequency : int
|
||||
Number of mini-batches to log metrics.
|
||||
callbacks : list of Callback
|
||||
Callbacks to plug into the trainer. See Callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self, model, loss, metrics,
|
||||
optimizer, num_epochs, train_loader, valid_loader,
|
||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
|
||||
callbacks=None):
|
||||
assert torch.cuda.is_available()
|
||||
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
|
||||
loss, metrics, optimizer, num_epochs, None, None,
|
||||
batch_size, workers, device, log_frequency, callbacks)
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.model.train()
|
||||
meters = AverageMeterGroup()
|
||||
for step, (x, y) in enumerate(self.train_loader):
|
||||
x, y = x.to(self.device), y.to(self.device)
|
||||
self.optimizer.zero_grad()
|
||||
self.mutator.reset()
|
||||
logits = self.model(x)
|
||||
loss = self.loss(logits, y)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
metrics = self.metrics(logits, y)
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.train_loader), meters)
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
self.model.eval()
|
||||
meters = AverageMeterGroup()
|
||||
with torch.no_grad():
|
||||
for step, (x, y) in enumerate(self.valid_loader):
|
||||
x, y = x.to(self.device), y.to(self.device)
|
||||
self.mutator.reset()
|
||||
logits = self.model(x)
|
||||
loss = self.loss(logits, y)
|
||||
metrics = self.metrics(logits, y)
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.valid_loader), meters)
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import get_and_apply_next_architecture
|
|
@ -1,217 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import nni
|
||||
from nni.runtime.env_vars import trial_env_vars
|
||||
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
|
||||
from nni.nas.tensorflow.mutator import Mutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
|
||||
LAYER_CHOICE = "layer_choice"
|
||||
INPUT_CHOICE = "input_choice"
|
||||
|
||||
|
||||
def get_and_apply_next_architecture(model):
|
||||
"""
|
||||
Wrapper of :class:`~nni.nas.tensorflow.classic_nas.mutator.ClassicMutator` to make it more meaningful,
|
||||
similar to ``get_next_parameter`` for HPO.
|
||||
Tt will generate search space based on ``model``.
|
||||
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
|
||||
generating search space for the experiment.
|
||||
If not, there are still two mode, one is nni experiment mode where users
|
||||
use ``nnictl`` to start an experiment. The other is standalone mode
|
||||
where users directly run the trial command, this mode chooses the first
|
||||
one(s) for each LayerChoice and InputChoice.
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
|
||||
"""
|
||||
ClassicMutator(model)
|
||||
|
||||
|
||||
class ClassicMutator(Mutator):
|
||||
"""
|
||||
This mutator is to apply the architecture chosen from tuner.
|
||||
It implements the forward function of LayerChoice and InputChoice,
|
||||
to only activate the chosen ones.
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super(ClassicMutator, self).__init__(model)
|
||||
self._chosen_arch = {}
|
||||
self._search_space = self._generate_search_space()
|
||||
if NNI_GEN_SEARCH_SPACE in os.environ:
|
||||
# dry run for only generating search space
|
||||
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
|
||||
sys.exit(0)
|
||||
|
||||
if trial_env_vars.NNI_PLATFORM is None:
|
||||
logger.warning("This is in standalone mode, the chosen are the first one(s).")
|
||||
self._chosen_arch = self._standalone_generate_chosen()
|
||||
else:
|
||||
# get chosen arch from tuner
|
||||
self._chosen_arch = nni.get_next_parameter()
|
||||
if self._chosen_arch is None:
|
||||
if trial_env_vars.NNI_PLATFORM == "unittest":
|
||||
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
|
||||
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
|
||||
self._chosen_arch = self._standalone_generate_chosen()
|
||||
else:
|
||||
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
|
||||
self.reset()
|
||||
|
||||
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
|
||||
"""
|
||||
Convert layer choice to tensor representation.
|
||||
Parameters
|
||||
----------
|
||||
mutable : Mutable
|
||||
idx : int
|
||||
Number `idx` of list will be selected.
|
||||
value : str
|
||||
The verbose representation of the selected value.
|
||||
search_space_item : list
|
||||
The list for corresponding search space.
|
||||
"""
|
||||
# doesn't support multihot for layer choice yet
|
||||
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
|
||||
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
|
||||
mask = tf.one_hot(idx, len(mutable))
|
||||
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
|
||||
|
||||
def _sample_input_choice(self, mutable, idx, value, search_space_item):
|
||||
"""
|
||||
Convert input choice to tensor representation.
|
||||
Parameters
|
||||
----------
|
||||
mutable : Mutable
|
||||
idx : int
|
||||
Number `idx` of list will be selected.
|
||||
value : str
|
||||
The verbose representation of the selected value.
|
||||
search_space_item : list
|
||||
The list for corresponding search space.
|
||||
"""
|
||||
candidate_repr = search_space_item["candidates"]
|
||||
multihot_list = [False] * mutable.n_candidates
|
||||
for i, v in zip(idx, value):
|
||||
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
|
||||
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
|
||||
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
|
||||
multihot_list[i] = True
|
||||
return tf.cast(multihot_list, tf.bool) # pylint: disable=not-callable
|
||||
|
||||
def sample_search(self):
|
||||
"""
|
||||
See :meth:`sample_final`.
|
||||
"""
|
||||
return self.sample_final()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Convert the chosen arch and apply it on model.
|
||||
"""
|
||||
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
|
||||
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
|
||||
self._chosen_arch.keys())
|
||||
result = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, (LayerChoice, InputChoice)):
|
||||
assert mutable.key in self._chosen_arch, \
|
||||
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
|
||||
data = self._chosen_arch[mutable.key]
|
||||
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
|
||||
"'{}' is not a valid choice.".format(data)
|
||||
if isinstance(mutable, LayerChoice):
|
||||
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
|
||||
self._search_space[mutable.key]["_value"])
|
||||
elif isinstance(mutable, InputChoice):
|
||||
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
|
||||
self._search_space[mutable.key]["_value"])
|
||||
elif isinstance(mutable, MutableScope):
|
||||
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
|
||||
else:
|
||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
|
||||
return result
|
||||
|
||||
def _standalone_generate_chosen(self):
|
||||
"""
|
||||
Generate the chosen architecture for standalone mode,
|
||||
i.e., choose the first one(s) for LayerChoice and InputChoice.
|
||||
::
|
||||
{ key_name: {"_value": "conv1",
|
||||
"_idx": 0} }
|
||||
{ key_name: {"_value": ["in1"],
|
||||
"_idx": [0]} }
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the chosen architecture
|
||||
"""
|
||||
chosen_arch = {}
|
||||
for key, val in self._search_space.items():
|
||||
if val["_type"] == LAYER_CHOICE:
|
||||
choices = val["_value"]
|
||||
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
|
||||
elif val["_type"] == INPUT_CHOICE:
|
||||
choices = val["_value"]["candidates"]
|
||||
n_chosen = val["_value"]["n_chosen"]
|
||||
if n_chosen is None:
|
||||
n_chosen = len(choices)
|
||||
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
|
||||
else:
|
||||
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
|
||||
return chosen_arch
|
||||
|
||||
def _generate_search_space(self):
|
||||
"""
|
||||
Generate search space from mutables.
|
||||
Here is the search space format:
|
||||
::
|
||||
{ key_name: {"_type": "layer_choice",
|
||||
"_value": ["conv1", "conv2"]} }
|
||||
{ key_name: {"_type": "input_choice",
|
||||
"_value": {"candidates": ["in1", "in2"],
|
||||
"n_chosen": 1}} }
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the generated search space
|
||||
"""
|
||||
search_space = {}
|
||||
for mutable in self.mutables:
|
||||
# for now we only generate flattened search space
|
||||
if isinstance(mutable, LayerChoice):
|
||||
key = mutable.key
|
||||
val = mutable.names
|
||||
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
|
||||
elif isinstance(mutable, InputChoice):
|
||||
key = mutable.key
|
||||
search_space[key] = {"_type": INPUT_CHOICE,
|
||||
"_value": {"candidates": mutable.choose_from,
|
||||
"n_chosen": mutable.n_chosen}}
|
||||
elif isinstance(mutable, MutableScope):
|
||||
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
|
||||
else:
|
||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
|
||||
return search_space
|
||||
|
||||
def _dump_search_space(self, file_path):
|
||||
with open(file_path, "w") as ss_file:
|
||||
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import EnasMutator
|
||||
from .trainer import EnasTrainer
|
|
@ -1,162 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import Dense, Embedding, LSTMCell, RNN
|
||||
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction
|
||||
|
||||
from nni.nas.tensorflow.mutator import Mutator
|
||||
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
|
||||
|
||||
|
||||
class EnasMutator(Mutator):
|
||||
def __init__(self, model,
|
||||
lstm_size=64,
|
||||
lstm_num_layers=1,
|
||||
tanh_constant=1.5,
|
||||
cell_exit_extra_step=False,
|
||||
skip_target=0.4,
|
||||
temperature=None,
|
||||
branch_bias=0.25,
|
||||
entropy_reduction='sum'):
|
||||
super().__init__(model)
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
self.cell_exit_extra_step = cell_exit_extra_step
|
||||
|
||||
cells = [LSTMCell(units=lstm_size, use_bias=False) for _ in range(lstm_num_layers)]
|
||||
self.lstm = RNN(cells, stateful=True)
|
||||
self.g_emb = tf.random.normal((1, 1, lstm_size)) * 0.1
|
||||
self.skip_targets = tf.constant([1.0 - skip_target, skip_target])
|
||||
|
||||
self.max_layer_choice = 0
|
||||
self.bias_dict = {}
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
if self.max_layer_choice == 0:
|
||||
self.max_layer_choice = len(mutable)
|
||||
assert self.max_layer_choice == len(mutable), \
|
||||
"ENAS mutator requires all layer choice have the same number of candidates."
|
||||
if 'reduce' in mutable.key:
|
||||
bias = []
|
||||
for choice in mutable.choices:
|
||||
if 'conv' in str(type(choice)).lower():
|
||||
bias.append(branch_bias)
|
||||
else:
|
||||
bias.append(-branch_bias)
|
||||
self.bias_dict[mutable.key] = tf.constant(bias)
|
||||
|
||||
# exposed for trainer
|
||||
self.sample_log_prob = 0
|
||||
self.sample_entropy = 0
|
||||
self.sample_skip_penalty = 0
|
||||
|
||||
# internal nn layers
|
||||
self.embedding = Embedding(self.max_layer_choice + 1, lstm_size)
|
||||
self.soft = Dense(self.max_layer_choice, use_bias=False)
|
||||
self.attn_anchor = Dense(lstm_size, use_bias=False)
|
||||
self.attn_query = Dense(lstm_size, use_bias=False)
|
||||
self.v_attn = Dense(1, use_bias=False)
|
||||
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
|
||||
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean
|
||||
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
|
||||
|
||||
self._first_sample = True
|
||||
|
||||
def sample_search(self):
|
||||
self._initialize()
|
||||
self._sample(self.mutables)
|
||||
self._first_sample = False
|
||||
return self._choices
|
||||
|
||||
def sample_final(self):
|
||||
return self.sample_search()
|
||||
|
||||
def _sample(self, tree):
|
||||
mutable = tree.mutable
|
||||
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
|
||||
self._choices[mutable.key] = self._sample_layer_choice(mutable)
|
||||
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
|
||||
self._choices[mutable.key] = self._sample_input_choice(mutable)
|
||||
for child in tree.children:
|
||||
self._sample(child)
|
||||
if self.cell_exit_extra_step and isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
|
||||
self._anchors_hid[mutable.key] = self.lstm(self._inputs, 1)
|
||||
|
||||
def _initialize(self):
|
||||
self._choices = {}
|
||||
self._anchors_hid = {}
|
||||
self._inputs = self.g_emb
|
||||
# seems the `input_shape` parameter of RNN does not work
|
||||
# workaround it by omitting `reset_states` for first run
|
||||
if not self._first_sample:
|
||||
self.lstm.reset_states()
|
||||
self.sample_log_prob = 0
|
||||
self.sample_entropy = 0
|
||||
self.sample_skip_penalty = 0
|
||||
|
||||
def _sample_layer_choice(self, mutable):
|
||||
logit = self.soft(self.lstm(self._inputs))
|
||||
if self.temperature is not None:
|
||||
logit /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
logit = self.tanh_constant * tf.tanh(logit)
|
||||
if mutable.key in self.bias_dict:
|
||||
logit += self.bias_dict[mutable.key]
|
||||
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
|
||||
branch_id = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [1])
|
||||
log_prob = self.cross_entropy_loss(branch_id, logit)
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = log_prob * tf.math.exp(-log_prob)
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
self._inputs = tf.reshape(self.embedding(branch_id), [1, 1, -1])
|
||||
mask = tf.one_hot(branch_id, self.max_layer_choice)
|
||||
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
|
||||
|
||||
def _sample_input_choice(self, mutable):
|
||||
query, anchors = [], []
|
||||
for label in mutable.choose_from:
|
||||
if label not in self._anchors_hid:
|
||||
self._anchors_hid[label] = self.lstm(self._inputs)
|
||||
query.append(self.attn_anchor(self._anchors_hid[label]))
|
||||
anchors.append(self._anchors_hid[label])
|
||||
query = tf.concat(query, axis=0)
|
||||
query = tf.tanh(query + self.attn_query(anchors[-1]))
|
||||
query = self.v_attn(query)
|
||||
|
||||
if self.temperature is not None:
|
||||
query /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
query = self.tanh_constant * tf.tanh(query)
|
||||
|
||||
if mutable.n_chosen is None:
|
||||
logit = tf.concat([-query, query], axis=1)
|
||||
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
|
||||
skip = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
|
||||
skip_prob = tf.math.sigmoid(logit)
|
||||
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets))
|
||||
self.sample_skip_penalty += kl
|
||||
log_prob = self.cross_entropy_loss(skip, logit)
|
||||
|
||||
skip = tf.cast(skip, tf.float32)
|
||||
inputs = tf.tensordot(skip, tf.concat(anchors, 0), 1) / (1. + tf.reduce_sum(skip))
|
||||
self._inputs = tf.reshape(inputs, [1, 1, -1])
|
||||
|
||||
else:
|
||||
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
|
||||
logit = tf.reshape(query, [1, -1])
|
||||
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
|
||||
index = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
|
||||
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1])
|
||||
# when the size is 1, tf does not accept tensor here, complaining the shape is wrong
|
||||
# but using a numpy array seems fine
|
||||
log_prob = self.cross_entropy_loss(logit, query.numpy())
|
||||
self._inputs = tf.reshape(anchors[index.numpy()[0]], [1, 1, -1])
|
||||
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = log_prob * tf.exp(-log_prob)
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
assert len(skip) == mutable.n_candidates, (skip, mutable.n_candidates, mutable.n_chosen)
|
||||
return tf.cast(skip, tf.bool)
|
|
@ -1,205 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import logging
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
|
||||
from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
|
||||
|
||||
from .mutator import EnasMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnasTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
loss,
|
||||
metrics,
|
||||
reward_function,
|
||||
optimizer,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
dataset_train,
|
||||
dataset_valid,
|
||||
log_frequency=100,
|
||||
entropy_weight=0.0001,
|
||||
skip_weight=0.8,
|
||||
baseline_decay=0.999,
|
||||
child_steps=500,
|
||||
mutator_lr=0.00035,
|
||||
mutator_steps=50,
|
||||
mutator_steps_aggregate=20,
|
||||
aux_weight=0.4,
|
||||
test_arc_per_epoch=1,
|
||||
):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.metrics = metrics
|
||||
self.reward_function = reward_function
|
||||
self.optimizer = optimizer
|
||||
self.batch_size = batch_size
|
||||
self.num_epochs = num_epochs
|
||||
|
||||
x, y = dataset_train
|
||||
split = int(len(x) * 0.9)
|
||||
self.train_set = tf.data.Dataset.from_tensor_slices((x[:split], y[:split]))
|
||||
self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:]))
|
||||
self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid)
|
||||
|
||||
self.log_frequency = log_frequency
|
||||
self.entropy_weight = entropy_weight
|
||||
self.skip_weight = skip_weight
|
||||
self.baseline_decay = baseline_decay
|
||||
self.child_steps = child_steps
|
||||
self.mutator_lr = mutator_lr
|
||||
self.mutator_steps = mutator_steps
|
||||
self.mutator_steps_aggregate = mutator_steps_aggregate
|
||||
self.aux_weight = aux_weight
|
||||
self.test_arc_per_epoch = test_arc_per_epoch
|
||||
|
||||
self.mutator = EnasMutator(model)
|
||||
self.mutator_optim = Adam(learning_rate=self.mutator_lr)
|
||||
|
||||
self.baseline = 0.0
|
||||
|
||||
def train(self, validate=True):
|
||||
for epoch in range(self.num_epochs):
|
||||
logger.info("Epoch %d Training", epoch + 1)
|
||||
self.train_one_epoch(epoch)
|
||||
logger.info("Epoch %d Validating", epoch + 1)
|
||||
self.validate_one_epoch(epoch)
|
||||
|
||||
def validate(self):
|
||||
self.validate_one_epoch(-1)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
train_loader, valid_loader = self._create_train_loader()
|
||||
|
||||
# Sample model and train
|
||||
meters = AverageMeterGroup()
|
||||
|
||||
for step in range(1, self.child_steps + 1):
|
||||
x, y = next(train_loader)
|
||||
self.mutator.reset()
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
logits = self.model(x, training=True)
|
||||
if isinstance(logits, tuple):
|
||||
logits, aux_logits = logits
|
||||
aux_loss = self.loss(aux_logits, y)
|
||||
else:
|
||||
aux_loss = 0.0
|
||||
metrics = self.metrics(y, logits)
|
||||
loss = self.loss(y, logits) + self.aux_weight * aux_loss
|
||||
|
||||
grads = tape.gradient(loss, self.model.trainable_weights)
|
||||
grads = fill_zero_grads(grads, self.model.trainable_weights)
|
||||
grads, _ = tf.clip_by_global_norm(grads, 5.0)
|
||||
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
|
||||
|
||||
metrics["loss"] = tf.reduce_mean(loss).numpy()
|
||||
meters.update(metrics)
|
||||
|
||||
if self.log_frequency and step % self.log_frequency == 0:
|
||||
logger.info(
|
||||
"Model Epoch [%d/%d] Step [%d/%d] %s",
|
||||
epoch + 1,
|
||||
self.num_epochs,
|
||||
step,
|
||||
self.child_steps,
|
||||
meters,
|
||||
)
|
||||
|
||||
# Train sampler (mutator)
|
||||
meters = AverageMeterGroup()
|
||||
for mutator_step in range(1, self.mutator_steps + 1):
|
||||
grads_list = []
|
||||
for step in range(1, self.mutator_steps_aggregate + 1):
|
||||
with tf.GradientTape() as tape:
|
||||
x, y = next(valid_loader)
|
||||
self.mutator.reset()
|
||||
|
||||
logits = self.model(x, training=False)
|
||||
metrics = self.metrics(y, logits)
|
||||
reward = (
|
||||
self.reward_function(y, logits)
|
||||
+ self.entropy_weight * self.mutator.sample_entropy
|
||||
)
|
||||
self.baseline = self.baseline * self.baseline_decay + reward * (
|
||||
1 - self.baseline_decay
|
||||
)
|
||||
loss = self.mutator.sample_log_prob * (reward - self.baseline)
|
||||
loss += self.skip_weight * self.mutator.sample_skip_penalty
|
||||
|
||||
meters.update(
|
||||
{
|
||||
"reward": reward,
|
||||
"loss": tf.reduce_mean(loss).numpy(),
|
||||
"ent": self.mutator.sample_entropy.numpy(),
|
||||
"log_prob": self.mutator.sample_log_prob.numpy(),
|
||||
"baseline": self.baseline,
|
||||
"skip": self.mutator.sample_skip_penalty,
|
||||
}
|
||||
)
|
||||
|
||||
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
|
||||
if self.log_frequency and cur_step % self.log_frequency == 0:
|
||||
logger.info(
|
||||
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s",
|
||||
epoch + 1,
|
||||
self.num_epochs,
|
||||
mutator_step,
|
||||
self.mutator_steps,
|
||||
step,
|
||||
self.mutator_steps_aggregate,
|
||||
meters,
|
||||
)
|
||||
|
||||
grads = tape.gradient(loss, self.mutator.trainable_weights)
|
||||
grads = fill_zero_grads(grads, self.mutator.trainable_weights)
|
||||
grads_list.append(grads)
|
||||
total_grads = [
|
||||
tf.math.add_n(weight_grads) for weight_grads in zip(*grads_list)
|
||||
]
|
||||
total_grads, _ = tf.clip_by_global_norm(total_grads, 5.0)
|
||||
self.mutator_optim.apply_gradients(
|
||||
zip(total_grads, self.mutator.trainable_weights)
|
||||
)
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
test_loader = self._create_validate_loader()
|
||||
|
||||
for arc_id in range(self.test_arc_per_epoch):
|
||||
meters = AverageMeterGroup()
|
||||
for x, y in test_loader:
|
||||
self.mutator.reset()
|
||||
logits = self.model(x, training=False)
|
||||
if isinstance(logits, tuple):
|
||||
logits, _ = logits
|
||||
metrics = self.metrics(y, logits)
|
||||
loss = self.loss(y, logits)
|
||||
metrics["loss"] = tf.reduce_mean(loss).numpy()
|
||||
meters.update(metrics)
|
||||
|
||||
logger.info(
|
||||
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
|
||||
epoch + 1,
|
||||
self.num_epochs,
|
||||
arc_id + 1,
|
||||
self.test_arc_per_epoch,
|
||||
meters.summary(),
|
||||
)
|
||||
|
||||
def _create_train_loader(self):
|
||||
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
|
||||
test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size)
|
||||
return iter(train_set), iter(test_set)
|
||||
|
||||
def _create_validate_loader(self):
|
||||
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
__all__ = ['NasBench201Cell']
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, List, Dict, Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.nas.nn.pytorch import LayerChoice
|
||||
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
|
||||
|
||||
|
||||
class NasBench201Cell(nn.Module):
|
||||
"""
|
||||
Cell structure that is proposed in NAS-Bench-201.
|
||||
|
||||
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
|
||||
|
||||
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
|
||||
For every i < j, there is an edge from i-th node to j-th node.
|
||||
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
|
||||
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
|
||||
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
|
||||
and returns a ``Module``.
|
||||
|
||||
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
|
||||
|
||||
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
|
||||
and :math:`N` is defined by ``num_tensors``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op_candidates : list of callable
|
||||
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
|
||||
in_features : int
|
||||
Input dimension of cell.
|
||||
out_features : int
|
||||
Output dimension of cell.
|
||||
num_tensors : int
|
||||
Number of tensors in the cell (input included). Default: 4
|
||||
label : str
|
||||
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_dict(x):
|
||||
if isinstance(x, list):
|
||||
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
|
||||
return OrderedDict(x)
|
||||
|
||||
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
|
||||
in_features: int, out_features: int, num_tensors: int = 4,
|
||||
label: Optional[str] = None):
|
||||
super().__init__()
|
||||
self._label = generate_new_label(label)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.num_tensors = num_tensors
|
||||
|
||||
op_candidates = self._make_dict(op_candidates)
|
||||
|
||||
for tid in range(1, num_tensors):
|
||||
node_ops = nn.ModuleList()
|
||||
for j in range(tid):
|
||||
inp = in_features if j == 0 else out_features
|
||||
op_choices = OrderedDict([(key, cls(inp, out_features))
|
||||
for key, cls in op_candidates.items()])
|
||||
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
|
||||
self.layers.append(node_ops)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The forward of input choice is simply selecting first on all choices.
|
||||
It shouldn't be called directly by users in most cases.
|
||||
"""
|
||||
tensors: List[torch.Tensor] = [inputs]
|
||||
for layer in self.layers:
|
||||
current_tensor: List[torch.Tensor] = []
|
||||
for i, op in enumerate(layer): # type: ignore
|
||||
current_tensor.append(op(tensors[i])) # type: ignore
|
||||
tensors.append(torch.sum(torch.stack(current_tensor), 0))
|
||||
return tensors[-1]
|
|
@ -3,21 +3,17 @@
|
|||
|
||||
import copy
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, List, Dict, Union, Tuple, Optional
|
||||
from typing import Callable, List, Union, Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
|
||||
from nni.nas.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
|
||||
|
||||
from .api import LayerChoice, ValueChoice, ValueChoiceX, ChoiceOf
|
||||
from .cell import Cell
|
||||
from .nasbench101 import NasBench101Cell, NasBench101Mutator
|
||||
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
|
||||
from .choice import ValueChoice, ValueChoiceX, ChoiceOf
|
||||
from .mutation_utils import Mutable, get_fixed_value
|
||||
|
||||
|
||||
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']
|
||||
__all__ = ['Repeat']
|
||||
|
||||
|
||||
class Repeat(Mutable):
|
||||
|
@ -159,77 +155,3 @@ class Repeat(Mutable):
|
|||
|
||||
def __len__(self):
|
||||
return self.max_depth
|
||||
|
||||
|
||||
class NasBench201Cell(nn.Module):
|
||||
"""
|
||||
Cell structure that is proposed in NAS-Bench-201.
|
||||
|
||||
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
|
||||
|
||||
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
|
||||
For every i < j, there is an edge from i-th node to j-th node.
|
||||
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
|
||||
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
|
||||
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
|
||||
and returns a ``Module``.
|
||||
|
||||
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
|
||||
|
||||
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
|
||||
and :math:`N` is defined by ``num_tensors``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op_candidates : list of callable
|
||||
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
|
||||
in_features : int
|
||||
Input dimension of cell.
|
||||
out_features : int
|
||||
Output dimension of cell.
|
||||
num_tensors : int
|
||||
Number of tensors in the cell (input included). Default: 4
|
||||
label : str
|
||||
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_dict(x):
|
||||
if isinstance(x, list):
|
||||
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
|
||||
return OrderedDict(x)
|
||||
|
||||
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
|
||||
in_features: int, out_features: int, num_tensors: int = 4,
|
||||
label: Optional[str] = None):
|
||||
super().__init__()
|
||||
self._label = generate_new_label(label)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.num_tensors = num_tensors
|
||||
|
||||
op_candidates = self._make_dict(op_candidates)
|
||||
|
||||
for tid in range(1, num_tensors):
|
||||
node_ops = nn.ModuleList()
|
||||
for j in range(tid):
|
||||
inp = in_features if j == 0 else out_features
|
||||
op_choices = OrderedDict([(key, cls(inp, out_features))
|
||||
for key, cls in op_candidates.items()])
|
||||
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
|
||||
self.layers.append(node_ops)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The forward of input choice is simply selecting first on all choices.
|
||||
It shouldn't be called directly by users in most cases.
|
||||
"""
|
||||
tensors: List[torch.Tensor] = [inputs]
|
||||
for layer in self.layers:
|
||||
current_tensor: List[torch.Tensor] = []
|
||||
for i, op in enumerate(layer): # type: ignore
|
||||
current_tensor.append(op(tensors[i])) # type: ignore
|
||||
tensors.append(torch.sum(torch.stack(current_tensor), 0))
|
||||
return tensors[-1]
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class StackedLSTMCell(nn.Module):
|
||||
def __init__(self, layers, size, bias):
|
||||
super().__init__()
|
||||
self.lstm_num_layers = layers
|
||||
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
|
||||
for _ in range(self.lstm_num_layers)])
|
||||
|
||||
def forward(self, inputs, hidden):
|
||||
prev_h, prev_c = hidden
|
||||
next_h, next_c = [], []
|
||||
for i, m in enumerate(self.lstm_modules):
|
||||
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
|
||||
next_c.append(curr_c)
|
||||
next_h.append(curr_h)
|
||||
# current implementation only supports batch size equals 1,
|
||||
# but the algorithm does not necessarily have this limitation
|
||||
inputs = curr_h[-1].view(1, -1)
|
||||
return next_h, next_c
|
||||
|
||||
|
||||
class ReinforceField:
|
||||
"""
|
||||
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
|
||||
selected. Otherwise, any number of choices can be chosen.
|
||||
"""
|
||||
|
||||
def __init__(self, name, total, choose_one):
|
||||
self.name = name
|
||||
self.total = total
|
||||
self.choose_one = choose_one
|
||||
|
||||
def __repr__(self):
|
||||
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
|
||||
|
||||
|
||||
class ReinforceController(nn.Module):
|
||||
"""
|
||||
A controller that mutates the graph with RL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fields : list of ReinforceField
|
||||
List of fields to choose.
|
||||
lstm_size : int
|
||||
Controller LSTM hidden units.
|
||||
lstm_num_layers : int
|
||||
Number of layers for stacked LSTM.
|
||||
tanh_constant : float
|
||||
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
|
||||
skip_target : float
|
||||
Target probability that skipconnect (chosen by InputChoice) will appear.
|
||||
If the chosen number of inputs is away from the ``skip_connect``, there will be
|
||||
a sample skip penalty which is a KL divergence added.
|
||||
temperature : float
|
||||
Temperature constant that divides the logits.
|
||||
entropy_reduction : str
|
||||
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
|
||||
"""
|
||||
|
||||
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
|
||||
skip_target=0.4, temperature=None, entropy_reduction='sum'):
|
||||
super(ReinforceController, self).__init__()
|
||||
self.fields = fields
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_num_layers = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
self.skip_target = skip_target
|
||||
|
||||
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
|
||||
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
|
||||
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
|
||||
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
|
||||
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
|
||||
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
|
||||
requires_grad=False)
|
||||
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
|
||||
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
|
||||
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
|
||||
self.soft = nn.ModuleDict({
|
||||
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
|
||||
})
|
||||
self.embedding = nn.ModuleDict({
|
||||
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
|
||||
})
|
||||
|
||||
def resample(self):
|
||||
self._initialize()
|
||||
result = dict()
|
||||
for field in self.fields:
|
||||
result[field.name] = self._sample_single(field)
|
||||
return result
|
||||
|
||||
def _initialize(self):
|
||||
self._inputs = self.g_emb.data
|
||||
self._c = [torch.zeros((1, self.lstm_size),
|
||||
dtype=self._inputs.dtype,
|
||||
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
|
||||
self._h = [torch.zeros((1, self.lstm_size),
|
||||
dtype=self._inputs.dtype,
|
||||
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
|
||||
self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
|
||||
self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
|
||||
self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
|
||||
|
||||
def _lstm_next_step(self):
|
||||
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
|
||||
|
||||
def _sample_single(self, field):
|
||||
self._lstm_next_step()
|
||||
logit = self.soft[field.name](self._h[-1])
|
||||
if self.temperature is not None:
|
||||
logit /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
logit = self.tanh_constant * torch.tanh(logit)
|
||||
if field.choose_one:
|
||||
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
|
||||
log_prob = self.cross_entropy_loss(logit, sampled)
|
||||
self._inputs = self.embedding[field.name](sampled)
|
||||
else:
|
||||
logit = logit.view(-1, 1)
|
||||
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
|
||||
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
|
||||
skip_prob = torch.sigmoid(logit)
|
||||
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
|
||||
self.sample_skip_penalty += kl
|
||||
log_prob = self.cross_entropy_loss(logit, sampled)
|
||||
sampled = sampled.nonzero().view(-1)
|
||||
if sampled.sum().item():
|
||||
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
|
||||
else:
|
||||
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
|
||||
|
||||
sampled = sampled.detach().cpu().numpy().tolist()
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
if len(sampled) == 1:
|
||||
sampled = sampled[0]
|
||||
return sampled
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче