Promote Retiarii to NAS (step 1) - move files (#5020)

This commit is contained in:
Yuge Zhang 2022-07-27 13:18:21 +08:00 коммит произвёл GitHub
Родитель 481aa29299
Коммит 867871b244
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
138 изменённых файлов: 242 добавлений и 8414 удалений

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator
from .trainer import CdartsTrainer

Просмотреть файл

@ -1,143 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from nni.algorithms.nas.pytorch.darts import DartsMutator # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutables import LayerChoice # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutator import Mutator # pylint: disable=wrong-import-order
class RegularizedDartsMutator(DartsMutator):
"""
This is :class:`~nni.algorithms.nas.pytorch.darts.DartsMutator` basically, with two differences.
1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in
forward pass and thus consumes no memory.
2. Regularization on choices, to prevent the mutator from overfitting on some choices.
"""
def reset(self):
"""
Warnings
--------
Renamed :func:`~reset_with_loss` to return regularization loss on reset.
"""
raise ValueError("You should probably call `reset_with_loss`.")
def cut_choices(self, cut_num=2):
"""
Cut the choices with the smallest weights.
``cut_num`` should be the accumulative number of cutting, e.g., if first time cutting
is 2, the second time should be 4 to cut another two.
Parameters
----------
cut_num : int
Number of choices to cut, so far.
Warnings
--------
Though the parameters are set to :math:`-\infty` to be bypassed, they will still receive gradient of 0,
which introduced ``nan`` problem when calling ``optimizer.step()``. To solve this issue, a simple way is to
reset nan to :math:`-\infty` each time after the parameters are updated.
"""
# `cut_choices` is implemented but not used in current implementation of CdartsTrainer
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
_, idx = torch.topk(-self.choices[mutable.key], cut_num)
with torch.no_grad():
for i in idx:
self.choices[mutable.key][i] = -float("inf")
def reset_with_loss(self):
"""
Resample and return loss. If loss is 0, to avoid device issue, it will return ``None``.
Currently loss penalty are proportional to the L1-norm of parameters corresponding
to modules if their type name contains certain substrings. These substrings include: ``poolwithoutbn``,
``identity``, ``dilconv``.
"""
self._cache, reg_loss = self.sample_search()
return reg_loss
def sample_search(self):
result = super().sample_search()
loss = []
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
def need_reg(choice):
return any(t in str(type(choice)).lower() for t in ["poolwithoutbn", "identity", "dilconv"])
for i, choice in enumerate(mutable.choices):
if need_reg(choice):
norm = torch.abs(self.choices[mutable.key][i])
if norm < 1E10:
loss.append(norm)
if not loss:
return result, None
return result, sum(loss)
def export(self, logger=None):
"""
Export an architecture with logger. Genotype will be printed with logger.
Returns
-------
dict
A mapping from mutable keys to decisions.
"""
result = self.sample_final()
if hasattr(self.model, "plot_genotype") and logger is not None:
genotypes = self.model.plot_genotype(result, logger)
return result, genotypes
class RegularizedMutatorParallel(DistributedDataParallel):
"""
Parallelize :class:`~RegularizedDartsMutator`.
This makes :func:`~RegularizedDartsMutator.reset_with_loss` method parallelized,
also allowing :func:`~RegularizedDartsMutator.cut_choices` and :func:`~RegularizedDartsMutator.export`
to be easily accessible.
"""
def reset_with_loss(self):
"""
Parallelized :func:`~RegularizedDartsMutator.reset_with_loss`.
"""
result = self.module.reset_with_loss()
self.callback_queued = False
return result
def cut_choices(self, *args, **kwargs):
"""
Parallelized :func:`~RegularizedDartsMutator.cut_choices`.
"""
self.module.cut_choices(*args, **kwargs)
def export(self, logger):
"""
Parallelized :func:`~RegularizedDartsMutator.export`.
"""
return self.module.export(logger)
class DartsDiscreteMutator(Mutator):
"""
A mutator that applies the final sampling result of a parent mutator on another model to train.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : nni.nas.pytorch.mutator.Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
def __init__(self, model, parent_mutator):
super().__init__(model)
self.__dict__["parent_mutator"] = parent_mutator # avoid parameters to be included
def sample_search(self):
return self.parent_mutator.sample_final()

Просмотреть файл

@ -1,275 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import apex # pylint: disable=import-error
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator # pylint: disable=wrong-import-order
from nni.nas.pytorch.utils import AverageMeterGroup # pylint: disable=wrong-import-order
from .utils import CyclicIterator, TorchTensorEncoder, accuracy, reduce_metrics
PHASE_SMALL = "small"
PHASE_LARGE = "large"
class InteractiveKLLoss(nn.Module):
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
# self.kl_loss = nn.KLDivLoss(reduction = 'batchmean')
self.kl_loss = nn.KLDivLoss()
def forward(self, student, teacher):
return self.kl_loss(F.log_softmax(student / self.temperature, dim=1),
F.softmax(teacher / self.temperature, dim=1))
class CdartsTrainer(object):
"""
CDARTS trainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None,
regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True,
epochs=32, steps_per_epoch=None, loss_alpha=2, loss_T=2, distributed=True,
log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs',
w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4,
nasnet_lr=0.2, local_rank=0, share_module=True):
if logger is None:
logger = logging.getLogger(__name__)
train_loader, valid_loader = loaders
train_sampler, valid_sampler = samplers
self.train_loader = CyclicIterator(train_loader, train_sampler, distributed)
self.valid_loader = CyclicIterator(valid_loader, valid_sampler, distributed)
self.regular_coeff = regular_coeff
self.regular_ratio = regular_ratio
self.warmup_epochs = warmup_epochs
self.fix_head = fix_head
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
if self.steps_per_epoch is None:
self.steps_per_epoch = min(len(self.train_loader), len(self.valid_loader))
self.loss_alpha = loss_alpha
self.grad_clip = grad_clip
if interactive_type == "kl":
self.interactive_loss = InteractiveKLLoss(loss_T)
elif interactive_type == "smoothl1":
self.interactive_loss = nn.SmoothL1Loss()
self.loss_T = loss_T
self.distributed = distributed
self.log_frequency = log_frequency
self.main_proc = not distributed or local_rank == 0
self.logger = logger
self.checkpoint_dir = output_path
if self.main_proc:
os.makedirs(self.checkpoint_dir, exist_ok=True)
if distributed:
torch.distributed.barrier()
self.model_small = model_small
self.model_large = model_large
if self.fix_head:
for param in self.model_small.aux_head.parameters():
param.requires_grad = False
for param in self.model_large.aux_head.parameters():
param.requires_grad = False
self.mutator_small = RegularizedDartsMutator(self.model_small).cuda()
self.mutator_large = DartsDiscreteMutator(self.model_large, self.mutator_small).cuda()
self.criterion = criterion
self.optimizer_small = torch.optim.SGD(self.model_small.parameters(), w_lr,
momentum=w_momentum, weight_decay=w_weight_decay)
self.optimizer_large = torch.optim.SGD(self.model_large.parameters(), nasnet_lr,
momentum=w_momentum, weight_decay=w_weight_decay)
self.optimizer_alpha = torch.optim.Adam(self.mutator_small.parameters(), alpha_lr,
betas=(0.5, 0.999), weight_decay=alpha_weight_decay)
if distributed:
apex.parallel.convert_syncbn_model(self.model_small)
apex.parallel.convert_syncbn_model(self.model_large)
self.model_small = DistributedDataParallel(self.model_small, delay_allreduce=True)
self.model_large = DistributedDataParallel(self.model_large, delay_allreduce=True)
self.mutator_small = RegularizedMutatorParallel(self.mutator_small, delay_allreduce=True)
if share_module:
self.model_small.callback_queued = True
self.model_large.callback_queued = True
# mutator large never gets optimized, so do not need parallelized
def _warmup(self, phase, epoch):
assert phase in [PHASE_SMALL, PHASE_LARGE]
if phase == PHASE_SMALL:
model, optimizer = self.model_small, self.optimizer_small
elif phase == PHASE_LARGE:
model, optimizer = self.model_large, self.optimizer_large
model.train()
meters = AverageMeterGroup()
for step in range(self.steps_per_epoch):
x, y = next(self.train_loader)
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
logits_main, _ = model(x)
loss = self.criterion(logits_main, y)
loss.backward()
self._clip_grad_norm(model)
optimizer.step()
prec1, prec5 = accuracy(logits_main, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics, self.distributed)
meters.update(metrics)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s) %s", epoch + 1, self.epochs,
step + 1, self.steps_per_epoch, phase, meters)
def _clip_grad_norm(self, model):
if isinstance(model, DistributedDataParallel):
nn.utils.clip_grad_norm_(model.module.parameters(), self.grad_clip)
else:
nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
def _reset_nan(self, parameters):
with torch.no_grad():
for param in parameters:
for i, p in enumerate(param):
if p != p: # equivalent to `isnan(p)`
param[i] = float("-inf")
def _joint_train(self, epoch):
self.model_large.train()
self.model_small.train()
meters = AverageMeterGroup()
for step in range(self.steps_per_epoch):
trn_x, trn_y = next(self.train_loader)
val_x, val_y = next(self.valid_loader)
trn_x, trn_y = trn_x.cuda(), trn_y.cuda()
val_x, val_y = val_x.cuda(), val_y.cuda()
# step 1. optimize architecture
self.optimizer_alpha.zero_grad()
self.optimizer_large.zero_grad()
reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
(self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
loss_regular = self.mutator_small.reset_with_loss()
if loss_regular:
loss_regular *= reg_decay
logits_search, emsemble_logits_search = self.model_small(val_x)
logits_main, emsemble_logits_main = self.model_large(val_x)
loss_cls = (self.criterion(logits_search, val_y) + self.criterion(logits_main, val_y)) / self.loss_alpha
loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * (self.loss_T ** 2) * self.loss_alpha
loss = loss_cls + loss_interactive + loss_regular
loss.backward()
self._clip_grad_norm(self.model_large)
self.optimizer_large.step()
self.optimizer_alpha.step()
# NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices`
# step 2. optimize op weights
self.optimizer_small.zero_grad()
with torch.no_grad():
# resample architecture since parameters have been changed
self.mutator_small.reset_with_loss()
logits_search_train, _ = self.model_small(trn_x)
loss_weight = self.criterion(logits_search_train, trn_y)
loss_weight.backward()
self._clip_grad_norm(self.model_small)
self.optimizer_small.step()
metrics = {"loss_cls": loss_cls, "loss_interactive": loss_interactive,
"loss_regular": loss_regular, "loss_weight": loss_weight}
metrics = reduce_metrics(metrics, self.distributed)
meters.update(metrics)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint) %s", epoch + 1, self.epochs,
step + 1, self.steps_per_epoch, meters)
def train(self):
for epoch in range(self.epochs):
if epoch < self.warmup_epochs:
with torch.no_grad(): # otherwise grads will be retained on the architecture params
self.mutator_small.reset_with_loss()
self._warmup(PHASE_SMALL, epoch)
else:
with torch.no_grad():
self.mutator_large.reset()
self._warmup(PHASE_LARGE, epoch)
self._joint_train(epoch)
self.export(os.path.join(self.checkpoint_dir, "epoch_{:02d}.json".format(epoch)),
os.path.join(self.checkpoint_dir, "epoch_{:02d}.genotypes".format(epoch)))
def export(self, file, genotype_file):
if self.main_proc:
mutator_export, genotypes = self.mutator_small.export(self.logger)
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
with open(genotype_file, "w") as f:
f.write(str(genotypes))

Просмотреть файл

@ -1,76 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
import torch
import torch.distributed as dist
class CyclicIterator:
def __init__(self, loader, sampler, distributed):
self.loader = loader
self.sampler = sampler
self.epoch = 0
self.distributed = distributed
self._next_epoch()
def _next_epoch(self):
if self.distributed:
self.sampler.set_epoch(self.epoch)
self.iterator = iter(self.loader)
self.epoch += 1
def __len__(self):
return len(self.loader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iterator)
except StopIteration:
self._next_epoch()
return next(self.iterator)
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
return o.tolist()
return super().default(o)
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))
return res
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= float(os.environ["WORLD_SIZE"])
return rt
def reduce_metrics(metrics, distributed=False):
if distributed:
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
return {k: v.item() for k, v in metrics.items()}

Просмотреть файл

@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import get_and_apply_next_architecture

Просмотреть файл

@ -1,221 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import sys
import torch
import nni
from nni.runtime.env_vars import trial_env_vars
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
from nni.nas.pytorch.mutator import Mutator
logger = logging.getLogger(__name__)
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
LAYER_CHOICE = "layer_choice"
INPUT_CHOICE = "input_choice"
def get_and_apply_next_architecture(model):
"""
Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful,
similar to ``get_next_parameter`` for HPO.
It will generate search space based on ``model``.
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ``nnictl`` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
ClassicMutator(model)
class ClassicMutator(Mutator):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
def __init__(self, model):
super(ClassicMutator, self).__init__(model)
self._chosen_arch = {}
self._search_space = self._generate_search_space()
if NNI_GEN_SEARCH_SPACE in os.environ:
# dry run for only generating search space
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
sys.exit(0)
if trial_env_vars.NNI_PLATFORM is None:
logger.warning("This is in standalone mode, the chosen are the first one(s).")
self._chosen_arch = self._standalone_generate_chosen()
else:
# get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter()
if self._chosen_arch is None:
if trial_env_vars.NNI_PLATFORM == "unittest":
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
self._chosen_arch = self._standalone_generate_chosen()
else:
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
self.reset()
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
"""
Convert layer choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
# doesn't support multihot for layer choice yet
onehot_list = [False] * len(mutable)
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
onehot_list[idx] = True
return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable
def _sample_input_choice(self, mutable, idx, value, search_space_item):
"""
Convert input choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
candidate_repr = search_space_item["candidates"]
multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
def sample_search(self):
"""
See :meth:`sample_final`.
"""
return self.sample_final()
def sample_final(self):
"""
Convert the chosen arch and apply it on model.
"""
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
self._chosen_arch.keys())
result = dict()
for mutable in self.mutables:
if isinstance(mutable, (LayerChoice, InputChoice)):
assert mutable.key in self._chosen_arch, \
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result
def _standalone_generate_chosen(self):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice.
::
{ key_name: {"_value": "conv1",
"_idx": 0} }
{ key_name: {"_value": ["in1"],
"_idx": [0]} }
Returns
-------
dict
the chosen architecture
"""
chosen_arch = {}
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
elif val["_type"] == INPUT_CHOICE:
choices = val["_value"]["candidates"]
n_chosen = val["_value"]["n_chosen"]
if n_chosen is None:
n_chosen = len(choices)
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
else:
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
return chosen_arch
def _generate_search_space(self):
"""
Generate search space from mutables.
Here is the search space format:
::
{ key_name: {"_type": "layer_choice",
"_value": ["conv1", "conv2"]} }
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
the generated search space
"""
search_space = {}
for mutable in self.mutables:
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = mutable.names
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
def _dump_search_space(self, file_path):
with open(file_path, "w") as ss_file:
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)

Просмотреть файл

@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import CreamSupernetTrainer

Просмотреть файл

@ -1,403 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from copy import deepcopy
import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .utils import accuracy, reduce_metrics
logger = logging.getLogger(__name__)
class CreamSupernetTrainer(Trainer):
"""
This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
Parameters
----------
model : nn.Module
Model with mutables.
loss : callable
Called with logits and targets. Returns a loss tensor.
val_loss : callable
Called with logits and targets for validation only. Returns a loss tensor.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterablez
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
valid_loader : iterablez
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
mutator : Mutator
A mutator object that has been initialized with the model.
batch_size : int
Batch size.
log_frequency : int
Number of mini-batches to log metrics.
meta_sta_epoch : int
start epoch of using meta matching network to pick teacher architecture
update_iter : int
interval of updating meta matching networks
slices : int
batch size of mini training data in the process of training meta matching network
pool_size : int
board size
pick_method : basestring
how to pick teacher network
choice_num : int
number of operations in supernet
sta_num : int
layer number of each stage in supernet (5 stage in supernet)
acc_gap : int
maximum accuracy improvement to omit the limitation of flops
flops_dict : Dict
dictionary of each layer's operations in supernet
flops_fixed : int
flops of fixed part in supernet
local_rank : int
index of current rank
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def __init__(self, model, loss, val_loss,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, log_frequency=None,
meta_sta_epoch=20, update_iter=200, slices=2,
pool_size=10, pick_method='meta', choice_num=6,
sta_num=(4, 4, 4, 4, 4), acc_gap=5,
flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None):
assert torch.cuda.is_available()
super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None,
optimizer, num_epochs, None, None,
batch_size, None, None, log_frequency, callbacks)
self.model = model
self.loss = loss
self.val_loss = val_loss
self.train_loader = train_loader
self.valid_loader = valid_loader
self.log_frequency = log_frequency
self.batch_size = batch_size
self.optimizer = optimizer
self.model = model
self.loss = loss
self.num_epochs = num_epochs
self.meta_sta_epoch = meta_sta_epoch
self.update_iter = update_iter
self.slices = slices
self.pick_method = pick_method
self.pool_size = pool_size
self.local_rank = local_rank
self.choice_num = choice_num
self.sta_num = sta_num
self.acc_gap = acc_gap
self.flops_dict = flops_dict
self.flops_fixed = flops_fixed
self.current_student_arch = None
self.current_teacher_arch = None
self.main_proc = (local_rank == 0)
self.current_epoch = 0
self.prioritized_board = []
# size of prioritized board
def _board_size(self):
return len(self.prioritized_board)
# select teacher architecture according to the logit difference
def _select_teacher(self):
self._replace_mutator_cand(self.current_student_arch)
if self.pick_method == 'top1':
meta_value, teacher_cand = 0.5, sorted(
self.prioritized_board, reverse=True)[0][3]
elif self.pick_method == 'meta':
meta_value, cand_idx, teacher_cand = -1000000000, -1, None
for now_idx, item in enumerate(self.prioritized_board):
inputx = item[4]
output = torch.nn.functional.softmax(self.model(inputx), dim=1)
weight = self.model.module.forward_meta(output - item[5])
if weight > meta_value:
meta_value = weight
cand_idx = now_idx
teacher_cand = self.prioritized_board[cand_idx][3]
assert teacher_cand is not None
meta_value = torch.nn.functional.sigmoid(-weight)
else:
raise ValueError('Method Not supported')
return meta_value, teacher_cand
# check whether to update prioritized board
def _isUpdateBoard(self, prec1, flops):
if self.current_epoch <= self.meta_sta_epoch:
return False
if len(self.prioritized_board) < self.pool_size:
return True
if prec1 > self.prioritized_board[-1][1] + self.acc_gap:
return True
if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]:
return True
return False
# update prioritized board
def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops):
if self._isUpdateBoard(prec1, flops):
val_prec1 = prec1
training_data = deepcopy(inputs[:self.slices].detach())
if len(self.prioritized_board) == 0:
features = deepcopy(outputs[:self.slices].detach())
else:
features = deepcopy(
teacher_output[:self.slices].detach())
self.prioritized_board.append(
(val_prec1,
prec1,
flops,
self.current_student_arch,
training_data,
torch.nn.functional.softmax(
features,
dim=1)))
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
if len(self.prioritized_board) > self.pool_size:
del self.prioritized_board[-1]
# only update student network weights
def _update_student_weights_only(self, grad_1):
for weight, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1):
weight.grad = grad_item
torch.nn.utils.clip_grad_norm_(
self.model.module.rand_parameters(self.current_student_arch), 1)
self.optimizer.step()
for weight, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1):
del weight.grad
# only update meta networks weights
def _update_meta_weights_only(self, teacher_cand, grad_teacher):
for weight, grad_item in zip(self.model.module.rand_parameters(
teacher_cand, self.pick_method == 'meta'), grad_teacher):
weight.grad = grad_item
# clip gradients
torch.nn.utils.clip_grad_norm_(
self.model.module.rand_parameters(
self.current_student_arch, self.pick_method == 'meta'), 1)
self.optimizer.step()
for weight, grad_item in zip(self.model.module.rand_parameters(
teacher_cand, self.pick_method == 'meta'), grad_teacher):
del weight.grad
# simulate sgd updating
def _simulate_sgd_update(self, w, g, optimizer):
return g * optimizer.param_groups[-1]['lr'] + w
# split training images into several slices
def _get_minibatch_input(self, input): # pylint: disable=redefined-builtin
slice = self.slices # pylint: disable=redefined-builtin
x = deepcopy(input[:slice].clone().detach())
return x
# calculate 1st gradient of student architectures
def _calculate_1st_gradient(self, kd_loss):
self.optimizer.zero_grad()
grad = torch.autograd.grad(
kd_loss,
self.model.module.rand_parameters(self.current_student_arch),
create_graph=True)
return grad
# calculate 2nd gradient of meta networks
def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight):
self.optimizer.zero_grad()
grad_student_val = torch.autograd.grad(
validation_loss,
self.model.module.rand_parameters(self.current_student_arch),
retain_graph=True)
grad_teacher = torch.autograd.grad(
students_weight[0],
self.model.module.rand_parameters(
teacher_cand,
self.pick_method == 'meta'),
grad_outputs=grad_student_val)
return grad_teacher
# forward training data
def _forward_training(self, x, meta_value):
self._replace_mutator_cand(self.current_student_arch)
output = self.model(x)
with torch.no_grad():
self._replace_mutator_cand(self.current_teacher_arch)
teacher_output = self.model(x)
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
kd_loss = meta_value * \
self._cross_entropy_loss_with_soft_target(output, soft_label)
return kd_loss
# calculate soft target loss
def _cross_entropy_loss_with_soft_target(self, pred, soft_target):
logsoftmax = torch.nn.LogSoftmax()
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
# forward validation data
def _forward_validation(self, input, target): # pylint: disable=redefined-builtin
slice = self.slices # pylint: disable=redefined-builtin
x = input[slice:slice * 2].clone()
self._replace_mutator_cand(self.current_student_arch)
output_2 = self.model(x)
validation_loss = self.loss(output_2, target[slice:slice * 2])
return validation_loss
def _isUpdateMeta(self, batch_idx):
isUpdate = True
isUpdate &= (self.current_epoch > self.meta_sta_epoch)
isUpdate &= (batch_idx > 0)
isUpdate &= (batch_idx % self.update_iter == 0)
isUpdate &= (self._board_size() > 0)
return isUpdate
def _replace_mutator_cand(self, cand):
self.mutator._cache = cand
# update meta matching networks
def _run_update(self, input, target, batch_idx): # pylint: disable=redefined-builtin
if self._isUpdateMeta(batch_idx):
x = self._get_minibatch_input(input)
meta_value, teacher_cand = self._select_teacher()
kd_loss = self._forward_training(x, meta_value)
# calculate 1st gradient
grad_1st = self._calculate_1st_gradient(kd_loss)
# simulate updated student weights
students_weight = [
self._simulate_sgd_update(
p, grad_item, self.optimizer) for p, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1st)]
# update student weights
self._update_student_weights_only(grad_1st)
validation_loss = self._forward_validation(input, target)
# calculate 2nd gradient
grad_teacher = self._calculate_2nd_gradient(validation_loss, teacher_cand, students_weight)
# update meta matching networks
self._update_meta_weights_only(teacher_cand, grad_teacher)
# delete internal variants
del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight
def _get_cand_flops(self, cand):
flops = 0
for block_id, block in enumerate(cand):
if block == 'LayerChoice1' or block_id == 'LayerChoice23':
continue
for idx, choice in enumerate(cand[block]):
flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
return flops + self.flops_fixed
def train_one_epoch(self, epoch):
self.current_epoch = epoch
meters = AverageMeterGroup()
self.steps_per_epoch = len(self.train_loader)
for step, (input_data, target) in enumerate(self.train_loader):
self.mutator.reset()
self.current_student_arch = self.mutator._cache
input_data, target = input_data.cuda(), target.cuda()
# calculate flops of current architecture
cand_flops = self._get_cand_flops(self.mutator._cache)
# update meta matching network
self._run_update(input_data, target, step)
if self._board_size() > 0:
# select teacher architecture
meta_value, teacher_cand = self._select_teacher()
self.current_teacher_arch = teacher_cand
# forward supernet
if self._board_size() == 0 or epoch <= self.meta_sta_epoch:
self._replace_mutator_cand(self.current_student_arch)
output = self.model(input_data)
loss = self.loss(output, target)
kd_loss, teacher_output, teacher_cand = None, None, None
else:
self._replace_mutator_cand(self.current_student_arch)
output = self.model(input_data)
gt_loss = self.loss(output, target)
with torch.no_grad():
self._replace_mutator_cand(self.current_teacher_arch)
teacher_output = self.model(input_data).detach()
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label)
loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2
# update network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# update metrics
prec1, prec5 = accuracy(output, target, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics)
meters.update(metrics)
# update prioritized board
self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs,
step + 1, len(self.train_loader), meters)
if self.main_proc and self.num_epochs == epoch + 1:
for idx, i in enumerate(self.prioritized_board):
logger.info("No.%s %s", idx, i[:4])
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
self.mutator.reset()
logits = self.model(x)
loss = self.val_loss(logits, y)
prec1, prec5 = accuracy(logits, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)

Просмотреть файл

@ -1,37 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch.distributed as dist
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))
return res
def reduce_metrics(metrics):
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= float(os.environ["WORLD_SIZE"])
return rt

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import DartsMutator
from .trainer import DartsTrainer

Просмотреть файл

@ -1,85 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
_logger = logging.getLogger(__name__)
class DartsMutator(Mutator):
"""
Connects the model in a DARTS (differentiable) way.
An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no
op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
(not chosen).
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
Attributes
----------
choices: ParameterDict
dict that maps keys of LayerChoices to weighted-connection float tensors.
"""
def __init__(self, model):
super().__init__(model)
self.choices = nn.ParameterDict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
def device(self):
for v in self.choices.values():
return v.device
def sample_search(self):
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
elif isinstance(mutable, InputChoice):
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
return result
def sample_final(self):
result = dict()
edges_max = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice):
if mutable.n_chosen is not None:
weights = []
for src_key in mutable.choose_from:
if src_key not in edges_max:
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
weights.append(edges_max.get(src_key, 0.))
weights = torch.tensor(weights) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
# If an edge is never selected, there is no need to calculate any op on this edge.
# This is to eliminate redundant calculation.
result[src_key] = torch.zeros_like(result[src_key])
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
else:
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result

Просмотреть файл

@ -1,214 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import torch
import torch.nn as nn
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import DartsMutator
logger = logging.getLogger(__name__)
class DartsTrainer(Trainer):
"""
DARTS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
n_train = len(self.dataset_train)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_one_epoch(self, epoch):
self.model.train()
self.mutator.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# phase 1. architecture step
self.ctrl_optim.zero_grad()
if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 2: child network step
self.optimizer.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
self.optimizer.step()
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
self.mutator.eval()
meters = AverageMeterGroup()
with torch.no_grad():
self.mutator.reset()
for step, (X, y) in enumerate(self.test_loader):
X, y = X.to(self.device), y.to(self.device)
logits = self.model(X)
metrics = self.metrics(logits, y)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.test_loader), meters)
def _logits_and_loss(self, X, y):
self.mutator.reset()
logits = self.model(X)
loss = self.loss(logits, y)
self._write_graph_status()
return logits, loss
def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
"""
Compute unrolled loss and backward its gradients
"""
backup_params = copy.deepcopy(tuple(self.model.parameters()))
# do virtual step on training data
lr = self.optimizer.param_groups[0]["lr"]
momentum = self.optimizer.param_groups[0]["momentum"]
weight_decay = self.optimizer.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
# restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.optimizer.state[w].get("momentum_buffer", 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import EnasMutator
from .trainer import EnasTrainer

Просмотреть файл

@ -1,197 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class EnasMutator(Mutator):
"""
A mutator that mutates the graph with RL.
Parameters
----------
model : nn.Module
PyTorch model.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
super().__init__(model)
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
self.max_layer_choice = 0
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def sample_search(self):
self._initialize()
self._sample(self.mutables)
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
def _initialize(self):
self._choices = dict()
self._anchors_hid = dict()
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def _sample_layer_choice(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
return skip.bool()

Просмотреть файл

@ -1,209 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from itertools import cycle
import torch
import torch.nn as nn
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup, to_device
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
class EnasTrainer(Trainer):
"""
ENAS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4,
test_arc_per_epoch=1):
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.reward_function = reward_function
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.batch_size = batch_size
self.workers = workers
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.mutator_steps_aggregate = mutator_steps_aggregate
self.mutator_steps = mutator_steps
self.child_steps = child_steps
self.aux_weight = aux_weight
self.test_arc_per_epoch = test_arc_per_epoch
self.init_dataloader()
def init_dataloader(self):
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=self.batch_size,
num_workers=self.workers)
self.train_loader = cycle(self.train_loader)
self.valid_loader = cycle(self.valid_loader)
def train_one_epoch(self, epoch):
# Sample model and train
self.model.train()
self.mutator.eval()
meters = AverageMeterGroup()
for step in range(1, self.child_steps + 1):
x, y = next(self.train_loader)
x, y = to_device(x, self.device), to_device(y, self.device)
self.optimizer.zero_grad()
with torch.no_grad():
self.mutator.reset()
self._write_graph_status()
logits = self.model(x)
if isinstance(logits, tuple):
logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y)
else:
aux_loss = 0.
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
loss = loss + self.aux_weight * aux_loss
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.optimizer.step()
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
self.num_epochs, step, self.child_steps, meters)
# Train sampler (mutator)
self.model.eval()
self.mutator.train()
meters = AverageMeterGroup()
for mutator_step in range(1, self.mutator_steps + 1):
self.mutator_optim.zero_grad()
for step in range(1, self.mutator_steps_aggregate + 1):
x, y = next(self.valid_loader)
x, y = to_device(x, self.device), to_device(y, self.device)
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
self._write_graph_status()
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight:
reward += self.entropy_weight * self.mutator.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
metrics["reward"] = reward
metrics["loss"] = loss.item()
metrics["ent"] = self.mutator.sample_entropy.item()
metrics["log_prob"] = self.mutator.sample_log_prob.item()
metrics["baseline"] = self.baseline
metrics["skip"] = self.mutator.sample_skip_penalty
loss /= self.mutator_steps_aggregate
loss.backward()
meters.update(metrics)
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
if self.log_frequency is not None and cur_step % self.log_frequency == 0:
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs,
mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
meters)
nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
self.mutator_optim.step()
def validate_one_epoch(self, epoch):
with torch.no_grad():
for arc_id in range(self.test_arc_per_epoch):
meters = AverageMeterGroup()
for x, y in self.test_loader:
x, y = to_device(x, self.device), to_device(y, self.device)
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch,
meters.summary())

Просмотреть файл

@ -1,14 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import
from .mutator import FBNetMutator # noqa: F401
from .trainer import FBNetTrainer # noqa: F401
from .utils import ( # noqa: F401
LookUpTable,
NASConfig,
RegularizerLoss,
model_init,
supernet_sample,
)

Просмотреть файл

@ -1,268 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import, division, print_function
import torch
from torch import nn as nn
from torch.nn import functional as F
import numpy as np
from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.pytorch.mutables import LayerChoice
class MixedOp(nn.Module):
"""
This class is to instantiate and manage info of one LayerChoice.
It includes architecture weights and member functions for the weights.
"""
def __init__(self, mutable, latency):
"""
Parameters
----------
mutable : LayerChoice
A LayerChoice in user model
latency : List
performance cost for each op in mutable
"""
super(MixedOp, self).__init__()
self.latency = latency
n_choices = len(mutable)
self.path_alpha = nn.Parameter(
torch.FloatTensor([1.0 / n_choices for i in range(n_choices)])
)
self.path_alpha.requires_grad = False
self.temperature = 1.0
def get_path_alpha(self):
"""Return the architecture parameter."""
return self.path_alpha
def get_weighted_latency(self):
"""Return the weighted perf_cost of current mutable."""
soft_masks = self.probs_over_ops()
weighted_latency = sum(m * l for m, l in zip(soft_masks, self.latency))
return weighted_latency
def set_temperature(self, temperature):
"""
Set the annealed temperature for gumbel softmax.
Parameters
----------
temperature : float
The annealed temperature for gumbel softmax
"""
self.temperature = temperature
def to_requires_grad(self):
"""Enable gradient calculation."""
self.path_alpha.requires_grad = True
def to_disable_grad(self):
"""Disable gradient calculation."""
self.path_alpha.requires_grad = False
def probs_over_ops(self):
"""Apply gumbel softmax to generate probability distribution."""
return F.gumbel_softmax(self.path_alpha, self.temperature)
def forward(self, mutable, x):
"""
Define forward of LayerChoice.
Parameters
----------
mutable : LayerChoice
this layer's mutable
x : tensor
inputs of this layer, only support one input
Returns
-------
output: tensor
output of this layer
"""
candidate_ops = list(mutable)
soft_masks = self.probs_over_ops()
output = sum(m * op(x) for m, op in zip(soft_masks, candidate_ops))
return output
@property
def chosen_index(self):
"""
choose the op with max prob
Returns
-------
int
index of the chosen one
"""
alphas = self.path_alpha.data.detach().cpu().numpy()
index = int(np.argmax(alphas))
return index
class FBNetMutator(BaseMutator):
"""
This mutator initializes and operates all the LayerChoices of the supernet.
It is for the related trainer to control the training flow of LayerChoices,
coordinating with whole training process.
"""
def __init__(self, model, lookup_table):
"""
Init a MixedOp instance for each mutable i.e., LayerChoice.
And register the instantiated MixedOp in corresponding LayerChoice.
If does not register it in LayerChoice, DataParallel does'nt work then,
for architecture weights are not included in the DataParallel model.
When MixedOPs are registered, we use ```requires_grad``` to control
whether calculate gradients of architecture weights.
Parameters
----------
model : pytorch model
The model that users want to tune,
it includes search space defined with nni nas apis
lookup_table : class
lookup table object to manage model space information,
including candidate ops for each stage as the model space,
input channels/output channels/stride/fm_size as the layer config,
and the performance information for perf_cost accumulation.
"""
super(FBNetMutator, self).__init__(model)
self.mutable_list = []
# Collect the op names of the candidate ops within each mutable
ops_names_mutable = dict()
left = 0
right = 1
for stage_name in lookup_table.layer_num:
right = lookup_table.layer_num[stage_name]
stage_ops = lookup_table.lut_ops[stage_name]
ops_names = [op_name for op_name in stage_ops]
for i in range(left, left + right):
ops_names_mutable[i] = ops_names
left += right
# Create the mixed op
for i, mutable in enumerate(self.undedup_mutables):
ops_names = ops_names_mutable[i]
latency_mutable = lookup_table.lut_perf[i]
latency = [latency_mutable[op_name] for op_name in ops_names]
self.mutable_list.append(mutable)
mutable.registered_module = MixedOp(mutable, latency)
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
implementation is defined in mutator.
Parameters
----------
mutable: LayerChoice
forward logic of this input mutable
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
Returns
-------
torch.Tensor
output of this mutable, i.e., LayerChoice
int
index of the chosen op
"""
# FIXME: return mask, to be consistent with other algorithms
idx = mutable.registered_module.chosen_index
return mutable.registered_module(mutable, *args, **kwargs), idx
def num_arch_params(self):
"""
The number of mutables, i.e., LayerChoice
Returns
-------
int
the number of LayerChoice in user model
"""
return len(self.mutable_list)
def get_architecture_parameters(self):
"""
Get all the architecture parameters.
yield
-----
PyTorch Parameter
Return path_alpha of the traversed mutable
"""
for mutable in self.undedup_mutables:
yield mutable.registered_module.get_path_alpha()
def get_weighted_latency(self):
"""
Get the latency weighted by gumbel softmax coefficients.
yield
-----
Tuple
Return the weighted_latency of the traversed mutable
"""
for mutable in self.undedup_mutables:
yield mutable.registered_module.get_weighted_latency()
def set_temperature(self, temperature):
"""
Set the annealed temperature of the op for gumbel softmax.
Parameters
----------
temperature : float
The annealed temperature for gumbel softmax
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_temperature(temperature)
def arch_requires_grad(self):
"""
Make architecture weights require gradient
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_requires_grad()
def arch_disable_grad(self):
"""
Disable gradient of architecture weights, i.e., does not
calculate gradient for them.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_disable_grad()
def sample_final(self):
"""
Generate the final chosen architecture.
Returns
-------
dict
the choice of each mutable, i.e., LayerChoice
"""
result = dict()
for mutable in self.undedup_mutables:
assert isinstance(mutable, LayerChoice)
index = mutable.registered_module.chosen_index
# pylint: disable=not-callable
result[mutable.key] = (
F.one_hot(torch.tensor(index), num_classes=len(mutable))
.view(-1)
.bool(),
)
return result

Просмотреть файл

@ -1,413 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import, division, print_function
import json
import os
import time
import torch
import numpy as np
from torch.autograd import Variable
from nni.nas.pytorch.base_trainer import BaseTrainer
from nni.nas.pytorch.trainer import TorchTensorEncoder
from nni.nas.pytorch.utils import AverageMeter
from .mutator import FBNetMutator
from .utils import RegularizerLoss, accuracy
class FBNetTrainer(BaseTrainer):
def __init__(
self,
model,
model_optim,
criterion,
device,
device_ids,
lookup_table,
train_loader,
valid_loader,
n_epochs=120,
load_ckpt=False,
arch_path=None,
logger=None,
):
"""
Parameters
----------
model : pytorch model
the user model, which has mutables
model_optim : pytorch optimizer
the user defined optimizer
criterion : pytorch loss
the main task loss, nn.CrossEntropyLoss() is for classification
device : pytorch device
the devices to train/search the model
device_ids : list of int
the indexes of devices used for training
lookup_table : class
lookup table object for fbnet training
train_loader : pytorch data loader
data loader for the training set
valid_loader : pytorch data loader
data loader for the validation set
n_epochs : int
number of epochs to train/search
load_ckpt : bool
whether load checkpoint
arch_path : str
the path to store chosen architecture
logger : logger
the logger
"""
self.model = model
self.model_optim = model_optim
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
self.dev_num = len(device_ids)
self.n_epochs = n_epochs
self.lookup_table = lookup_table
self.config = lookup_table.config
self.start_epoch = self.config.start_epoch
self.temp = self.config.init_temperature
self.exp_anneal_rate = self.config.exp_anneal_rate
self.mode = self.config.mode
self.load_ckpt = load_ckpt
self.arch_path = arch_path
self.logger = logger
# scheduler of learning rate
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
model_optim, T_max=n_epochs, last_epoch=-1
)
# init mutator
self.mutator = FBNetMutator(model, lookup_table)
self.mutator.set_temperature(self.temp)
# DataParallel should be put behind the init of mutator
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
self.model.to(device)
# build architecture optimizer
self.arch_optimizer = torch.optim.AdamW(
self.mutator.get_architecture_parameters(),
self.config.nas_lr,
weight_decay=self.config.nas_weight_decay,
)
self.reg_loss = RegularizerLoss(config=self.config)
self.criterion = criterion
self.epoch = 0
def _layer_choice_sample(self):
"""
Sample the index of network within layer choice
"""
stages = [stage_name for stage_name in self.lookup_table.layer_num]
stage_lnum = [self.lookup_table.layer_num[stage] for stage in stages]
# get the choice idx in each layer
choice_ids = list()
layer_id = 0
for param in self.mutator.get_architecture_parameters():
param_np = param.cpu().detach().numpy()
op_idx = np.argmax(param_np)
choice_ids.append(op_idx)
self.logger.info(
"layer {}: {}, index: {}".format(layer_id, param_np, op_idx)
)
layer_id += 1
# get the arch_sample
choice_names = list()
layer_id = 0
for i, stage_name in enumerate(stages):
ops_names = [op for op in self.lookup_table.lut_ops[stage_name]]
for _ in range(stage_lnum[i]):
searched_op = ops_names[choice_ids[layer_id]]
choice_names.append(searched_op)
layer_id += 1
self.logger.info(choice_names)
return choice_names
def _get_perf_cost(self, requires_grad=True):
"""
Get the accumulated performance cost.
"""
perf_cost = Variable(
torch.zeros(1), requires_grad=requires_grad
).to(self.device, non_blocking=True)
for latency in self.mutator.get_weighted_latency():
perf_cost = perf_cost + latency
return perf_cost
def _validate(self):
"""
Do validation. During validation, LayerChoices use the mixed-op.
Returns
-------
float, float, float
average loss, average top1 accuracy, average top5 accuracy
"""
self.valid_loader.batch_sampler.drop_last = False
batch_time = AverageMeter("batch_time")
losses = AverageMeter("losses")
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
# test on validation set under eval mode
self.model.eval()
end = time.time()
with torch.no_grad():
for i, (images, labels) in enumerate(self.valid_loader):
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
output = self.model(images)
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == len(self.valid_loader):
test_log = (
"Valid" + ": [{0}/{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})".format(
i,
len(self.valid_loader) - 1,
batch_time=batch_time,
loss=losses,
top1=top1,
top5=top5,
)
)
self.logger.info(test_log)
return losses.avg, top1.avg, top5.avg
def _train_epoch(self, epoch, optimizer, arch_train=False):
"""
Train one epoch.
"""
batch_time = AverageMeter("batch_time")
data_time = AverageMeter("data_time")
losses = AverageMeter("losses")
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
# switch to train mode
self.model.train()
data_loader = self.valid_loader if arch_train else self.train_loader
end = time.time()
for i, (images, labels) in enumerate(data_loader):
data_time.update(time.time() - end)
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
output = self.model(images)
loss = self.criterion(output, labels)
# hardware-aware loss
perf_cost = self._get_perf_cost(requires_grad=True)
regu_loss = self.reg_loss(perf_cost)
if self.mode.startswith("mul"):
loss = loss * regu_loss
elif self.mode.startswith("add"):
loss = loss + regu_loss
# measure accuracy and record loss
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0].item(), images.size(0))
top5.update(acc5[0].item(), images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
batch_log = (
"Warmup Train [{0}][{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"Loss {losses.val:.4f} ({losses.avg:.4f})\t"
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\t".format(
epoch + 1,
i,
batch_time=batch_time,
data_time=data_time,
losses=losses,
top1=top1,
top5=top5,
)
)
self.logger.info(batch_log)
def _warm_up(self):
"""
Warm up the model, while the architecture weights are not trained.
"""
for epoch in range(self.epoch, self.start_epoch):
self.logger.info("\n--------Warmup epoch: %d--------\n", epoch + 1)
self._train_epoch(epoch, self.model_optim)
# adjust learning rate
self.scheduler.step()
# validation
val_loss, val_top1, val_top5 = self._validate()
val_log = (
"Warmup Valid [{0}/{1}]\t"
"loss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}".format(
epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5
)
)
self.logger.info(val_log)
if epoch % 10 == 0:
filename = os.path.join(
self.config.model_dir, "checkpoint_%s.pth" % epoch
)
self.save_checkpoint(epoch, filename)
def _train(self):
"""
Train the model, it trains model weights and architecute weights.
Architecture weights are trained according to the schedule.
Before updating architecture weights, ```requires_grad``` is enabled.
Then, it is disabled after the updating, in order not to update
architecture weights when training model weights.
"""
arch_param_num = self.mutator.num_arch_params()
self.logger.info("#arch_params: {}".format(arch_param_num))
self.epoch = max(self.start_epoch, self.epoch)
ckpt_path = self.config.model_dir
choice_names = None
top1_best = 0.0
for epoch in range(self.epoch, self.n_epochs):
self.logger.info("\n--------Train epoch: %d--------\n", epoch + 1)
# update the weight parameters
self._train_epoch(epoch, self.model_optim)
# adjust learning rate
self.scheduler.step()
self.logger.info("Update architecture parameters")
# update the architecture parameters
self.mutator.arch_requires_grad()
self._train_epoch(epoch, self.arch_optimizer, True)
self.mutator.arch_disable_grad()
# temperature annealing
self.temp = self.temp * self.exp_anneal_rate
self.mutator.set_temperature(self.temp)
# sample the architecture of sub-network
choice_names = self._layer_choice_sample()
# validate
val_loss, val_top1, val_top5 = self._validate()
val_log = (
"Valid [{0}]\t"
"loss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}".format(
epoch + 1, val_loss, val_top1, val_top5
)
)
self.logger.info(val_log)
if epoch % 10 == 0:
filename = os.path.join(ckpt_path, "checkpoint_%s.pth" % epoch)
self.save_checkpoint(epoch, filename, choice_names)
val_top1 = val_top1.cpu().as_numpy()
if val_top1 > top1_best:
filename = os.path.join(ckpt_path, "checkpoint_best.pth")
self.save_checkpoint(epoch, filename, choice_names)
top1_best = val_top1
def save_checkpoint(self, epoch, filename, choice_names=None):
"""
Save checkpoint of the whole model.
Saving model weights and architecture weights as ```filename```,
and saving currently chosen architecture in ```arch_path```.
"""
state = {
"model": self.model.state_dict(),
"optim": self.model_optim.state_dict(),
"epoch": epoch,
"arch_sample": choice_names,
}
torch.save(state, filename)
self.logger.info("Save checkpoint to {0:}".format(filename))
if self.arch_path:
self.export(self.arch_path)
def load_checkpoint(self, filename):
"""
Load the checkpoint from ```ckpt_path```.
"""
ckpt = torch.load(filename)
self.epoch = ckpt["epoch"]
self.model.load_state_dict(ckpt["model"])
self.model_optim.load_state_dict(ckpt["optim"])
def train(self):
"""
Train the whole model.
"""
if self.load_ckpt:
ckpt_path = self.config.model_dir
filename = os.path.join(ckpt_path, "checkpoint_best.pth")
if os.path.exists(filename):
self.load_checkpoint(filename)
if self.epoch < self.start_epoch:
self._warm_up()
self._train()
def export(self, file_name):
"""
Export the chosen architecture into a file
Parameters
----------
file_name : str
the file that stores exported chosen architecture
"""
exported_arch = self.mutator.sample_final()
with open(file_name, "w") as f:
json.dump(
exported_arch,
f,
indent=2,
sort_keys=True,
cls=TorchTensorEncoder,
)
def validate(self):
raise NotImplementedError
def checkpoint(self):
raise NotImplementedError

Просмотреть файл

@ -1,433 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import, division, print_function
import ast
import os
import timeit
import torch
import numpy as np
import torch.nn as nn
from nni.compression.pytorch.utils import count_flops_params
LUT_FILE = "lut.npy"
LUT_JSON_FILE = "lut.txt"
LUT_PATH = "lut"
DATA_TYPE = "float"
class NASConfig:
def __init__(
self,
perf_metric="flops",
lut_load=False,
lut_load_format="json",
model_dir=None,
nas_lr=0.01,
nas_weight_decay=5e-4,
mode="mul",
alpha=0.25,
beta=0.6,
start_epoch=50,
init_temperature=5.0,
exp_anneal_rate=np.exp(-0.045),
search_space=None,
):
# LUT of performance metric
# flops means the multiplies, latency means the time cost on platform
self.perf_metric = perf_metric
assert perf_metric in [
"flops",
"latency",
], "perf_metric should be ['flops', 'latency']"
# wether load or create lut file
self.lut_load = lut_load
assert lut_load_format in [
"json",
"numpy",
], "lut_load_format should be ['json', 'numpy']"
self.lut_load_format = lut_load_format
# necessary dirs
self.lut_en = model_dir is not None
if self.lut_en:
self.model_dir = model_dir
os.makedirs(model_dir, exist_ok=True)
self.lut_path = os.path.join(model_dir, LUT_PATH)
os.makedirs(self.lut_path, exist_ok=True)
# NAS learning setting
self.nas_lr = nas_lr
self.nas_weight_decay = nas_weight_decay
# hardware-aware loss setting
self.mode = mode
assert mode in ["mul", "add"], "mode should be ['mul', 'add']"
self.alpha = alpha
self.beta = beta
# NAS training setting
self.start_epoch = start_epoch
self.init_temperature = init_temperature
self.exp_anneal_rate = exp_anneal_rate
# definition of search blocks and space
self.search_space = search_space
class RegularizerLoss(nn.Module):
"""Auxilliary loss for hardware-aware NAS."""
def __init__(self, config):
"""
Parameters
----------
config : class
to manage the configuration for NAS training, and search space etc.
"""
super(RegularizerLoss, self).__init__()
self.mode = config.mode
self.alpha = config.alpha
self.beta = config.beta
def forward(self, perf_cost, batch_size=1):
"""
Parameters
----------
perf_cost : tensor
the accumulated performance cost
batch_size : int
batch size for normalization
Returns
-------
output: tensor
the hardware-aware constraint loss
"""
if self.mode == "mul":
log_loss = torch.log(perf_cost / batch_size) ** self.beta
return self.alpha * log_loss
elif self.mode == "add":
linear_loss = (perf_cost / batch_size) ** self.beta
return self.alpha * linear_loss
else:
raise NotImplementedError
def accuracy(output, target, topk=(1,)):
"""
Computes the precision@k for the specified values of k
Parameters
----------
output : pytorch tensor
output, e.g., predicted value
target : pytorch tensor
label
topk : tuple
specify top1 and top5
Returns
-------
list
accuracy of top1 and top5
"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def supernet_sample(model, state_dict, sampled_arch=[], lookup_table=None):
"""
Initialize the searched sub-model from supernet.
Parameters
----------
model : pytorch model
the created subnet
state_dict : checkpoint
the checkpoint of supernet, including the pre-trained params
sampled_arch : list of str
the searched layer names of the subnet
lookup_table : class
to manage the candidate ops, layer information and layer performance
"""
replace = list()
stages = [stage for stage in lookup_table.layer_num]
stage_lnum = [lookup_table.layer_num[stage] for stage in stages]
if sampled_arch:
layer_id = 0
for i, stage in enumerate(stages):
ops_names = [op_name for op_name in lookup_table.lut_ops[stage]]
for _ in range(stage_lnum[i]):
searched_op = sampled_arch[layer_id]
op_i = ops_names.index(searched_op)
replace.append(
[
"blocks.{}.".format(layer_id),
"blocks.{}.op.".format(layer_id),
"blocks.{}.{}.".format(layer_id, op_i),
]
)
layer_id += 1
model_init(model, state_dict, replace=replace)
def model_init(model, state_dict, replace=[]):
"""Initialize the model from state_dict."""
prefix = "module."
param_dict = dict()
for k, v in state_dict.items():
if k.startswith(prefix):
k = k[7:]
param_dict[k] = v
for k, (name, m) in enumerate(model.named_modules()):
if replace:
for layer_replace in replace:
assert len(layer_replace) == 3, "The elements should be three."
pre_scope, key, replace_key = layer_replace
if pre_scope in name:
name = name.replace(key, replace_key)
# Copy the state_dict to current model
if (name + ".weight" in param_dict) or (
name + ".running_mean" in param_dict
):
if isinstance(m, nn.BatchNorm2d):
shape = m.running_mean.shape
if shape == param_dict[name + ".running_mean"].shape:
if m.weight is not None:
m.weight.data = param_dict[name + ".weight"]
m.bias.data = param_dict[name + ".bias"]
m.running_mean = param_dict[name + ".running_mean"]
m.running_var = param_dict[name + ".running_var"]
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
shape = m.weight.data.shape
if shape == param_dict[name + ".weight"].shape:
m.weight.data = param_dict[name + ".weight"]
if m.bias is not None:
m.bias.data = param_dict[name + ".bias"]
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data = param_dict[name + ".weight"]
if m.bias is not None:
m.bias.data = param_dict[name + ".bias"]
class LookUpTable:
"""Build look-up table for NAS."""
def __init__(self, config, primitives):
"""
Parameters
----------
config : class
to manage the configuration for NAS training, and search space etc.
"""
self.config = config
# definition of search blocks and space
self.search_space = config.search_space
# layers for NAS
self.cnt_layers = len(self.search_space["input_shape"])
# constructors for each operation
self.lut_ops = {
stage_name: {
op_name: primitives[op_name]
for op_name in self.search_space["stages"][stage_name]["ops"]
}
for stage_name in self.search_space["stages"]
}
self.layer_num = {
stage_name: self.search_space["stages"][stage_name]["layer_num"]
for stage_name in self.search_space["stages"]
}
# arguments for the ops constructors, input_shapes just for convinience
self.layer_configs, self.layer_in_shapes = self._layer_configs()
# lookup_table
self.perf_metric = config.perf_metric
if config.lut_en:
self.lut_perf = None
self.lut_file = os.path.join(config.lut_path, LUT_FILE)
self.lut_json_file = LUT_JSON_FILE
if config.lut_load:
if config.lut_load_format == "numpy":
# Load data from numpy file
self._load_from_file()
else:
# Load data from json file
self._load_from_json_file()
else:
self._create_perfs()
def _layer_configs(self):
"""Generate basic params for different layers."""
# layer_configs are : c_in, c_out, stride, fm_size
layer_configs = [
[
self.search_space["input_shape"][layer_id][0],
self.search_space["channel_size"][layer_id],
self.search_space["strides"][layer_id],
self.search_space["fm_size"][layer_id],
]
for layer_id in range(self.cnt_layers)
]
# layer_in_shapes are (C_in, input_w, input_h)
layer_in_shapes = self.search_space["input_shape"]
return layer_configs, layer_in_shapes
def _create_perfs(self, cnt_of_runs=200):
"""Create performance cost for each op."""
if self.perf_metric == "latency":
self.lut_perf = self._calculate_latency(cnt_of_runs)
elif self.perf_metric == "flops":
self.lut_perf = self._calculate_flops()
self._write_lut_to_file()
def _calculate_flops(self, eps=0.001):
"""FLOPs cost."""
flops_lut = [{} for i in range(self.cnt_layers)]
layer_id = 0
for stage_name in self.lut_ops:
stage_ops = self.lut_ops[stage_name]
ops_num = self.layer_num[stage_name]
for _ in range(ops_num):
for op_name in stage_ops:
layer_config = self.layer_configs[layer_id]
key_params = {"fm_size": layer_config[3]}
op = stage_ops[op_name](*layer_config[0:3], **key_params)
# measured in Flops
in_shape = self.layer_in_shapes[layer_id]
x = (1, in_shape[0], in_shape[1], in_shape[2])
flops, _, _ = count_flops_params(op, x, verbose=False)
flops = eps if flops == 0.0 else flops
flops_lut[layer_id][op_name] = float(flops)
layer_id += 1
return flops_lut
def _calculate_latency(self, cnt_of_runs):
"""Latency cost."""
LATENCY_BATCH_SIZE = 1
latency_lut = [{} for i in range(self.cnt_layers)]
layer_id = 0
for stage_name in self.lut_ops:
stage_ops = self.lut_ops[stage_name]
ops_num = self.layer_num[stage_name]
for _ in range(ops_num):
for op_name in stage_ops:
layer_config = self.layer_configs[layer_id]
key_params = {"fm_size": layer_config[3]}
op = stage_ops[op_name](*layer_config[0:3], **key_params)
input_data = torch.randn(
(LATENCY_BATCH_SIZE, *self.layer_in_shapes[layer_id])
)
globals()["op"], globals()["input_data"] = op, input_data
total_time = timeit.timeit(
"output = op(input_data)",
setup="gc.enable()",
globals=globals(),
number=cnt_of_runs,
)
# measured in micro-second
latency_lut[layer_id][op_name] = (
total_time / cnt_of_runs / LATENCY_BATCH_SIZE * 1e6
)
layer_id += 1
return latency_lut
def _write_lut_to_file(self):
"""Save lut as numpy file."""
np.save(self.lut_file, self.lut_perf)
def _load_from_file(self):
"""Load numpy file."""
self.lut_perf = np.load(self.lut_file, allow_pickle=True)
def _load_from_json_file(self):
"""Load json file."""
"""
lut_json_file ('lut.txt') format:
{'op_name': operator_name,
'op_data_shape': (input_w, input_h, C_in, C_out, stride),
'op_dtype': data_type,
'op_latency': latency}
{...}
{...}
"""
latency_file = open(self.lut_json_file, "r")
ops_latency = latency_file.readlines()
"""ops_lut: {'op_name': {'op_data_shape': {'op_dtype': latency}}}"""
ops_lut = {}
for op_latency in ops_latency:
assert isinstance(op_latency, str) or isinstance(op_latency, dict)
if isinstance(op_latency, str):
record = ast.literal_eval(op_latency)
elif isinstance(op_latency, dict):
record = op_latency
op_name = record["op_name"]
"""op_data_shape: (input_w, input_h, C_in, C_out, stride)"""
op_data_shape = record["op_data_shape"]
op_dtype = record["op_dtype"]
op_latency = record["op_latency"]
if op_name not in ops_lut:
ops_lut[op_name] = {}
if op_data_shape not in ops_lut[op_name]:
ops_lut[op_name][op_data_shape] = {}
ops_lut[op_name][op_data_shape][op_dtype] = op_latency
self.lut_perf = [{} for i in range(self.cnt_layers)]
layer_id = 0
for stage_name in self.lut_ops:
stage_ops = self.lut_ops[stage_name]
ops_num = self.layer_num[stage_name]
for _ in range(ops_num):
for op_name in stage_ops:
layer_config = self.layer_configs[layer_id]
layer_in_shape = self.layer_in_shapes[layer_id]
input_w = layer_in_shape[1]
input_h = layer_in_shape[2]
c_in = layer_config[0]
c_out = layer_config[1]
stride = layer_config[2]
op_data_shape = (input_w, input_h, c_in, c_out, stride)
if op_name in ops_lut and op_data_shape in ops_lut[op_name]:
self.lut_perf[layer_id][op_name] = \
ops_lut[op_name][op_data_shape][DATA_TYPE]
layer_id += 1

Просмотреть файл

@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import PdartsTrainer

Просмотреть файл

@ -1,93 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import numpy as np
import torch
from torch import nn
from nni.algorithms.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator):
"""
It works with PdartsTrainer to calculate ops weights,
and drop weights in different PDARTS epochs.
"""
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
self.pdarts_epoch_index = pdarts_epoch_index
self.pdarts_num_to_drop = pdarts_num_to_drop
if switches is None:
self.switches = {}
else:
self.switches = switches
super(PdartsMutator, self).__init__(model)
# this loop go through mutables with different keys,
# it's mainly to update length of choices.
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(len(mutable))])
choices = self.choices[mutable.key]
operations_count = np.sum(switches)
# +1 and -1 are caused by zero operation in darts network
# the zero operation is not in choices list in network, but its weight are in,
# so it needs one more weights and switch for zero.
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
self.switches[mutable.key] = switches
# update LayerChoice instances in model,
# it's physically remove dropped choices operations.
for module in self.model.modules():
if isinstance(module, LayerChoice):
switches = self.switches.get(module.key)
choices = self.choices[module.key]
if len(module) > len(choices):
# from last to first, so that it won't effect previous indexes after removed one.
for index in range(len(switches)-1, -1, -1):
if switches[index] == False:
del module[index]
assert len(module) <= len(choices), "Failed to remove dropped choices."
def export(self):
# Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
results = super().sample_final()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
# As some operations are dropped physically,
# so it needs to fill back false to track dropped operations.
trained_result = results[mutable.key]
trained_index = 0
switches = self.switches[mutable.key]
result = torch.Tensor(switches).bool()
for index in range(len(result)):
if result[index]:
result[index] = trained_result[trained_index]
trained_index += 1
results[mutable.key] = result
return results
def drop_paths(self):
"""
This method is called when a PDARTS epoch is finished.
It prepares switches for next epoch.
candidate operations with False switch will be doppped in next epoch.
"""
all_switches = copy.deepcopy(self.switches)
for key in all_switches:
switches = all_switches[key]
idxs = []
for j in range(len(switches)):
if switches[j]:
idxs.append(j)
sorted_weights = self.choices[key].data.cpu().numpy()[:-1]
drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]]
for idx in drop:
switches[idxs[idx]] = False
return all_switches

Просмотреть файл

@ -1,86 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import BaseTrainer, TorchTensorEncoder
from .mutator import PdartsMutator
logger = logging.getLogger(__name__)
class PdartsTrainer(BaseTrainer):
"""
This trainer implements the PDARTS algorithm.
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
pdarts_num_layers means how many layers more than first epoch.
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
So that the grew network can in similar size.
"""
def __init__(self, model_creator, init_layers, metrics,
num_epochs, dataset_train, dataset_valid,
pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 1],
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, unrolled=False):
super(PdartsTrainer, self).__init__()
self.model_creator = model_creator
self.init_layers = init_layers
self.pdarts_num_layers = pdarts_num_layers
self.pdarts_num_to_drop = pdarts_num_to_drop
self.pdarts_epoch = len(pdarts_num_to_drop)
self.darts_parameters = {
"metrics": metrics,
"num_epochs": num_epochs,
"dataset_train": dataset_train,
"dataset_valid": dataset_valid,
"batch_size": batch_size,
"workers": workers,
"device": device,
"log_frequency": log_frequency,
"unrolled": unrolled
}
self.callbacks = callbacks if callbacks is not None else []
def train(self):
switches = None
for epoch in range(self.pdarts_epoch):
layers = self.init_layers+self.pdarts_num_layers[epoch]
model, criterion, optim, lr_scheduler = self.model_creator(layers)
self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)
for callback in self.callbacks:
callback.build(model, self.mutator, self)
callback.on_epoch_begin(epoch)
darts_callbacks = []
if lr_scheduler is not None:
darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
callbacks=darts_callbacks, **self.darts_parameters)
logger.info("start pdarts training epoch %s...", epoch)
self.trainer.train()
switches = self.mutator.drop_paths()
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def validate(self):
self.trainer.validate()
def export(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self):
raise NotImplementedError("Not implemented yet")

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import ProxylessNasMutator
from .trainer import ProxylessNasTrainer

Просмотреть файл

@ -1,478 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
import numpy as np
from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.pytorch.mutables import LayerChoice
from .utils import detach_variable
class ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = detach_variable(x)
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class MixedOp(nn.Module):
"""
This class is to instantiate and manage info of one LayerChoice.
It includes architecture weights, binary weights, and member functions
operating the weights.
forward_mode:
forward/backward mode for LayerChoice: None, two, full, and full_v2.
For training architecture weights, we use full_v2 by default, and for training
model weights, we use None.
"""
forward_mode = None
def __init__(self, mutable):
"""
Parameters
----------
mutable : LayerChoice
A LayerChoice in user model
"""
super(MixedOp, self).__init__()
self.ap_path_alpha = nn.Parameter(torch.Tensor(len(mutable)))
self.ap_path_wb = nn.Parameter(torch.Tensor(len(mutable)))
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
self.active_index = [0]
self.inactive_index = None
self.log_prob = None
self.current_prob_over_ops = None
self.n_choices = len(mutable)
def get_ap_path_alpha(self):
return self.ap_path_alpha
def to_requires_grad(self):
self.ap_path_alpha.requires_grad = True
self.ap_path_wb.requires_grad = True
def to_disable_grad(self):
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
def forward(self, mutable, x):
"""
Define forward of LayerChoice. For 'full_v2', backward is also defined.
The 'two' mode is explained in section 3.2.1 in the paper.
The 'full_v2' mode is explained in Appendix D in the paper.
Parameters
----------
mutable : LayerChoice
this layer's mutable
x : tensor
inputs of this layer, only support one input
Returns
-------
output: tensor
output of this layer
"""
if MixedOp.forward_mode == 'full' or MixedOp.forward_mode == 'two':
output = 0
for _i in self.active_index:
oi = self.candidate_ops[_i](x)
output = output + self.ap_path_wb[_i] * oi
for _i in self.inactive_index:
oi = self.candidate_ops[_i](x)
output = output + self.ap_path_wb[_i] * oi.detach()
elif MixedOp.forward_mode == 'full_v2':
def run_function(key, candidate_ops, active_id):
def forward(_x):
return candidate_ops[active_id](_x)
return forward
def backward_function(key, candidate_ops, active_id, binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(candidate_ops)):
if k != active_id:
out_k = candidate_ops[k](_x.data)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
output = ArchGradientFunction.apply(
x, self.ap_path_wb, run_function(mutable.key, list(mutable), self.active_index[0]),
backward_function(mutable.key, list(mutable), self.active_index[0], self.ap_path_wb))
else:
output = self.active_op(mutable)(x)
return output
@property
def probs_over_ops(self):
"""
Apply softmax on alpha to generate probability distribution
Returns
-------
pytorch tensor
probability distribution
"""
probs = F.softmax(self.ap_path_alpha, dim=0) # softmax to probability
return probs
@property
def chosen_index(self):
"""
choose the op with max prob
Returns
-------
int
index of the chosen one
numpy.float32
prob of the chosen one
"""
probs = self.probs_over_ops.data.cpu().numpy()
index = int(np.argmax(probs))
return index, probs[index]
def active_op(self, mutable):
"""
assume only one path is active
Returns
-------
PyTorch module
the chosen operation
"""
return mutable[self.active_index[0]]
@property
def active_op_index(self):
"""
return active op's index, the active op is sampled
Returns
-------
int
index of the active op
"""
return self.active_index[0]
def set_chosen_op_active(self):
"""
set chosen index, active and inactive indexes
"""
chosen_idx, _ = self.chosen_index
self.active_index = [chosen_idx]
self.inactive_index = [_i for _i in range(0, chosen_idx)] + \
[_i for _i in range(chosen_idx + 1, self.n_choices)]
def binarize(self, mutable):
"""
Sample based on alpha, and set binary weights accordingly.
ap_path_wb is set in this function, which is called binarize.
Parameters
----------
mutable : LayerChoice
this layer's mutable
"""
self.log_prob = None
# reset binary gates
self.ap_path_wb.data.zero_()
probs = self.probs_over_ops
if MixedOp.forward_mode == 'two':
# sample two ops according to probs
sample_op = torch.multinomial(probs.data, 2, replacement=False)
probs_slice = F.softmax(torch.stack([
self.ap_path_alpha[idx] for idx in sample_op
]), dim=0)
self.current_prob_over_ops = torch.zeros_like(probs)
for i, idx in enumerate(sample_op):
self.current_prob_over_ops[idx] = probs_slice[i]
# choose one to be active and the other to be inactive according to probs_slice
c = torch.multinomial(probs_slice.data, 1)[0] # 0 or 1
active_op = sample_op[c].item()
inactive_op = sample_op[1-c].item()
self.active_index = [active_op]
self.inactive_index = [inactive_op]
# set binary gate
self.ap_path_wb.data[active_op] = 1.0
else:
sample = torch.multinomial(probs, 1)[0].item()
self.active_index = [sample]
self.inactive_index = [_i for _i in range(0, sample)] + \
[_i for _i in range(sample + 1, len(mutable))]
self.log_prob = torch.log(probs[sample])
self.current_prob_over_ops = probs
self.ap_path_wb.data[sample] = 1.0
# avoid over-regularization
for choice in mutable:
for _, param in choice.named_parameters():
param.grad = None
@staticmethod
def delta_ij(i, j):
if i == j:
return 1
else:
return 0
def set_arch_param_grad(self, mutable):
"""
Calculate alpha gradient for this LayerChoice.
It is calculated using gradient of binary gate, probs of ops.
"""
binary_grads = self.ap_path_wb.grad.data
if self.active_op(mutable).is_zero_layer():
self.ap_path_alpha.grad = None
return
if self.ap_path_alpha.grad is None:
self.ap_path_alpha.grad = torch.zeros_like(self.ap_path_alpha.data)
if MixedOp.forward_mode == 'two':
involved_idx = self.active_index + self.inactive_index
probs_slice = F.softmax(torch.stack([
self.ap_path_alpha[idx] for idx in involved_idx
]), dim=0).data
for i in range(2):
for j in range(2):
origin_i = involved_idx[i]
origin_j = involved_idx[j]
self.ap_path_alpha.grad.data[origin_i] += \
binary_grads[origin_j] * probs_slice[j] * (MixedOp.delta_ij(i, j) - probs_slice[i])
for _i, idx in enumerate(self.active_index):
self.active_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
for _i, idx in enumerate(self.inactive_index):
self.inactive_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
else:
probs = self.probs_over_ops.data
for i in range(self.n_choices):
for j in range(self.n_choices):
self.ap_path_alpha.grad.data[i] += binary_grads[j] * probs[j] * (MixedOp.delta_ij(i, j) - probs[i])
return
def rescale_updated_arch_param(self):
"""
rescale architecture weights for the 'two' mode.
"""
if not isinstance(self.active_index[0], tuple):
assert self.active_op.is_zero_layer()
return
involved_idx = [idx for idx, _ in (self.active_index + self.inactive_index)]
old_alphas = [alpha for _, alpha in (self.active_index + self.inactive_index)]
new_alphas = [self.ap_path_alpha.data[idx] for idx in involved_idx]
offset = math.log(
sum([math.exp(alpha) for alpha in new_alphas]) / sum([math.exp(alpha) for alpha in old_alphas])
)
for idx in involved_idx:
self.ap_path_alpha.data[idx] -= offset
class ProxylessNasMutator(BaseMutator):
"""
This mutator initializes and operates all the LayerChoices of the input model.
It is for the corresponding trainer to control the training process of LayerChoices,
coordinating with whole training process.
"""
def __init__(self, model):
"""
Init a MixedOp instance for each mutable i.e., LayerChoice.
And register the instantiated MixedOp in corresponding LayerChoice.
If does not register it in LayerChoice, DataParallel does not work then,
because architecture weights are not included in the DataParallel model.
When MixedOPs are registered, we use ```requires_grad``` to control
whether calculate gradients of architecture weights.
Parameters
----------
model : pytorch model
The model that users want to tune, it includes search space defined with nni nas apis
"""
super(ProxylessNasMutator, self).__init__(model)
self._unused_modules = None
self.mutable_list = []
for mutable in self.undedup_mutables:
self.mutable_list.append(mutable)
mutable.registered_module = MixedOp(mutable)
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
implementation is defined in mutator.
Parameters
----------
mutable: LayerChoice
forward logic of this input mutable
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
Returns
-------
torch.Tensor
output of this mutable, i.e., LayerChoice
int
index of the chosen op
"""
# FIXME: return mask, to be consistent with other algorithms
idx = mutable.registered_module.active_op_index
return mutable.registered_module(mutable, *args, **kwargs), idx
def reset_binary_gates(self):
"""
For each LayerChoice, binarize binary weights
based on alpha to only activate one op.
It traverses all the mutables in the model to do this.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.binarize(mutable)
def set_chosen_op_active(self):
"""
For each LayerChoice, set the op with highest alpha as the chosen op.
Usually used for validation.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_chosen_op_active()
def num_arch_params(self):
"""
The number of mutables, i.e., LayerChoice
Returns
-------
int
the number of LayerChoice in user model
"""
return len(self.mutable_list)
def set_arch_param_grad(self):
"""
For each LayerChoice, calculate gradients for architecture weights, i.e., alpha
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_arch_param_grad(mutable)
def get_architecture_parameters(self):
"""
Get all the architecture parameters.
yield
-----
PyTorch Parameter
Return ap_path_alpha of the traversed mutable
"""
for mutable in self.undedup_mutables:
yield mutable.registered_module.get_ap_path_alpha()
def change_forward_mode(self, mode):
"""
Update forward mode of MixedOps, as training architecture weights and
model weights use different forward modes.
"""
MixedOp.forward_mode = mode
def get_forward_mode(self):
"""
Get forward mode of MixedOp
Returns
-------
string
the current forward mode of MixedOp
"""
return MixedOp.forward_mode
def rescale_updated_arch_param(self):
"""
Rescale architecture weights in 'two' mode.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.rescale_updated_arch_param()
def unused_modules_off(self):
"""
Remove unused modules for each mutables.
The removed modules are kept in ```self._unused_modules``` for resume later.
"""
self._unused_modules = []
for mutable in self.undedup_mutables:
mixed_op = mutable.registered_module
unused = {}
if self.get_forward_mode() in ['full', 'two', 'full_v2']:
involved_index = mixed_op.active_index + mixed_op.inactive_index
else:
involved_index = mixed_op.active_index
for i in range(mixed_op.n_choices):
if i not in involved_index:
unused[i] = mutable[i]
mutable[i] = None
self._unused_modules.append(unused)
def unused_modules_back(self):
"""
Resume the removed modules back.
"""
if self._unused_modules is None:
return
for m, unused in zip(self.mutable_list, self._unused_modules):
for i in unused:
m[i] = unused[i]
self._unused_modules = None
def arch_requires_grad(self):
"""
Make architecture weights require gradient
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_requires_grad()
def arch_disable_grad(self):
"""
Disable gradient of architecture weights, i.e., does not
calcuate gradient for them.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_disable_grad()
def sample_final(self):
"""
Generate the final chosen architecture.
Returns
-------
dict
the choice of each mutable, i.e., LayerChoice
"""
result = dict()
for mutable in self.undedup_mutables:
assert isinstance(mutable, LayerChoice)
index, _ = mutable.registered_module.chosen_index
# pylint: disable=not-callable
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool()
return result

Просмотреть файл

@ -1,500 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import time
import json
import logging
import torch
from torch import nn as nn
from nni.nas.pytorch.base_trainer import BaseTrainer
from nni.nas.pytorch.trainer import TorchTensorEncoder
from nni.nas.pytorch.utils import AverageMeter
from .mutator import ProxylessNasMutator
from .utils import cross_entropy_with_label_smoothing, accuracy
logger = logging.getLogger(__name__)
class ProxylessNasTrainer(BaseTrainer):
def __init__(self, model, model_optim, device,
train_loader, valid_loader, label_smoothing=0.1,
n_epochs=120, init_lr=0.025, binary_mode='full_v2',
arch_init_type='normal', arch_init_ratio=1e-3,
arch_optim_lr=1e-3, arch_weight_decay=0,
grad_update_arch_param_every=5, grad_update_steps=1,
warmup=True, warmup_epochs=25,
arch_valid_frequency=1,
load_ckpt=False, ckpt_path=None, arch_path=None):
"""
Parameters
----------
model : pytorch model
the user model, which has mutables
model_optim : pytorch optimizer
the user defined optimizer
device : pytorch device
the devices to train/search the model
train_loader : pytorch data loader
data loader for the training set
valid_loader : pytorch data loader
data loader for the validation set
label_smoothing : float
for label smoothing
n_epochs : int
number of epochs to train/search
init_lr : float
init learning rate for training the model
binary_mode : str
the forward/backward mode for the binary weights in mutator
arch_init_type : str
the way to init architecture parameters
arch_init_ratio : float
the ratio to init architecture parameters
arch_optim_lr : float
learning rate of the architecture parameters optimizer
arch_weight_decay : float
weight decay of the architecture parameters optimizer
grad_update_arch_param_every : int
update architecture weights every this number of minibatches
grad_update_steps : int
during each update of architecture weights, the number of steps to train
warmup : bool
whether to do warmup
warmup_epochs : int
the number of epochs to do during warmup
arch_valid_frequency : int
frequency of printing validation result
load_ckpt : bool
whether load checkpoint
ckpt_path : str
checkpoint path, if load_ckpt is True, ckpt_path cannot be None
arch_path : str
the path to store chosen architecture
"""
self.model = model
self.model_optim = model_optim
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
self.n_epochs = n_epochs
self.init_lr = init_lr
self.warmup = warmup
self.warmup_epochs = warmup_epochs
self.arch_valid_frequency = arch_valid_frequency
self.label_smoothing = label_smoothing
self.train_batch_size = train_loader.batch_sampler.batch_size
self.valid_batch_size = valid_loader.batch_sampler.batch_size
# update architecture parameters every this number of minibatches
self.grad_update_arch_param_every = grad_update_arch_param_every
# the number of steps per architecture parameter update
self.grad_update_steps = grad_update_steps
self.binary_mode = binary_mode
self.load_ckpt = load_ckpt
self.ckpt_path = ckpt_path
self.arch_path = arch_path
# init mutator
self.mutator = ProxylessNasMutator(model)
# DataParallel should be put behind the init of mutator
self.model = torch.nn.DataParallel(self.model)
self.model.to(self.device)
# iter of valid dataset for training architecture weights
self._valid_iter = None
# init architecture weights
self._init_arch_params(arch_init_type, arch_init_ratio)
# build architecture optimizer
self.arch_optimizer = torch.optim.Adam(self.mutator.get_architecture_parameters(),
arch_optim_lr,
weight_decay=arch_weight_decay,
betas=(0, 0.999),
eps=1e-8)
self.criterion = nn.CrossEntropyLoss()
self.warmup_curr_epoch = 0
self.train_curr_epoch = 0
def _init_arch_params(self, init_type='normal', init_ratio=1e-3):
"""
Initialize architecture weights
"""
for param in self.mutator.get_architecture_parameters():
if init_type == 'normal':
param.data.normal_(0, init_ratio)
elif init_type == 'uniform':
param.data.uniform_(-init_ratio, init_ratio)
else:
raise NotImplementedError
def _validate(self):
"""
Do validation. During validation, LayerChoices use the chosen active op.
Returns
-------
float, float, float
average loss, average top1 accuracy, average top5 accuracy
"""
self.valid_loader.batch_sampler.batch_size = self.valid_batch_size
self.valid_loader.batch_sampler.drop_last = False
self.mutator.set_chosen_op_active()
# remove unused modules to save memory
self.mutator.unused_modules_off()
# test on validation set under train mode
self.model.train()
batch_time = AverageMeter('batch_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
end = time.time()
with torch.no_grad():
for i, (images, labels) in enumerate(self.valid_loader):
images, labels = images.to(self.device), labels.to(self.device)
output = self.model(images)
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == len(self.valid_loader):
test_log = 'Valid' + ': [{0}/{1}]\t'\
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
format(i, len(self.valid_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
# return top5:
test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(top5=top5)
logger.info(test_log)
self.mutator.unused_modules_back()
return losses.avg, top1.avg, top5.avg
def _warm_up(self):
"""
Warm up the model, during warm up, architecture weights are not trained.
"""
lr_max = 0.05
data_loader = self.train_loader
nBatch = len(data_loader)
T_total = self.warmup_epochs * nBatch # total num of batches
for epoch in range(self.warmup_curr_epoch, self.warmup_epochs):
logger.info('\n--------Warmup epoch: %d--------\n', epoch + 1)
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
# switch to train mode
self.model.train()
end = time.time()
logger.info('warm_up epoch: %d', epoch)
for i, (images, labels) in enumerate(data_loader):
data_time.update(time.time() - end)
# lr
T_cur = epoch * nBatch + i
warmup_lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total))
for param_group in self.model_optim.param_groups:
param_group['lr'] = warmup_lr
images, labels = images.to(self.device), labels.to(self.device)
# compute output
self.mutator.reset_binary_gates() # random sample binary gates
self.mutator.unused_modules_off() # remove unused module for speedup
output = self.model(images)
if self.label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
else:
loss = self.criterion(output, labels)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
self.model.zero_grad()
loss.backward()
self.model_optim.step()
# unused modules back
self.mutator.unused_modules_back()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == nBatch:
batch_log = 'Warmup Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5, lr=warmup_lr)
logger.info(batch_log)
val_loss, val_top1, val_top5 = self._validate()
val_log = 'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}\t' \
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}M'. \
format(epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5, top1=top1, top5=top5)
logger.info(val_log)
self.save_checkpoint()
self.warmup_curr_epoch += 1
def _get_update_schedule(self, nBatch):
"""
Generate schedule for training architecture weights. Key means after which minibatch
to update architecture weights, value means how many steps for the update.
Parameters
----------
nBatch : int
the total number of minibatches in one epoch
Returns
-------
dict
the schedule for updating architecture weights
"""
schedule = {}
for i in range(nBatch):
if (i + 1) % self.grad_update_arch_param_every == 0:
schedule[i] = self.grad_update_steps
return schedule
def _calc_learning_rate(self, epoch, batch=0, nBatch=None):
"""
Update learning rate.
"""
T_total = self.n_epochs * nBatch
T_cur = epoch * nBatch + batch
lr = 0.5 * self.init_lr * (1 + math.cos(math.pi * T_cur / T_total))
return lr
def _adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
"""
Adjust learning of a given optimizer and return the new learning rate
Parameters
----------
optimizer : pytorch optimizer
the used optimizer
epoch : int
the current epoch number
batch : int
the current minibatch
nBatch : int
the total number of minibatches in one epoch
Returns
-------
float
the adjusted learning rate
"""
new_lr = self._calc_learning_rate(epoch, batch, nBatch)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
def _train(self):
"""
Train the model, it trains model weights and architecute weights.
Architecture weights are trained according to the schedule.
Before updating architecture weights, ```requires_grad``` is enabled.
Then, it is disabled after the updating, in order not to update
architecture weights when training model weights.
"""
nBatch = len(self.train_loader)
arch_param_num = self.mutator.num_arch_params()
binary_gates_num = self.mutator.num_arch_params()
logger.info('#arch_params: %d\t#binary_gates: %d', arch_param_num, binary_gates_num)
update_schedule = self._get_update_schedule(nBatch)
for epoch in range(self.train_curr_epoch, self.n_epochs):
logger.info('\n--------Train epoch: %d--------\n', epoch + 1)
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
# switch to train mode
self.model.train()
end = time.time()
for i, (images, labels) in enumerate(self.train_loader):
data_time.update(time.time() - end)
lr = self._adjust_learning_rate(self.model_optim, epoch, batch=i, nBatch=nBatch)
# train weight parameters
images, labels = images.to(self.device), labels.to(self.device)
self.mutator.reset_binary_gates()
self.mutator.unused_modules_off()
output = self.model(images)
if self.label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
else:
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
self.model.zero_grad()
loss.backward()
self.model_optim.step()
self.mutator.unused_modules_back()
if epoch > 0:
for _ in range(update_schedule.get(i, 0)):
start_time = time.time()
# GradientArchSearchConfig
self.mutator.arch_requires_grad()
arch_loss, exp_value = self._gradient_step()
self.mutator.arch_disable_grad()
used_time = time.time() - start_time
log_str = 'Architecture [%d-%d]\t Time %.4f\t Loss %.4f\t null %s' % \
(epoch + 1, i, used_time, arch_loss, exp_value)
logger.info(log_str)
batch_time.update(time.time() - end)
end = time.time()
# training log
if i % 10 == 0 or i + 1 == nBatch:
batch_log = 'Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5, lr=lr)
logger.info(batch_log)
# validate
if (epoch + 1) % self.arch_valid_frequency == 0:
val_loss, val_top1, val_top5 = self._validate()
val_log = 'Valid [{0}]\tloss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}\t' \
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
format(epoch + 1, val_loss, val_top1, val_top5, top1=top1, top5=top5)
logger.info(val_log)
self.save_checkpoint()
self.train_curr_epoch += 1
def _valid_next_batch(self):
"""
Get next one minibatch from validation set
Returns
-------
(tensor, tensor)
the tuple of images and labels
"""
if self._valid_iter is None:
self._valid_iter = iter(self.valid_loader)
try:
data = next(self._valid_iter)
except StopIteration:
self._valid_iter = iter(self.valid_loader)
data = next(self._valid_iter)
return data
def _gradient_step(self):
"""
This gradient step is for updating architecture weights.
Mutator is intensively used in this function to operate on
architecture weights.
Returns
-------
float, None
loss of the model, None
"""
# use the same batch size as train batch size for architecture weights
self.valid_loader.batch_sampler.batch_size = self.train_batch_size
self.valid_loader.batch_sampler.drop_last = True
self.model.train()
self.mutator.change_forward_mode(self.binary_mode)
time1 = time.time() # time
# sample a batch of data from validation set
images, labels = self._valid_next_batch()
images, labels = images.to(self.device), labels.to(self.device)
time2 = time.time() # time
self.mutator.reset_binary_gates()
self.mutator.unused_modules_off()
output = self.model(images)
time3 = time.time()
ce_loss = self.criterion(output, labels)
expected_value = None
loss = ce_loss
self.model.zero_grad()
loss.backward()
self.mutator.set_arch_param_grad()
self.arch_optimizer.step()
if self.mutator.get_forward_mode() == 'two':
self.mutator.rescale_updated_arch_param()
self.mutator.unused_modules_back()
self.mutator.change_forward_mode(None)
time4 = time.time()
logger.info('(%.4f, %.4f, %.4f)', time2 - time1, time3 - time2, time4 - time3)
return loss.data.item(), expected_value.item() if expected_value is not None else None
def save_checkpoint(self):
"""
Save checkpoint of the whole model. Saving model weights and architecture weights in
```ckpt_path```, and saving currently chosen architecture in ```arch_path```.
"""
if self.ckpt_path:
state = {
'warmup_curr_epoch': self.warmup_curr_epoch,
'train_curr_epoch': self.train_curr_epoch,
'model': self.model.state_dict(),
'optim': self.model_optim.state_dict(),
'arch_optim': self.arch_optimizer.state_dict()
}
torch.save(state, self.ckpt_path)
if self.arch_path:
self.export(self.arch_path)
def load_checkpoint(self):
"""
Load the checkpoint from ```ckpt_path```.
"""
assert self.ckpt_path is not None, "If load_ckpt is not None, ckpt_path should not be None"
ckpt = torch.load(self.ckpt_path)
self.warmup_curr_epoch = ckpt['warmup_curr_epoch']
self.train_curr_epoch = ckpt['train_curr_epoch']
self.model.load_state_dict(ckpt['model'])
self.model_optim.load_state_dict(ckpt['optim'])
self.arch_optimizer.load_state_dict(ckpt['arch_optim'])
def train(self):
"""
Train the whole model.
"""
if self.load_ckpt:
self.load_checkpoint()
if self.warmup:
self._warm_up()
self._train()
def export(self, file_name):
"""
Export the chosen architecture into a file
Parameters
----------
file_name : str
the file that stores exported chosen architecture
"""
exported_arch = self.mutator.sample_final()
with open(file_name, 'w') as f:
json.dump(exported_arch, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def validate(self):
raise NotImplementedError
def checkpoint(self):
raise NotImplementedError

Просмотреть файл

@ -1,78 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
def detach_variable(inputs):
"""
Detach variables
Parameters
----------
inputs : pytorch tensors
pytorch tensors
"""
if isinstance(inputs, tuple):
return tuple([detach_variable(x) for x in inputs])
else:
x = inputs.detach()
x.requires_grad = inputs.requires_grad
return x
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
"""
Parameters
----------
pred : pytorch tensor
predicted value
target : pytorch tensor
label
label_smoothing : float
the degree of label smoothing
Returns
-------
pytorch tensor
cross entropy
"""
logsoftmax = nn.LogSoftmax()
n_classes = pred.size(1)
# convert to one-hot
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros_like(pred)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def accuracy(output, target, topk=(1,)):
"""
Computes the precision@k for the specified values of k
Parameters
----------
output : pytorch tensor
output, e.g., predicted value
target : pytorch tensor
label
topk : tuple
specify top1 and top5
Returns
-------
list
accuracy of top1 and top5
"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

Просмотреть файл

@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import RandomMutator

Просмотреть файл

@ -1,39 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class RandomMutator(Mutator):
"""
Random mutator that samples a random candidate in the search space each time ``reset()``.
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior.
"""
def sample_search(self):
"""
Sample a random candidate.
"""
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
gen_index = torch.randint(high=len(mutable), size=(1, ))
result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool()
elif isinstance(mutable, InputChoice):
if mutable.n_chosen is None:
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
else:
perm = torch.randperm(mutable.n_candidates)
mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
return result
def sample_final(self):
"""
Same as :meth:`sample_search`.
"""
return self.sample_search()

Просмотреть файл

@ -1,6 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .evolution import SPOSEvolution
from .mutator import SPOSSupernetTrainingMutator
from .trainer import SPOSSupernetTrainer

Просмотреть файл

@ -1,223 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import re
from collections import deque
import numpy as np
from nni.tuner import Tuner
from nni.algorithms.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE
_logger = logging.getLogger(__name__)
class SPOSEvolution(Tuner):
"""
SPOS evolution tuner.
Parameters
----------
max_epochs : int
Maximum number of epochs to run.
num_select : int
Number of survival candidates of each epoch.
num_population : int
Number of candidates at the start of each epoch. If candidates generated by
crossover and mutation are not enough, the rest will be filled with random
candidates.
m_prob : float
The probability of mutation.
num_crossover : int
Number of candidates generated by crossover in each epoch.
num_mutation : int
Number of candidates generated by mutation in each epoch.
"""
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25):
assert num_population >= num_select
self.max_epochs = max_epochs
self.num_select = num_select
self.num_population = num_population
self.m_prob = m_prob
self.num_crossover = num_crossover
self.num_mutation = num_mutation
self.epoch = 0
self.candidates = []
self.search_space = None
self.random_state = np.random.RandomState(0)
# async status
self._to_evaluate_queue = deque()
self._sending_parameter_queue = deque()
self._pending_result_ids = set()
self._reward_dict = dict()
self._id2candidate = dict()
self._st_callback = None
def update_search_space(self, search_space):
"""
Handle the initialization/update event of search space.
"""
self._search_space = search_space
self._next_round()
def _next_round(self):
_logger.info("Epoch %d, generating...", self.epoch)
if self.epoch == 0:
self._get_random_population()
self.export_results(self.candidates)
else:
best_candidates = self._select_top_candidates()
self.export_results(best_candidates)
if self.epoch >= self.max_epochs:
return
self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
self._get_random_population()
self.epoch += 1
def _random_candidate(self):
chosen_arch = dict()
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
index = self.random_state.randint(len(choices))
chosen_arch[key] = {"_value": choices[index], "_idx": index}
elif val["_type"] == INPUT_CHOICE:
raise NotImplementedError("Input choice is not implemented yet.")
return chosen_arch
def _add_to_evaluate_queue(self, cand):
_logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
self._reward_dict[self._hashcode(cand)] = 0.
self._to_evaluate_queue.append(cand)
def _get_random_population(self):
while len(self.candidates) < self.num_population:
cand = self._random_candidate()
if self._is_legal(cand):
_logger.info("Random candidate generated.")
self._add_to_evaluate_queue(cand)
self.candidates.append(cand)
def _get_crossover(self, best):
result = []
for _ in range(10 * self.num_crossover):
cand_p1 = best[self.random_state.randint(len(best))]
cand_p2 = best[self.random_state.randint(len(best))]
assert cand_p1.keys() == cand_p2.keys()
cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
for k in cand_p1.keys()}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_crossover:
break
_logger.info("Found %d architectures with crossover.", len(result))
return result
def _get_mutation(self, best):
result = []
for _ in range(10 * self.num_mutation):
cand = best[self.random_state.randint(len(best))].copy()
mutation_sample = np.random.random_sample(len(cand))
for s, k in zip(mutation_sample, cand):
if s < self.m_prob:
choices = self._search_space[k]["_value"]
index = self.random_state.randint(len(choices))
cand[k] = {"_value": choices[index], "_idx": index}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_mutation:
break
_logger.info("Found %d architectures with mutation.", len(result))
return result
def _get_architecture_repr(self, cand):
return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
self._hashcode(cand))
def _is_legal(self, cand):
if self._hashcode(cand) in self._reward_dict:
return False
return True
def _select_top_candidates(self):
reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
_logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
_logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
return result
@staticmethod
def _hashcode(d):
return json.dumps(d, sort_keys=True)
def _bind_and_send_parameters(self):
"""
There are two types of resources: parameter ids and candidates. This function is called at
necessary times to bind these resources to send new trials with st_callback.
"""
result = []
while self._sending_parameter_queue and self._to_evaluate_queue:
parameter_id = self._sending_parameter_queue.popleft()
parameters = self._to_evaluate_queue.popleft()
self._id2candidate[parameter_id] = parameters
result.append(parameters)
self._pending_result_ids.add(parameter_id)
self._st_callback(parameter_id, parameters)
_logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
return result
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
Callback function necessary to implement a tuner. This will put more parameter ids into the
parameter id queue.
"""
if "st_callback" in kwargs and self._st_callback is None:
self._st_callback = kwargs["st_callback"]
for parameter_id in parameter_id_list:
self._sending_parameter_queue.append(parameter_id)
self._bind_and_send_parameters()
return [] # always not use this. might induce problem of over-sending
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Callback function. Receive a trial result.
"""
_logger.info("Candidate %d, reported reward %f", parameter_id, value)
self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
def trial_end(self, parameter_id, success, **kwargs):
"""
Callback function when a trial is ended and resource is released.
"""
self._pending_result_ids.remove(parameter_id)
if not self._pending_result_ids and not self._to_evaluate_queue:
# a new epoch now
self._next_round()
assert self._st_callback is not None
self._bind_and_send_parameters()
def export_results(self, result):
"""
Export a number of candidates to `checkpoints` dir.
Parameters
----------
result : dict
Chosen architectures to be exported.
"""
os.makedirs("checkpoints", exist_ok=True)
for i, cand in enumerate(result):
converted = dict()
for cand_key, cand_val in cand.items():
onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
converted[cand_key] = onehot
with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
json.dump(converted, fp)

Просмотреть файл

@ -1,66 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import numpy as np
from nni.algorithms.nas.pytorch.random import RandomMutator
_logger = logging.getLogger(__name__)
class SPOSSupernetTrainingMutator(RandomMutator):
"""
A random mutator with flops limit.
Parameters
----------
model : nn.Module
PyTorch model.
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
flops_lb : number
Lower bound of flops.
flops_ub : number
Upper bound of flops.
flops_bin_num : number
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
uniform, but the sampling will be slower.
flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate.
"""
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500):
super().__init__(model)
self._flops_func = flops_func
if self._flops_func is not None:
self._flops_bin_num = flops_bin_num
self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)]
self._flops_sample_timeout = flops_sample_timeout
def sample_search(self):
"""
Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly
relative to flops.
Returns
-------
dict
"""
if self._flops_func is not None:
for times in range(self._flops_sample_timeout):
idx = np.random.randint(self._flops_bin_num)
cand = super().sample_search()
if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]:
_logger.debug("Sampled candidate flops %f in %d times.", cand, times)
return cand
_logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout)
return super().sample_search()
def sample_final(self):
"""
Implement only to suffice the interface of Mutator.
"""
return self.sample_search()

Просмотреть файл

@ -1,95 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import SPOSSupernetTrainingMutator
logger = logging.getLogger(__name__)
class SPOSSupernetTrainer(Trainer):
"""
This trainer trains a supernet that can be used for evolution search.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : nni.nas.pytorch.mutator.Mutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterable
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
dataset_valid : iterable
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None,
batch_size, workers, device, log_frequency, callbacks)
self.train_loader = train_loader
self.valid_loader = valid_loader
def train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)

Просмотреть файл

Просмотреть файл

@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import get_and_apply_next_architecture

Просмотреть файл

@ -1,217 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
import json
import logging
import os
import sys
import tensorflow as tf
import nni
from nni.runtime.env_vars import trial_env_vars
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
from nni.nas.tensorflow.mutator import Mutator
logger = logging.getLogger(__name__)
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
LAYER_CHOICE = "layer_choice"
INPUT_CHOICE = "input_choice"
def get_and_apply_next_architecture(model):
"""
Wrapper of :class:`~nni.nas.tensorflow.classic_nas.mutator.ClassicMutator` to make it more meaningful,
similar to ``get_next_parameter`` for HPO.
Tt will generate search space based on ``model``.
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ``nnictl`` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
ClassicMutator(model)
class ClassicMutator(Mutator):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
def __init__(self, model):
super(ClassicMutator, self).__init__(model)
self._chosen_arch = {}
self._search_space = self._generate_search_space()
if NNI_GEN_SEARCH_SPACE in os.environ:
# dry run for only generating search space
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
sys.exit(0)
if trial_env_vars.NNI_PLATFORM is None:
logger.warning("This is in standalone mode, the chosen are the first one(s).")
self._chosen_arch = self._standalone_generate_chosen()
else:
# get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter()
if self._chosen_arch is None:
if trial_env_vars.NNI_PLATFORM == "unittest":
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
self._chosen_arch = self._standalone_generate_chosen()
else:
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
self.reset()
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
"""
Convert layer choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
# doesn't support multihot for layer choice yet
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
mask = tf.one_hot(idx, len(mutable))
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
def _sample_input_choice(self, mutable, idx, value, search_space_item):
"""
Convert input choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
candidate_repr = search_space_item["candidates"]
multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True
return tf.cast(multihot_list, tf.bool) # pylint: disable=not-callable
def sample_search(self):
"""
See :meth:`sample_final`.
"""
return self.sample_final()
def sample_final(self):
"""
Convert the chosen arch and apply it on model.
"""
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
self._chosen_arch.keys())
result = dict()
for mutable in self.mutables:
if isinstance(mutable, (LayerChoice, InputChoice)):
assert mutable.key in self._chosen_arch, \
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result
def _standalone_generate_chosen(self):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice.
::
{ key_name: {"_value": "conv1",
"_idx": 0} }
{ key_name: {"_value": ["in1"],
"_idx": [0]} }
Returns
-------
dict
the chosen architecture
"""
chosen_arch = {}
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
elif val["_type"] == INPUT_CHOICE:
choices = val["_value"]["candidates"]
n_chosen = val["_value"]["n_chosen"]
if n_chosen is None:
n_chosen = len(choices)
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
else:
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
return chosen_arch
def _generate_search_space(self):
"""
Generate search space from mutables.
Here is the search space format:
::
{ key_name: {"_type": "layer_choice",
"_value": ["conv1", "conv2"]} }
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
the generated search space
"""
search_space = {}
for mutable in self.mutables:
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = mutable.names
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
def _dump_search_space(self, file_path):
with open(file_path, "w") as ss_file:
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import EnasMutator
from .trainer import EnasTrainer

Просмотреть файл

@ -1,162 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LSTMCell, RNN
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction
from nni.nas.tensorflow.mutator import Mutator
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
class EnasMutator(Mutator):
def __init__(self, model,
lstm_size=64,
lstm_num_layers=1,
tanh_constant=1.5,
cell_exit_extra_step=False,
skip_target=0.4,
temperature=None,
branch_bias=0.25,
entropy_reduction='sum'):
super().__init__(model)
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
cells = [LSTMCell(units=lstm_size, use_bias=False) for _ in range(lstm_num_layers)]
self.lstm = RNN(cells, stateful=True)
self.g_emb = tf.random.normal((1, 1, lstm_size)) * 0.1
self.skip_targets = tf.constant([1.0 - skip_target, skip_target])
self.max_layer_choice = 0
self.bias_dict = {}
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
if 'reduce' in mutable.key:
bias = []
for choice in mutable.choices:
if 'conv' in str(type(choice)).lower():
bias.append(branch_bias)
else:
bias.append(-branch_bias)
self.bias_dict[mutable.key] = tf.constant(bias)
# exposed for trainer
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
# internal nn layers
self.embedding = Embedding(self.max_layer_choice + 1, lstm_size)
self.soft = Dense(self.max_layer_choice, use_bias=False)
self.attn_anchor = Dense(lstm_size, use_bias=False)
self.attn_query = Dense(lstm_size, use_bias=False)
self.v_attn = Dense(1, use_bias=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
self._first_sample = True
def sample_search(self):
self._initialize()
self._sample(self.mutables)
self._first_sample = False
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if self.cell_exit_extra_step and isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
self._anchors_hid[mutable.key] = self.lstm(self._inputs, 1)
def _initialize(self):
self._choices = {}
self._anchors_hid = {}
self._inputs = self.g_emb
# seems the `input_shape` parameter of RNN does not work
# workaround it by omitting `reset_states` for first run
if not self._first_sample:
self.lstm.reset_states()
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _sample_layer_choice(self, mutable):
logit = self.soft(self.lstm(self._inputs))
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * tf.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
branch_id = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [1])
log_prob = self.cross_entropy_loss(branch_id, logit)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.math.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = tf.reshape(self.embedding(branch_id), [1, 1, -1])
mask = tf.one_hot(branch_id, self.max_layer_choice)
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._anchors_hid[label] = self.lstm(self._inputs)
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = tf.concat(query, axis=0)
query = tf.tanh(query + self.attn_query(anchors[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * tf.tanh(query)
if mutable.n_chosen is None:
logit = tf.concat([-query, query], axis=1)
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
skip = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip_prob = tf.math.sigmoid(logit)
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(skip, logit)
skip = tf.cast(skip, tf.float32)
inputs = tf.tensordot(skip, tf.concat(anchors, 0), 1) / (1. + tf.reduce_sum(skip))
self._inputs = tf.reshape(inputs, [1, 1, -1])
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = tf.reshape(query, [1, -1])
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
index = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1])
# when the size is 1, tf does not accept tensor here, complaining the shape is wrong
# but using a numpy array seems fine
log_prob = self.cross_entropy_loss(logit, query.numpy())
self._inputs = tf.reshape(anchors[index.numpy()[0]], [1, 1, -1])
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
assert len(skip) == mutable.n_candidates, (skip, mutable.n_candidates, mutable.n_chosen)
return tf.cast(skip, tf.bool)

Просмотреть файл

@ -1,205 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
import logging
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
class EnasTrainer:
def __init__(
self,
model,
loss,
metrics,
reward_function,
optimizer,
batch_size,
num_epochs,
dataset_train,
dataset_valid,
log_frequency=100,
entropy_weight=0.0001,
skip_weight=0.8,
baseline_decay=0.999,
child_steps=500,
mutator_lr=0.00035,
mutator_steps=50,
mutator_steps_aggregate=20,
aux_weight=0.4,
test_arc_per_epoch=1,
):
self.model = model
self.loss = loss
self.metrics = metrics
self.reward_function = reward_function
self.optimizer = optimizer
self.batch_size = batch_size
self.num_epochs = num_epochs
x, y = dataset_train
split = int(len(x) * 0.9)
self.train_set = tf.data.Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid)
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.child_steps = child_steps
self.mutator_lr = mutator_lr
self.mutator_steps = mutator_steps
self.mutator_steps_aggregate = mutator_steps_aggregate
self.aux_weight = aux_weight
self.test_arc_per_epoch = test_arc_per_epoch
self.mutator = EnasMutator(model)
self.mutator_optim = Adam(learning_rate=self.mutator_lr)
self.baseline = 0.0
def train(self, validate=True):
for epoch in range(self.num_epochs):
logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
def validate(self):
self.validate_one_epoch(-1)
def train_one_epoch(self, epoch):
train_loader, valid_loader = self._create_train_loader()
# Sample model and train
meters = AverageMeterGroup()
for step in range(1, self.child_steps + 1):
x, y = next(train_loader)
self.mutator.reset()
with tf.GradientTape() as tape:
logits = self.model(x, training=True)
if isinstance(logits, tuple):
logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y)
else:
aux_loss = 0.0
metrics = self.metrics(y, logits)
loss = self.loss(y, logits) + self.aux_weight * aux_loss
grads = tape.gradient(loss, self.model.trainable_weights)
grads = fill_zero_grads(grads, self.model.trainable_weights)
grads, _ = tf.clip_by_global_norm(grads, 5.0)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
metrics["loss"] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
if self.log_frequency and step % self.log_frequency == 0:
logger.info(
"Model Epoch [%d/%d] Step [%d/%d] %s",
epoch + 1,
self.num_epochs,
step,
self.child_steps,
meters,
)
# Train sampler (mutator)
meters = AverageMeterGroup()
for mutator_step in range(1, self.mutator_steps + 1):
grads_list = []
for step in range(1, self.mutator_steps_aggregate + 1):
with tf.GradientTape() as tape:
x, y = next(valid_loader)
self.mutator.reset()
logits = self.model(x, training=False)
metrics = self.metrics(y, logits)
reward = (
self.reward_function(y, logits)
+ self.entropy_weight * self.mutator.sample_entropy
)
self.baseline = self.baseline * self.baseline_decay + reward * (
1 - self.baseline_decay
)
loss = self.mutator.sample_log_prob * (reward - self.baseline)
loss += self.skip_weight * self.mutator.sample_skip_penalty
meters.update(
{
"reward": reward,
"loss": tf.reduce_mean(loss).numpy(),
"ent": self.mutator.sample_entropy.numpy(),
"log_prob": self.mutator.sample_log_prob.numpy(),
"baseline": self.baseline,
"skip": self.mutator.sample_skip_penalty,
}
)
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
if self.log_frequency and cur_step % self.log_frequency == 0:
logger.info(
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s",
epoch + 1,
self.num_epochs,
mutator_step,
self.mutator_steps,
step,
self.mutator_steps_aggregate,
meters,
)
grads = tape.gradient(loss, self.mutator.trainable_weights)
grads = fill_zero_grads(grads, self.mutator.trainable_weights)
grads_list.append(grads)
total_grads = [
tf.math.add_n(weight_grads) for weight_grads in zip(*grads_list)
]
total_grads, _ = tf.clip_by_global_norm(total_grads, 5.0)
self.mutator_optim.apply_gradients(
zip(total_grads, self.mutator.trainable_weights)
)
def validate_one_epoch(self, epoch):
test_loader = self._create_validate_loader()
for arc_id in range(self.test_arc_per_epoch):
meters = AverageMeterGroup()
for x, y in test_loader:
self.mutator.reset()
logits = self.model(x, training=False)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(y, logits)
loss = self.loss(y, logits)
metrics["loss"] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
logger.info(
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
epoch + 1,
self.num_epochs,
arc_id + 1,
self.test_arc_per_epoch,
meters.summary(),
)
def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set)
def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -0,0 +1,87 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['NasBench201Cell']
from collections import OrderedDict
from typing import Callable, List, Dict, Union, Optional
import torch
import torch.nn as nn
from nni.nas.nn.pytorch import LayerChoice
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
class NasBench201Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-201.
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.layers = nn.ModuleList()
self.in_features = in_features
self.out_features = out_features
self.num_tensors = num_tensors
op_candidates = self._make_dict(op_candidates)
for tid in range(1, num_tensors):
node_ops = nn.ModuleList()
for j in range(tid):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors: List[torch.Tensor] = [inputs]
for layer in self.layers:
current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) # type: ignore
tensors.append(torch.sum(torch.stack(current_tensor), 0))
return tensors[-1]

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -3,21 +3,17 @@
import copy
import warnings
from collections import OrderedDict
from typing import Callable, List, Dict, Union, Tuple, Optional
from typing import Callable, List, Union, Tuple, Optional
import torch
import torch.nn as nn
from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from nni.nas.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice, ValueChoice, ValueChoiceX, ChoiceOf
from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
from .choice import ValueChoice, ValueChoiceX, ChoiceOf
from .mutation_utils import Mutable, get_fixed_value
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']
__all__ = ['Repeat']
class Repeat(Mutable):
@ -159,77 +155,3 @@ class Repeat(Mutable):
def __len__(self):
return self.max_depth
class NasBench201Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-201.
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.layers = nn.ModuleList()
self.in_features = in_features
self.out_features = out_features
self.num_tensors = num_tensors
op_candidates = self._make_dict(op_candidates)
for tid in range(1, num_tensors):
node_ops = nn.ModuleList()
for j in range(tid):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors: List[torch.Tensor] = [inputs]
for layer in self.layers:
current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) # type: ignore
tensors.append(torch.sum(torch.stack(current_tensor), 0))
return tensors[-1]

Просмотреть файл

Просмотреть файл

@ -0,0 +1,150 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
return sampled

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше