Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
This commit is contained in:
Yuge Zhang 2022-08-01 15:21:43 +08:00 коммит произвёл GitHub
Родитель d6dcb48319 bc6d8796d2
Коммит a0fd003671
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
243 изменённых файлов: 18175 добавлений и 25706 удалений

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator
from .trainer import CdartsTrainer

Просмотреть файл

@ -1,143 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from nni.algorithms.nas.pytorch.darts import DartsMutator # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutables import LayerChoice # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutator import Mutator # pylint: disable=wrong-import-order
class RegularizedDartsMutator(DartsMutator):
"""
This is :class:`~nni.algorithms.nas.pytorch.darts.DartsMutator` basically, with two differences.
1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in
forward pass and thus consumes no memory.
2. Regularization on choices, to prevent the mutator from overfitting on some choices.
"""
def reset(self):
"""
Warnings
--------
Renamed :func:`~reset_with_loss` to return regularization loss on reset.
"""
raise ValueError("You should probably call `reset_with_loss`.")
def cut_choices(self, cut_num=2):
"""
Cut the choices with the smallest weights.
``cut_num`` should be the accumulative number of cutting, e.g., if first time cutting
is 2, the second time should be 4 to cut another two.
Parameters
----------
cut_num : int
Number of choices to cut, so far.
Warnings
--------
Though the parameters are set to :math:`-\infty` to be bypassed, they will still receive gradient of 0,
which introduced ``nan`` problem when calling ``optimizer.step()``. To solve this issue, a simple way is to
reset nan to :math:`-\infty` each time after the parameters are updated.
"""
# `cut_choices` is implemented but not used in current implementation of CdartsTrainer
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
_, idx = torch.topk(-self.choices[mutable.key], cut_num)
with torch.no_grad():
for i in idx:
self.choices[mutable.key][i] = -float("inf")
def reset_with_loss(self):
"""
Resample and return loss. If loss is 0, to avoid device issue, it will return ``None``.
Currently loss penalty are proportional to the L1-norm of parameters corresponding
to modules if their type name contains certain substrings. These substrings include: ``poolwithoutbn``,
``identity``, ``dilconv``.
"""
self._cache, reg_loss = self.sample_search()
return reg_loss
def sample_search(self):
result = super().sample_search()
loss = []
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
def need_reg(choice):
return any(t in str(type(choice)).lower() for t in ["poolwithoutbn", "identity", "dilconv"])
for i, choice in enumerate(mutable.choices):
if need_reg(choice):
norm = torch.abs(self.choices[mutable.key][i])
if norm < 1E10:
loss.append(norm)
if not loss:
return result, None
return result, sum(loss)
def export(self, logger=None):
"""
Export an architecture with logger. Genotype will be printed with logger.
Returns
-------
dict
A mapping from mutable keys to decisions.
"""
result = self.sample_final()
if hasattr(self.model, "plot_genotype") and logger is not None:
genotypes = self.model.plot_genotype(result, logger)
return result, genotypes
class RegularizedMutatorParallel(DistributedDataParallel):
"""
Parallelize :class:`~RegularizedDartsMutator`.
This makes :func:`~RegularizedDartsMutator.reset_with_loss` method parallelized,
also allowing :func:`~RegularizedDartsMutator.cut_choices` and :func:`~RegularizedDartsMutator.export`
to be easily accessible.
"""
def reset_with_loss(self):
"""
Parallelized :func:`~RegularizedDartsMutator.reset_with_loss`.
"""
result = self.module.reset_with_loss()
self.callback_queued = False
return result
def cut_choices(self, *args, **kwargs):
"""
Parallelized :func:`~RegularizedDartsMutator.cut_choices`.
"""
self.module.cut_choices(*args, **kwargs)
def export(self, logger):
"""
Parallelized :func:`~RegularizedDartsMutator.export`.
"""
return self.module.export(logger)
class DartsDiscreteMutator(Mutator):
"""
A mutator that applies the final sampling result of a parent mutator on another model to train.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : nni.nas.pytorch.mutator.Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
def __init__(self, model, parent_mutator):
super().__init__(model)
self.__dict__["parent_mutator"] = parent_mutator # avoid parameters to be included
def sample_search(self):
return self.parent_mutator.sample_final()

Просмотреть файл

@ -1,275 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import apex # pylint: disable=import-error
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator # pylint: disable=wrong-import-order
from nni.nas.pytorch.utils import AverageMeterGroup # pylint: disable=wrong-import-order
from .utils import CyclicIterator, TorchTensorEncoder, accuracy, reduce_metrics
PHASE_SMALL = "small"
PHASE_LARGE = "large"
class InteractiveKLLoss(nn.Module):
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
# self.kl_loss = nn.KLDivLoss(reduction = 'batchmean')
self.kl_loss = nn.KLDivLoss()
def forward(self, student, teacher):
return self.kl_loss(F.log_softmax(student / self.temperature, dim=1),
F.softmax(teacher / self.temperature, dim=1))
class CdartsTrainer(object):
"""
CDARTS trainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None,
regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True,
epochs=32, steps_per_epoch=None, loss_alpha=2, loss_T=2, distributed=True,
log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs',
w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4,
nasnet_lr=0.2, local_rank=0, share_module=True):
if logger is None:
logger = logging.getLogger(__name__)
train_loader, valid_loader = loaders
train_sampler, valid_sampler = samplers
self.train_loader = CyclicIterator(train_loader, train_sampler, distributed)
self.valid_loader = CyclicIterator(valid_loader, valid_sampler, distributed)
self.regular_coeff = regular_coeff
self.regular_ratio = regular_ratio
self.warmup_epochs = warmup_epochs
self.fix_head = fix_head
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
if self.steps_per_epoch is None:
self.steps_per_epoch = min(len(self.train_loader), len(self.valid_loader))
self.loss_alpha = loss_alpha
self.grad_clip = grad_clip
if interactive_type == "kl":
self.interactive_loss = InteractiveKLLoss(loss_T)
elif interactive_type == "smoothl1":
self.interactive_loss = nn.SmoothL1Loss()
self.loss_T = loss_T
self.distributed = distributed
self.log_frequency = log_frequency
self.main_proc = not distributed or local_rank == 0
self.logger = logger
self.checkpoint_dir = output_path
if self.main_proc:
os.makedirs(self.checkpoint_dir, exist_ok=True)
if distributed:
torch.distributed.barrier()
self.model_small = model_small
self.model_large = model_large
if self.fix_head:
for param in self.model_small.aux_head.parameters():
param.requires_grad = False
for param in self.model_large.aux_head.parameters():
param.requires_grad = False
self.mutator_small = RegularizedDartsMutator(self.model_small).cuda()
self.mutator_large = DartsDiscreteMutator(self.model_large, self.mutator_small).cuda()
self.criterion = criterion
self.optimizer_small = torch.optim.SGD(self.model_small.parameters(), w_lr,
momentum=w_momentum, weight_decay=w_weight_decay)
self.optimizer_large = torch.optim.SGD(self.model_large.parameters(), nasnet_lr,
momentum=w_momentum, weight_decay=w_weight_decay)
self.optimizer_alpha = torch.optim.Adam(self.mutator_small.parameters(), alpha_lr,
betas=(0.5, 0.999), weight_decay=alpha_weight_decay)
if distributed:
apex.parallel.convert_syncbn_model(self.model_small)
apex.parallel.convert_syncbn_model(self.model_large)
self.model_small = DistributedDataParallel(self.model_small, delay_allreduce=True)
self.model_large = DistributedDataParallel(self.model_large, delay_allreduce=True)
self.mutator_small = RegularizedMutatorParallel(self.mutator_small, delay_allreduce=True)
if share_module:
self.model_small.callback_queued = True
self.model_large.callback_queued = True
# mutator large never gets optimized, so do not need parallelized
def _warmup(self, phase, epoch):
assert phase in [PHASE_SMALL, PHASE_LARGE]
if phase == PHASE_SMALL:
model, optimizer = self.model_small, self.optimizer_small
elif phase == PHASE_LARGE:
model, optimizer = self.model_large, self.optimizer_large
model.train()
meters = AverageMeterGroup()
for step in range(self.steps_per_epoch):
x, y = next(self.train_loader)
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
logits_main, _ = model(x)
loss = self.criterion(logits_main, y)
loss.backward()
self._clip_grad_norm(model)
optimizer.step()
prec1, prec5 = accuracy(logits_main, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics, self.distributed)
meters.update(metrics)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s) %s", epoch + 1, self.epochs,
step + 1, self.steps_per_epoch, phase, meters)
def _clip_grad_norm(self, model):
if isinstance(model, DistributedDataParallel):
nn.utils.clip_grad_norm_(model.module.parameters(), self.grad_clip)
else:
nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
def _reset_nan(self, parameters):
with torch.no_grad():
for param in parameters:
for i, p in enumerate(param):
if p != p: # equivalent to `isnan(p)`
param[i] = float("-inf")
def _joint_train(self, epoch):
self.model_large.train()
self.model_small.train()
meters = AverageMeterGroup()
for step in range(self.steps_per_epoch):
trn_x, trn_y = next(self.train_loader)
val_x, val_y = next(self.valid_loader)
trn_x, trn_y = trn_x.cuda(), trn_y.cuda()
val_x, val_y = val_x.cuda(), val_y.cuda()
# step 1. optimize architecture
self.optimizer_alpha.zero_grad()
self.optimizer_large.zero_grad()
reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
(self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
loss_regular = self.mutator_small.reset_with_loss()
if loss_regular:
loss_regular *= reg_decay
logits_search, emsemble_logits_search = self.model_small(val_x)
logits_main, emsemble_logits_main = self.model_large(val_x)
loss_cls = (self.criterion(logits_search, val_y) + self.criterion(logits_main, val_y)) / self.loss_alpha
loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * (self.loss_T ** 2) * self.loss_alpha
loss = loss_cls + loss_interactive + loss_regular
loss.backward()
self._clip_grad_norm(self.model_large)
self.optimizer_large.step()
self.optimizer_alpha.step()
# NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices`
# step 2. optimize op weights
self.optimizer_small.zero_grad()
with torch.no_grad():
# resample architecture since parameters have been changed
self.mutator_small.reset_with_loss()
logits_search_train, _ = self.model_small(trn_x)
loss_weight = self.criterion(logits_search_train, trn_y)
loss_weight.backward()
self._clip_grad_norm(self.model_small)
self.optimizer_small.step()
metrics = {"loss_cls": loss_cls, "loss_interactive": loss_interactive,
"loss_regular": loss_regular, "loss_weight": loss_weight}
metrics = reduce_metrics(metrics, self.distributed)
meters.update(metrics)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint) %s", epoch + 1, self.epochs,
step + 1, self.steps_per_epoch, meters)
def train(self):
for epoch in range(self.epochs):
if epoch < self.warmup_epochs:
with torch.no_grad(): # otherwise grads will be retained on the architecture params
self.mutator_small.reset_with_loss()
self._warmup(PHASE_SMALL, epoch)
else:
with torch.no_grad():
self.mutator_large.reset()
self._warmup(PHASE_LARGE, epoch)
self._joint_train(epoch)
self.export(os.path.join(self.checkpoint_dir, "epoch_{:02d}.json".format(epoch)),
os.path.join(self.checkpoint_dir, "epoch_{:02d}.genotypes".format(epoch)))
def export(self, file, genotype_file):
if self.main_proc:
mutator_export, genotypes = self.mutator_small.export(self.logger)
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
with open(genotype_file, "w") as f:
f.write(str(genotypes))

Просмотреть файл

@ -1,76 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
import torch
import torch.distributed as dist
class CyclicIterator:
def __init__(self, loader, sampler, distributed):
self.loader = loader
self.sampler = sampler
self.epoch = 0
self.distributed = distributed
self._next_epoch()
def _next_epoch(self):
if self.distributed:
self.sampler.set_epoch(self.epoch)
self.iterator = iter(self.loader)
self.epoch += 1
def __len__(self):
return len(self.loader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iterator)
except StopIteration:
self._next_epoch()
return next(self.iterator)
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
return o.tolist()
return super().default(o)
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))
return res
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= float(os.environ["WORLD_SIZE"])
return rt
def reduce_metrics(metrics, distributed=False):
if distributed:
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
return {k: v.item() for k, v in metrics.items()}

Просмотреть файл

@ -1,221 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import sys
import torch
import nni
from nni.runtime.env_vars import trial_env_vars
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
from nni.nas.pytorch.mutator import Mutator
logger = logging.getLogger(__name__)
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
LAYER_CHOICE = "layer_choice"
INPUT_CHOICE = "input_choice"
def get_and_apply_next_architecture(model):
"""
Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful,
similar to ``get_next_parameter`` for HPO.
It will generate search space based on ``model``.
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ``nnictl`` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
ClassicMutator(model)
class ClassicMutator(Mutator):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
def __init__(self, model):
super(ClassicMutator, self).__init__(model)
self._chosen_arch = {}
self._search_space = self._generate_search_space()
if NNI_GEN_SEARCH_SPACE in os.environ:
# dry run for only generating search space
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
sys.exit(0)
if trial_env_vars.NNI_PLATFORM is None:
logger.warning("This is in standalone mode, the chosen are the first one(s).")
self._chosen_arch = self._standalone_generate_chosen()
else:
# get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter()
if self._chosen_arch is None:
if trial_env_vars.NNI_PLATFORM == "unittest":
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
self._chosen_arch = self._standalone_generate_chosen()
else:
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
self.reset()
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
"""
Convert layer choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
# doesn't support multihot for layer choice yet
onehot_list = [False] * len(mutable)
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
onehot_list[idx] = True
return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable
def _sample_input_choice(self, mutable, idx, value, search_space_item):
"""
Convert input choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
candidate_repr = search_space_item["candidates"]
multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
def sample_search(self):
"""
See :meth:`sample_final`.
"""
return self.sample_final()
def sample_final(self):
"""
Convert the chosen arch and apply it on model.
"""
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
self._chosen_arch.keys())
result = dict()
for mutable in self.mutables:
if isinstance(mutable, (LayerChoice, InputChoice)):
assert mutable.key in self._chosen_arch, \
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result
def _standalone_generate_chosen(self):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice.
::
{ key_name: {"_value": "conv1",
"_idx": 0} }
{ key_name: {"_value": ["in1"],
"_idx": [0]} }
Returns
-------
dict
the chosen architecture
"""
chosen_arch = {}
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
elif val["_type"] == INPUT_CHOICE:
choices = val["_value"]["candidates"]
n_chosen = val["_value"]["n_chosen"]
if n_chosen is None:
n_chosen = len(choices)
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
else:
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
return chosen_arch
def _generate_search_space(self):
"""
Generate search space from mutables.
Here is the search space format:
::
{ key_name: {"_type": "layer_choice",
"_value": ["conv1", "conv2"]} }
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
the generated search space
"""
search_space = {}
for mutable in self.mutables:
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = mutable.names
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
def _dump_search_space(self, file_path):
with open(file_path, "w") as ss_file:
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)

Просмотреть файл

@ -1,403 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from copy import deepcopy
import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .utils import accuracy, reduce_metrics
logger = logging.getLogger(__name__)
class CreamSupernetTrainer(Trainer):
"""
This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
Parameters
----------
model : nn.Module
Model with mutables.
loss : callable
Called with logits and targets. Returns a loss tensor.
val_loss : callable
Called with logits and targets for validation only. Returns a loss tensor.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterablez
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
valid_loader : iterablez
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
mutator : Mutator
A mutator object that has been initialized with the model.
batch_size : int
Batch size.
log_frequency : int
Number of mini-batches to log metrics.
meta_sta_epoch : int
start epoch of using meta matching network to pick teacher architecture
update_iter : int
interval of updating meta matching networks
slices : int
batch size of mini training data in the process of training meta matching network
pool_size : int
board size
pick_method : basestring
how to pick teacher network
choice_num : int
number of operations in supernet
sta_num : int
layer number of each stage in supernet (5 stage in supernet)
acc_gap : int
maximum accuracy improvement to omit the limitation of flops
flops_dict : Dict
dictionary of each layer's operations in supernet
flops_fixed : int
flops of fixed part in supernet
local_rank : int
index of current rank
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def __init__(self, model, loss, val_loss,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, log_frequency=None,
meta_sta_epoch=20, update_iter=200, slices=2,
pool_size=10, pick_method='meta', choice_num=6,
sta_num=(4, 4, 4, 4, 4), acc_gap=5,
flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None):
assert torch.cuda.is_available()
super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None,
optimizer, num_epochs, None, None,
batch_size, None, None, log_frequency, callbacks)
self.model = model
self.loss = loss
self.val_loss = val_loss
self.train_loader = train_loader
self.valid_loader = valid_loader
self.log_frequency = log_frequency
self.batch_size = batch_size
self.optimizer = optimizer
self.model = model
self.loss = loss
self.num_epochs = num_epochs
self.meta_sta_epoch = meta_sta_epoch
self.update_iter = update_iter
self.slices = slices
self.pick_method = pick_method
self.pool_size = pool_size
self.local_rank = local_rank
self.choice_num = choice_num
self.sta_num = sta_num
self.acc_gap = acc_gap
self.flops_dict = flops_dict
self.flops_fixed = flops_fixed
self.current_student_arch = None
self.current_teacher_arch = None
self.main_proc = (local_rank == 0)
self.current_epoch = 0
self.prioritized_board = []
# size of prioritized board
def _board_size(self):
return len(self.prioritized_board)
# select teacher architecture according to the logit difference
def _select_teacher(self):
self._replace_mutator_cand(self.current_student_arch)
if self.pick_method == 'top1':
meta_value, teacher_cand = 0.5, sorted(
self.prioritized_board, reverse=True)[0][3]
elif self.pick_method == 'meta':
meta_value, cand_idx, teacher_cand = -1000000000, -1, None
for now_idx, item in enumerate(self.prioritized_board):
inputx = item[4]
output = torch.nn.functional.softmax(self.model(inputx), dim=1)
weight = self.model.module.forward_meta(output - item[5])
if weight > meta_value:
meta_value = weight
cand_idx = now_idx
teacher_cand = self.prioritized_board[cand_idx][3]
assert teacher_cand is not None
meta_value = torch.nn.functional.sigmoid(-weight)
else:
raise ValueError('Method Not supported')
return meta_value, teacher_cand
# check whether to update prioritized board
def _isUpdateBoard(self, prec1, flops):
if self.current_epoch <= self.meta_sta_epoch:
return False
if len(self.prioritized_board) < self.pool_size:
return True
if prec1 > self.prioritized_board[-1][1] + self.acc_gap:
return True
if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]:
return True
return False
# update prioritized board
def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops):
if self._isUpdateBoard(prec1, flops):
val_prec1 = prec1
training_data = deepcopy(inputs[:self.slices].detach())
if len(self.prioritized_board) == 0:
features = deepcopy(outputs[:self.slices].detach())
else:
features = deepcopy(
teacher_output[:self.slices].detach())
self.prioritized_board.append(
(val_prec1,
prec1,
flops,
self.current_student_arch,
training_data,
torch.nn.functional.softmax(
features,
dim=1)))
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
if len(self.prioritized_board) > self.pool_size:
del self.prioritized_board[-1]
# only update student network weights
def _update_student_weights_only(self, grad_1):
for weight, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1):
weight.grad = grad_item
torch.nn.utils.clip_grad_norm_(
self.model.module.rand_parameters(self.current_student_arch), 1)
self.optimizer.step()
for weight, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1):
del weight.grad
# only update meta networks weights
def _update_meta_weights_only(self, teacher_cand, grad_teacher):
for weight, grad_item in zip(self.model.module.rand_parameters(
teacher_cand, self.pick_method == 'meta'), grad_teacher):
weight.grad = grad_item
# clip gradients
torch.nn.utils.clip_grad_norm_(
self.model.module.rand_parameters(
self.current_student_arch, self.pick_method == 'meta'), 1)
self.optimizer.step()
for weight, grad_item in zip(self.model.module.rand_parameters(
teacher_cand, self.pick_method == 'meta'), grad_teacher):
del weight.grad
# simulate sgd updating
def _simulate_sgd_update(self, w, g, optimizer):
return g * optimizer.param_groups[-1]['lr'] + w
# split training images into several slices
def _get_minibatch_input(self, input): # pylint: disable=redefined-builtin
slice = self.slices # pylint: disable=redefined-builtin
x = deepcopy(input[:slice].clone().detach())
return x
# calculate 1st gradient of student architectures
def _calculate_1st_gradient(self, kd_loss):
self.optimizer.zero_grad()
grad = torch.autograd.grad(
kd_loss,
self.model.module.rand_parameters(self.current_student_arch),
create_graph=True)
return grad
# calculate 2nd gradient of meta networks
def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight):
self.optimizer.zero_grad()
grad_student_val = torch.autograd.grad(
validation_loss,
self.model.module.rand_parameters(self.current_student_arch),
retain_graph=True)
grad_teacher = torch.autograd.grad(
students_weight[0],
self.model.module.rand_parameters(
teacher_cand,
self.pick_method == 'meta'),
grad_outputs=grad_student_val)
return grad_teacher
# forward training data
def _forward_training(self, x, meta_value):
self._replace_mutator_cand(self.current_student_arch)
output = self.model(x)
with torch.no_grad():
self._replace_mutator_cand(self.current_teacher_arch)
teacher_output = self.model(x)
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
kd_loss = meta_value * \
self._cross_entropy_loss_with_soft_target(output, soft_label)
return kd_loss
# calculate soft target loss
def _cross_entropy_loss_with_soft_target(self, pred, soft_target):
logsoftmax = torch.nn.LogSoftmax()
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
# forward validation data
def _forward_validation(self, input, target): # pylint: disable=redefined-builtin
slice = self.slices # pylint: disable=redefined-builtin
x = input[slice:slice * 2].clone()
self._replace_mutator_cand(self.current_student_arch)
output_2 = self.model(x)
validation_loss = self.loss(output_2, target[slice:slice * 2])
return validation_loss
def _isUpdateMeta(self, batch_idx):
isUpdate = True
isUpdate &= (self.current_epoch > self.meta_sta_epoch)
isUpdate &= (batch_idx > 0)
isUpdate &= (batch_idx % self.update_iter == 0)
isUpdate &= (self._board_size() > 0)
return isUpdate
def _replace_mutator_cand(self, cand):
self.mutator._cache = cand
# update meta matching networks
def _run_update(self, input, target, batch_idx): # pylint: disable=redefined-builtin
if self._isUpdateMeta(batch_idx):
x = self._get_minibatch_input(input)
meta_value, teacher_cand = self._select_teacher()
kd_loss = self._forward_training(x, meta_value)
# calculate 1st gradient
grad_1st = self._calculate_1st_gradient(kd_loss)
# simulate updated student weights
students_weight = [
self._simulate_sgd_update(
p, grad_item, self.optimizer) for p, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1st)]
# update student weights
self._update_student_weights_only(grad_1st)
validation_loss = self._forward_validation(input, target)
# calculate 2nd gradient
grad_teacher = self._calculate_2nd_gradient(validation_loss, teacher_cand, students_weight)
# update meta matching networks
self._update_meta_weights_only(teacher_cand, grad_teacher)
# delete internal variants
del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight
def _get_cand_flops(self, cand):
flops = 0
for block_id, block in enumerate(cand):
if block == 'LayerChoice1' or block_id == 'LayerChoice23':
continue
for idx, choice in enumerate(cand[block]):
flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
return flops + self.flops_fixed
def train_one_epoch(self, epoch):
self.current_epoch = epoch
meters = AverageMeterGroup()
self.steps_per_epoch = len(self.train_loader)
for step, (input_data, target) in enumerate(self.train_loader):
self.mutator.reset()
self.current_student_arch = self.mutator._cache
input_data, target = input_data.cuda(), target.cuda()
# calculate flops of current architecture
cand_flops = self._get_cand_flops(self.mutator._cache)
# update meta matching network
self._run_update(input_data, target, step)
if self._board_size() > 0:
# select teacher architecture
meta_value, teacher_cand = self._select_teacher()
self.current_teacher_arch = teacher_cand
# forward supernet
if self._board_size() == 0 or epoch <= self.meta_sta_epoch:
self._replace_mutator_cand(self.current_student_arch)
output = self.model(input_data)
loss = self.loss(output, target)
kd_loss, teacher_output, teacher_cand = None, None, None
else:
self._replace_mutator_cand(self.current_student_arch)
output = self.model(input_data)
gt_loss = self.loss(output, target)
with torch.no_grad():
self._replace_mutator_cand(self.current_teacher_arch)
teacher_output = self.model(input_data).detach()
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label)
loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2
# update network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# update metrics
prec1, prec5 = accuracy(output, target, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics)
meters.update(metrics)
# update prioritized board
self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs,
step + 1, len(self.train_loader), meters)
if self.main_proc and self.num_epochs == epoch + 1:
for idx, i in enumerate(self.prioritized_board):
logger.info("No.%s %s", idx, i[:4])
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
self.mutator.reset()
logits = self.model(x)
loss = self.val_loss(logits, y)
prec1, prec5 = accuracy(logits, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)

Просмотреть файл

@ -1,37 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch.distributed as dist
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))
return res
def reduce_metrics(metrics):
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= float(os.environ["WORLD_SIZE"])
return rt

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import DartsMutator
from .trainer import DartsTrainer

Просмотреть файл

@ -1,85 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
_logger = logging.getLogger(__name__)
class DartsMutator(Mutator):
"""
Connects the model in a DARTS (differentiable) way.
An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no
op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
(not chosen).
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
Attributes
----------
choices: ParameterDict
dict that maps keys of LayerChoices to weighted-connection float tensors.
"""
def __init__(self, model):
super().__init__(model)
self.choices = nn.ParameterDict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
def device(self):
for v in self.choices.values():
return v.device
def sample_search(self):
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
elif isinstance(mutable, InputChoice):
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
return result
def sample_final(self):
result = dict()
edges_max = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice):
if mutable.n_chosen is not None:
weights = []
for src_key in mutable.choose_from:
if src_key not in edges_max:
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
weights.append(edges_max.get(src_key, 0.))
weights = torch.tensor(weights) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
# If an edge is never selected, there is no need to calculate any op on this edge.
# This is to eliminate redundant calculation.
result[src_key] = torch.zeros_like(result[src_key])
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
else:
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result

Просмотреть файл

@ -1,214 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import torch
import torch.nn as nn
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import DartsMutator
logger = logging.getLogger(__name__)
class DartsTrainer(Trainer):
"""
DARTS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
n_train = len(self.dataset_train)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_one_epoch(self, epoch):
self.model.train()
self.mutator.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# phase 1. architecture step
self.ctrl_optim.zero_grad()
if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 2: child network step
self.optimizer.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
self.optimizer.step()
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
self.mutator.eval()
meters = AverageMeterGroup()
with torch.no_grad():
self.mutator.reset()
for step, (X, y) in enumerate(self.test_loader):
X, y = X.to(self.device), y.to(self.device)
logits = self.model(X)
metrics = self.metrics(logits, y)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.test_loader), meters)
def _logits_and_loss(self, X, y):
self.mutator.reset()
logits = self.model(X)
loss = self.loss(logits, y)
self._write_graph_status()
return logits, loss
def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
"""
Compute unrolled loss and backward its gradients
"""
backup_params = copy.deepcopy(tuple(self.model.parameters()))
# do virtual step on training data
lr = self.optimizer.param_groups[0]["lr"]
momentum = self.optimizer.param_groups[0]["momentum"]
weight_decay = self.optimizer.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
# restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.optimizer.state[w].get("momentum_buffer", 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import EnasMutator
from .trainer import EnasTrainer

Просмотреть файл

@ -1,197 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class EnasMutator(Mutator):
"""
A mutator that mutates the graph with RL.
Parameters
----------
model : nn.Module
PyTorch model.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
super().__init__(model)
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
self.max_layer_choice = 0
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def sample_search(self):
self._initialize()
self._sample(self.mutables)
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
def _initialize(self):
self._choices = dict()
self._anchors_hid = dict()
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def _sample_layer_choice(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
return skip.bool()

Просмотреть файл

@ -1,209 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from itertools import cycle
import torch
import torch.nn as nn
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup, to_device
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
class EnasTrainer(Trainer):
"""
ENAS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4,
test_arc_per_epoch=1):
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.reward_function = reward_function
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.batch_size = batch_size
self.workers = workers
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.mutator_steps_aggregate = mutator_steps_aggregate
self.mutator_steps = mutator_steps
self.child_steps = child_steps
self.aux_weight = aux_weight
self.test_arc_per_epoch = test_arc_per_epoch
self.init_dataloader()
def init_dataloader(self):
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=self.batch_size,
num_workers=self.workers)
self.train_loader = cycle(self.train_loader)
self.valid_loader = cycle(self.valid_loader)
def train_one_epoch(self, epoch):
# Sample model and train
self.model.train()
self.mutator.eval()
meters = AverageMeterGroup()
for step in range(1, self.child_steps + 1):
x, y = next(self.train_loader)
x, y = to_device(x, self.device), to_device(y, self.device)
self.optimizer.zero_grad()
with torch.no_grad():
self.mutator.reset()
self._write_graph_status()
logits = self.model(x)
if isinstance(logits, tuple):
logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y)
else:
aux_loss = 0.
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
loss = loss + self.aux_weight * aux_loss
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.optimizer.step()
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
self.num_epochs, step, self.child_steps, meters)
# Train sampler (mutator)
self.model.eval()
self.mutator.train()
meters = AverageMeterGroup()
for mutator_step in range(1, self.mutator_steps + 1):
self.mutator_optim.zero_grad()
for step in range(1, self.mutator_steps_aggregate + 1):
x, y = next(self.valid_loader)
x, y = to_device(x, self.device), to_device(y, self.device)
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
self._write_graph_status()
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight:
reward += self.entropy_weight * self.mutator.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
metrics["reward"] = reward
metrics["loss"] = loss.item()
metrics["ent"] = self.mutator.sample_entropy.item()
metrics["log_prob"] = self.mutator.sample_log_prob.item()
metrics["baseline"] = self.baseline
metrics["skip"] = self.mutator.sample_skip_penalty
loss /= self.mutator_steps_aggregate
loss.backward()
meters.update(metrics)
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
if self.log_frequency is not None and cur_step % self.log_frequency == 0:
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs,
mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
meters)
nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
self.mutator_optim.step()
def validate_one_epoch(self, epoch):
with torch.no_grad():
for arc_id in range(self.test_arc_per_epoch):
meters = AverageMeterGroup()
for x, y in self.test_loader:
x, y = to_device(x, self.device), to_device(y, self.device)
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch,
meters.summary())

Просмотреть файл

@ -1,14 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import
from .mutator import FBNetMutator # noqa: F401
from .trainer import FBNetTrainer # noqa: F401
from .utils import ( # noqa: F401
LookUpTable,
NASConfig,
RegularizerLoss,
model_init,
supernet_sample,
)

Просмотреть файл

@ -1,268 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import, division, print_function
import torch
from torch import nn as nn
from torch.nn import functional as F
import numpy as np
from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.pytorch.mutables import LayerChoice
class MixedOp(nn.Module):
"""
This class is to instantiate and manage info of one LayerChoice.
It includes architecture weights and member functions for the weights.
"""
def __init__(self, mutable, latency):
"""
Parameters
----------
mutable : LayerChoice
A LayerChoice in user model
latency : List
performance cost for each op in mutable
"""
super(MixedOp, self).__init__()
self.latency = latency
n_choices = len(mutable)
self.path_alpha = nn.Parameter(
torch.FloatTensor([1.0 / n_choices for i in range(n_choices)])
)
self.path_alpha.requires_grad = False
self.temperature = 1.0
def get_path_alpha(self):
"""Return the architecture parameter."""
return self.path_alpha
def get_weighted_latency(self):
"""Return the weighted perf_cost of current mutable."""
soft_masks = self.probs_over_ops()
weighted_latency = sum(m * l for m, l in zip(soft_masks, self.latency))
return weighted_latency
def set_temperature(self, temperature):
"""
Set the annealed temperature for gumbel softmax.
Parameters
----------
temperature : float
The annealed temperature for gumbel softmax
"""
self.temperature = temperature
def to_requires_grad(self):
"""Enable gradient calculation."""
self.path_alpha.requires_grad = True
def to_disable_grad(self):
"""Disable gradient calculation."""
self.path_alpha.requires_grad = False
def probs_over_ops(self):
"""Apply gumbel softmax to generate probability distribution."""
return F.gumbel_softmax(self.path_alpha, self.temperature)
def forward(self, mutable, x):
"""
Define forward of LayerChoice.
Parameters
----------
mutable : LayerChoice
this layer's mutable
x : tensor
inputs of this layer, only support one input
Returns
-------
output: tensor
output of this layer
"""
candidate_ops = list(mutable)
soft_masks = self.probs_over_ops()
output = sum(m * op(x) for m, op in zip(soft_masks, candidate_ops))
return output
@property
def chosen_index(self):
"""
choose the op with max prob
Returns
-------
int
index of the chosen one
"""
alphas = self.path_alpha.data.detach().cpu().numpy()
index = int(np.argmax(alphas))
return index
class FBNetMutator(BaseMutator):
"""
This mutator initializes and operates all the LayerChoices of the supernet.
It is for the related trainer to control the training flow of LayerChoices,
coordinating with whole training process.
"""
def __init__(self, model, lookup_table):
"""
Init a MixedOp instance for each mutable i.e., LayerChoice.
And register the instantiated MixedOp in corresponding LayerChoice.
If does not register it in LayerChoice, DataParallel does'nt work then,
for architecture weights are not included in the DataParallel model.
When MixedOPs are registered, we use ```requires_grad``` to control
whether calculate gradients of architecture weights.
Parameters
----------
model : pytorch model
The model that users want to tune,
it includes search space defined with nni nas apis
lookup_table : class
lookup table object to manage model space information,
including candidate ops for each stage as the model space,
input channels/output channels/stride/fm_size as the layer config,
and the performance information for perf_cost accumulation.
"""
super(FBNetMutator, self).__init__(model)
self.mutable_list = []
# Collect the op names of the candidate ops within each mutable
ops_names_mutable = dict()
left = 0
right = 1
for stage_name in lookup_table.layer_num:
right = lookup_table.layer_num[stage_name]
stage_ops = lookup_table.lut_ops[stage_name]
ops_names = [op_name for op_name in stage_ops]
for i in range(left, left + right):
ops_names_mutable[i] = ops_names
left += right
# Create the mixed op
for i, mutable in enumerate(self.undedup_mutables):
ops_names = ops_names_mutable[i]
latency_mutable = lookup_table.lut_perf[i]
latency = [latency_mutable[op_name] for op_name in ops_names]
self.mutable_list.append(mutable)
mutable.registered_module = MixedOp(mutable, latency)
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
implementation is defined in mutator.
Parameters
----------
mutable: LayerChoice
forward logic of this input mutable
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
Returns
-------
torch.Tensor
output of this mutable, i.e., LayerChoice
int
index of the chosen op
"""
# FIXME: return mask, to be consistent with other algorithms
idx = mutable.registered_module.chosen_index
return mutable.registered_module(mutable, *args, **kwargs), idx
def num_arch_params(self):
"""
The number of mutables, i.e., LayerChoice
Returns
-------
int
the number of LayerChoice in user model
"""
return len(self.mutable_list)
def get_architecture_parameters(self):
"""
Get all the architecture parameters.
yield
-----
PyTorch Parameter
Return path_alpha of the traversed mutable
"""
for mutable in self.undedup_mutables:
yield mutable.registered_module.get_path_alpha()
def get_weighted_latency(self):
"""
Get the latency weighted by gumbel softmax coefficients.
yield
-----
Tuple
Return the weighted_latency of the traversed mutable
"""
for mutable in self.undedup_mutables:
yield mutable.registered_module.get_weighted_latency()
def set_temperature(self, temperature):
"""
Set the annealed temperature of the op for gumbel softmax.
Parameters
----------
temperature : float
The annealed temperature for gumbel softmax
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_temperature(temperature)
def arch_requires_grad(self):
"""
Make architecture weights require gradient
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_requires_grad()
def arch_disable_grad(self):
"""
Disable gradient of architecture weights, i.e., does not
calculate gradient for them.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_disable_grad()
def sample_final(self):
"""
Generate the final chosen architecture.
Returns
-------
dict
the choice of each mutable, i.e., LayerChoice
"""
result = dict()
for mutable in self.undedup_mutables:
assert isinstance(mutable, LayerChoice)
index = mutable.registered_module.chosen_index
# pylint: disable=not-callable
result[mutable.key] = (
F.one_hot(torch.tensor(index), num_classes=len(mutable))
.view(-1)
.bool(),
)
return result

Просмотреть файл

@ -1,413 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import, division, print_function
import json
import os
import time
import torch
import numpy as np
from torch.autograd import Variable
from nni.nas.pytorch.base_trainer import BaseTrainer
from nni.nas.pytorch.trainer import TorchTensorEncoder
from nni.nas.pytorch.utils import AverageMeter
from .mutator import FBNetMutator
from .utils import RegularizerLoss, accuracy
class FBNetTrainer(BaseTrainer):
def __init__(
self,
model,
model_optim,
criterion,
device,
device_ids,
lookup_table,
train_loader,
valid_loader,
n_epochs=120,
load_ckpt=False,
arch_path=None,
logger=None,
):
"""
Parameters
----------
model : pytorch model
the user model, which has mutables
model_optim : pytorch optimizer
the user defined optimizer
criterion : pytorch loss
the main task loss, nn.CrossEntropyLoss() is for classification
device : pytorch device
the devices to train/search the model
device_ids : list of int
the indexes of devices used for training
lookup_table : class
lookup table object for fbnet training
train_loader : pytorch data loader
data loader for the training set
valid_loader : pytorch data loader
data loader for the validation set
n_epochs : int
number of epochs to train/search
load_ckpt : bool
whether load checkpoint
arch_path : str
the path to store chosen architecture
logger : logger
the logger
"""
self.model = model
self.model_optim = model_optim
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
self.dev_num = len(device_ids)
self.n_epochs = n_epochs
self.lookup_table = lookup_table
self.config = lookup_table.config
self.start_epoch = self.config.start_epoch
self.temp = self.config.init_temperature
self.exp_anneal_rate = self.config.exp_anneal_rate
self.mode = self.config.mode
self.load_ckpt = load_ckpt
self.arch_path = arch_path
self.logger = logger
# scheduler of learning rate
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
model_optim, T_max=n_epochs, last_epoch=-1
)
# init mutator
self.mutator = FBNetMutator(model, lookup_table)
self.mutator.set_temperature(self.temp)
# DataParallel should be put behind the init of mutator
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
self.model.to(device)
# build architecture optimizer
self.arch_optimizer = torch.optim.AdamW(
self.mutator.get_architecture_parameters(),
self.config.nas_lr,
weight_decay=self.config.nas_weight_decay,
)
self.reg_loss = RegularizerLoss(config=self.config)
self.criterion = criterion
self.epoch = 0
def _layer_choice_sample(self):
"""
Sample the index of network within layer choice
"""
stages = [stage_name for stage_name in self.lookup_table.layer_num]
stage_lnum = [self.lookup_table.layer_num[stage] for stage in stages]
# get the choice idx in each layer
choice_ids = list()
layer_id = 0
for param in self.mutator.get_architecture_parameters():
param_np = param.cpu().detach().numpy()
op_idx = np.argmax(param_np)
choice_ids.append(op_idx)
self.logger.info(
"layer {}: {}, index: {}".format(layer_id, param_np, op_idx)
)
layer_id += 1
# get the arch_sample
choice_names = list()
layer_id = 0
for i, stage_name in enumerate(stages):
ops_names = [op for op in self.lookup_table.lut_ops[stage_name]]
for _ in range(stage_lnum[i]):
searched_op = ops_names[choice_ids[layer_id]]
choice_names.append(searched_op)
layer_id += 1
self.logger.info(choice_names)
return choice_names
def _get_perf_cost(self, requires_grad=True):
"""
Get the accumulated performance cost.
"""
perf_cost = Variable(
torch.zeros(1), requires_grad=requires_grad
).to(self.device, non_blocking=True)
for latency in self.mutator.get_weighted_latency():
perf_cost = perf_cost + latency
return perf_cost
def _validate(self):
"""
Do validation. During validation, LayerChoices use the mixed-op.
Returns
-------
float, float, float
average loss, average top1 accuracy, average top5 accuracy
"""
self.valid_loader.batch_sampler.drop_last = False
batch_time = AverageMeter("batch_time")
losses = AverageMeter("losses")
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
# test on validation set under eval mode
self.model.eval()
end = time.time()
with torch.no_grad():
for i, (images, labels) in enumerate(self.valid_loader):
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
output = self.model(images)
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == len(self.valid_loader):
test_log = (
"Valid" + ": [{0}/{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})".format(
i,
len(self.valid_loader) - 1,
batch_time=batch_time,
loss=losses,
top1=top1,
top5=top5,
)
)
self.logger.info(test_log)
return losses.avg, top1.avg, top5.avg
def _train_epoch(self, epoch, optimizer, arch_train=False):
"""
Train one epoch.
"""
batch_time = AverageMeter("batch_time")
data_time = AverageMeter("data_time")
losses = AverageMeter("losses")
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
# switch to train mode
self.model.train()
data_loader = self.valid_loader if arch_train else self.train_loader
end = time.time()
for i, (images, labels) in enumerate(data_loader):
data_time.update(time.time() - end)
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
output = self.model(images)
loss = self.criterion(output, labels)
# hardware-aware loss
perf_cost = self._get_perf_cost(requires_grad=True)
regu_loss = self.reg_loss(perf_cost)
if self.mode.startswith("mul"):
loss = loss * regu_loss
elif self.mode.startswith("add"):
loss = loss + regu_loss
# measure accuracy and record loss
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0].item(), images.size(0))
top5.update(acc5[0].item(), images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
batch_log = (
"Warmup Train [{0}][{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"Loss {losses.val:.4f} ({losses.avg:.4f})\t"
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\t".format(
epoch + 1,
i,
batch_time=batch_time,
data_time=data_time,
losses=losses,
top1=top1,
top5=top5,
)
)
self.logger.info(batch_log)
def _warm_up(self):
"""
Warm up the model, while the architecture weights are not trained.
"""
for epoch in range(self.epoch, self.start_epoch):
self.logger.info("\n--------Warmup epoch: %d--------\n", epoch + 1)
self._train_epoch(epoch, self.model_optim)
# adjust learning rate
self.scheduler.step()
# validation
val_loss, val_top1, val_top5 = self._validate()
val_log = (
"Warmup Valid [{0}/{1}]\t"
"loss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}".format(
epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5
)
)
self.logger.info(val_log)
if epoch % 10 == 0:
filename = os.path.join(
self.config.model_dir, "checkpoint_%s.pth" % epoch
)
self.save_checkpoint(epoch, filename)
def _train(self):
"""
Train the model, it trains model weights and architecute weights.
Architecture weights are trained according to the schedule.
Before updating architecture weights, ```requires_grad``` is enabled.
Then, it is disabled after the updating, in order not to update
architecture weights when training model weights.
"""
arch_param_num = self.mutator.num_arch_params()
self.logger.info("#arch_params: {}".format(arch_param_num))
self.epoch = max(self.start_epoch, self.epoch)
ckpt_path = self.config.model_dir
choice_names = None
top1_best = 0.0
for epoch in range(self.epoch, self.n_epochs):
self.logger.info("\n--------Train epoch: %d--------\n", epoch + 1)
# update the weight parameters
self._train_epoch(epoch, self.model_optim)
# adjust learning rate
self.scheduler.step()
self.logger.info("Update architecture parameters")
# update the architecture parameters
self.mutator.arch_requires_grad()
self._train_epoch(epoch, self.arch_optimizer, True)
self.mutator.arch_disable_grad()
# temperature annealing
self.temp = self.temp * self.exp_anneal_rate
self.mutator.set_temperature(self.temp)
# sample the architecture of sub-network
choice_names = self._layer_choice_sample()
# validate
val_loss, val_top1, val_top5 = self._validate()
val_log = (
"Valid [{0}]\t"
"loss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}".format(
epoch + 1, val_loss, val_top1, val_top5
)
)
self.logger.info(val_log)
if epoch % 10 == 0:
filename = os.path.join(ckpt_path, "checkpoint_%s.pth" % epoch)
self.save_checkpoint(epoch, filename, choice_names)
val_top1 = val_top1.cpu().as_numpy()
if val_top1 > top1_best:
filename = os.path.join(ckpt_path, "checkpoint_best.pth")
self.save_checkpoint(epoch, filename, choice_names)
top1_best = val_top1
def save_checkpoint(self, epoch, filename, choice_names=None):
"""
Save checkpoint of the whole model.
Saving model weights and architecture weights as ```filename```,
and saving currently chosen architecture in ```arch_path```.
"""
state = {
"model": self.model.state_dict(),
"optim": self.model_optim.state_dict(),
"epoch": epoch,
"arch_sample": choice_names,
}
torch.save(state, filename)
self.logger.info("Save checkpoint to {0:}".format(filename))
if self.arch_path:
self.export(self.arch_path)
def load_checkpoint(self, filename):
"""
Load the checkpoint from ```ckpt_path```.
"""
ckpt = torch.load(filename)
self.epoch = ckpt["epoch"]
self.model.load_state_dict(ckpt["model"])
self.model_optim.load_state_dict(ckpt["optim"])
def train(self):
"""
Train the whole model.
"""
if self.load_ckpt:
ckpt_path = self.config.model_dir
filename = os.path.join(ckpt_path, "checkpoint_best.pth")
if os.path.exists(filename):
self.load_checkpoint(filename)
if self.epoch < self.start_epoch:
self._warm_up()
self._train()
def export(self, file_name):
"""
Export the chosen architecture into a file
Parameters
----------
file_name : str
the file that stores exported chosen architecture
"""
exported_arch = self.mutator.sample_final()
with open(file_name, "w") as f:
json.dump(
exported_arch,
f,
indent=2,
sort_keys=True,
cls=TorchTensorEncoder,
)
def validate(self):
raise NotImplementedError
def checkpoint(self):
raise NotImplementedError

Просмотреть файл

@ -1,433 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import, division, print_function
import ast
import os
import timeit
import torch
import numpy as np
import torch.nn as nn
from nni.compression.pytorch.utils import count_flops_params
LUT_FILE = "lut.npy"
LUT_JSON_FILE = "lut.txt"
LUT_PATH = "lut"
DATA_TYPE = "float"
class NASConfig:
def __init__(
self,
perf_metric="flops",
lut_load=False,
lut_load_format="json",
model_dir=None,
nas_lr=0.01,
nas_weight_decay=5e-4,
mode="mul",
alpha=0.25,
beta=0.6,
start_epoch=50,
init_temperature=5.0,
exp_anneal_rate=np.exp(-0.045),
search_space=None,
):
# LUT of performance metric
# flops means the multiplies, latency means the time cost on platform
self.perf_metric = perf_metric
assert perf_metric in [
"flops",
"latency",
], "perf_metric should be ['flops', 'latency']"
# wether load or create lut file
self.lut_load = lut_load
assert lut_load_format in [
"json",
"numpy",
], "lut_load_format should be ['json', 'numpy']"
self.lut_load_format = lut_load_format
# necessary dirs
self.lut_en = model_dir is not None
if self.lut_en:
self.model_dir = model_dir
os.makedirs(model_dir, exist_ok=True)
self.lut_path = os.path.join(model_dir, LUT_PATH)
os.makedirs(self.lut_path, exist_ok=True)
# NAS learning setting
self.nas_lr = nas_lr
self.nas_weight_decay = nas_weight_decay
# hardware-aware loss setting
self.mode = mode
assert mode in ["mul", "add"], "mode should be ['mul', 'add']"
self.alpha = alpha
self.beta = beta
# NAS training setting
self.start_epoch = start_epoch
self.init_temperature = init_temperature
self.exp_anneal_rate = exp_anneal_rate
# definition of search blocks and space
self.search_space = search_space
class RegularizerLoss(nn.Module):
"""Auxilliary loss for hardware-aware NAS."""
def __init__(self, config):
"""
Parameters
----------
config : class
to manage the configuration for NAS training, and search space etc.
"""
super(RegularizerLoss, self).__init__()
self.mode = config.mode
self.alpha = config.alpha
self.beta = config.beta
def forward(self, perf_cost, batch_size=1):
"""
Parameters
----------
perf_cost : tensor
the accumulated performance cost
batch_size : int
batch size for normalization
Returns
-------
output: tensor
the hardware-aware constraint loss
"""
if self.mode == "mul":
log_loss = torch.log(perf_cost / batch_size) ** self.beta
return self.alpha * log_loss
elif self.mode == "add":
linear_loss = (perf_cost / batch_size) ** self.beta
return self.alpha * linear_loss
else:
raise NotImplementedError
def accuracy(output, target, topk=(1,)):
"""
Computes the precision@k for the specified values of k
Parameters
----------
output : pytorch tensor
output, e.g., predicted value
target : pytorch tensor
label
topk : tuple
specify top1 and top5
Returns
-------
list
accuracy of top1 and top5
"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def supernet_sample(model, state_dict, sampled_arch=[], lookup_table=None):
"""
Initialize the searched sub-model from supernet.
Parameters
----------
model : pytorch model
the created subnet
state_dict : checkpoint
the checkpoint of supernet, including the pre-trained params
sampled_arch : list of str
the searched layer names of the subnet
lookup_table : class
to manage the candidate ops, layer information and layer performance
"""
replace = list()
stages = [stage for stage in lookup_table.layer_num]
stage_lnum = [lookup_table.layer_num[stage] for stage in stages]
if sampled_arch:
layer_id = 0
for i, stage in enumerate(stages):
ops_names = [op_name for op_name in lookup_table.lut_ops[stage]]
for _ in range(stage_lnum[i]):
searched_op = sampled_arch[layer_id]
op_i = ops_names.index(searched_op)
replace.append(
[
"blocks.{}.".format(layer_id),
"blocks.{}.op.".format(layer_id),
"blocks.{}.{}.".format(layer_id, op_i),
]
)
layer_id += 1
model_init(model, state_dict, replace=replace)
def model_init(model, state_dict, replace=[]):
"""Initialize the model from state_dict."""
prefix = "module."
param_dict = dict()
for k, v in state_dict.items():
if k.startswith(prefix):
k = k[7:]
param_dict[k] = v
for k, (name, m) in enumerate(model.named_modules()):
if replace:
for layer_replace in replace:
assert len(layer_replace) == 3, "The elements should be three."
pre_scope, key, replace_key = layer_replace
if pre_scope in name:
name = name.replace(key, replace_key)
# Copy the state_dict to current model
if (name + ".weight" in param_dict) or (
name + ".running_mean" in param_dict
):
if isinstance(m, nn.BatchNorm2d):
shape = m.running_mean.shape
if shape == param_dict[name + ".running_mean"].shape:
if m.weight is not None:
m.weight.data = param_dict[name + ".weight"]
m.bias.data = param_dict[name + ".bias"]
m.running_mean = param_dict[name + ".running_mean"]
m.running_var = param_dict[name + ".running_var"]
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
shape = m.weight.data.shape
if shape == param_dict[name + ".weight"].shape:
m.weight.data = param_dict[name + ".weight"]
if m.bias is not None:
m.bias.data = param_dict[name + ".bias"]
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data = param_dict[name + ".weight"]
if m.bias is not None:
m.bias.data = param_dict[name + ".bias"]
class LookUpTable:
"""Build look-up table for NAS."""
def __init__(self, config, primitives):
"""
Parameters
----------
config : class
to manage the configuration for NAS training, and search space etc.
"""
self.config = config
# definition of search blocks and space
self.search_space = config.search_space
# layers for NAS
self.cnt_layers = len(self.search_space["input_shape"])
# constructors for each operation
self.lut_ops = {
stage_name: {
op_name: primitives[op_name]
for op_name in self.search_space["stages"][stage_name]["ops"]
}
for stage_name in self.search_space["stages"]
}
self.layer_num = {
stage_name: self.search_space["stages"][stage_name]["layer_num"]
for stage_name in self.search_space["stages"]
}
# arguments for the ops constructors, input_shapes just for convinience
self.layer_configs, self.layer_in_shapes = self._layer_configs()
# lookup_table
self.perf_metric = config.perf_metric
if config.lut_en:
self.lut_perf = None
self.lut_file = os.path.join(config.lut_path, LUT_FILE)
self.lut_json_file = LUT_JSON_FILE
if config.lut_load:
if config.lut_load_format == "numpy":
# Load data from numpy file
self._load_from_file()
else:
# Load data from json file
self._load_from_json_file()
else:
self._create_perfs()
def _layer_configs(self):
"""Generate basic params for different layers."""
# layer_configs are : c_in, c_out, stride, fm_size
layer_configs = [
[
self.search_space["input_shape"][layer_id][0],
self.search_space["channel_size"][layer_id],
self.search_space["strides"][layer_id],
self.search_space["fm_size"][layer_id],
]
for layer_id in range(self.cnt_layers)
]
# layer_in_shapes are (C_in, input_w, input_h)
layer_in_shapes = self.search_space["input_shape"]
return layer_configs, layer_in_shapes
def _create_perfs(self, cnt_of_runs=200):
"""Create performance cost for each op."""
if self.perf_metric == "latency":
self.lut_perf = self._calculate_latency(cnt_of_runs)
elif self.perf_metric == "flops":
self.lut_perf = self._calculate_flops()
self._write_lut_to_file()
def _calculate_flops(self, eps=0.001):
"""FLOPs cost."""
flops_lut = [{} for i in range(self.cnt_layers)]
layer_id = 0
for stage_name in self.lut_ops:
stage_ops = self.lut_ops[stage_name]
ops_num = self.layer_num[stage_name]
for _ in range(ops_num):
for op_name in stage_ops:
layer_config = self.layer_configs[layer_id]
key_params = {"fm_size": layer_config[3]}
op = stage_ops[op_name](*layer_config[0:3], **key_params)
# measured in Flops
in_shape = self.layer_in_shapes[layer_id]
x = (1, in_shape[0], in_shape[1], in_shape[2])
flops, _, _ = count_flops_params(op, x, verbose=False)
flops = eps if flops == 0.0 else flops
flops_lut[layer_id][op_name] = float(flops)
layer_id += 1
return flops_lut
def _calculate_latency(self, cnt_of_runs):
"""Latency cost."""
LATENCY_BATCH_SIZE = 1
latency_lut = [{} for i in range(self.cnt_layers)]
layer_id = 0
for stage_name in self.lut_ops:
stage_ops = self.lut_ops[stage_name]
ops_num = self.layer_num[stage_name]
for _ in range(ops_num):
for op_name in stage_ops:
layer_config = self.layer_configs[layer_id]
key_params = {"fm_size": layer_config[3]}
op = stage_ops[op_name](*layer_config[0:3], **key_params)
input_data = torch.randn(
(LATENCY_BATCH_SIZE, *self.layer_in_shapes[layer_id])
)
globals()["op"], globals()["input_data"] = op, input_data
total_time = timeit.timeit(
"output = op(input_data)",
setup="gc.enable()",
globals=globals(),
number=cnt_of_runs,
)
# measured in micro-second
latency_lut[layer_id][op_name] = (
total_time / cnt_of_runs / LATENCY_BATCH_SIZE * 1e6
)
layer_id += 1
return latency_lut
def _write_lut_to_file(self):
"""Save lut as numpy file."""
np.save(self.lut_file, self.lut_perf)
def _load_from_file(self):
"""Load numpy file."""
self.lut_perf = np.load(self.lut_file, allow_pickle=True)
def _load_from_json_file(self):
"""Load json file."""
"""
lut_json_file ('lut.txt') format:
{'op_name': operator_name,
'op_data_shape': (input_w, input_h, C_in, C_out, stride),
'op_dtype': data_type,
'op_latency': latency}
{...}
{...}
"""
latency_file = open(self.lut_json_file, "r")
ops_latency = latency_file.readlines()
"""ops_lut: {'op_name': {'op_data_shape': {'op_dtype': latency}}}"""
ops_lut = {}
for op_latency in ops_latency:
assert isinstance(op_latency, str) or isinstance(op_latency, dict)
if isinstance(op_latency, str):
record = ast.literal_eval(op_latency)
elif isinstance(op_latency, dict):
record = op_latency
op_name = record["op_name"]
"""op_data_shape: (input_w, input_h, C_in, C_out, stride)"""
op_data_shape = record["op_data_shape"]
op_dtype = record["op_dtype"]
op_latency = record["op_latency"]
if op_name not in ops_lut:
ops_lut[op_name] = {}
if op_data_shape not in ops_lut[op_name]:
ops_lut[op_name][op_data_shape] = {}
ops_lut[op_name][op_data_shape][op_dtype] = op_latency
self.lut_perf = [{} for i in range(self.cnt_layers)]
layer_id = 0
for stage_name in self.lut_ops:
stage_ops = self.lut_ops[stage_name]
ops_num = self.layer_num[stage_name]
for _ in range(ops_num):
for op_name in stage_ops:
layer_config = self.layer_configs[layer_id]
layer_in_shape = self.layer_in_shapes[layer_id]
input_w = layer_in_shape[1]
input_h = layer_in_shape[2]
c_in = layer_config[0]
c_out = layer_config[1]
stride = layer_config[2]
op_data_shape = (input_w, input_h, c_in, c_out, stride)
if op_name in ops_lut and op_data_shape in ops_lut[op_name]:
self.lut_perf[layer_id][op_name] = \
ops_lut[op_name][op_data_shape][DATA_TYPE]
layer_id += 1

Просмотреть файл

@ -1,93 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import numpy as np
import torch
from torch import nn
from nni.algorithms.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator):
"""
It works with PdartsTrainer to calculate ops weights,
and drop weights in different PDARTS epochs.
"""
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
self.pdarts_epoch_index = pdarts_epoch_index
self.pdarts_num_to_drop = pdarts_num_to_drop
if switches is None:
self.switches = {}
else:
self.switches = switches
super(PdartsMutator, self).__init__(model)
# this loop go through mutables with different keys,
# it's mainly to update length of choices.
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(len(mutable))])
choices = self.choices[mutable.key]
operations_count = np.sum(switches)
# +1 and -1 are caused by zero operation in darts network
# the zero operation is not in choices list in network, but its weight are in,
# so it needs one more weights and switch for zero.
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
self.switches[mutable.key] = switches
# update LayerChoice instances in model,
# it's physically remove dropped choices operations.
for module in self.model.modules():
if isinstance(module, LayerChoice):
switches = self.switches.get(module.key)
choices = self.choices[module.key]
if len(module) > len(choices):
# from last to first, so that it won't effect previous indexes after removed one.
for index in range(len(switches)-1, -1, -1):
if switches[index] == False:
del module[index]
assert len(module) <= len(choices), "Failed to remove dropped choices."
def export(self):
# Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
results = super().sample_final()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
# As some operations are dropped physically,
# so it needs to fill back false to track dropped operations.
trained_result = results[mutable.key]
trained_index = 0
switches = self.switches[mutable.key]
result = torch.Tensor(switches).bool()
for index in range(len(result)):
if result[index]:
result[index] = trained_result[trained_index]
trained_index += 1
results[mutable.key] = result
return results
def drop_paths(self):
"""
This method is called when a PDARTS epoch is finished.
It prepares switches for next epoch.
candidate operations with False switch will be doppped in next epoch.
"""
all_switches = copy.deepcopy(self.switches)
for key in all_switches:
switches = all_switches[key]
idxs = []
for j in range(len(switches)):
if switches[j]:
idxs.append(j)
sorted_weights = self.choices[key].data.cpu().numpy()[:-1]
drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]]
for idx in drop:
switches[idxs[idx]] = False
return all_switches

Просмотреть файл

@ -1,86 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import BaseTrainer, TorchTensorEncoder
from .mutator import PdartsMutator
logger = logging.getLogger(__name__)
class PdartsTrainer(BaseTrainer):
"""
This trainer implements the PDARTS algorithm.
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
pdarts_num_layers means how many layers more than first epoch.
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
So that the grew network can in similar size.
"""
def __init__(self, model_creator, init_layers, metrics,
num_epochs, dataset_train, dataset_valid,
pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 1],
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, unrolled=False):
super(PdartsTrainer, self).__init__()
self.model_creator = model_creator
self.init_layers = init_layers
self.pdarts_num_layers = pdarts_num_layers
self.pdarts_num_to_drop = pdarts_num_to_drop
self.pdarts_epoch = len(pdarts_num_to_drop)
self.darts_parameters = {
"metrics": metrics,
"num_epochs": num_epochs,
"dataset_train": dataset_train,
"dataset_valid": dataset_valid,
"batch_size": batch_size,
"workers": workers,
"device": device,
"log_frequency": log_frequency,
"unrolled": unrolled
}
self.callbacks = callbacks if callbacks is not None else []
def train(self):
switches = None
for epoch in range(self.pdarts_epoch):
layers = self.init_layers+self.pdarts_num_layers[epoch]
model, criterion, optim, lr_scheduler = self.model_creator(layers)
self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)
for callback in self.callbacks:
callback.build(model, self.mutator, self)
callback.on_epoch_begin(epoch)
darts_callbacks = []
if lr_scheduler is not None:
darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
callbacks=darts_callbacks, **self.darts_parameters)
logger.info("start pdarts training epoch %s...", epoch)
self.trainer.train()
switches = self.mutator.drop_paths()
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def validate(self):
self.trainer.validate()
def export(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self):
raise NotImplementedError("Not implemented yet")

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import ProxylessNasMutator
from .trainer import ProxylessNasTrainer

Просмотреть файл

@ -1,478 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
import numpy as np
from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.pytorch.mutables import LayerChoice
from .utils import detach_variable
class ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = detach_variable(x)
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class MixedOp(nn.Module):
"""
This class is to instantiate and manage info of one LayerChoice.
It includes architecture weights, binary weights, and member functions
operating the weights.
forward_mode:
forward/backward mode for LayerChoice: None, two, full, and full_v2.
For training architecture weights, we use full_v2 by default, and for training
model weights, we use None.
"""
forward_mode = None
def __init__(self, mutable):
"""
Parameters
----------
mutable : LayerChoice
A LayerChoice in user model
"""
super(MixedOp, self).__init__()
self.ap_path_alpha = nn.Parameter(torch.Tensor(len(mutable)))
self.ap_path_wb = nn.Parameter(torch.Tensor(len(mutable)))
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
self.active_index = [0]
self.inactive_index = None
self.log_prob = None
self.current_prob_over_ops = None
self.n_choices = len(mutable)
def get_ap_path_alpha(self):
return self.ap_path_alpha
def to_requires_grad(self):
self.ap_path_alpha.requires_grad = True
self.ap_path_wb.requires_grad = True
def to_disable_grad(self):
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
def forward(self, mutable, x):
"""
Define forward of LayerChoice. For 'full_v2', backward is also defined.
The 'two' mode is explained in section 3.2.1 in the paper.
The 'full_v2' mode is explained in Appendix D in the paper.
Parameters
----------
mutable : LayerChoice
this layer's mutable
x : tensor
inputs of this layer, only support one input
Returns
-------
output: tensor
output of this layer
"""
if MixedOp.forward_mode == 'full' or MixedOp.forward_mode == 'two':
output = 0
for _i in self.active_index:
oi = self.candidate_ops[_i](x)
output = output + self.ap_path_wb[_i] * oi
for _i in self.inactive_index:
oi = self.candidate_ops[_i](x)
output = output + self.ap_path_wb[_i] * oi.detach()
elif MixedOp.forward_mode == 'full_v2':
def run_function(key, candidate_ops, active_id):
def forward(_x):
return candidate_ops[active_id](_x)
return forward
def backward_function(key, candidate_ops, active_id, binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(candidate_ops)):
if k != active_id:
out_k = candidate_ops[k](_x.data)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
output = ArchGradientFunction.apply(
x, self.ap_path_wb, run_function(mutable.key, list(mutable), self.active_index[0]),
backward_function(mutable.key, list(mutable), self.active_index[0], self.ap_path_wb))
else:
output = self.active_op(mutable)(x)
return output
@property
def probs_over_ops(self):
"""
Apply softmax on alpha to generate probability distribution
Returns
-------
pytorch tensor
probability distribution
"""
probs = F.softmax(self.ap_path_alpha, dim=0) # softmax to probability
return probs
@property
def chosen_index(self):
"""
choose the op with max prob
Returns
-------
int
index of the chosen one
numpy.float32
prob of the chosen one
"""
probs = self.probs_over_ops.data.cpu().numpy()
index = int(np.argmax(probs))
return index, probs[index]
def active_op(self, mutable):
"""
assume only one path is active
Returns
-------
PyTorch module
the chosen operation
"""
return mutable[self.active_index[0]]
@property
def active_op_index(self):
"""
return active op's index, the active op is sampled
Returns
-------
int
index of the active op
"""
return self.active_index[0]
def set_chosen_op_active(self):
"""
set chosen index, active and inactive indexes
"""
chosen_idx, _ = self.chosen_index
self.active_index = [chosen_idx]
self.inactive_index = [_i for _i in range(0, chosen_idx)] + \
[_i for _i in range(chosen_idx + 1, self.n_choices)]
def binarize(self, mutable):
"""
Sample based on alpha, and set binary weights accordingly.
ap_path_wb is set in this function, which is called binarize.
Parameters
----------
mutable : LayerChoice
this layer's mutable
"""
self.log_prob = None
# reset binary gates
self.ap_path_wb.data.zero_()
probs = self.probs_over_ops
if MixedOp.forward_mode == 'two':
# sample two ops according to probs
sample_op = torch.multinomial(probs.data, 2, replacement=False)
probs_slice = F.softmax(torch.stack([
self.ap_path_alpha[idx] for idx in sample_op
]), dim=0)
self.current_prob_over_ops = torch.zeros_like(probs)
for i, idx in enumerate(sample_op):
self.current_prob_over_ops[idx] = probs_slice[i]
# choose one to be active and the other to be inactive according to probs_slice
c = torch.multinomial(probs_slice.data, 1)[0] # 0 or 1
active_op = sample_op[c].item()
inactive_op = sample_op[1-c].item()
self.active_index = [active_op]
self.inactive_index = [inactive_op]
# set binary gate
self.ap_path_wb.data[active_op] = 1.0
else:
sample = torch.multinomial(probs, 1)[0].item()
self.active_index = [sample]
self.inactive_index = [_i for _i in range(0, sample)] + \
[_i for _i in range(sample + 1, len(mutable))]
self.log_prob = torch.log(probs[sample])
self.current_prob_over_ops = probs
self.ap_path_wb.data[sample] = 1.0
# avoid over-regularization
for choice in mutable:
for _, param in choice.named_parameters():
param.grad = None
@staticmethod
def delta_ij(i, j):
if i == j:
return 1
else:
return 0
def set_arch_param_grad(self, mutable):
"""
Calculate alpha gradient for this LayerChoice.
It is calculated using gradient of binary gate, probs of ops.
"""
binary_grads = self.ap_path_wb.grad.data
if self.active_op(mutable).is_zero_layer():
self.ap_path_alpha.grad = None
return
if self.ap_path_alpha.grad is None:
self.ap_path_alpha.grad = torch.zeros_like(self.ap_path_alpha.data)
if MixedOp.forward_mode == 'two':
involved_idx = self.active_index + self.inactive_index
probs_slice = F.softmax(torch.stack([
self.ap_path_alpha[idx] for idx in involved_idx
]), dim=0).data
for i in range(2):
for j in range(2):
origin_i = involved_idx[i]
origin_j = involved_idx[j]
self.ap_path_alpha.grad.data[origin_i] += \
binary_grads[origin_j] * probs_slice[j] * (MixedOp.delta_ij(i, j) - probs_slice[i])
for _i, idx in enumerate(self.active_index):
self.active_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
for _i, idx in enumerate(self.inactive_index):
self.inactive_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
else:
probs = self.probs_over_ops.data
for i in range(self.n_choices):
for j in range(self.n_choices):
self.ap_path_alpha.grad.data[i] += binary_grads[j] * probs[j] * (MixedOp.delta_ij(i, j) - probs[i])
return
def rescale_updated_arch_param(self):
"""
rescale architecture weights for the 'two' mode.
"""
if not isinstance(self.active_index[0], tuple):
assert self.active_op.is_zero_layer()
return
involved_idx = [idx for idx, _ in (self.active_index + self.inactive_index)]
old_alphas = [alpha for _, alpha in (self.active_index + self.inactive_index)]
new_alphas = [self.ap_path_alpha.data[idx] for idx in involved_idx]
offset = math.log(
sum([math.exp(alpha) for alpha in new_alphas]) / sum([math.exp(alpha) for alpha in old_alphas])
)
for idx in involved_idx:
self.ap_path_alpha.data[idx] -= offset
class ProxylessNasMutator(BaseMutator):
"""
This mutator initializes and operates all the LayerChoices of the input model.
It is for the corresponding trainer to control the training process of LayerChoices,
coordinating with whole training process.
"""
def __init__(self, model):
"""
Init a MixedOp instance for each mutable i.e., LayerChoice.
And register the instantiated MixedOp in corresponding LayerChoice.
If does not register it in LayerChoice, DataParallel does not work then,
because architecture weights are not included in the DataParallel model.
When MixedOPs are registered, we use ```requires_grad``` to control
whether calculate gradients of architecture weights.
Parameters
----------
model : pytorch model
The model that users want to tune, it includes search space defined with nni nas apis
"""
super(ProxylessNasMutator, self).__init__(model)
self._unused_modules = None
self.mutable_list = []
for mutable in self.undedup_mutables:
self.mutable_list.append(mutable)
mutable.registered_module = MixedOp(mutable)
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
implementation is defined in mutator.
Parameters
----------
mutable: LayerChoice
forward logic of this input mutable
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
Returns
-------
torch.Tensor
output of this mutable, i.e., LayerChoice
int
index of the chosen op
"""
# FIXME: return mask, to be consistent with other algorithms
idx = mutable.registered_module.active_op_index
return mutable.registered_module(mutable, *args, **kwargs), idx
def reset_binary_gates(self):
"""
For each LayerChoice, binarize binary weights
based on alpha to only activate one op.
It traverses all the mutables in the model to do this.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.binarize(mutable)
def set_chosen_op_active(self):
"""
For each LayerChoice, set the op with highest alpha as the chosen op.
Usually used for validation.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_chosen_op_active()
def num_arch_params(self):
"""
The number of mutables, i.e., LayerChoice
Returns
-------
int
the number of LayerChoice in user model
"""
return len(self.mutable_list)
def set_arch_param_grad(self):
"""
For each LayerChoice, calculate gradients for architecture weights, i.e., alpha
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_arch_param_grad(mutable)
def get_architecture_parameters(self):
"""
Get all the architecture parameters.
yield
-----
PyTorch Parameter
Return ap_path_alpha of the traversed mutable
"""
for mutable in self.undedup_mutables:
yield mutable.registered_module.get_ap_path_alpha()
def change_forward_mode(self, mode):
"""
Update forward mode of MixedOps, as training architecture weights and
model weights use different forward modes.
"""
MixedOp.forward_mode = mode
def get_forward_mode(self):
"""
Get forward mode of MixedOp
Returns
-------
string
the current forward mode of MixedOp
"""
return MixedOp.forward_mode
def rescale_updated_arch_param(self):
"""
Rescale architecture weights in 'two' mode.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.rescale_updated_arch_param()
def unused_modules_off(self):
"""
Remove unused modules for each mutables.
The removed modules are kept in ```self._unused_modules``` for resume later.
"""
self._unused_modules = []
for mutable in self.undedup_mutables:
mixed_op = mutable.registered_module
unused = {}
if self.get_forward_mode() in ['full', 'two', 'full_v2']:
involved_index = mixed_op.active_index + mixed_op.inactive_index
else:
involved_index = mixed_op.active_index
for i in range(mixed_op.n_choices):
if i not in involved_index:
unused[i] = mutable[i]
mutable[i] = None
self._unused_modules.append(unused)
def unused_modules_back(self):
"""
Resume the removed modules back.
"""
if self._unused_modules is None:
return
for m, unused in zip(self.mutable_list, self._unused_modules):
for i in unused:
m[i] = unused[i]
self._unused_modules = None
def arch_requires_grad(self):
"""
Make architecture weights require gradient
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_requires_grad()
def arch_disable_grad(self):
"""
Disable gradient of architecture weights, i.e., does not
calcuate gradient for them.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_disable_grad()
def sample_final(self):
"""
Generate the final chosen architecture.
Returns
-------
dict
the choice of each mutable, i.e., LayerChoice
"""
result = dict()
for mutable in self.undedup_mutables:
assert isinstance(mutable, LayerChoice)
index, _ = mutable.registered_module.chosen_index
# pylint: disable=not-callable
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool()
return result

Просмотреть файл

@ -1,500 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import time
import json
import logging
import torch
from torch import nn as nn
from nni.nas.pytorch.base_trainer import BaseTrainer
from nni.nas.pytorch.trainer import TorchTensorEncoder
from nni.nas.pytorch.utils import AverageMeter
from .mutator import ProxylessNasMutator
from .utils import cross_entropy_with_label_smoothing, accuracy
logger = logging.getLogger(__name__)
class ProxylessNasTrainer(BaseTrainer):
def __init__(self, model, model_optim, device,
train_loader, valid_loader, label_smoothing=0.1,
n_epochs=120, init_lr=0.025, binary_mode='full_v2',
arch_init_type='normal', arch_init_ratio=1e-3,
arch_optim_lr=1e-3, arch_weight_decay=0,
grad_update_arch_param_every=5, grad_update_steps=1,
warmup=True, warmup_epochs=25,
arch_valid_frequency=1,
load_ckpt=False, ckpt_path=None, arch_path=None):
"""
Parameters
----------
model : pytorch model
the user model, which has mutables
model_optim : pytorch optimizer
the user defined optimizer
device : pytorch device
the devices to train/search the model
train_loader : pytorch data loader
data loader for the training set
valid_loader : pytorch data loader
data loader for the validation set
label_smoothing : float
for label smoothing
n_epochs : int
number of epochs to train/search
init_lr : float
init learning rate for training the model
binary_mode : str
the forward/backward mode for the binary weights in mutator
arch_init_type : str
the way to init architecture parameters
arch_init_ratio : float
the ratio to init architecture parameters
arch_optim_lr : float
learning rate of the architecture parameters optimizer
arch_weight_decay : float
weight decay of the architecture parameters optimizer
grad_update_arch_param_every : int
update architecture weights every this number of minibatches
grad_update_steps : int
during each update of architecture weights, the number of steps to train
warmup : bool
whether to do warmup
warmup_epochs : int
the number of epochs to do during warmup
arch_valid_frequency : int
frequency of printing validation result
load_ckpt : bool
whether load checkpoint
ckpt_path : str
checkpoint path, if load_ckpt is True, ckpt_path cannot be None
arch_path : str
the path to store chosen architecture
"""
self.model = model
self.model_optim = model_optim
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
self.n_epochs = n_epochs
self.init_lr = init_lr
self.warmup = warmup
self.warmup_epochs = warmup_epochs
self.arch_valid_frequency = arch_valid_frequency
self.label_smoothing = label_smoothing
self.train_batch_size = train_loader.batch_sampler.batch_size
self.valid_batch_size = valid_loader.batch_sampler.batch_size
# update architecture parameters every this number of minibatches
self.grad_update_arch_param_every = grad_update_arch_param_every
# the number of steps per architecture parameter update
self.grad_update_steps = grad_update_steps
self.binary_mode = binary_mode
self.load_ckpt = load_ckpt
self.ckpt_path = ckpt_path
self.arch_path = arch_path
# init mutator
self.mutator = ProxylessNasMutator(model)
# DataParallel should be put behind the init of mutator
self.model = torch.nn.DataParallel(self.model)
self.model.to(self.device)
# iter of valid dataset for training architecture weights
self._valid_iter = None
# init architecture weights
self._init_arch_params(arch_init_type, arch_init_ratio)
# build architecture optimizer
self.arch_optimizer = torch.optim.Adam(self.mutator.get_architecture_parameters(),
arch_optim_lr,
weight_decay=arch_weight_decay,
betas=(0, 0.999),
eps=1e-8)
self.criterion = nn.CrossEntropyLoss()
self.warmup_curr_epoch = 0
self.train_curr_epoch = 0
def _init_arch_params(self, init_type='normal', init_ratio=1e-3):
"""
Initialize architecture weights
"""
for param in self.mutator.get_architecture_parameters():
if init_type == 'normal':
param.data.normal_(0, init_ratio)
elif init_type == 'uniform':
param.data.uniform_(-init_ratio, init_ratio)
else:
raise NotImplementedError
def _validate(self):
"""
Do validation. During validation, LayerChoices use the chosen active op.
Returns
-------
float, float, float
average loss, average top1 accuracy, average top5 accuracy
"""
self.valid_loader.batch_sampler.batch_size = self.valid_batch_size
self.valid_loader.batch_sampler.drop_last = False
self.mutator.set_chosen_op_active()
# remove unused modules to save memory
self.mutator.unused_modules_off()
# test on validation set under train mode
self.model.train()
batch_time = AverageMeter('batch_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
end = time.time()
with torch.no_grad():
for i, (images, labels) in enumerate(self.valid_loader):
images, labels = images.to(self.device), labels.to(self.device)
output = self.model(images)
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == len(self.valid_loader):
test_log = 'Valid' + ': [{0}/{1}]\t'\
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
format(i, len(self.valid_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
# return top5:
test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(top5=top5)
logger.info(test_log)
self.mutator.unused_modules_back()
return losses.avg, top1.avg, top5.avg
def _warm_up(self):
"""
Warm up the model, during warm up, architecture weights are not trained.
"""
lr_max = 0.05
data_loader = self.train_loader
nBatch = len(data_loader)
T_total = self.warmup_epochs * nBatch # total num of batches
for epoch in range(self.warmup_curr_epoch, self.warmup_epochs):
logger.info('\n--------Warmup epoch: %d--------\n', epoch + 1)
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
# switch to train mode
self.model.train()
end = time.time()
logger.info('warm_up epoch: %d', epoch)
for i, (images, labels) in enumerate(data_loader):
data_time.update(time.time() - end)
# lr
T_cur = epoch * nBatch + i
warmup_lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total))
for param_group in self.model_optim.param_groups:
param_group['lr'] = warmup_lr
images, labels = images.to(self.device), labels.to(self.device)
# compute output
self.mutator.reset_binary_gates() # random sample binary gates
self.mutator.unused_modules_off() # remove unused module for speedup
output = self.model(images)
if self.label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
else:
loss = self.criterion(output, labels)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
self.model.zero_grad()
loss.backward()
self.model_optim.step()
# unused modules back
self.mutator.unused_modules_back()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == nBatch:
batch_log = 'Warmup Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5, lr=warmup_lr)
logger.info(batch_log)
val_loss, val_top1, val_top5 = self._validate()
val_log = 'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}\t' \
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}M'. \
format(epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5, top1=top1, top5=top5)
logger.info(val_log)
self.save_checkpoint()
self.warmup_curr_epoch += 1
def _get_update_schedule(self, nBatch):
"""
Generate schedule for training architecture weights. Key means after which minibatch
to update architecture weights, value means how many steps for the update.
Parameters
----------
nBatch : int
the total number of minibatches in one epoch
Returns
-------
dict
the schedule for updating architecture weights
"""
schedule = {}
for i in range(nBatch):
if (i + 1) % self.grad_update_arch_param_every == 0:
schedule[i] = self.grad_update_steps
return schedule
def _calc_learning_rate(self, epoch, batch=0, nBatch=None):
"""
Update learning rate.
"""
T_total = self.n_epochs * nBatch
T_cur = epoch * nBatch + batch
lr = 0.5 * self.init_lr * (1 + math.cos(math.pi * T_cur / T_total))
return lr
def _adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
"""
Adjust learning of a given optimizer and return the new learning rate
Parameters
----------
optimizer : pytorch optimizer
the used optimizer
epoch : int
the current epoch number
batch : int
the current minibatch
nBatch : int
the total number of minibatches in one epoch
Returns
-------
float
the adjusted learning rate
"""
new_lr = self._calc_learning_rate(epoch, batch, nBatch)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
def _train(self):
"""
Train the model, it trains model weights and architecute weights.
Architecture weights are trained according to the schedule.
Before updating architecture weights, ```requires_grad``` is enabled.
Then, it is disabled after the updating, in order not to update
architecture weights when training model weights.
"""
nBatch = len(self.train_loader)
arch_param_num = self.mutator.num_arch_params()
binary_gates_num = self.mutator.num_arch_params()
logger.info('#arch_params: %d\t#binary_gates: %d', arch_param_num, binary_gates_num)
update_schedule = self._get_update_schedule(nBatch)
for epoch in range(self.train_curr_epoch, self.n_epochs):
logger.info('\n--------Train epoch: %d--------\n', epoch + 1)
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
# switch to train mode
self.model.train()
end = time.time()
for i, (images, labels) in enumerate(self.train_loader):
data_time.update(time.time() - end)
lr = self._adjust_learning_rate(self.model_optim, epoch, batch=i, nBatch=nBatch)
# train weight parameters
images, labels = images.to(self.device), labels.to(self.device)
self.mutator.reset_binary_gates()
self.mutator.unused_modules_off()
output = self.model(images)
if self.label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
else:
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
self.model.zero_grad()
loss.backward()
self.model_optim.step()
self.mutator.unused_modules_back()
if epoch > 0:
for _ in range(update_schedule.get(i, 0)):
start_time = time.time()
# GradientArchSearchConfig
self.mutator.arch_requires_grad()
arch_loss, exp_value = self._gradient_step()
self.mutator.arch_disable_grad()
used_time = time.time() - start_time
log_str = 'Architecture [%d-%d]\t Time %.4f\t Loss %.4f\t null %s' % \
(epoch + 1, i, used_time, arch_loss, exp_value)
logger.info(log_str)
batch_time.update(time.time() - end)
end = time.time()
# training log
if i % 10 == 0 or i + 1 == nBatch:
batch_log = 'Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5, lr=lr)
logger.info(batch_log)
# validate
if (epoch + 1) % self.arch_valid_frequency == 0:
val_loss, val_top1, val_top5 = self._validate()
val_log = 'Valid [{0}]\tloss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}\t' \
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
format(epoch + 1, val_loss, val_top1, val_top5, top1=top1, top5=top5)
logger.info(val_log)
self.save_checkpoint()
self.train_curr_epoch += 1
def _valid_next_batch(self):
"""
Get next one minibatch from validation set
Returns
-------
(tensor, tensor)
the tuple of images and labels
"""
if self._valid_iter is None:
self._valid_iter = iter(self.valid_loader)
try:
data = next(self._valid_iter)
except StopIteration:
self._valid_iter = iter(self.valid_loader)
data = next(self._valid_iter)
return data
def _gradient_step(self):
"""
This gradient step is for updating architecture weights.
Mutator is intensively used in this function to operate on
architecture weights.
Returns
-------
float, None
loss of the model, None
"""
# use the same batch size as train batch size for architecture weights
self.valid_loader.batch_sampler.batch_size = self.train_batch_size
self.valid_loader.batch_sampler.drop_last = True
self.model.train()
self.mutator.change_forward_mode(self.binary_mode)
time1 = time.time() # time
# sample a batch of data from validation set
images, labels = self._valid_next_batch()
images, labels = images.to(self.device), labels.to(self.device)
time2 = time.time() # time
self.mutator.reset_binary_gates()
self.mutator.unused_modules_off()
output = self.model(images)
time3 = time.time()
ce_loss = self.criterion(output, labels)
expected_value = None
loss = ce_loss
self.model.zero_grad()
loss.backward()
self.mutator.set_arch_param_grad()
self.arch_optimizer.step()
if self.mutator.get_forward_mode() == 'two':
self.mutator.rescale_updated_arch_param()
self.mutator.unused_modules_back()
self.mutator.change_forward_mode(None)
time4 = time.time()
logger.info('(%.4f, %.4f, %.4f)', time2 - time1, time3 - time2, time4 - time3)
return loss.data.item(), expected_value.item() if expected_value is not None else None
def save_checkpoint(self):
"""
Save checkpoint of the whole model. Saving model weights and architecture weights in
```ckpt_path```, and saving currently chosen architecture in ```arch_path```.
"""
if self.ckpt_path:
state = {
'warmup_curr_epoch': self.warmup_curr_epoch,
'train_curr_epoch': self.train_curr_epoch,
'model': self.model.state_dict(),
'optim': self.model_optim.state_dict(),
'arch_optim': self.arch_optimizer.state_dict()
}
torch.save(state, self.ckpt_path)
if self.arch_path:
self.export(self.arch_path)
def load_checkpoint(self):
"""
Load the checkpoint from ```ckpt_path```.
"""
assert self.ckpt_path is not None, "If load_ckpt is not None, ckpt_path should not be None"
ckpt = torch.load(self.ckpt_path)
self.warmup_curr_epoch = ckpt['warmup_curr_epoch']
self.train_curr_epoch = ckpt['train_curr_epoch']
self.model.load_state_dict(ckpt['model'])
self.model_optim.load_state_dict(ckpt['optim'])
self.arch_optimizer.load_state_dict(ckpt['arch_optim'])
def train(self):
"""
Train the whole model.
"""
if self.load_ckpt:
self.load_checkpoint()
if self.warmup:
self._warm_up()
self._train()
def export(self, file_name):
"""
Export the chosen architecture into a file
Parameters
----------
file_name : str
the file that stores exported chosen architecture
"""
exported_arch = self.mutator.sample_final()
with open(file_name, 'w') as f:
json.dump(exported_arch, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def validate(self):
raise NotImplementedError
def checkpoint(self):
raise NotImplementedError

Просмотреть файл

@ -1,78 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
def detach_variable(inputs):
"""
Detach variables
Parameters
----------
inputs : pytorch tensors
pytorch tensors
"""
if isinstance(inputs, tuple):
return tuple([detach_variable(x) for x in inputs])
else:
x = inputs.detach()
x.requires_grad = inputs.requires_grad
return x
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
"""
Parameters
----------
pred : pytorch tensor
predicted value
target : pytorch tensor
label
label_smoothing : float
the degree of label smoothing
Returns
-------
pytorch tensor
cross entropy
"""
logsoftmax = nn.LogSoftmax()
n_classes = pred.size(1)
# convert to one-hot
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros_like(pred)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def accuracy(output, target, topk=(1,)):
"""
Computes the precision@k for the specified values of k
Parameters
----------
output : pytorch tensor
output, e.g., predicted value
target : pytorch tensor
label
topk : tuple
specify top1 and top5
Returns
-------
list
accuracy of top1 and top5
"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

Просмотреть файл

@ -1,39 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class RandomMutator(Mutator):
"""
Random mutator that samples a random candidate in the search space each time ``reset()``.
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior.
"""
def sample_search(self):
"""
Sample a random candidate.
"""
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
gen_index = torch.randint(high=len(mutable), size=(1, ))
result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool()
elif isinstance(mutable, InputChoice):
if mutable.n_chosen is None:
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
else:
perm = torch.randperm(mutable.n_candidates)
mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
return result
def sample_final(self):
"""
Same as :meth:`sample_search`.
"""
return self.sample_search()

Просмотреть файл

@ -1,6 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .evolution import SPOSEvolution
from .mutator import SPOSSupernetTrainingMutator
from .trainer import SPOSSupernetTrainer

Просмотреть файл

@ -1,223 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import re
from collections import deque
import numpy as np
from nni.tuner import Tuner
from nni.algorithms.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE
_logger = logging.getLogger(__name__)
class SPOSEvolution(Tuner):
"""
SPOS evolution tuner.
Parameters
----------
max_epochs : int
Maximum number of epochs to run.
num_select : int
Number of survival candidates of each epoch.
num_population : int
Number of candidates at the start of each epoch. If candidates generated by
crossover and mutation are not enough, the rest will be filled with random
candidates.
m_prob : float
The probability of mutation.
num_crossover : int
Number of candidates generated by crossover in each epoch.
num_mutation : int
Number of candidates generated by mutation in each epoch.
"""
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25):
assert num_population >= num_select
self.max_epochs = max_epochs
self.num_select = num_select
self.num_population = num_population
self.m_prob = m_prob
self.num_crossover = num_crossover
self.num_mutation = num_mutation
self.epoch = 0
self.candidates = []
self.search_space = None
self.random_state = np.random.RandomState(0)
# async status
self._to_evaluate_queue = deque()
self._sending_parameter_queue = deque()
self._pending_result_ids = set()
self._reward_dict = dict()
self._id2candidate = dict()
self._st_callback = None
def update_search_space(self, search_space):
"""
Handle the initialization/update event of search space.
"""
self._search_space = search_space
self._next_round()
def _next_round(self):
_logger.info("Epoch %d, generating...", self.epoch)
if self.epoch == 0:
self._get_random_population()
self.export_results(self.candidates)
else:
best_candidates = self._select_top_candidates()
self.export_results(best_candidates)
if self.epoch >= self.max_epochs:
return
self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
self._get_random_population()
self.epoch += 1
def _random_candidate(self):
chosen_arch = dict()
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
index = self.random_state.randint(len(choices))
chosen_arch[key] = {"_value": choices[index], "_idx": index}
elif val["_type"] == INPUT_CHOICE:
raise NotImplementedError("Input choice is not implemented yet.")
return chosen_arch
def _add_to_evaluate_queue(self, cand):
_logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
self._reward_dict[self._hashcode(cand)] = 0.
self._to_evaluate_queue.append(cand)
def _get_random_population(self):
while len(self.candidates) < self.num_population:
cand = self._random_candidate()
if self._is_legal(cand):
_logger.info("Random candidate generated.")
self._add_to_evaluate_queue(cand)
self.candidates.append(cand)
def _get_crossover(self, best):
result = []
for _ in range(10 * self.num_crossover):
cand_p1 = best[self.random_state.randint(len(best))]
cand_p2 = best[self.random_state.randint(len(best))]
assert cand_p1.keys() == cand_p2.keys()
cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
for k in cand_p1.keys()}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_crossover:
break
_logger.info("Found %d architectures with crossover.", len(result))
return result
def _get_mutation(self, best):
result = []
for _ in range(10 * self.num_mutation):
cand = best[self.random_state.randint(len(best))].copy()
mutation_sample = np.random.random_sample(len(cand))
for s, k in zip(mutation_sample, cand):
if s < self.m_prob:
choices = self._search_space[k]["_value"]
index = self.random_state.randint(len(choices))
cand[k] = {"_value": choices[index], "_idx": index}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_mutation:
break
_logger.info("Found %d architectures with mutation.", len(result))
return result
def _get_architecture_repr(self, cand):
return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
self._hashcode(cand))
def _is_legal(self, cand):
if self._hashcode(cand) in self._reward_dict:
return False
return True
def _select_top_candidates(self):
reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
_logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
_logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
return result
@staticmethod
def _hashcode(d):
return json.dumps(d, sort_keys=True)
def _bind_and_send_parameters(self):
"""
There are two types of resources: parameter ids and candidates. This function is called at
necessary times to bind these resources to send new trials with st_callback.
"""
result = []
while self._sending_parameter_queue and self._to_evaluate_queue:
parameter_id = self._sending_parameter_queue.popleft()
parameters = self._to_evaluate_queue.popleft()
self._id2candidate[parameter_id] = parameters
result.append(parameters)
self._pending_result_ids.add(parameter_id)
self._st_callback(parameter_id, parameters)
_logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
return result
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
Callback function necessary to implement a tuner. This will put more parameter ids into the
parameter id queue.
"""
if "st_callback" in kwargs and self._st_callback is None:
self._st_callback = kwargs["st_callback"]
for parameter_id in parameter_id_list:
self._sending_parameter_queue.append(parameter_id)
self._bind_and_send_parameters()
return [] # always not use this. might induce problem of over-sending
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Callback function. Receive a trial result.
"""
_logger.info("Candidate %d, reported reward %f", parameter_id, value)
self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
def trial_end(self, parameter_id, success, **kwargs):
"""
Callback function when a trial is ended and resource is released.
"""
self._pending_result_ids.remove(parameter_id)
if not self._pending_result_ids and not self._to_evaluate_queue:
# a new epoch now
self._next_round()
assert self._st_callback is not None
self._bind_and_send_parameters()
def export_results(self, result):
"""
Export a number of candidates to `checkpoints` dir.
Parameters
----------
result : dict
Chosen architectures to be exported.
"""
os.makedirs("checkpoints", exist_ok=True)
for i, cand in enumerate(result):
converted = dict()
for cand_key, cand_val in cand.items():
onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
converted[cand_key] = onehot
with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
json.dump(converted, fp)

Просмотреть файл

@ -1,66 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import numpy as np
from nni.algorithms.nas.pytorch.random import RandomMutator
_logger = logging.getLogger(__name__)
class SPOSSupernetTrainingMutator(RandomMutator):
"""
A random mutator with flops limit.
Parameters
----------
model : nn.Module
PyTorch model.
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
flops_lb : number
Lower bound of flops.
flops_ub : number
Upper bound of flops.
flops_bin_num : number
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
uniform, but the sampling will be slower.
flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate.
"""
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500):
super().__init__(model)
self._flops_func = flops_func
if self._flops_func is not None:
self._flops_bin_num = flops_bin_num
self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)]
self._flops_sample_timeout = flops_sample_timeout
def sample_search(self):
"""
Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly
relative to flops.
Returns
-------
dict
"""
if self._flops_func is not None:
for times in range(self._flops_sample_timeout):
idx = np.random.randint(self._flops_bin_num)
cand = super().sample_search()
if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]:
_logger.debug("Sampled candidate flops %f in %d times.", cand, times)
return cand
_logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout)
return super().sample_search()
def sample_final(self):
"""
Implement only to suffice the interface of Mutator.
"""
return self.sample_search()

Просмотреть файл

@ -1,95 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import SPOSSupernetTrainingMutator
logger = logging.getLogger(__name__)
class SPOSSupernetTrainer(Trainer):
"""
This trainer trains a supernet that can be used for evolution search.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : nni.nas.pytorch.mutator.Mutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterable
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
dataset_valid : iterable
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None,
batch_size, workers, device, log_frequency, callbacks)
self.train_loader = train_loader
self.valid_loader = valid_loader
def train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)

Просмотреть файл

@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import get_and_apply_next_architecture

Просмотреть файл

@ -1,217 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
import json
import logging
import os
import sys
import tensorflow as tf
import nni
from nni.runtime.env_vars import trial_env_vars
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
from nni.nas.tensorflow.mutator import Mutator
logger = logging.getLogger(__name__)
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
LAYER_CHOICE = "layer_choice"
INPUT_CHOICE = "input_choice"
def get_and_apply_next_architecture(model):
"""
Wrapper of :class:`~nni.nas.tensorflow.classic_nas.mutator.ClassicMutator` to make it more meaningful,
similar to ``get_next_parameter`` for HPO.
Tt will generate search space based on ``model``.
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ``nnictl`` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
ClassicMutator(model)
class ClassicMutator(Mutator):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
def __init__(self, model):
super(ClassicMutator, self).__init__(model)
self._chosen_arch = {}
self._search_space = self._generate_search_space()
if NNI_GEN_SEARCH_SPACE in os.environ:
# dry run for only generating search space
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
sys.exit(0)
if trial_env_vars.NNI_PLATFORM is None:
logger.warning("This is in standalone mode, the chosen are the first one(s).")
self._chosen_arch = self._standalone_generate_chosen()
else:
# get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter()
if self._chosen_arch is None:
if trial_env_vars.NNI_PLATFORM == "unittest":
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
self._chosen_arch = self._standalone_generate_chosen()
else:
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
self.reset()
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
"""
Convert layer choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
# doesn't support multihot for layer choice yet
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
mask = tf.one_hot(idx, len(mutable))
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
def _sample_input_choice(self, mutable, idx, value, search_space_item):
"""
Convert input choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
candidate_repr = search_space_item["candidates"]
multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True
return tf.cast(multihot_list, tf.bool) # pylint: disable=not-callable
def sample_search(self):
"""
See :meth:`sample_final`.
"""
return self.sample_final()
def sample_final(self):
"""
Convert the chosen arch and apply it on model.
"""
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
self._chosen_arch.keys())
result = dict()
for mutable in self.mutables:
if isinstance(mutable, (LayerChoice, InputChoice)):
assert mutable.key in self._chosen_arch, \
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result
def _standalone_generate_chosen(self):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice.
::
{ key_name: {"_value": "conv1",
"_idx": 0} }
{ key_name: {"_value": ["in1"],
"_idx": [0]} }
Returns
-------
dict
the chosen architecture
"""
chosen_arch = {}
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
elif val["_type"] == INPUT_CHOICE:
choices = val["_value"]["candidates"]
n_chosen = val["_value"]["n_chosen"]
if n_chosen is None:
n_chosen = len(choices)
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
else:
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
return chosen_arch
def _generate_search_space(self):
"""
Generate search space from mutables.
Here is the search space format:
::
{ key_name: {"_type": "layer_choice",
"_value": ["conv1", "conv2"]} }
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
the generated search space
"""
search_space = {}
for mutable in self.mutables:
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = mutable.names
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
def _dump_search_space(self, file_path):
with open(file_path, "w") as ss_file:
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)

Просмотреть файл

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import EnasMutator
from .trainer import EnasTrainer

Просмотреть файл

@ -1,162 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LSTMCell, RNN
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction
from nni.nas.tensorflow.mutator import Mutator
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
class EnasMutator(Mutator):
def __init__(self, model,
lstm_size=64,
lstm_num_layers=1,
tanh_constant=1.5,
cell_exit_extra_step=False,
skip_target=0.4,
temperature=None,
branch_bias=0.25,
entropy_reduction='sum'):
super().__init__(model)
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
cells = [LSTMCell(units=lstm_size, use_bias=False) for _ in range(lstm_num_layers)]
self.lstm = RNN(cells, stateful=True)
self.g_emb = tf.random.normal((1, 1, lstm_size)) * 0.1
self.skip_targets = tf.constant([1.0 - skip_target, skip_target])
self.max_layer_choice = 0
self.bias_dict = {}
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
if 'reduce' in mutable.key:
bias = []
for choice in mutable.choices:
if 'conv' in str(type(choice)).lower():
bias.append(branch_bias)
else:
bias.append(-branch_bias)
self.bias_dict[mutable.key] = tf.constant(bias)
# exposed for trainer
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
# internal nn layers
self.embedding = Embedding(self.max_layer_choice + 1, lstm_size)
self.soft = Dense(self.max_layer_choice, use_bias=False)
self.attn_anchor = Dense(lstm_size, use_bias=False)
self.attn_query = Dense(lstm_size, use_bias=False)
self.v_attn = Dense(1, use_bias=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
self._first_sample = True
def sample_search(self):
self._initialize()
self._sample(self.mutables)
self._first_sample = False
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if self.cell_exit_extra_step and isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
self._anchors_hid[mutable.key] = self.lstm(self._inputs, 1)
def _initialize(self):
self._choices = {}
self._anchors_hid = {}
self._inputs = self.g_emb
# seems the `input_shape` parameter of RNN does not work
# workaround it by omitting `reset_states` for first run
if not self._first_sample:
self.lstm.reset_states()
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _sample_layer_choice(self, mutable):
logit = self.soft(self.lstm(self._inputs))
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * tf.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
branch_id = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [1])
log_prob = self.cross_entropy_loss(branch_id, logit)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.math.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = tf.reshape(self.embedding(branch_id), [1, 1, -1])
mask = tf.one_hot(branch_id, self.max_layer_choice)
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._anchors_hid[label] = self.lstm(self._inputs)
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = tf.concat(query, axis=0)
query = tf.tanh(query + self.attn_query(anchors[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * tf.tanh(query)
if mutable.n_chosen is None:
logit = tf.concat([-query, query], axis=1)
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
skip = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip_prob = tf.math.sigmoid(logit)
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(skip, logit)
skip = tf.cast(skip, tf.float32)
inputs = tf.tensordot(skip, tf.concat(anchors, 0), 1) / (1. + tf.reduce_sum(skip))
self._inputs = tf.reshape(inputs, [1, 1, -1])
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = tf.reshape(query, [1, -1])
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
index = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1])
# when the size is 1, tf does not accept tensor here, complaining the shape is wrong
# but using a numpy array seems fine
log_prob = self.cross_entropy_loss(logit, query.numpy())
self._inputs = tf.reshape(anchors[index.numpy()[0]], [1, 1, -1])
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
assert len(skip) == mutable.n_candidates, (skip, mutable.n_candidates, mutable.n_chosen)
return tf.cast(skip, tf.bool)

Просмотреть файл

@ -1,205 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
import logging
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
class EnasTrainer:
def __init__(
self,
model,
loss,
metrics,
reward_function,
optimizer,
batch_size,
num_epochs,
dataset_train,
dataset_valid,
log_frequency=100,
entropy_weight=0.0001,
skip_weight=0.8,
baseline_decay=0.999,
child_steps=500,
mutator_lr=0.00035,
mutator_steps=50,
mutator_steps_aggregate=20,
aux_weight=0.4,
test_arc_per_epoch=1,
):
self.model = model
self.loss = loss
self.metrics = metrics
self.reward_function = reward_function
self.optimizer = optimizer
self.batch_size = batch_size
self.num_epochs = num_epochs
x, y = dataset_train
split = int(len(x) * 0.9)
self.train_set = tf.data.Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid)
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.child_steps = child_steps
self.mutator_lr = mutator_lr
self.mutator_steps = mutator_steps
self.mutator_steps_aggregate = mutator_steps_aggregate
self.aux_weight = aux_weight
self.test_arc_per_epoch = test_arc_per_epoch
self.mutator = EnasMutator(model)
self.mutator_optim = Adam(learning_rate=self.mutator_lr)
self.baseline = 0.0
def train(self, validate=True):
for epoch in range(self.num_epochs):
logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
def validate(self):
self.validate_one_epoch(-1)
def train_one_epoch(self, epoch):
train_loader, valid_loader = self._create_train_loader()
# Sample model and train
meters = AverageMeterGroup()
for step in range(1, self.child_steps + 1):
x, y = next(train_loader)
self.mutator.reset()
with tf.GradientTape() as tape:
logits = self.model(x, training=True)
if isinstance(logits, tuple):
logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y)
else:
aux_loss = 0.0
metrics = self.metrics(y, logits)
loss = self.loss(y, logits) + self.aux_weight * aux_loss
grads = tape.gradient(loss, self.model.trainable_weights)
grads = fill_zero_grads(grads, self.model.trainable_weights)
grads, _ = tf.clip_by_global_norm(grads, 5.0)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
metrics["loss"] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
if self.log_frequency and step % self.log_frequency == 0:
logger.info(
"Model Epoch [%d/%d] Step [%d/%d] %s",
epoch + 1,
self.num_epochs,
step,
self.child_steps,
meters,
)
# Train sampler (mutator)
meters = AverageMeterGroup()
for mutator_step in range(1, self.mutator_steps + 1):
grads_list = []
for step in range(1, self.mutator_steps_aggregate + 1):
with tf.GradientTape() as tape:
x, y = next(valid_loader)
self.mutator.reset()
logits = self.model(x, training=False)
metrics = self.metrics(y, logits)
reward = (
self.reward_function(y, logits)
+ self.entropy_weight * self.mutator.sample_entropy
)
self.baseline = self.baseline * self.baseline_decay + reward * (
1 - self.baseline_decay
)
loss = self.mutator.sample_log_prob * (reward - self.baseline)
loss += self.skip_weight * self.mutator.sample_skip_penalty
meters.update(
{
"reward": reward,
"loss": tf.reduce_mean(loss).numpy(),
"ent": self.mutator.sample_entropy.numpy(),
"log_prob": self.mutator.sample_log_prob.numpy(),
"baseline": self.baseline,
"skip": self.mutator.sample_skip_penalty,
}
)
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
if self.log_frequency and cur_step % self.log_frequency == 0:
logger.info(
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s",
epoch + 1,
self.num_epochs,
mutator_step,
self.mutator_steps,
step,
self.mutator_steps_aggregate,
meters,
)
grads = tape.gradient(loss, self.mutator.trainable_weights)
grads = fill_zero_grads(grads, self.mutator.trainable_weights)
grads_list.append(grads)
total_grads = [
tf.math.add_n(weight_grads) for weight_grads in zip(*grads_list)
]
total_grads, _ = tf.clip_by_global_norm(total_grads, 5.0)
self.mutator_optim.apply_gradients(
zip(total_grads, self.mutator.trainable_weights)
)
def validate_one_epoch(self, epoch):
test_loader = self._create_validate_loader()
for arc_id in range(self.test_arc_per_epoch):
meters = AverageMeterGroup()
for x, y in test_loader:
self.mutator.reset()
logits = self.model(x, training=False)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(y, logits)
loss = self.loss(y, logits)
metrics["loss"] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
logger.info(
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
epoch + 1,
self.num_epochs,
arc_id + 1,
self.test_arc_per_epoch,
meters.summary(),
)
def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set)
def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))

Просмотреть файл

@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .execution import *
from .fixed import fixed_arch
from .mutable import *
from .utils import *

Просмотреть файл

@ -0,0 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
from .evaluator import Evaluator
from .functional import FunctionalEvaluator
shortcut_framework(__name__)
del shortcut_framework

Просмотреть файл

@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['Evaluator']
import abc
from typing import Any, Callable, Type, Union, cast
class Evaluator(abc.ABC):
"""
Evaluator of a model. An evaluator should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional evaluator might directly import the function and call the function.
"""
def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
"""
return self._execute(model_cls)
def __repr__(self):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@staticmethod
def _load(ir: Any) -> 'Evaluator':
evaluator_type = ir.get('type')
if isinstance(evaluator_type, str):
# for debug purposes only
for subclass in Evaluator.__subclasses__():
if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(cast(type, evaluator_type), Evaluator)
return cast(Type[Evaluator], evaluator_type)._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
"""
Subclass implements ``_dump`` for their own serialization.
They should return a dict, with a key ``type`` which equals ``self.__class__``,
and optionally other keys.
"""
pass
@abc.abstractmethod
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
pass

Просмотреть файл

@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import nni
from .evaluator import Evaluator
@nni.trace
class FunctionalEvaluator(Evaluator):
"""
Functional evaluator that directly takes a function and thus should be general.
Attributes
----------
function
The full name of the function.
arguments
Keyword arguments for the function other than model.
"""
def __init__(self, function, **kwargs):
self.function = function
self.arguments = kwargs
@staticmethod
def _load(ir):
return FunctionalEvaluator(ir['function'], **ir['arguments'])
def _dump(self):
return {
'type': self.__class__,
'function': self.function,
'arguments': self.arguments
}
def _execute(self, model_cls):
return self.function(model_cls, **self.arguments)
def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments

Просмотреть файл

@ -1,4 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import RandomMutator
from .lightning import *

Просмотреть файл

@ -0,0 +1,236 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings
from typing import Dict, List, Optional, Union, Type
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer
__all__ = [
'_MultiModelSupervisedLearningModule', 'MultiModelSupervisedLearningModule',
'_ClassificationModule', 'Classification',
'_RegressionModule', 'Regression',
]
@nni.trace
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.criterion_cls = criterion
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
self.metrics_args = metrics
self.n_models = n_models
def dump_kwargs(self):
kwargs = {}
kwargs['criterion'] = self.criterion_cls
kwargs['metrics'] = self.metrics_args
kwargs['n_models'] = self.n_models
kwargs['learning_rate'] = self.hparams['learning_rate']
kwargs['weight_decay'] = self.hparams['weight_decay']
kwargs['optimizer'] = self.optimizer
return kwargs
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
multi_loss = []
for idx, y_hat in enumerate(multi_y_hat):
loss = self.criterion(y_hat.to("cpu"), y.to("cpu"))
self.log(f'train_loss_{idx}', loss, prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'train_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
multi_loss.append(loss)
return sum(multi_loss)
def validation_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'val_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'val_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'test_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'test_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
# TODO: split metric of multiple models?
if len(self.metrics) == 1:
metric_name = next(iter(self.metrics))
ret = []
for idx in range(self.n_models):
ret.append(self.trainer.callback_metrics[f'val_{idx}_' + metric_name].item())
return ret
else:
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
Users who needs cross-graph optimization should use this module.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
"""
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class _ClassificationModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Classification(Lightning):
"""
Trainer that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
class _RegressionModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Regression(Lightning):
"""
Trainer that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)

Просмотреть файл

@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytorch_lightning as pl
from pytorch_lightning.strategies import SingleDeviceStrategy
class BypassStrategy(SingleDeviceStrategy):
strategy_name = "single_device"
def model_to_device(self) -> None:
pass
class Trainer(pl.Trainer):
"""
Trainer for cross-graph optimization.
Parameters
----------
use_cgo : bool
Whether cross-graph optimization (CGO) is used.
If it is True, CGO will manage device placement.
Any device placement from pytorch lightning will be bypassed.
default: False
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, use_cgo=False, **trainer_kwargs):
if use_cgo:
if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
if 'strategy' in trainer_kwargs:
raise ValueError("cgo.trainer does not support specifying strategy")
trainer_kwargs['strategy'] = BypassStrategy()
super().__init__(**trainer_kwargs)

Просмотреть файл

@ -0,0 +1,412 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Union, Optional, List, Callable, Type
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim
import torchmetrics
import torch.utils.data as torch_data
import nni
from nni.common.serializer import is_traceable
try:
from .cgo import trainer as cgo_trainer
cgo_import_failed = False
except ImportError:
cgo_import_failed = True
from nni.nas.evaluator import Evaluator
from nni.typehint import Literal
__all__ = [
'LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression',
'_AccuracyWithLogits', '_SupervisedLearningModule', '_ClassificationModule', '_RegressionModule',
# FIXME: hack to make it importable for tests
]
class LightningModule(pl.LightningModule):
"""
Basic wrapper of generated model.
Lightning modules used in NNI should inherit this class.
It's a subclass of ``pytorch_lightning.LightningModule``.
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
"""
running_mode: Literal['multi', 'oneshot'] = 'multi'
"""An indicator of whether current module is running in a multi-trial experiment or an one-shot.
This flag should be automatically set by experiments when they start to run.
"""
def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
"""Set the inner model (architecture) to train / evaluate.
Parameters
----------
model : callable or nn.Module
Can be a callable returning nn.Module or nn.Module.
"""
if isinstance(model, nn.Module):
self.model = model
else:
self.model = model()
Trainer = nni.trace(pl.Trainer)
Trainer.__doc__ = """
Traced version of ``pytorch_lightning.Trainer``. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
"""
DataLoader = nni.trace(torch_data.DataLoader)
DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""
@nni.trace
class Lightning(Evaluator):
"""
Delegate the whole training to PyTorch Lightning.
Since the arguments passed to the initialization needs to be serialized, ``LightningModule``, ``Trainer`` or
``DataLoader`` in this file should be used. Another option is to hide dataloader in the Lightning module, in
which case, dataloaders are not required for this class to work.
Following the programming style of Lightning, metrics sent to NNI should be obtained from ``callback_metrics``
in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name
and type depend on the specific task.
.. warning::
The Lightning evaluator are stateful. If you try to use a previous Lightning evaluator,
please note that the inner ``lightning_module`` and ``trainer`` will be reused.
Parameters
----------
lightning_module
Lightning module that defines the training logic.
trainer
Lightning trainer that handles the training.
train_dataloders
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
val_dataloaders
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
"""
def __init__(self, lightning_module: LightningModule, trainer: Trainer,
train_dataloaders: Optional[Any] = None,
val_dataloaders: Optional[Any] = None,
train_dataloader: Optional[Any] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
if cgo_import_failed:
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.nas.evaluator.pytorch.cgo.trainer'
if not _check_dataloader(train_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {train_dataloaders}',
RuntimeWarning)
if not _check_dataloader(val_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {val_dataloaders}',
RuntimeWarning)
self.module = lightning_module
self.trainer = trainer
self.train_dataloaders = train_dataloaders
self.val_dataloaders = val_dataloaders
@staticmethod
def _load(ir):
return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders'])
def _dump(self):
return {
'type': self.__class__,
'module': self.module,
'trainer': self.trainer,
'train_dataloaders': self.train_dataloaders,
'val_dataloaders': self.val_dataloaders
}
def _execute(self, model_cls):
return self.fit(model_cls)
@property
def train_dataloader(self):
warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning)
def __eq__(self, other):
eq_func = False
eq_args = False
if other is None:
return False
if hasattr(self, "function") and hasattr(other, "function"):
eq_func = getattr(self, "function") == getattr(other, "function")
elif not (hasattr(self, "function") or hasattr(other, "function")):
eq_func = True
if hasattr(self, "arguments") and hasattr(other, "arguments"):
eq_args = getattr(self, "arguments") == getattr(other, "arguments")
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
eq_args = True
return eq_func and eq_args
def fit(self, model):
"""
Fit the model with provided dataloader, with Lightning trainer.
Parameters
----------
model : nn.Module
The model to fit.
"""
self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders)
def _check_dataloader(dataloader):
# Check the type of dataloader recursively.
if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader])
if isinstance(dataloader, dict):
return all([_check_dataloader(v) for v in dataloader.values()])
if isinstance(dataloader, torch_data.DataLoader):
return is_traceable(dataloader)
return True
### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule):
trainer: pl.Trainer
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetrics.Metric]],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
self.export_onnx = None
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
for name, metric in self.metrics.items():
self.log('train_' + name, metric(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
if self.running_mode == 'multi' and self.export_onnx is not None:
self.export_onnx.parent.mkdir(exist_ok=True)
try:
self.to_onnx(self.export_onnx, x, export_params=True)
except RuntimeError as e:
warnings.warn(f'ONNX conversion failed. As a result, you might not be able to use visualization. Error message: {e}')
self.export_onnx = None
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('test_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self):
if not self.trainer.sanity_checking and self.running_mode == 'multi':
# Don't report metric when sanity checking
nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self):
if self.running_mode == 'multi':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
if len(self.metrics) == 1:
metric_name = next(iter(self.metrics))
return self.trainer.callback_metrics['val_' + metric_name].item()
else:
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn_functional.softmax(pred, dim=-1), target)
@nni.trace
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Classification(Lightning):
"""
Evaluator that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloaders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
Examples
--------
>>> evaluator = Classification()
To use customized criterion and optimizer:
>>> evaluator = Classification(nn.LabelSmoothingCrossEntropy, optimizer=torch.optim.SGD)
Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration:
>>> evaluator = Classification(accelerator='gpu', devices=2, strategy='ddp')
"""
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
@nni.trace
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Regression(Lightning):
"""
Evaluator that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloaders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default: true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
Examples
--------
>>> evaluator = Regression()
Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration:
>>> evaluator = Regression(gpus=1)
"""
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)

Просмотреть файл

Просмотреть файл

@ -1,4 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .nasbench201 import NASBench201Cell
from .api import *
from .common import *

79
nni/nas/execution/api.py Normal file
Просмотреть файл

@ -0,0 +1,79 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
import warnings
from typing import Iterable
from nni.nas.execution.common import (
Model, ModelStatus,
AbstractExecutionEngine,
DefaultListener
)
_execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine
if _execution_engine is not None:
warnings.warn('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.',
RuntimeWarning)
_execution_engine = engine
def get_execution_engine() -> AbstractExecutionEngine:
global _execution_engine
assert _execution_engine is not None, 'You need to set execution engine, before using it.'
return _execution_engine
def get_and_register_default_listener(engine: AbstractExecutionEngine) -> DefaultListener:
global _default_listener
if _default_listener is None:
_default_listener = DefaultListener()
engine.register_graph_listener(_default_listener)
return _default_listener
def submit_models(*models: Model) -> None:
engine = get_execution_engine()
get_and_register_default_listener(engine)
engine.submit_models(*models)
def list_models(*models: Model) -> Iterable[Model]:
engine = get_execution_engine()
get_and_register_default_listener(engine)
return engine.list_models()
def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine())
while True:
time.sleep(1)
left_models = [g for g in models if not g.status in (ModelStatus.Trained, ModelStatus.Failed)]
if not left_models:
break
def query_available_resources() -> int:
engine = get_execution_engine()
resources = engine.query_available_resource()
return resources if isinstance(resources, int) else len(resources)
def is_stopped_exec(model: Model) -> bool:
return model.status in (ModelStatus.Trained, ModelStatus.Failed)
def budget_exhausted() -> bool:
engine = get_execution_engine()
return engine.budget_exhausted()

Просмотреть файл

@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .engine import *
from .graph_op import *
from .graph import *
from .integration_api import *
from .integration import *
from .listener import *
from .utils import *

Просмотреть файл

@ -0,0 +1,153 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, Iterable, NewType, List, Union, Type
from .graph import Model, MetricData
__all__ = [
'GraphData', 'WorkerInfo', 'MetricData',
'AbstractGraphListener', 'AbstractExecutionEngine'
]
GraphData: Type[Any] = NewType('GraphData', Any)
"""
A _serializable_ internal data type defined by execution engine.
Execution engine will submit this kind of data through NNI to worker machine, and train it there.
A `GraphData` object describes a (merged) executable graph.
This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format.
See `AbstractExecutionEngine` for details.
"""
WorkerInfo: Type[Any] = NewType('WorkerInfo', Any)
"""
To be designed. Discussion needed.
This describes the properties of a worker machine. (e.g. memory size)
"""
class AbstractGraphListener(ABC):
"""
Abstract listener interface to receive graph events.
Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener.
"""
@abstractmethod
def on_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the final metric of a graph.
"""
raise NotImplementedError
@abstractmethod
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the latest intermediate metric of a trainning graph.
"""
pass
@abstractmethod
def on_training_end(self, model: Model, success: bool) -> None:
"""
Reports either a graph is fully trained or the training process has failed.
"""
pass
class AbstractExecutionEngine(ABC):
"""
The abstract interface of execution engine.
Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial.
Strategy will get the singleton execution engine object through a global API,
and use it in either sync or async manner.
Execution engine is responsible for submitting (maybe-optimized) models to NNI,
and assigning their metrics to the `Model` object after training.
Execution engine is also responsible to launch the graph in trial process,
because it's the only one who understands graph data, or "hyper-parameter" in NNI's term.
Execution engine will leverage NNI Advisor APIs, which are yet open for discussion.
In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly,
and will receive metrics from `Model` attributes.
Execution engine could assume that strategy will only submit graph when there are availabe resources (for now).
In asynchronized use case, the strategy will register a listener to receive events,
while still using `submit_models` to train.
There will be a `BaseExecutionEngine` subclass.
Inner-graph optimizing is supposed to derive `BaseExecutionEngine`,
while overrides `submit_models` and `trial_execute_graph`.
cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly,
because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic.
There might be some util functions benefit all optimizing methods,
but non-mandatory utils should not be covered in abstract interface.
"""
@abstractmethod
def submit_models(self, *models: Model) -> None:
"""
Submit models to NNI.
This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`.
"""
raise NotImplementedError
@abstractmethod
def list_models(self) -> Iterable[Model]:
"""
Get all models in submitted.
Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
"""
raise NotImplementedError
@abstractmethod
def query_available_resource(self) -> Union[List[WorkerInfo], int]: # type: ignore
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractmethod
def budget_exhausted(self) -> bool:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise NotImplementedError
@abstractmethod
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
"""
Register a listener to receive graph events.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractclassmethod
def trial_execute_graph(cls) -> MetricData:
"""
Train graph and returns its metrics, in a separate trial process.
Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method.
Because this method will be invoked in trial process on training platform,
it has different context from other methods and has no access to global variable or `self`.
However util APIs like `.utils.experiment_config()` should still be available.
"""
raise NotImplementedError

Просмотреть файл

@ -0,0 +1,774 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Model representation for engines based on graph.
"""
from __future__ import annotations
import json
from enum import Enum
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING:
from .mutator import Mutator
from nni.nas.evaluator import Evaluator
from nni.nas.utils import uid
from .graph_op import Cell, Operation, _IOPseudoOperation
__all__ = [
'Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData',
'DebugEvaluator',
]
MetricData = Any
"""
Type hint for graph metrics (loss, accuracy, etc).
"""
EdgeEndpoint = Tuple['Node', Optional[int]]
"""
Type hint for edge's endpoint. The int indicates nodes' order.
"""
class Model:
"""
Represents a neural network model.
During mutation, one :class:`Model` object is created for each trainable snapshot.
For example, consider a mutator that insert a node at an edge for each iteration.
In one iteration, the mutator invokes 4 primitives: add node, remove edge, add edge to head, add edge to tail.
These 4 primitives operates in one :class:`Model` object.
When they are all done the model will be set to "frozen" (trainable) status and be submitted to execution engine.
And then a new iteration starts, and a new :class:`Model` object is created by forking last model.
Attributes
----------
python_object
Python object of base model. It will be none when the base model is not available.
python_class
Python class that base model is converted from.
python_init_params
Initialization parameters of python class.
status
See :class:`ModelStatus`.
root_graph
The outermost graph which usually takes dataset as input and feeds output to loss function.
graphs
All graphs (subgraphs) in this model.
evaluator
Model evaluator
history
Mutation history.
``self`` is directly mutated from ``self.history[-1]``;
``self.history[-1]`` is mutated from ``self.history[-2]``, and so on.
``self.history[0]`` is the base graph.
metric
Training result of the model, or ``None`` if it's not yet trained or has failed to train.
intermediate_metrics
Intermediate training metrics. If the model is not trained, it's an empty list.
"""
def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead'
self.model_id: int = uid('model')
self.python_object: Optional[Any] = None # type is uncertain because it could differ between DL frameworks
self.python_class: Optional[Type] = None
self.python_init_params: Optional[Dict[str, Any]] = None
self.status: ModelStatus = ModelStatus.Mutating
self._root_graph_name: str = '_model'
self.graphs: Dict[str, Graph] = {}
self.evaluator: Optional[Evaluator] = None
self.history: List['Mutation'] = []
self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = []
def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics}, ' + \
f'python_class={self.python_class})'
@property
def root_graph(self) -> 'Graph':
return self.graphs[self._root_graph_name]
def fork(self) -> 'Model':
"""
Create a new model which has same topology, names, and IDs to current one.
Can only be invoked on a frozen model.
The new model will be in `Mutating` state.
This API is used in mutator base class.
"""
new_model = Model(_internal=True)
new_model._root_graph_name = self._root_graph_name
new_model.python_class = self.python_class
new_model.python_init_params = self.python_init_params
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.evaluator = self.evaluator # TODO this needs a clever copy (not deepcopy) if we need mutation
new_model.history = [*self.history]
# Note: the history is not updated. It will be updated when the model is changed, that is in mutator.
return new_model
@staticmethod
def _load(ir: Any) -> 'Model':
model = Model(_internal=True)
for graph_name, graph_data in ir.items():
if graph_name != '_evaluator':
Graph._load(model, graph_name, graph_data)._register()
if '_evaluator' in ir:
model.evaluator = Evaluator._load(ir['_evaluator'])
return model
def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
if self.evaluator is not None:
ret['_evaluator'] = self.evaluator._dump()
return ret
def get_nodes(self) -> Iterable['Node']:
"""
Traverse through all the nodes.
"""
for graph in self.graphs.values():
for node in graph.nodes:
yield node
def get_nodes_by_label(self, label: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given label.
There could be multiple nodes with the same label. Name space name can uniquely
identify a graph or node.
NOTE: the implementation does not support the class abstraction
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_label(label)
matched_nodes.extend(nodes)
return matched_nodes
def get_nodes_by_type(self, type_name: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given type.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_type(type_name)
matched_nodes.extend(nodes)
return matched_nodes
def get_node_by_name(self, node_name: str) -> 'Node' | None:
"""
Traverse all the nodes to find the matched node with the given name.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_name(node_name)
matched_nodes.extend(nodes)
assert len(matched_nodes) <= 1
if matched_nodes:
return matched_nodes[0]
else:
return None
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_python_name(python_name)
matched_nodes.extend(nodes)
# assert len(matched_nodes) <= 1
if matched_nodes:
return matched_nodes[0]
else:
return None
def get_cell_nodes(self) -> List['Node']:
matched_nodes = []
for graph in self.graphs.values():
nodes = [node for node in graph.nodes if isinstance(node.operation, Cell)]
matched_nodes.extend(nodes)
return matched_nodes
class ModelStatus(Enum):
"""
The status of model.
A model is created in `Mutating` status.
When the mutation is done and the model get ready to train, its status becomes `Frozen`.
When training started, the model's status becomes `Training`.
If training is successfully ended, model's `metric` attribute get set and its status becomes `Trained`.
If training failed, the status becomes `Failed`.
"""
Mutating = "mutating"
Frozen = "frozen"
Training = "training"
Trained = "trained"
Failed = "failed"
_InputPseudoUid = -1
_OutputPseudoUid = -2
class Graph:
"""
Graph topology.
This class simply represents the topology, with no semantic meaning.
All other information like metric, non-graph functions, mutation history, etc should go to :class:`Model`.
Each graph belongs to and only belongs to one :class:`Model`.
Attributes
----------
model
The model containing (and owning) this graph.
id
Unique ID in the model.
If two models have graphs of identical ID, they are semantically the same graph.
Typically this means one graph is mutated from another, or they are both mutated from one ancestor.
name
Mnemonic name of this graph. It should have an one-to-one mapping with ID.
input_names
Optional mnemonic names of input parameters.
output_names
Optional mnemonic names of output values.
input_node
Incoming node.
output_node
Output node.
hidden_nodes
Hidden nodes
nodes
All input/output/hidden nodes.
edges
Edges.
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
def __init__(self, model: Model, graph_id: int, name: str = cast(str, None), _internal: bool = False):
assert _internal, '`Graph()` is private'
self.model: Model = model
self.id: int = graph_id
self.name: str = name or f'_generated_{graph_id}'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self.python_name: Optional[str] = None
self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _IOPseudoOperation('_inputs'), _internal=True)
self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _IOPseudoOperation('_outputs'), _internal=True)
self.hidden_nodes: List[Node] = []
self.edges: List[Edge] = []
def __repr__(self):
return f'Graph(id={self.id}, name={self.name}, ' + \
f'input_names={self.input_node.operation.io_names}, ' + \
f'output_names={self.output_node.operation.io_names}, ' + \
f'num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})'
@property
def nodes(self) -> List['Node']:
return [self.input_node, self.output_node] + self.hidden_nodes
def _add_input(self, input_name) -> None:
if self.input_node.operation.io_names is None:
self.input_node.operation.io_names = [input_name]
else:
self.input_node.operation.io_names.append(input_name)
def _add_output(self, output_name) -> None:
if self.output_node.operation.io_names is None:
self.output_node.operation.io_names = [output_name]
else:
self.output_node.operation.io_names.append(output_name)
@overload
def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def add_node(self, name, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(operation_or_type, cast(dict, parameters), name)
return Node(self, uid(), name, op, _internal=True)._register()
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(operation_or_type, cast(dict, parameters), name)
new_node = Node(self, uid(), name, op, _internal=True)._register()
# update edges
self.add_edge((edge.head, edge.head_slot), (new_node, None))
self.add_edge((new_node, None), (edge.tail, edge.tail_slot))
self.del_edge(edge)
return new_node
# mutation
def add_edge(self, head: EdgeEndpoint, tail: EdgeEndpoint) -> 'Edge':
assert head[0].graph is self and tail[0].graph is self
return Edge(head, tail, _internal=True)._register()
def del_edge(self, edge: 'Edge') -> None:
self.edges.remove(edge)
def get_node_by_name(self, name: str) -> Optional['Node']:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found = [node for node in self.nodes if node.name == name]
return found[0] if found else None
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Returns the node which has specified python_name; or returns `None` if no node has this python_name.
"""
found = [node for node in self.nodes if node.python_name == python_name]
return found[0] if found else None
def get_nodes_by_type(self, operation_type: str) -> List['Node']:
"""
Returns nodes whose operation is specified typed.
"""
return [node for node in self.hidden_nodes if node.operation.type == operation_type]
def get_node_by_id(self, node_id: int) -> Optional['Node']:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found = [node for node in self.nodes if node.id == node_id]
return found[0] if found else None
def get_nodes_by_label(self, label: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.label == label]
def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name]
def get_nodes_by_python_name(self, python_name: str) -> List['Node']:
return [node for node in self.nodes if node.python_name == python_name]
def topo_sort(self) -> List['Node']:
node_to_fanin = {}
curr_nodes = []
for node in self.nodes:
fanin = len(node.incoming_edges)
node_to_fanin[node] = fanin
if fanin == 0:
curr_nodes.append(node)
sorted_nodes = []
while curr_nodes:
curr_node = curr_nodes.pop(0)
sorted_nodes.append(curr_node)
# use successor_slots because a node may connect to another node multiple times
# to different slots
for successor_slot in curr_node.successor_slots:
successor = successor_slot[0]
node_to_fanin[successor] -= 1
if node_to_fanin[successor] == 0:
curr_nodes.append(successor)
for key in node_to_fanin:
assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(
key,
node_to_fanin[key],
key.predecessors[0],
self.edges,
node_to_fanin.values(),
node_to_fanin.keys())
return sorted_nodes
def fork(self) -> 'Graph':
"""
Fork the model and returns corresponding graph in new model.
This shortcut might be helpful because many algorithms only cares about "stem" subgraph instead of whole model.
"""
return self.model.fork().graphs[self.name]
def __eq__(self, other: object) -> bool:
return self is other
def _fork_to(self, model: Model, name_prefix='') -> 'Graph':
new_graph = Graph(model, self.id, name_prefix + self.name, _internal=True)._register()
# TODO: use node copy instead
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
new_graph.python_name = self.python_name
for node in self.hidden_nodes:
new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True)
new_node.python_name = node.python_name
new_node.update_label(node.label)
new_node._register()
id_to_new_node = {node.id: node for node in new_graph.nodes}
for edge in self.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
return new_graph
def _copy(self) -> 'Graph':
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, uid(), _internal=True)._register()
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
new_graph.python_name = self.python_name
id_to_new_node = {} # old node ID -> new node object
for old_node in self.hidden_nodes:
new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register()
new_node.python_name = old_node.python_name
new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node
for edge in self.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
return new_graph
def _register(self) -> 'Graph':
self.model.graphs[self.name] = self
return self
def _rename_graph(self, old_name, new_name):
self.model.graphs[old_name].name = new_name
self.model.graphs[new_name] = self.model.graphs[old_name]
del self.model.graphs[old_name]
@staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, uid(), name, _internal=True)
graph.input_node.operation.io_names = ir.get('inputs')
graph.output_node.operation.io_names = ir.get('outputs')
for node_name, node_data in ir['nodes'].items():
Node._load(graph, node_name, node_data)._register()
for edge_data in ir['edges']:
Edge._load(graph, edge_data)._register()
return graph
def _dump(self) -> Any:
return {
'inputs': self.input_node.operation.io_names,
'outputs': self.output_node.operation.io_names,
'nodes': {node.name: node._dump() for node in self.hidden_nodes},
'edges': [edge._dump() for edge in self.edges]
}
class Node:
"""
An operation or an opaque subgraph inside a graph.
Each node belongs to and only belongs to one :class:`Graph`.
Nodes should never be created with constructor. Use :meth:`Graph.add_node` instead.
The node itself is for topology only.
Information of tensor calculation should all go inside ``operation`` attribute.
TODO: parameter of subgraph (cell)
It's easy to assign parameters on cell node, but it's hard to "use" them.
We need to design a way to reference stored cell parameters in inner node operations.
e.g. ``self.fc = Linear(self.units)`` <- how to express ``self.units`` in IR?
Attributes
----------
graph
The graph containing this node.
id
Unique ID in the model.
If two models have nodes with same ID, they are semantically the same node.
name
Mnemonic name. It should have an one-to-one mapping with ID.
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
label
Optional. If two nodes have the same label, they are considered same by the mutator.
operation
Operation.
cell
Read only shortcut to get the referenced subgraph.
If this node is not a subgraph (is a primitive operation), accessing ``cell`` will raise an error.
predecessors
Predecessor nodes of this node in the graph. This is an optional mutation helper.
successors
Successor nodes of this node in the graph. This is an optional mutation helper.
incoming_edges
Incoming edges of this node in the graph. This is an optional mutation helper.
outgoing_edges
Outgoing edges of this node in the graph. This is an optional mutation helper.
"""
def __init__(self, graph, node_id, name, operation, _internal=False):
self.graph: Graph = graph
self.id: int = node_id
self.name: str = name or f'_generated_{node_id}'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self.python_name: Optional[str] = None
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation
self.label: Optional[str] = None
def __repr__(self):
return f'Node(id={self.id}, name={self.name}, python_name={self.python_name}, label={self.label}, operation={self.operation})'
@property
def predecessors(self) -> List['Node']:
return sorted(set(edge.head for edge in self.incoming_edges), key=(lambda node: node.id))
@property
def successors(self) -> List['Node']:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
@property
def successor_slots(self) -> Set[Tuple['Node', Union[int, None]]]:
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
@property
def incoming_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.tail is self]
@property
def outgoing_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.head is self]
@property
def cell(self) -> Graph:
assert isinstance(self.operation, Cell)
return self.graph.model.graphs[self.operation.parameters['cell']]
def update_label(self, label: Optional[str]) -> None:
self.label = label
@overload
def update_operation(self, operation: Operation) -> None: ...
@overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> None: ...
def update_operation(self, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
self.operation = operation_or_type
else:
self.operation = Operation.new(operation_or_type, cast(dict, parameters))
# mutation
def remove(self) -> None:
assert not self.incoming_edges and not self.outgoing_edges
self.graph.hidden_nodes.remove(self)
# mutation
def specialize_cell(self) -> Graph:
"""
Only available if the operation is a cell.
Duplicate the cell template and let this node reference to newly created copy.
"""
new_cell = self.cell._copy()._register()
self.operation = Cell(new_cell.name)
return new_cell
def __eq__(self, other: object) -> bool:
return self is other
def __hash__(self) -> int:
return hash(id(self))
def _register(self) -> 'Node':
self.graph.hidden_nodes.append(self)
return self
@staticmethod
def _load(graph: Graph, name: str, ir: Any) -> 'Node':
if ir['operation']['type'] == '_cell':
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}), attributes=ir['operation'].get('attributes', {}))
else:
op = Operation.new(ir['operation']['type'],
ir['operation'].get('parameters', {}),
attributes=ir['operation'].get('attributes', {}))
node = Node(graph, uid(), name, op)
if 'label' in ir:
node.update_label(ir['label'])
return node
def _dump(self) -> Any:
ret: Dict[str, Any] = {
'operation': {
'type': self.operation.type,
'parameters': self.operation.parameters,
'attributes': self.operation.attributes
}
}
if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:
ret['label'] = self.label
if self.python_name is not None:
ret['python_name'] = self.python_name
return ret
class Edge:
"""
A tensor, or "data flow", between two nodes.
Example forward code snippet: ::
a, b, c = split(x)
p = concat(a, c)
q = sum(b, p)
z = relu(q)
Edges in above snippet: ::
+ head: (split, 0), tail: (concat, 0) # a in concat
+ head: (split, 2), tail: (concat, 1) # c in concat
+ head: (split, 1), tail: (sum, -1 or 0) # b in sum
+ head: (concat, null), tail: (sum, -1 or 1) # p in sum
+ head: (sum, null), tail: (relu, null) # q in relu
Attributes
----------
graph
Graph.
head
Head node.
tail
Tail node.
head_slot
Index of outputs in head node.
If the node has only one output, this should be ``null``.
tail_slot
Index of inputs in tail node.
If the node has only one input, this should be ``null``.
If the node does not care about order, this can be ``-1``.
"""
def __init__(self, head: EdgeEndpoint, tail: EdgeEndpoint, _internal: bool = False):
assert _internal, '`Edge()` is private'
self.graph: Graph = head[0].graph
self.head: Node = head[0]
self.tail: Node = tail[0]
self.head_slot: Optional[int] = head[1]
self.tail_slot: Optional[int] = tail[1]
def __repr__(self):
return f'Edge(head=({self.head}, {self.head_slot}), tail=({self.tail}, {self.tail_slot}))'
# mutation
def remove(self) -> None:
self.graph.edges.remove(self)
def _register(self) -> 'Edge':
self.graph.edges.append(self)
return self
@staticmethod
def _load(graph: Graph, ir: Any) -> 'Edge':
head = graph.get_node_by_name(ir['head'][0])
tail = graph.get_node_by_name(ir['tail'][0])
assert head is not None and tail is not None
return Edge((head, ir['head'][1]), (tail, ir['tail'][1]), _internal=True)
def _dump(self) -> Any:
return {
'head': [self.head.name, self.head_slot],
'tail': [self.tail.name, self.tail_slot]
}
class Mutation:
"""
An execution of mutation, which consists of four parts: a mutator, a list of decisions (choices),
the model that it comes from, and the model that it becomes.
In general cases, the mutation logs are not reliable and should not be replayed as the mutators can
be arbitrarily complex. However, for inline mutations, the labels correspond to mutator labels here,
this can be useful for metadata visualization and python execution mode.
Attributes
----------
mutator
Mutator.
samples
Decisions/choices.
from_
Model that is comes from.
to
Model that it becomes.
"""
def __init__(self, mutator: 'Mutator', samples: List[Any], from_: Model, to: Model): # noqa: F821
self.mutator: 'Mutator' = mutator # noqa: F821
self.samples: List[Any] = samples
self.from_: Model = from_
self.to: Model = to
def __repr__(self):
return f'Edge(mutator={self.mutator}, samples={self.samples}, from={self.from_}, to={self.to})'
class IllegalGraphError(ValueError):
def __init__(self, graph, *args):
self._debug_dump_graph(graph)
super().__init__(*args)
@staticmethod
def _debug_dump_graph(graph):
if isinstance(graph, Graph):
graph = graph._dump()
with open('generated/debug.json', 'w') as dump_file:
json.dump(graph, dump_file, indent=4)
class DebugEvaluator(Evaluator):
@staticmethod
def _load(ir: Any) -> 'DebugEvaluator':
return DebugEvaluator()
def _dump(self) -> Any:
return {'type': DebugEvaluator}
def _execute(self, model_cls: type) -> Any:
pass
def __eq__(self, other) -> bool:
return True

Просмотреть файл

@ -0,0 +1,251 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Operations used in graph-based engine.
"""
from typing import (Any, Dict, List, Optional, cast)
from nni.common.framework import get_default_framework
__all__ = ['Operation', 'Cell', 'PyTorchOperation', 'TensorFlowOperation']
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
class Operation:
"""
Calculation logic of a graph node.
The constructor is private. Use `Operation.new()` to create operation object.
`Operation` is a naive record.
Do not "mutate" its attributes or store information relate to specific node.
All complex logic should be implemented in `Node` class.
Attributes
----------
type
Operation type name (e.g. Conv2D).
If it starts with underscore, the "operation" is a special one (e.g. subgraph, input/output).
parameters
Arbitrary key-value parameters (e.g. kernel_size).
"""
io_names: List[str] = []
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name
self.parameters: Dict[str, Any] = parameters
self.attributes: Dict[str, Any] = attributes
def to_init_code(self, field: str) -> str:
raise NotImplementedError()
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise NotImplementedError()
def _to_class_name(self) -> str:
raise NotImplementedError()
def __bool__(self) -> bool:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None), cell_name: str = cast(str, None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Operation':
parameters = parameters or {}
attributes = attributes or {}
if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
else:
if get_default_framework() in ('torch', 'pytorch'):
from nni.nas.execution.pytorch import op_def # pylint: disable=unused-import
cls = PyTorchOperation._find_subclass(type_name)
elif get_default_framework() in ('tf', 'tensorflow'):
from nni.nas.execution.tensorflow import op_def # pylint: disable=unused-import
cls = TensorFlowOperation._find_subclass(type_name)
else:
raise ValueError(f'Unsupported framework: {get_default_framework()}')
return cls(type_name, parameters, _internal=True, attributes=attributes)
@classmethod
def _find_subclass(cls, subclass_name):
for subclass in cls.__subclasses__():
if subclass.__name__ == subclass_name:
return subclass
return cls
def __repr__(self):
type_name = type(self).__name__
args = [f'{key}={repr(value)}' for key, value in self.parameters.items()]
if type_name != self.type:
args = [f'type="{self.type}"'] + args
return f'{type_name}({", ".join(args)})'
def __eq__(self, other):
return type(other) is type(self) and other.type == self.type and other.parameters == self.parameters
class PyTorchOperation(Operation):
@classmethod
def _find_subclass(cls, subclass_name):
if cls.to_class_name(subclass_name) is not None:
subclass_name = 'ModuleOperator'
if cls.is_functional(subclass_name):
subclass_name = 'FunctionalOperator'
for subclass in cls.__subclasses__():
if hasattr(subclass, '_ori_type_name') and \
subclass_name in cast(Any, subclass)._ori_type_name:
return subclass
for subclass in cls.__subclasses__():
if hasattr(subclass, '_artificial_op_name') and \
subclass_name in cast(Any, subclass)._artificial_op_name:
return subclass
return cls
@classmethod
def to_class_name(cls, type_name) -> Optional[str]:
if type_name.startswith('__torch__.'):
return type_name[len('__torch__.'):]
elif type_name.startswith('__mutated__.'):
return type_name[len('__mutated__.'):]
else:
return None
@classmethod
def is_functional(cls, type_name) -> bool:
return type_name.startswith('Function.')
def _to_class_name(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):]
else:
return None
def get_import_pkg(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):].split('.')[0]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):].split('.')[0]
else:
return None
def to_init_code(self, field: str) -> Optional[str]:
if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
"""
Parameters
----------
field : str
the name of member submodule
output : str
the output name (lvalue) of this line of code
inputs : List[str]
variables used in this line of code
inputs_value : List[Any]
some variables are actually constant, their real values are recorded in ```inputs_value```.
if not constant, we simply put None at the corresponding index
Returns
-------
str
generated code line
"""
if self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
class TensorFlowOperation(Operation):
def _to_class_name(self) -> str:
return 'K.layers.' + self.type
class Cell(PyTorchOperation):
"""
TODO: this is pytorch cell
An operation reference to a subgraph.
Example code:
```
def __init__(...):
...
self.cell = CustomCell(...)
self.relu = K.layers.ReLU()
...
def forward(...):
...
x = self.cell(x)
...
```
In above example, node `self.cell`'s operation is `Cell(cell_name='CustomCell')`.
For comparison, `self.relu`'s operation is `Operation(type='ReLU')`.
TODO: parameters of subgraph (see `Node` class)
Attributes
----------
type
Always "_cell".
parameters
A dict with only one item; the key is "cell" and the value is cell's name.
framework
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)):
self.type = '_cell'
self.cell_name = cell_name
self.parameters = parameters or {}
self.attributes = attributes or {}
def _to_class_name(self):
# TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name)
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class _IOPseudoOperation(Operation):
"""
This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def __init__(self, type_name: str, io_names: List[str] = cast(List[str], None)):
assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.io_names = io_names
def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def __bool__(self) -> bool:
return False

Просмотреть файл

@ -0,0 +1,235 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['RetiariiAdvisor']
import logging
import os
from typing import Any, Callable, Optional, Dict, List, Tuple
import nni
from nni.common.serializer import PayloadTooLarge
from nni.common.version import version_dump
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.tuner_command_channel import CommandType
from nni.utils import MetricType
from .graph import MetricData
from .integration_api import register_advisor
_logger = logging.getLogger(__name__)
class RetiariiAdvisor(MsgDispatcherBase):
"""
The class is to connect Retiarii components to NNI backend.
It can be considered as a Python wrapper of NNI manager.
It will function as the main thread when running a Retiarii experiment through NNI.
Strategy will be launched as its thread, who will call APIs in execution engine. Execution
engine will then find the advisor singleton and send payloads to advisor.
When metrics are sent back, advisor will first receive the payloads, who will call the callback
function (that is a member function in graph listener).
The conversion advisor provides are minimum. It is only a send/receive module, and execution engine
needs to handle all the rest.
Attributes
----------
send_trial_callback
request_trial_jobs_callback
trial_end_callback
intermediate_metric_callback
final_metric_callback
"""
def __init__(self, url: str):
super().__init__(url)
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
self.send_trial_callback: Optional[Callable[[dict], None]] = None
self.request_trial_jobs_callback: Optional[Callable[[int], None]] = None
self.trial_end_callback: Optional[Callable[[int, bool], None]] = None
self.intermediate_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.parameters_count = 0
# Sometimes messages arrive first before the callbacks get registered.
# Or in case that we allow engine to be absent during the experiment.
# Here we need to store the messages and invoke them later.
self.call_queue: List[Tuple[str, list]] = []
def register_callbacks(self, callbacks: Dict[str, Callable[..., None]]):
"""
Register callbacks for NNI backend.
Parameters
----------
callbacks
A dictionary of callbacks.
The key is the name of the callback. The value is the callback function.
"""
self.send_trial_callback = callbacks.get('send_trial')
self.request_trial_jobs_callback = callbacks.get('request_trial_jobs')
self.trial_end_callback = callbacks.get('trial_end')
self.intermediate_metric_callback = callbacks.get('intermediate_metric')
self.final_metric_callback = callbacks.get('final_metric')
self.process_queued_callbacks()
def process_queued_callbacks(self) -> None:
"""
Process callbacks in queue.
Consume the messages that haven't been handled previously.
"""
processed_idx = []
for queue_idx, (call_name, call_args) in enumerate(self.call_queue):
if call_name == 'send_trial' and self.send_trial_callback is not None:
self.send_trial_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
if call_name == 'request_trial_jobs' and self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
if call_name == 'trial_end' and self.trial_end_callback is not None:
self.trial_end_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
if call_name == 'intermediate_metric' and self.intermediate_metric_callback is not None:
self.intermediate_metric_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
if call_name == 'final_metric' and self.final_metric_callback is not None:
self.final_metric_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
# Remove processed messages
for idx in reversed(processed_idx):
self.call_queue.pop(idx)
def invoke_callback(self, name: str, *args: Any) -> None:
"""
Invoke callback.
"""
self.call_queue.append((name, list(args)))
self.process_queued_callbacks()
def handle_initialize(self, data):
"""callback for initializing the advisor
Parameters
----------
data: dict
search space
"""
self.handle_update_search_space(data)
self.send(CommandType.Initialized, '')
def _validate_placement_constraint(self, placement_constraint):
if placement_constraint is None:
raise ValueError('placement_constraint is None')
if not 'type' in placement_constraint:
raise ValueError('placement_constraint must have `type`')
if not 'gpus' in placement_constraint:
raise ValueError('placement_constraint must have `gpus`')
if placement_constraint['type'] not in ['None', 'GPUNumber', 'Device']:
raise ValueError('placement_constraint.type must be either `None`,. `GPUNumber` or `Device`')
if placement_constraint['type'] == 'None' and len(placement_constraint['gpus']) > 0:
raise ValueError('placement_constraint.gpus must be an empty list when type == None')
if placement_constraint['type'] == 'GPUNumber':
if len(placement_constraint['gpus']) != 1:
raise ValueError('placement_constraint.gpus currently only support one host when type == GPUNumber')
for e in placement_constraint['gpus']:
if not isinstance(e, int):
raise ValueError('placement_constraint.gpus must be a list of number when type == GPUNumber')
if placement_constraint['type'] == 'Device':
for e in placement_constraint['gpus']:
if not isinstance(e, tuple):
raise ValueError('placement_constraint.gpus must be a list of tuple when type == Device')
if not (len(e) == 2 and isinstance(e[0], str) and isinstance(e[1], int)):
raise ValueError('placement_constraint.gpus`s tuple must be (str, int)')
def send_trial(self, parameters, placement_constraint=None):
"""
Send parameters to NNI.
Parameters
----------
parameters : Any
Any payload.
Returns
-------
int
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
"""
self.parameters_count += 1
if placement_constraint is None:
placement_constraint = {
'type': 'None',
'gpus': []
}
self._validate_placement_constraint(placement_constraint)
new_trial = {
'parameter_id': self.parameters_count,
'parameters': parameters,
'parameter_source': 'algorithm',
'placement_constraint': placement_constraint,
'version_info': version_dump()
}
_logger.debug('New trial sent: %s', new_trial)
try:
send_payload = nni.dump(new_trial, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
except PayloadTooLarge:
raise ValueError(
'Serialization failed when trying to dump the model because payload too large (larger than 64 KB). '
'This is usually caused by pickling large objects (like datasets) by mistake. '
'See the full error traceback for details and https://nni.readthedocs.io/en/stable/NAS/Serialization.html '
'for how to resolve such issue. '
)
# trial parameters can be super large, disable pickle size limit here
# nevertheless, there could still be blocked by pipe / nni-manager
self.send(CommandType.NewTrialJob, send_payload)
self.invoke_callback('send_trial', parameters)
return self.parameters_count
def mark_experiment_as_ending(self):
self.send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials):
_logger.debug('Request trial jobs: %s', num_trials)
self.invoke_callback('request_trial_jobs', num_trials)
def handle_update_search_space(self, data):
_logger.debug('Received search space: %s', data)
self.search_space = data
def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data)
self.invoke_callback('trial_end', nni.load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
self.invoke_callback('intermediate_metric', data['parameter_id'], self._process_value(data['value']))
elif data['type'] == MetricType.FINAL:
self.invoke_callback('final_metric', data['parameter_id'], self._process_value(data['value']))
@staticmethod
def _process_value(value) -> Any: # hopefully a float
value = nni.load(value)
if isinstance(value, dict):
if 'default' in value:
return value['default']
else:
return value
return value

Просмотреть файл

@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = [
'get_advisor', 'register_advisor', 'send_trial', 'receive_trial_parameters', 'get_experiment_id',
'_advisor' # FIXME: hack to make it importable for tests
]
import warnings
from typing import NewType, Any
import nni
from nni.common.version import version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
_advisor = None # type is RetiariiAdvisor
def get_advisor():
# return type: RetiariiAdvisor
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor):
# type of advisor: RetiariiAdvisor
global _advisor
if _advisor is not None:
warnings.warn('Advisor is already set.'
'You should avoid instantiating RetiariiExperiment twice in one proces.'
'If you are running in a Jupyter notebook, please restart the kernel.')
_advisor = advisor
def send_trial(parameters: dict, placement_constraint=None) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters, placement_constraint)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params = nni.get_next_parameter()
# version check, optional
raw_params = nni.trial._params
if raw_params is not None and 'version_info' in raw_params:
version_check(raw_params['version_info'])
else:
warnings.warn('Version check failed because `version_info` is not found.')
return params
def get_experiment_id() -> str:
return nni.get_experiment_id()

Просмотреть файл

@ -0,0 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['DefaultListener']
from .graph import Model, ModelStatus, MetricData
from .engine import AbstractGraphListener
class DefaultListener(AbstractGraphListener):
def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
model.intermediate_metrics.append(metric)
def on_training_end(self, model: Model, success: bool) -> None:
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed

Просмотреть файл

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['unpack_if_only_one', 'get_mutation_dict', 'mutation_dict_to_summary', 'get_mutation_summary']
from typing import Any, List
from .graph import Model
def unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
def get_mutation_dict(model: Model):
return {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model.history}
def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary = {}
for label, samples in mutation.items():
# FIXME: this check might be wrong
if not isinstance(samples, list):
mutation_summary[label] = samples
else:
for i, sample in enumerate(samples):
mutation_summary[f'{label}_{i}'] = sample
return mutation_summary
def get_mutation_summary(model: Model) -> dict:
mutation = get_mutation_dict(model)
return mutation_dict_to_summary(mutation)

Просмотреть файл

@ -0,0 +1,154 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
from nni.nas.execution.common import Model, receive_trial_parameters, get_mutation_dict
from .graph import BaseExecutionEngine
class BenchmarkGraphData:
SUPPORTED_BENCHMARK_LIST = [
'nasbench101',
'nasbench201-cifar10',
'nasbench201-cifar100',
'nasbench201-imagenet16',
'nds-cifar10',
'nds-imagenet',
'nlp'
]
def __init__(self, mutation: Dict[str, Any], benchmark: str,
metric_name: Optional[str] = None,
db_path: Optional[str] = None) -> None:
self.mutation = mutation # mutation dict. e.g., {'layer1': 'conv3x3', ...}
self.benchmark = benchmark # e.g., nasbench101, nasbench201, ...
self.db_path = db_path # path to directory of database
def dump(self) -> dict:
from nni.nas.benchmarks.constants import DATABASE_DIR
return {
'mutation': self.mutation,
'benchmark': self.benchmark,
'db_path': self.db_path or DATABASE_DIR # database path need to be passed from manager to worker
}
@staticmethod
def load(data) -> 'BenchmarkGraphData':
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
def __repr__(self) -> str:
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
class BenchmarkExecutionEngine(BaseExecutionEngine):
"""
Execution engine that does not actually run any trial, but query the database for results.
The database query is done on the trial end to make sure intermediate metrics are available.
It will also support an accelerated mode that returns metric immediately without even running into NNI manager
(not implemented yet).
"""
def __init__(self, benchmark: Union[str, Callable[[BenchmarkGraphData], Tuple[float, List[float]]]], acceleration: bool = False):
super().__init__()
assert benchmark in BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST, \
f'{benchmark} is not one of the supported benchmarks: {BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST}'
self.benchmark = benchmark
self.acceleration = acceleration
def pack_model_data(self, model: Model) -> Any:
# called when a new model is submitted to backend.
# convert a Model into a data that is acceptable by trial end.
mutation = get_mutation_dict(model)
graph_data = BenchmarkGraphData(mutation, self.benchmark)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
final, intermediates = cls.query_in_benchmark(graph_data)
import nni
for i in intermediates:
nni.report_intermediate_result(i)
nni.report_final_result(final)
@staticmethod
def query_in_benchmark(graph_data: BenchmarkGraphData) -> Tuple[float, List[float]]:
if not isinstance(graph_data.benchmark, str):
return graph_data.benchmark(graph_data)
# built-in benchmarks with default query setting
if graph_data.benchmark == 'nasbench101':
from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats
arch = None
for t in graph_data.mutation.values():
if isinstance(t, dict):
arch = t
if arch is None:
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
return _convert_to_final_and_intermediates(
query_nb101_trial_stats(arch, 108, include_intermediates=True),
'valid_acc'
)
elif graph_data.benchmark.startswith('nasbench201'):
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
dataset = graph_data.benchmark.split('-')[-1]
return _convert_to_final_and_intermediates(
query_nb201_trial_stats(_flatten_architecture(graph_data.mutation), 200, dataset, include_intermediates=True),
'valid_acc',
)
elif graph_data.benchmark.startswith('nds'):
# FIXME: not tested yet
from nni.nas.benchmarks.nds import query_nds_trial_stats
dataset = graph_data.benchmark.split('-')[-1]
return _convert_to_final_and_intermediates(
query_nds_trial_stats(None, None, None, None, _flatten_architecture(graph_data.mutation),
dataset, include_intermediates=True),
'valid_acc'
)
elif graph_data.benchmark.startswith('nlp'):
# FIXME: not tested yet
from nni.nas.benchmarks.nlp import query_nlp_trial_stats
# TODO: I'm not sure of the availble datasets in this benchmark. and the docs are missing.
return _convert_to_final_and_intermediates(
query_nlp_trial_stats(_flatten_architecture(graph_data.mutation), 'ptb', include_intermediates=True),
'valid_acc'
)
else:
raise ValueError(f'{graph_data.benchmark} is not a supported benchmark.')
def _flatten_architecture(mutation: Dict[str, Any], benchmark: Optional[str] = None):
# STRONG ASSUMPTION HERE!
# This assumes that the benchmarked search space is a one-level search space.
# This means that it is either ONE cell or ONE network.
# Two cell search space like NDS is not supported yet for now.
# Some benchmark even needs special handling to pop out invalid keys. I don't think this is a good design.
# support double underscore to be compatible with naming convention in base engine
ret = {k.split('/')[-1].split('__')[-1]: v for k, v in mutation.items()}
if benchmark == 'nasbench101':
ret = {k: v for k, v in ret.items() if k.startswith('op') or k.startswith('input')}
ret = {k: v if k.startswith('op') or isinstance(v, list) else [v] for k, v in ret.items()}
return ret
def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_name: str) -> Tuple[float, List[float]]:
# convert benchmark results from database to
# final result (float) and intermediate results (list of floats)
benchmark_result = list(benchmark_result)
assert len(benchmark_result) > 0, 'Invalid query. Results from benchmark is empty.'
if len(benchmark_result) > 1:
benchmark_result = random.choice(benchmark_result)
else:
benchmark_result = benchmark_result[0]
benchmark_result = cast(dict, benchmark_result)
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]

Просмотреть файл

@ -1,4 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import PdartsTrainer
from .engine import *

Просмотреть файл

@ -0,0 +1,402 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['CGOExecutionEngine', 'TrialSubmission']
import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple, cast
from dataclasses import dataclass
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from nni.nas import utils
from nni.nas.execution.common import (
AbstractExecutionEngine, AbstractGraphListener, WorkerInfo,
Model, ModelStatus, MetricData, Node,
RetiariiAdvisor, send_trial, receive_trial_parameters, get_advisor,
)
from nni.nas.execution.pytorch import codegen
from nni.nas.evaluator.pytorch.lightning import Lightning
from nni.nas.evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
from nni.nas.execution.pytorch.graph import BaseGraphData
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
_logger = logging.getLogger(__name__)
def _noop(*args, **kwargs):
pass
@dataclass
class TrialSubmission:
model: Model
placement: Dict[Node, Device]
grouped_models: List[Model]
class CGOExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with Cross-Graph Optimization (CGO).
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
Otherwise, a model will be submitted independently without any cross-graph optimization.
Parameters
----------
training_service
The remote training service config.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
def __init__(self, training_service: RemoteConfig,
max_concurrency: int = None,
batch_waiting_time: int = 60,
rest_port: int | None = None,
rest_url_prefix: str | None = None
) -> None:
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency
devices = self._construct_devices(training_service)
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
self._optimizers = [DedupInputOptimizer()]
self._original_models = {}
self._original_model_to_multi_model = {}
self._trial_to_original_models = {}
self._trial_used_devices: Dict[int, List[Device]] = {}
self._history: List[Model] = []
self._queuing_models: List[Model] = []
self._models_to_retry: List[Model] = []
self._queue_lock = threading.Lock()
# register advisor callbacks
advisor: RetiariiAdvisor = get_advisor()
advisor.register_callbacks({
'send_trial': _noop,
'request_trial_jobs': _noop,
'trial_end': self._trial_end_callback,
'intermediate_metric': self._intermediate_metric_callback,
'final_metric': self._final_metric_callback
})
self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def join(self):
self._stopped = True
self._consumer_thread.join()
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None:
curr_time = time.time()
_logger.info('%d models are submitted', len(models))
self._queue_lock.acquire()
self._queuing_models.extend([(curr_time, _) for _ in models])
self._queue_lock.release()
def _submit_retry_models(self, models: List[Model]) -> None:
_logger.info('%d models are retried', len(models))
self._queue_lock.acquire()
self._models_to_retry.extend(models)
self._queue_lock.release()
def _consume_models(self):
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
while not self._stopped:
if len(self._models_to_retry) > 0:
self._queue_lock.acquire()
# retrying jobs should be first scheduled.
for m in self._models_to_retry:
if len(self.available_devices) > 0:
self._submit_models_in_batch(m) # submit the single model to avoid cross-graph optimization.
self._models_to_retry = self._models_to_retry[1:]
self._queue_lock.release()
if len(self._queuing_models) > 0:
self._queue_lock.acquire()
curr_time = time.time()
num_models_to_submit = len(self.available_devices)
if self.max_concurrency:
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
if curr_time - self._queuing_models[0][0] > self._batch_waiting_time:
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
if num_models_to_submit > 0:
self._submit_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]])
self._queuing_models = self._queuing_models[num_models_to_submit:]
self._queue_lock.release()
time.sleep(1)
def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
unique_gpus = sorted(list(set([e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
placement_constraint = None
if len(unique_gpus) > 0:
placement_constraint = {}
placement_constraint['type'] = 'Device'
placement_constraint['gpus'] = [(e.node_id, e.gpu_id) for e in unique_gpus]
return placement_constraint
def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
_logger.debug('model id: %s', str([m.model_id for m in models]))
logical = self._build_logical(models)
for opt in self._optimizers:
opt.convert(logical)
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {})
placement_constraint = self._extract_placement_constaint(placement)
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
# unique non-cpu devices used by the trial
self._trial_used_devices[trial_id] = list(set([_ for _ in placement.values() if isinstance(_, GPUDevice)]))
# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[trial_id]:
self.available_devices.remove(used_device) # used_device must be in self.available_devices
self._running_models[trial_id] = model
self._trial_to_original_models[trial_id] = []
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._trial_to_original_models[trial_id].append(m.model_id)
self._history.append(m)
def list_models(self) -> Iterable[Model]:
return self._history
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, Device], List[Model]]]:
"""
Return the assembled models as a list of tuple.
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
"""
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)
if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
phy_models_and_placements = []
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
assert isinstance(model.evaluator, Lightning), \
"cross-graph optimization only supports pytorch lighting as evaluator"
assert isinstance(model.evaluator.module, _MultiModelSupervisedLearningModule), \
"cross-graph optimization only support MultiModelSupervisedLearningModule"
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module.dump_kwargs().copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model)
new_module = _MultiModelSupervisedLearningModule(**new_module_init_params)
model.evaluator.module = new_module
phy_models_and_placements.append((model, model_placement, multi_model.keys()))
return phy_models_and_placements
def _build_logical(self, models: List[Model]) -> LogicalPlan:
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
return logical_plan
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
# def _send_trial_callback(self, paramater: dict) -> None:
# if len(self.available_devices) == 0:
# _logger.warning('There is no available devices, but trial is submitted.')
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
# def _request_trial_jobs_callback(self, num_trials: int) -> None:
# self.resources += num_trials
# _logger.info('on_resource_available: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
models_to_retry = []
for model_id in self._original_model_to_multi_model:
if self._original_model_to_multi_model[model_id] == model:
original_model = self._original_models[model_id]
if success:
original_model.status = ModelStatus.Trained
else:
original_model.status = ModelStatus.Failed
# the failed models in a multi-model will be retried one by one w/o CGO
if len(self._trial_to_original_models[trial_id]) > 1:
models_to_retry.append(original_model)
for listener in self._listeners:
listener.on_training_end(original_model, success)
if len(models_to_retry) > 0:
self._submit_retry_models(models_to_retry)
self.available_devices.extend(self._trial_used_devices[trial_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[trial_id]
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].intermediate_metrics.append(merged_metrics[model_id])
for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[model_id], merged_metrics[model_id])
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
_logger.debug(metrics)
if isinstance(metrics, float):
self._listeners[0].on_metric(self._running_models[trial_id], metrics)
else:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].metric = merged_metrics[model_id]
for listener in self._listeners:
listener.on_metric(self._original_models[model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]:
# the _queuing_models need to use available_devices first
self._queue_lock.acquire()
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
self._queue_lock.release()
return available_for_more_models
def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
_logger.info('CGO_ENGINE trial parameters received')
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
trainer_instance = graph_data.evaluator
model_cls = utils.import_(f'_generated_model.{random_str}._model')
trainer_instance.fit(model_cls())
os.remove(file_name)
class AssemblePolicy:
@staticmethod
def _is_related_node(model: Model, node: Node):
if isinstance(node, AbstractLogicalNode):
if model in node.related_models:
return True
else:
if model == node.graph.model:
return True
return False
@staticmethod
def _check_graph_connectivity(model: Model,
group_model: Dict[Model, Device],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
AssemblePolicy._is_related_node(model, edge.tail):
for grouped_model in group_model:
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
AssemblePolicy._is_related_node(grouped_model, edge.tail):
return True
return False
@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, _MultiModelSupervisedLearningModule)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
return False
return True
@staticmethod
def group(logical_plan, available_devices):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models = []
group_model = {}
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
for idx, m in enumerate(logical_plan.models):
# models in one group should
# (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(group_model)
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(group_model)
group_model = {}
return all_grouped_models

Просмотреть файл

@ -0,0 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC
from .logical_plan import LogicalPlan
class AbstractOptimizer(ABC):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
raise NotImplementedError

Просмотреть файл

@ -0,0 +1,336 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from typing import Dict, Tuple, Any
from nni.retiarii.utils import uid
from nni.common.device import Device, CPUDevice
from nni.nas.execution.common.graph import Cell, Edge, Graph, Model, Node
from nni.nas.execution.common.graph_op import Operation, _IOPseudoOperation
class AbstractLogicalNode(Node):
def __init__(self, graph, node_id, name, operation, _internal=False):
super().__init__(graph, node_id, name, operation, _internal=_internal)
self.related_models = []
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces the logical node with an executable physical node for the physical model.
Parameters
----------
multi_model_placement : dict
a dict of models and device placement.
These models will be assembled into the same physical model to run.
Returns
-------
node : Node
the physical node to replace the logical node in the physical model
placement : Device
the device placement of the returned physical node
"""
raise NotImplementedError
def _fork_to(self, graph: Graph):
raise NotImplementedError
class LogicalGraph(Graph):
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
super().__init__(model, graph_id, name='logical_' + name, _internal=_internal)
def _dump(self) -> Any:
nodes_dump = {}
for node in self.hidden_nodes:
if isinstance(node, OriginNode):
nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump()
else:
nodes_dump[f"{node.graph.model.model_id}_{node.name}"] = node._dump()
edges_dump = []
for edge in self.edges:
if isinstance(edge.head, OriginNode):
head_info = f'{edge.head.original_graph.model.model_id}_{edge.head.name}'
else:
head_info = edge.head.name
if isinstance(edge.tail, OriginNode):
tail_info = f'{edge.tail.original_graph.model.model_id}_{edge.tail.name}'
else:
tail_info = edge.tail.name
edges_dump.append((head_info, tail_info))
return {
'inputs': self.input_node.operation.io_names,
'outputs': self.output_node.operation.io_names,
'nodes': nodes_dump,
'edges': edges_dump
}
def _fork_to(self, model: Model) -> Graph:
new_graph = Graph(model, self.id, self.name,
_internal=True)._register()
for node in self.hidden_nodes:
if isinstance(node, AbstractLogicalNode):
node._fork_to(new_graph)
else:
Node(new_graph, node.id, node.name,
node.operation, _internal=True)._register()
id_to_new_node = {node.__repr__(): node for node in new_graph.nodes}
for edge in self.edges:
new_head = id_to_new_node[edge.head.__repr__()]
new_tail = id_to_new_node[edge.tail.__repr__()]
Edge((new_head, edge.head_slot),
(new_tail, edge.tail_slot), _internal=True)._register()
return new_graph
class OriginNode(AbstractLogicalNode):
"""
This is logical node representing the original node without any modification.
In assemble, just return the original node along with the physical placement given by multi_model_placement.
"""
def __init__(self, logical_graph: LogicalGraph,
original_graph: Graph, original_node: Node,
name: str, operation, _internal=False):
super().__init__(logical_graph, original_node.id, name, operation)
self.original_graph = original_graph
self.original_node = original_node
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
model_id = self.original_node.graph.model.model_id
new_node = Node(self.original_node.graph, self.original_node.id,
f"M_{model_id}_" +
self.original_node.name,
self.original_node.operation)
return new_node, multi_model_placement[self.original_node.graph.model]
def __repr__(self):
return f'OriginNode(id={self.id}, name={self.name}, \
operation={self.operation}, origin_model_id={self.original_graph.model.model_id})'
def _fork_to(self, graph: Graph):
OriginNode(graph, self.original_graph, self.original_node,
self.name, self.operation)._register()
class LogicalPlan:
def __init__(self, plan_id=0) -> None:
self.lp_model = Model(_internal=True)
self.id = plan_id
self.logical_graph = LogicalGraph(
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
self.lp_model._root_graph_name = self.logical_graph.name
self.models = []
def add_model(self, model: Model):
self.models.append(model)
# Only optimize the root graph.
self._merge_graph(model.root_graph)
def _merge_graph(self, from_graph):
to_graph = self.logical_graph
id_to_new_node = {} # old node ID -> new node object
for old_node in from_graph.nodes:
new_node = OriginNode(to_graph, old_node.graph,
old_node, old_node.name,
old_node.operation, _internal=True)._register()
id_to_new_node[old_node.id] = new_node
for edge in from_graph.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
def assemble(self, multi_model_placement: Dict[Model, Device]) \
-> Tuple[Model, Dict[Node, Device]]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces all the logical node in this LogicalPlan with executable physical nodes
for the physical model.
Parameters
----------
multi_model_placement : dict
a dict of models and device placement.
These models will be assembled into the same physical model to run.
Returns
-------
phy_model : Model
the physical model formed by models in `multi_model_placement`
all logical node are replaced by physical nodes
node_placements : dict
the device placement of the nodes in `phy_model`
"""
phy_model = Model(_internal=True)
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
phy_graph._rename_graph(phy_graph.name, "_model")
# merge sub-graphs
for model in multi_model_placement:
if phy_model.evaluator is None and model.evaluator is not None:
phy_model.evaluator = model.evaluator
for graph_name in model.graphs:
if graph_name != model._root_graph_name:
new_graph = model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_')
# prefix of M_ of hidden_nodes name in non-root graphs is added here
for new_node in new_graph.hidden_nodes:
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model.model_id}_{old_cell_name}'
assert(phy_model.evaluator is not None)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
evaluator_slot = {} # Model ID -> Slot ID
input_slot_mapping = {}
output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes
hidden_nodes = phy_graph.hidden_nodes.copy()
node_placements = {}
added_models = []
for node in hidden_nodes:
if isinstance(node, OriginNode):
model_id = node.original_graph.model.model_id
if node.original_graph.model not in multi_model_placement:
for edge in node.incoming_edges:
edge.remove()
for edge in node.outgoing_edges:
edge.remove()
node.remove()
continue
if isinstance(node, AbstractLogicalNode):
new_node, placement = node.assemble(multi_model_placement)
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in evaluator_slot:
added_models.append(model_id)
evaluator_slot[model_id] = len(added_models) - 1
slot = evaluator_slot[model_id]
else:
slot = evaluator_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot
if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot
self.node_replace(node, new_node)
# name prefix of M_ of cells in hidden_nodes of root graphs is added here
# FIXME: merge this rename with non-root graph, only do once.
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'
# input should be at CPU, move it to GPU first if necessary
if isinstance(new_node.operation, _IOPseudoOperation) and new_node.operation.type == '_inputs':
# hack: only support single_server
node_placements[new_node] = CPUDevice(node_id=placement.node_id)
else:
node_placements[new_node] = placement
node.remove()
# If two nodes are placed on different devices, use ToDevice op to copy the node
# TODO: when copying one node to multiple devices, broadcast is more efficient than P2P communication
existing_edges = phy_graph.edges.copy()
# Avoid a node is copied multiple times on the same device
copied_op: Dict[Tuple(Node, Device), Node] = {}
for edge in existing_edges:
head_placement = node_placements[edge.head]
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.node_id != tail_placement.node_id:
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
dst_name = edge.head.name + "_to_" + edge.tail.name
to_operation = Operation.new(
'ToDevice', {
"device": tail_placement, "src": (
edge.head.name, edge.head_slot), "dst": dst_name})
to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
node_placements[to_node] = head_placement
edge.head = to_node
edge.head_slot = None
# merge all input nodes into one with multiple slots
input_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_inputs':
input_nodes.append(node)
for edge in phy_graph.edges:
if edge.head in input_nodes:
edge.head_slot = input_slot_mapping[edge.head]
edge.head = phy_graph.input_node
# merge all output nodes into one with multiple slots
output_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
output_nodes.append(node)
for edge in phy_graph.edges:
if edge.tail in output_nodes:
edge.tail_slot = output_slot_mapping[edge.tail]
edge.tail = phy_graph.output_node
for node in input_nodes:
node.remove()
for node in output_nodes:
node.remove()
return phy_model, node_placements
def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
# TODO: currently, only support single input slot and output slot.
if input_slot_mapping is not None or output_slot_mapping is not None:
raise ValueError('Slot mapping is not supported')
phy_graph = old_node.graph
new_node.graph = phy_graph
new_node._register()
for edge in phy_graph.edges:
if edge.head == old_node:
edge.head = new_node
elif edge.tail == old_node:
edge.tail = new_node
# after the replacement, there might be multiple duplicated edges
# with the same input and output nodes, which should be de-duplicated
self._remove_duplicated_edges()
def _remove_duplicated_edges(self):
# TODO: it does not have duplicated edges if only supporting dedup input
# Duplicated edges appear when a chain of prefix nodes are deduplicated
pass

Просмотреть файл

@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Dict, Tuple
from nni.nas.utils import uid
from nni.nas.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.common.device import GPUDevice
from nni.nas.execution.common.graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode)
_supported_evaluators = [MultiModelSupervisedLearningModule]
class DedupInputNode(AbstractLogicalNode):
"""
This is logical node representing the node for deduplication.
In assemble, just return one copy of the original node when multiple models are assembled.
These models will share the result of once calculation.
"""
def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id,
"Dedup_" + nodes_to_dedup[0].name,
nodes_to_dedup[0].operation)
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
self.related_models = [_.original_graph.model for _ in self.origin_nodes]
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id,
f'M_{node.original_graph.model.model_id}_{node.name}',
node.operation)
return new_node, multi_model_placement[node.original_graph.model]
raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model')
def _fork_to(self, graph: Graph):
DedupInputNode(graph, self.id, self.origin_nodes)._register()
def __repr__(self) -> str:
return f'DedupNode(id={self.id}, name={self.name}, \
len(nodes_to_dedup)={len(self.origin_nodes)}'
class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None:
pass
def _check_supported_evaluator(self, evaluator):
for e in _supported_evaluators:
if isinstance(evaluator, e):
return True
return False
def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check:
return True
if root_node.operation.type == '_inputs' and \
node_to_check.operation.type == '_inputs' and \
isinstance(root_node, OriginNode) and \
isinstance(node_to_check, OriginNode):
if self._check_supported_evaluator(root_node.original_graph.model.evaluator):
return False
if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator:
return True
else:
return False
else:
return False
def convert(self, logical_plan: LogicalPlan) -> None:
nodes_to_skip = set()
while True: # repeat until the logical_graph converges
input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs")
# _PseudoOperation(type_name="_inputs"))
root_node = None
for node in input_nodes:
if node in nodes_to_skip:
continue
root_node = node
break
if root_node is None:
break # end of convert
else:
nodes_to_dedup = []
for node in input_nodes:
if node in nodes_to_skip:
continue
if self._check_deduplicate_by_node(root_node, node):
nodes_to_dedup.append(node)
assert(len(nodes_to_dedup) >= 1)
if len(nodes_to_dedup) == 1:
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
edge.head = dedup_node
if edge.tail in nodes_to_dedup:
edge.tail = dedup_node
for node in nodes_to_dedup:
node.remove()

Просмотреть файл

@ -0,0 +1,234 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['model_to_pytorch_script']
import logging
import re
from typing import Dict, List, Tuple, Any, cast
from nni.common.device import Device, GPUDevice
from nni.nas.execution.common.graph import IllegalGraphError, Edge, Graph, Node, Model
from nni.nas.execution.common.graph_op import PyTorchOperation
from nni.nas.utils import STATE_DICT_PY_MAPPING
from .op_def import ToDevice
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.debug('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node, graph_name: str) -> Tuple[List[str], List[Any]]:
"""
Format the inputs of a given node.
Inputs will be formatted with ``_format_variable_name``
Parameters
----------
node : Node
a graph node, get and format its inputs
graph_name : str
subgraph name, to format variable names
Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges = _sorted_incoming_edges(node)
inputs = []
inputs_value = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(_format_variable_name(edge.head.operation.io_names[edge.head_slot], graph_name))
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
inputs_value.append(None)
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append(_format_variable_name(edge.head.name, graph_name))
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
else:
inputs_value.append(None)
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(_format_variable_name(edge.head.name, graph_name), edge.head_slot))
inputs_value.append(None)
return inputs, inputs_value
def _format_variable_name(name: str, graph_name: str) -> str:
"""
1. replace invalid characters in node name
2. variables name (full name space) is too long, shorten the name by removing the prefix ```graph_name```
"""
name = name[len(graph_name):] if name.startswith(graph_name) else name
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
name = re.sub(r'\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name = name[1:]
elif name.startswith('_'):
# to avoid conflicts between '_' and '__'
name = 'i' + name
return name
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
'''
Since CUDA_VISIBLE_DEVICES will be set to the list of real GPU ID,
we need to remap the GPU ID when generating code to match them correctly.
For example, when CUDA_VISIBLE_DEVICES="0,3", we need to use "cuda:0", "cuda:1" in the generated code.
'''
unique_devices = sorted(list(set([e for e in placement.values() if isinstance(e, GPUDevice)])))
node_gpu_cnt = {}
cuda_remapped_id = {}
for d in unique_devices:
if d.node_id not in node_gpu_cnt:
node_gpu_cnt[d.node_id] = 0
node_gpu_cnt[d.node_id] += 1
cuda_remapped_id[d] = node_gpu_cnt[d.node_id] - 1
return cuda_remapped_id
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
node_python_mappings = {}
cuda_remapped_id = None
if placement:
cuda_remapped_id = generate_cuda_mapping(placement)
for node in nodes:
if node.operation:
if placement and isinstance(node.operation, ToDevice):
cuda_remapped_id = cast(dict, cuda_remapped_id)
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])
if node.operation.type == 'shared':
continue
pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
py_variable_name = _format_variable_name(node.name, graph_name)
node_code = node.operation.to_init_code(py_variable_name)
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
assert cuda_remapped_id is not None
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
else:
device_repr = placement[node].device_repr()
node_codes.append(f"{node_code}.to('{device_repr}')")
else:
node_codes.append(node_code)
# Map to module hierarchies in original search space python code
node_python_mappings[py_variable_name] = node.python_name
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
for name in graph.input_node.operation.io_names:
assert not name.startswith(graph_name)
input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = []
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs, inputs_value = _format_inputs(node, graph_name)
node_name = _format_variable_name(node.name, graph_name)
submodule_name = node_name
if node.operation.type == 'shared':
submodule_name = _format_variable_name(node.operation.parameters['reference'], graph_name)
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))
output_names, _ = _format_inputs(graph.output_node, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
linebreak = '\n '
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
# TODO: handle imports
_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nni.nas.nn.pytorch
{}
{}
'''
_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
def __init__(self):
super().__init__()
{nodes}
def forward(self, {inputs}):
{edges}
return {outputs}
'''

Просмотреть файл

@ -1,4 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import CreamSupernetTrainer
from .graph_gen import convert_to_graph

Просмотреть файл

@ -0,0 +1,848 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import re
import torch
from nni.nas.execution.common import Graph, Model, Node, Cell, Operation
from nni.nas.nn.pytorch import InputChoice, Placeholder, LayerChoice
from nni.nas.utils import get_init_parameters_or_fail, get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import (
_convert_name, build_full_name, _without_shape_info,
_extract_info_from_trace_node, get_full_name_by_scope_name,
is_layerchoice_node, match_node, build_cand_name,
build_python_name
)
class GraphConverter:
def __init__(self):
self.global_seq = 0
self.global_graph_id = 0
def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
if _input in output_remap:
assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None
src_node = node_index[predecessor_node]
assert isinstance(src_node, Node)
elif _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
if len(predecessor_outputs) == 1:
idx = None
else:
idx = predecessor_outputs.index(_input)
ir_predecessor_node = node_index[predecessor_node]
src_node_idx = idx
assert isinstance(ir_predecessor_node, Node)
src_node = ir_predecessor_node
return src_node, src_node_idx
def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
ignore_first = False
continue
# handle source node
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
# handle destination node
dst_node = new_node
if is_single_input:
dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
# create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1
def create_prim_constant_node(self, ir_graph, node, module_name):
# NOTE: compare with string not type, because the type is defined in pytorch C code.
# `.kind()` can also be used here
if node.outputsAt(0).type().str() == 'None':
attrs = {'type': 'None'}
else:
attrs = {'type': node.outputsAt(0).type().str(), 'value': node.outputsAt(0).toIValue()}
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
node.kind(), attrs)
return new_node
def handle_prim_attr_node(self, node, module):
assert node.hasAttribute('name')
value = None
if node.inputsAt(0).debugName() == 'self':
_val = getattr(module, node.s('name'))
# TODO: serialize complex data type, and output proper error message
if isinstance(_val, (int, float, str, bool)):
value = _val
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName(), 'value': value}
return node.kind(), attrs
def _remove_mangle(self, module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(self, ir_graph, targeted_type=None):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph,
module, module_name, module_python_name,
ir_model, ir_graph,
shared_module_index=None):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
shared_module_index : dict
it is used for knowing which module has been created an ir node,
if created and invoked again, then the new ir node can simply reference that ir node.
this way we can identify shared modules (i.e., one module invoked multiple times in `forward` function)
Returns
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs = []
for _input in sm_graph.inputs():
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
if shared_module_index is None:
shared_module_index = {}
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
# ===================handle control flow: if===================
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
elif tensor.node().kind() == 'aten::le':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} <= {right})'
elif tensor.node().kind() == 'aten::ge':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} >= {right})'
elif tensor.node().kind() == 'aten::__not__':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(not {value})'
elif tensor.node().kind() == 'aten::Bool':
value = _generate_expr(tensor.node().inputsAt(0))
return f'bool({value})'
elif tensor.node().kind() == 'aten::__is__':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} is {right})'
elif tensor.node().kind() == 'aten::__isnot__':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} is not {right})'
elif tensor.node().kind() == 'aten::ne':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} != {right})'
elif tensor.node().kind() == 'aten::gt':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} > {right})'
elif tensor.node().kind() == 'aten::lt':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} < {right})'
elif tensor.node().kind() == 'prim::If':
raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
elif tensor.node().kind() == 'aten::abs':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(torch.abs({value}))'
elif tensor.node().kind() == 'aten::sum':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(torch.sum({value}))'
elif tensor.node().kind() == 'aten::item':
value = _generate_expr(tensor.node().inputsAt(0))
return f'({value}.item())'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
'you are suggested to decorate the corresponding class with "@basic_unit".')
expr = _generate_expr(cond_tensor)
return eval(expr)
def handle_if_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
self._add_edge(ir_graph, blocks[chosen_block].returnNode(), graph_inputs, node_index, new_node, output_remap)
last_block_node = new_node
return last_block_node
# ===================handle function call===================
def handle_function_callmethod(node):
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
# NOTE: "forward__0" is hacky, LSTM instance is parsed to call forward__0 in torchscript
if node.s('name') in ['forward', 'forward__0']:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
if submodule.inputsAt(0).debugName() == 'self':
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_python_name = build_python_name(module_python_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, submodule_python_name,
ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList
predecessor = submodule.inputsAt(0).node()
module_name_space = [submodule_name]
while predecessor.inputsAt(0).debugName() != 'self':
# this is for dealing with nested ModuleList. below is an example
# %3 : __torch__.torch.nn.modules.container.___torch_mangle_0.ModuleList = prim::GetAttr[name="ops"](%self)
# %5 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="0"](%3)
# %7 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="1"](%3)
# %9 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="2"](%3)
# %11 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="3"](%3)
# %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="0"](%5)
# %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="1"](%5)
# %state.2 : Tensor = prim::CallMethod[name="forward"](%14, %x.1) # modulelist.py:18:24
# %state.4 : Tensor = prim::CallMethod[name="forward"](%16, %state.2) # modulelist.py:18:24
assert predecessor.kind() == 'prim::GetAttr'
module_name_space.append(predecessor.s('name'))
predecessor = predecessor.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name')
module_name_space.append(predecessor.s('name'))
submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
submodule_python_name = build_python_name(module_python_name, list(reversed(module_name_space)))
submodule_obj = module
script_submodule = script_module
for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(submodule_obj, each_name)
script_submodule = script_submodule._modules[each_name]
subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name,
submodule_python_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
if submodule_full_name in shared_module_index:
# this module is invoked more than once, the ir node has already been created
# create a reference node for it.
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self.global_seq += 1
shared_node_name = build_full_name(submodule_full_name, '', self.global_seq)
shared_node_python_name = build_python_name(submodule_python_name, self.global_seq)
shared_type_operation = Operation.new('shared', {'reference': submodule_full_name})
subcell = ir_graph.add_node(shared_node_name, shared_type_operation)
subcell.python_name = shared_node_python_name
else:
# this module is processed for the first time, build cell for it
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
subcell.python_name = submodule_python_name
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, InputChoice):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
subcell.python_name = submodule_python_name
shared_module_index[submodule_full_name] = subcell
node_index[node] = subcell
# connect the cell into graph
self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else:
# handle normal member function
assert hasattr(script_module, node.s('name'))
# TODO: support non member functions
assert node.inputsAt(0).debugName() == 'self'
script_method = getattr(script_module, node.s('name')) # <class 'torch._C.ScriptMethod'>
# step #1: generate graph ir for this method
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
self.handle_graph_nodes(script_module, script_method.graph, module,
module_name, module_python_name, ir_model, method_ir_graph, shared_module_index)
self.refine_graph(method_ir_graph)
# step #2: merge this graph to its module graph
for h_node in method_ir_graph.hidden_nodes:
h_node.graph = ir_graph
ir_graph.hidden_nodes.append(h_node)
for edge in method_ir_graph.edges:
edge.graph = ir_graph
if edge.head == method_ir_graph.input_node:
# this is a member method, 'self' is the first argument, thus +1
assert edge.head_slot is not None
_input = node.inputsAt(edge.head_slot + 1)
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
edge.head = src_node
edge.head_slot = src_node_idx
if edge.tail == method_ir_graph.output_node:
# since the following nodes have not been created, skip this edge
# edge.head is the output node of this method
# TODO: check whether there could be multiple output nodes???
node_index[node] = edge.head
continue
ir_graph.edges.append(edge)
# ===================handle each single node===================
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if node.kind() == 'prim::CallMethod':
handle_function_callmethod(node)
elif node.kind() == 'prim::CallFunction':
func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
self.global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
'{}.{}'.format(func_type_str, func_name))
func_python_name = build_python_name(module_python_name, func_name)
func_node.python_name = func_python_name
node_index[node] = func_node
self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = self.create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() in ['prim::ListConstruct', 'prim::ListUnpack', 'prim::TupleConstruct', 'prim::TupleUnpack']:
self.global_seq += 1
prim_op_name = node.kind().split('::')[-1]
new_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = self.handle_prim_attr_node(node, module)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
elif node.kind().startswith('prim::'):
self.global_seq += 1
prim_op_name = node.kind().replace('::', '__')
prim_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
node_index[node] = prim_node
self._add_edge(ir_graph, node, graph_inputs, node_index, prim_node, output_remap)
elif node.kind() == 'aten::append':
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_op_python_name = node.kind().replace('aten::', '')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
aten_python_name = build_python_name(module_python_name, aten_op_python_name)
aten_node.python_name = aten_python_name
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
return node_index[node]
for node in sm_graph.nodes():
handle_single_node(node)
if node_index != {}:
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
else:
# here is an example that the ir_graph and node_index is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1)
# add an edge from head to tail to handle this situation
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ir_graph.output_node, None))
def merge_aten_slices(self, ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
for edge in head_node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(head_node)
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def refine_graph(self, ir_graph):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
self.merge_aten_slices(ir_graph)
def _handle_inputchoice(self, module):
return {
'n_candidates': module.n_candidates,
'n_chosen': module.n_chosen,
'reduction': module.reduction,
'label': module.label
}
def _handle_valuechoice(self, module):
return {
'candidates': module.candidates,
'label': module.label,
}
def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
m_attrs = None
if original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
graph.python_name = module_python_name
candidate_name_list = []
for cand_name in module.names:
cand = module[cand_name]
script_cand = script_module._modules[cand_name]
cand_full_name = build_cand_name(cand_name, module.label)
cand_python_name = build_python_name(module_python_name, cand_name)
candidate_name_list.append(cand_full_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_full_name, cand_python_name, ir_model)
if subgraph is not None:
cand_node = graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
cand_node.python_name = cand_python_name
else:
cand_type = '__torch__.' + get_importable_name(cand.__class__)
cand_node = graph.add_node(cand_full_name, cand_type, attrs)
cand_node.python_name = cand_python_name
graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice:
m_attrs = self._handle_inputchoice(module)
elif original_type_name == OpTypeName.ValueChoice:
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and \
original_type_name in torch.nn.__dict__ and \
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_nni_basic_unit', False):
# this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None:
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
self.global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
ir_graph.python_name = module_python_name
# handle graph nodes
self.handle_graph_nodes(script_module, sm_graph, module,
module_name, module_python_name, ir_model, ir_graph)
self.refine_graph(ir_graph)
ir_graph._register()
# add mutation signal for special modules
if original_type_name == OpTypeName.Repeat:
attrs = {
'mutation': 'repeat',
'label': module.label,
'depth': module.depth_choice,
'max_depth': module.max_depth,
'min_depth': module.min_depth,
}
return ir_graph, attrs
return ir_graph, {}
def convert_module(self, script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
return self._convert_module(script_module, module, module_name, None, ir_model)
class GraphConverterWithShape(GraphConverter):
"""
Convert a pytorch model to nni ir along with input/output shape info.
Based ir acquired through ``torch.jit.script``
and shape info acquired through ``torch.jit.trace``.
.. warning::
Known issues:
1. ``InputChoice`` and ``ValueChoice`` not supported yet.
2. Currently random inputs are fed while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, None, ir_model)
self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, dummy_input)
return ir_graph, attrs
def _initialize_parameters(self, ir_model: 'Model'):
for ir_node in ir_model.get_nodes():
if ir_node.operation.parameters is None:
ir_node.operation.parameters = {}
ir_node.operation.attributes.setdefault('input_shape', [])
ir_node.operation.attributes.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph
tm_graph = self._trace(module, dummy_input)
for node in tm_graph.nodes():
shape_parameters, parameters = _extract_info_from_trace_node(node)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node = match_node(ir_model, node, module_name)
if ir_node is not None:
ir_node.operation.attributes.update(shape_parameters)
if parameters:
ir_node.operation.parameters.update(parameters)
self.propagate_shape(ir_model)
# trace each layerchoice
for name, submodule in module.named_modules():
# TODO: support InputChoice and ValueChoice
if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name)
assert lc_node is not None, f'Cannot find a node with name {full_name}'
for cand_name in submodule.names:
cand = submodule[cand_name]
cand_name = build_cand_name(cand_name, submodule.label)
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.attributes['input_shape']]
self._trace_module(cand, cand_name, ir_model, lc_inputs)
def propagate_shape(self, ir_model: 'Model'):
def propagate_shape_for_graph(graph: 'Graph'):
if graph == ir_model.root_graph:
return
graph_node = ir_model.get_node_by_name(graph.name)
assert graph_node is not None, f'Cannot find a node with name {graph.name}'
if not _without_shape_info(graph_node):
return
if is_layerchoice_node(graph_node):
cand_name = graph_node.operation.parameters['candidates'][0]
cand_node = ir_model.get_node_by_name(cand_name)
assert cand_node is not None, f'Cannot find a node with name {cand_name}'
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
graph_node.operation.attributes['output_shape'] = cand_node.operation.attributes['output_shape']
else:
input_shape = [[]] * len(graph.input_node.operation.io_names or [])
output_shape = [[]] * len(graph.output_node.operation.io_names or [])
for edge in graph.input_node.outgoing_edges:
node = edge.tail
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.attributes['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.attributes['input_shape'][edge.tail_slot or 0]
graph_node.operation.attributes['input_shape'] = input_shape
for edge in graph.output_node.incoming_edges:
node = edge.head
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.attributes['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.attributes['output_shape'][edge.head_slot or 0]
graph_node.operation.attributes['output_shape'] = output_shape
propagate_shape_for_graph(graph_node.graph)
# propagate from node to graph
for node in ir_model.get_nodes():
propagate_shape_for_graph(node.graph)
def _trace(self, module, dummy_input):
traced_module = torch.jit.trace(module, dummy_input)
torch._C._jit_pass_inline(traced_module.graph)
return traced_module.graph
def remove_dummy_nodes(self, ir_model: 'Model'):
# remove identity nodes
for node in ir_model.get_nodes_by_type('noop_identity'):
graph = node.graph
for in_edge in node.incoming_edges:
for out_edge in node.outgoing_edges:
if in_edge.tail_slot == out_edge.head_slot:
graph.add_edge(head=(in_edge.head, in_edge.head_slot), tail=(out_edge.tail, out_edge.tail_slot))
graph.del_edge(in_edge)
graph.del_edge(out_edge)
break
node.remove()
def convert_to_graph(script_module, module, converter=None, **kwargs):
"""
Convert module to our graph ir, i.e., build a :class:`Model` type
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
converter : `TorchConverter`
default `GraphConverter` is used
kwargs:
will be passed to `converter.convert_module()`
Returns
-------
Model
the constructed IR model
"""
model = Model(_internal=True)
module_name = '_model'
if converter is None:
converter = GraphConverter()
converter.convert_module(script_module, module, module_name, model, **kwargs)
return model

Просмотреть файл

@ -0,0 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum
# except the special case which can not treat as a basic module from pytorch
MODULE_EXCEPT_LIST = ['Sequential']
class OpTypeName(str, Enum):
"""
op type to its type name str
"""
Attr = 'Attr'
Constant = 'Constant'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
Repeat = 'Repeat'
Cell = 'Cell'

Просмотреть файл

@ -0,0 +1,259 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
from typing_extensions import TypeGuard
from nni.nas.execution.common import Cell, Model, Graph, Node, Edge
def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
if seq is None:
return '{}__{}'.format(prefix, name)
else:
return '{}__{}{}'.format(prefix, name, str(seq))
def build_python_name(prefix, name):
if isinstance(name, list):
name = '.'.join(name)
if prefix:
return '{}.{}'.format(prefix, name)
else: # predix could be None
return name
def build_cand_name(name, label):
return f'layerchoice_{label}_{name}'
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
def _extract_info_from_trace_node(trace_node):
"""
Extract parameters from a trace node.
Parameters
----------
trace_node: torch._C.Value
"""
input_shape = []
output_shape = []
inputs = list(trace_node.inputs())
# cat input tensors are in a strange place
if trace_node.kind() == 'aten::cat':
input_shape = [input.type().sizes() for input in inputs[0].node().inputs()]
else:
for _input in inputs:
input_type = _input.type()
if input_type.kind() == 'TensorType':
shape = input_type.sizes()
if shape:
input_shape.append(shape)
for _output in trace_node.outputs():
output_type = _output.type()
if output_type.kind() == 'TensorType':
shape = output_type.sizes()
if shape:
output_shape.append(shape)
shape_parameters = {
'input_shape': input_shape,
'output_shape': output_shape,
}
if trace_node.kind() == 'aten::cat':
parameters = {'dim': inputs[1].toIValue()}
return shape_parameters, parameters
else:
return shape_parameters, None
def is_layerchoice_node(ir_node: Optional[Node]) -> TypeGuard[Node]:
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
return True
else:
return False
def get_full_name_by_scope_name(ir_model: Model, scope_names, prefix=''):
full_name = prefix
for last_scope in range(len(scope_names)):
ir_node = ir_model.get_node_by_name(full_name)
# check if it's layerchoice
if is_layerchoice_node(ir_node):
full_name = f'layerchoice_{ir_node.operation.parameters["label"]}_{scope_names[last_scope]}'
else:
full_name = build_full_name(full_name, scope_names[last_scope])
return full_name
def match_node(ir_model: Model, torch_node, prefix=''):
"""
Match the corresponding node of a torch._C.Value
"""
scope_names = torch_node.scopeName().split('/')[-1].split('.')[1:]
full_name = get_full_name_by_scope_name(ir_model, scope_names, prefix)
# handle the case when node is not nn.Module, but directly used in forward()
# Because name can't be directly matched, so I use a hacky way.
# I match the first unshaped node of that kind
graph = ir_model.graphs.get(full_name)
if graph is not None:
for node in graph.get_nodes_by_type(torch_node.kind()):
if not node.operation.attributes['input_shape']:
return node
return None
else:
return ir_model.get_node_by_name(full_name)
def _without_shape_info(node: Node):
return not node.operation.attributes['input_shape'] and not node.operation.attributes['output_shape']
def flatten_model_graph(ir_model: Model):
"""
Flatten the subgraph into root graph.
"""
def _flatten(graph: Graph):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
new_ir_model = ir_model.fork()
_flatten(new_ir_model.root_graph)
# remove subgraphs
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
return new_ir_model
def flatten_model_graph_without_layerchoice(ir_model: Model):
"""
Flatten the subgraph into root graph and jump all layerchoice
"""
def _flatten_without_layerchoice(graph: Graph):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
if is_layerchoice_node(node):
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
del model.graphs[node.name]
node.remove()
return
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten_without_layerchoice(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
new_ir_model = ir_model.fork()
_flatten_without_layerchoice(new_ir_model.root_graph)
# remove subgraphs
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
return new_ir_model

Просмотреть файл

@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_evaluator':
continue
with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue')
cell_node = {}
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])),
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))}
subgraph.node(ioput['_inputs'])
subgraph.node(ioput['_outputs'])
for node_name, node_value in graph['nodes'].items():
value = node_value['operation']
if value['type'] == '_cell':
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs']))
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs']))
cell_node[node_name] = (cell_input_name, cell_output_name)
print('cell: ', node_name, cell_input_name, cell_output_name)
else:
subgraph.node(node_name)
for edge in graph['edges']:
src = edge['head'][0]
if src == '_inputs':
src = ioput['_inputs']
elif src in cell_node:
src = cell_node[src][1]
dst = edge['tail'][0]
if dst == '_outputs':
dst = ioput['_outputs']
elif dst in cell_node:
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()

Просмотреть файл

@ -0,0 +1,167 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['BaseGraphData', 'BaseExecutionEngine']
import logging
import os
import random
import string
from typing import Any, Dict, Iterable, List
from nni.experiment import rest
from nni.nas.execution.common import (
AbstractExecutionEngine, AbstractGraphListener, RetiariiAdvisor, get_mutation_summary,
Model, ModelStatus, MetricData, Evaluator,
send_trial, receive_trial_parameters, get_advisor
)
from nni.nas.utils import import_
from .codegen import model_to_pytorch_script
_logger = logging.getLogger(__name__)
class BaseGraphData:
"""
Data sent between strategy and trial, in graph-based execution engine.
Attributes
----------
model_script
code of an instantiated PyTorch model
evaluator
training approach for model_script
mutation_summary
a dict of all the choices during mutations in the HPO search space format
"""
def __init__(self, model_script: str, evaluator: Evaluator, mutation_summary: dict) -> None:
self.model_script = model_script
self.evaluator = evaluator
self.mutation_summary = mutation_summary
def dump(self) -> dict:
return {
'model_script': self.model_script,
# engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary
}
@staticmethod
def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], Evaluator._load(data['evaluator']), data['mutation_summary'])
class BaseExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with no optimization at all.
Resource management is implemented in this class.
"""
def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self._history: List[Model] = []
self.resources = 0
# register advisor callbacks
advisor: RetiariiAdvisor = get_advisor()
advisor.register_callbacks({
'send_trial': self._send_trial_callback,
'request_trial_jobs': self._request_trial_jobs_callback,
'trial_end': self._trial_end_callback,
'intermediate_metric': self._intermediate_metric_callback,
'final_metric': self._final_metric_callback
})
def submit_models(self, *models: Model) -> None:
for model in models:
data = self.pack_model_data(model)
self._running_models[send_trial(data.dump())] = model
self._history.append(model)
def list_models(self) -> Iterable[Model]:
return self._history
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None:
if self.resources <= 0:
# FIXME: should be a warning message here
_logger.debug('There is no available resource, but trial is submitted.')
self.resources -= 1
_logger.debug('Resource used. Remaining: %d', self.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None:
self.resources += num_trials
_logger.debug('New resource available. Remaining: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
for listener in self._listeners:
listener.on_training_end(model, success)
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(model, metrics)
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.metric = metrics
for listener in self._listeners:
listener.on_metric(model, metrics)
def query_available_resource(self) -> int:
return self.resources
def budget_exhausted(self) -> bool:
resp = rest.get(self.port, '/check-status', self.url_prefix)
return resp['status'] == 'DONE'
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
model_cls = import_(f'_generated_model.{random_str}._model')
graph_data.evaluator._execute(model_cls)
os.remove(file_name)

Просмотреть файл

@ -0,0 +1,548 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import (Any, Dict, List)
import torch
import torch.nn.functional as nn_functional
from nni.nas.execution.common import PyTorchOperation
mem_format = [
'torch.contiguous_format', # 0
'torch.preserve_format', # 1
'torch.channels_last', # 2
]
# this snippet is copied from torch/onnx/symbolic_helper.py,
# the original definition is in c10/core/ScalarType.h
# This indicates each scalar type's corresponding
scalar_type_to_pytorch_type = [
'torch.uint8', # 0
'torch.int8', # 1
'torch.short', # 2
'torch.int', # 3
'torch.int64', # 4
'torch.half', # 5
'torch.float', # 6
'torch.double', # 7
'torch.complex32', # 8
'torch.complex64', # 9
'torch.complex128', # 10
'torch.bool', # 11
]
class NoOpIdentity(PyTorchOperation):
"""
this operator type is added by us
"""
_ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {", ".join(inputs)}'
class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
func_name = self.type[len('Function.'):]
if not hasattr(nn_functional, func_name):
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})'
class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] in ['None', 'NoneType']:
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): # 'Long()' ???
return f'{output} = {self.parameters["value"]}'
elif self.parameters['type'] == 'str':
str_val = self.parameters["value"]
return f'{output} = "{str_val}"'
elif self.parameters['type'] == 'Device':
value = self.parameters['value']
return f'{output} = torch.device("{value}")'
elif self.parameters['type'] in ('dict', 'list', 'tuple'):
# TODO: prim::TupleIndex is not supported yet
return f'{output} = {repr(self.parameters["value"])}'
else:
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = ({", ".join(inputs)})'
class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1
return f'{output} = {inputs[0]}'
class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}"
else:
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
class PrimUncheckedCast(PyTorchOperation):
_ori_type_name = ['prim::unchecked_cast']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}'
class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value is not None and inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if (len(inputs) - 1) % 4 == 0:
slices = []
dim = int((len(inputs) - 1) / 4)
for i in range(dim):
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif len(inputs) == 4:
# this case is for simple list
return f'{output} = {inputs[0]}[{inputs[1]}:{inputs[2]}:{inputs[3]}]'
else:
raise RuntimeError('Unsupported slice pattern')
# the following Aten classes means these aten ops are not in torch.Tensor
class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
# ====================================
class AtenTensors(PyTorchOperation):
_ori_type_name = ['aten::full', 'aten::full_like', 'aten::empty_like',
'aten::ones_like', 'aten::zeros_like', 'aten::rand',
'aten::randn', 'aten::scalar_tensor', 'aten::new_full',
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type)
# match number of inputs
overloaded_defs = [len(s.arguments) for s in schemas]
matched = overloaded_defs.index(len(inputs))
args_list = []
for idx, arg in enumerate(schemas[matched].arguments):
if arg.name == 'dtype':
arg_str = f'dtype={scalar_type_to_pytorch_type[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
elif arg.name == 'layout':
if inputs_value[idx] is not None:
arg_str = f'layout=torch.strided'
print('Warning: only support `torch.strided` for now!!!')
else:
arg_str = ''
elif arg.name == 'device':
arg_str = f'device=torch.device({inputs[idx]})' if inputs_value[idx] is not None else ''
elif arg.name == 'memory_format':
arg_str = f'memory_format={mem_format[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
elif arg.name == 'pin_memory':
# TODO: deal with this argument
continue
elif arg.name == 'requires_grad':
arg_str = f'requires_grad={inputs[idx]}' if inputs_value[idx] else ''
elif str(arg.type).startswith('Optional['):
arg_str = f'{arg.name}={inputs[idx]}'
else:
arg_str = f'{inputs[idx]}'
if arg_str != '':
args_list.append(arg_str)
op_name = self.type.split('::')[-1]
if hasattr(torch, op_name):
return f'{output} = torch.{op_name}({", ".join(args_list)})'
else:
return f'{output} = {inputs[0]}.{op_name}({", ".join(args_list[1:])})'
# ====================================
class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenMul(PyTorchOperation):
_ori_type_name = ['aten::mul']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} * {inputs[1]}'
class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}'
elif self.type == 'aten::Int':
return f'{output} = int({inputs[0]})'
elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})'
raise TypeError(f'Unexpected type: {self.type}')
class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]'
ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
# in v1.9 dtype is supported as input argument for view, but torch script does not support it
'aten::view': [('size', 'List[int]', 'None')],
# NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed
# torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor
# torch.std(input, unbiased) Tensor
'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')]
}
TensorOpExceptions = {
'aten::sub': lambda output, inputs: f'{output} = {inputs[0]} - {inputs[1]}', # example: x.size(1) - 3
'aten::add': lambda output, inputs: f'{output} = {inputs[0]} + {inputs[1]}' # example: input.shape[0] + 5
}
TorchOpExclude = ['aten::Size', 'aten::as_tensor', 'aten::device',
'aten::manual_seed', 'aten::quantized_gru', 'aten::quantized_lstm',
'aten::save', 'aten::tensor', 'aten::wait'
]
def _hidden(name):
return name.startswith('_') and not name.startswith('__')
def _emit_args(args):
# filter out the `out` argument here
return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out'
def _get_tensor_ops():
def is_tensor_method(schema):
if len(schema.arguments) == 0:
return False
self = schema.arguments[0]
if self.name != 'self':
return False
if not self.type.isSubtypeOf(torch._C.TensorType.get()):
return False
return True
op_args = {}
# discover methods
for elem in dir(torch.Tensor):
if not _hidden(elem):
schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
for schema in schemas:
if is_tensor_method(schema):
op_name = 'aten::' + elem
args = _emit_args(schema.arguments[1:])
if op_name in op_args:
op_args[op_name].append(args)
else:
op_args[op_name] = [args]
return op_args.keys(), op_args
def _get_torch_ops():
torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins: # type: ignore
name = mod.__name__
if name == 'torch._C._nn':
continue
# only process 'torch.XXX'
for elem in dir(mod):
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) # type: ignore
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
# remove _tan but not __and__
if not _hidden(elem):
op_name = 'aten::' + elem
if len(schema.arguments) > 0 and schema.arguments[0].name == 'self':
continue
args = _emit_args(schema.arguments)
if op_name in torch_op_args:
torch_op_args[op_name].append(args)
else:
torch_op_args[op_name] = [args]
return torch_op_args.keys(), torch_op_args
def _get_torch_ops_exclude_tensor_ops():
tensor_op_names, _ = _get_tensor_ops()
torch_op_names, torch_ops = _get_torch_ops()
torch_exclude_ops = {}
for name in torch_op_names:
if name not in tensor_op_names:
if name not in TorchOpExclude:
# exclude the ops that are not in
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
torch_exclude_ops[name] = torch_ops[name]
return torch_exclude_ops.keys(), torch_exclude_ops
class TensorOps(PyTorchOperation):
"""
corresponding to _get_tensor_ops in torch.jit.supported_ops
"""
_ori_type_name, _op_args = _get_tensor_ops()
comparison_ops = {'aten::eq': '==', 'aten::ne': '!=', 'aten::le': '<=', 'aten::ge': '>=', 'aten::lt': '<', 'aten::gt': '>'}
@staticmethod
def _get_matched_args(_type, inputs):
def has_same_arg_name(matched):
concated_names = []
for i, each in enumerate(matched):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i + 1]:
return False
return True
overloaded_defs = TensorOps._op_args[_type]
matched = []
for each in overloaded_defs:
# plus 1 because we skip the first argument when generating tensor op def
if len(each) + 1 == len(inputs):
matched.append(each)
if len(matched) == 1:
return matched[0]
elif len(matched) > 1:
# TODO: match with arg's type. manually choose for now
if has_same_arg_name(matched):
# return any one is okay
return matched[0]
elif _type in ManuallyChooseDef:
return ManuallyChooseDef[_type]
else:
raise RuntimeError(f'tensor op type {_type} has more than one matched: {matched}')
else:
if _type in TensorOpExceptions:
return None
raise RuntimeError(f'tensor op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: deal with conditional ops
if self.type in TensorOps.comparison_ops:
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
matched_args = TensorOps._get_matched_args(self.type, inputs)
if matched_args is None:
return TensorOpExceptions[self.type](output, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = {inputs[0]}.{op_name}({args_str})'
class TorchOps(PyTorchOperation):
"""
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
"""
_ori_type_name, _op_args = _get_torch_ops_exclude_tensor_ops()
# add 'aten::pixel_shuffle'
_op_args['aten::pixel_shuffle'] = [[('input', 'Tensor', 'None'), ('upscale_factor', 'Optional[int]', 'None')]]
_ori_type_name = _op_args.keys()
@staticmethod
def _get_matched_args(_type, inputs):
def has_same_arg_name(matched):
concated_names = []
for i, each in enumerate(matched):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i + 1]:
return False
return True
overloaded_defs = TorchOps._op_args[_type]
matched = []
for each in overloaded_defs:
if len(each) == len(inputs):
matched.append(each)
if len(matched) == 1:
return matched[0]
elif len(matched) > 1:
# TODO: match with arg's type. manually choose for now
if has_same_arg_name(matched):
# return any one is okay
return matched[0]
else:
raise RuntimeError(f'torch op type {_type} has more than one matched: {matched}')
else:
raise RuntimeError(f'torch op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = torch.{op_name}({args_str})'
class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
attributes: Dict[str, Any] = {}):
self.type = "ToDevice"
self.device = parameters['device']
self.overridden_device_repr = None
self.src = parameters['src']
self.dst = parameters['dst']
def override_device_repr(self, device_repr):
# CUDA GPUDevice may remap GPU physical ID to CUDA ID. The device repr is different from GPUDevice.device_repr()
# override_device_repr will be called in pytorch.graph_to_pytorch_model to replace device_repr with the correct
# CUDA ID, e.g., when a job uses Physical GPU-1,2, its CUDA ID should be "cuda:0" and "cuda:1".
# self.device.device_repr() would return "cuda:1" and "cuda:2", but override_device_repr should be "cuda:0" and
# "cuda:1"
self.overridden_device_repr = device_repr
def __repr__(self):
if self.overridden_device_repr is None:
return f'to("{self.device.device_repr()}")'
else:
return f'to("{self.overridden_device_repr}")'
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.overridden_device_repr is None:
forward_code = f'{output} = {inputs[0]}.to("{self.device.device_repr()}")'
else:
forward_code = f'{output} = {inputs[0]}.to("{self.overridden_device_repr}")'
return forward_code
class AtenDet(PyTorchOperation):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = torch.det({inputs[0]})'

Просмотреть файл

@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Type, cast
import torch.nn as nn
from nni.nas.execution.common import (
Model, receive_trial_parameters,
get_mutation_dict, mutation_dict_to_summary
)
from nni.nas.evaluator import Evaluator
from nni.nas.utils import ContextStack
from .graph import BaseExecutionEngine
class PythonGraphData:
def __init__(self, class_: Type[nn.Module], init_parameters: Dict[str, Any],
mutation: Dict[str, Any], evaluator: Evaluator) -> None:
self.class_ = class_
self.init_parameters = init_parameters
self.mutation = mutation
self.evaluator = evaluator
self.mutation_summary = mutation_dict_to_summary(mutation)
def dump(self) -> dict:
return {
'class': self.class_,
'init_parameters': self.init_parameters,
'mutation': self.mutation,
# engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary
}
@staticmethod
def load(data) -> 'PythonGraphData':
return PythonGraphData(data['class'], data['init_parameters'], data['mutation'], Evaluator._load(data['evaluator']))
class PurePythonExecutionEngine(BaseExecutionEngine):
"""
This is the execution engine that doesn't rely on Python-IR converter.
We didn't explicitly state this independency for now. Front-end needs to decide which converter / no converter
to use depending on the execution type. In the future, that logic may be moved into this execution engine.
The execution engine needs to store the class path of base model, and init parameters to re-initialize the model
with the mutation dict in the context, so that the mutable modules are created to be the fixed instance on the fly.
"""
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model)
assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(
cast(Type[nn.Module], model.python_class),
model.python_init_params or {}, mutation, model.evaluator
)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = PythonGraphData.load(receive_trial_parameters())
def _model():
return graph_data.class_(**graph_data.init_parameters)
with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.nas.execution.common import TensorFlowOperation
class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal, attributes=None):
if 'padding' not in parameters:
parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)

Просмотреть файл

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Entrypoint for trials.
"""
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('exec', choices=['base', 'py', 'cgo', 'benchmark'])
args = parser.parse_args()
if args.exec == 'base':
from .pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine
elif args.exec == 'cgo':
from .pytorch.cgo import CGOExecutionEngine
engine = CGOExecutionEngine
elif args.exec == 'py':
from .pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine
elif args.exec == 'benchmark':
from .pytorch.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine
else:
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
engine.trial_execute_graph()
if __name__ == '__main__':
main()

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework

Просмотреть файл

@ -1,4 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import get_and_apply_next_architecture
from .experiment_config import *
from .engine_config import *

Просмотреть файл

@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Optional, List
from nni.experiment.config.base import ConfigBase
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
@dataclass(init=False)
class ExecutionEngineConfig(ConfigBase):
name: str
@dataclass(init=False)
class PyEngineConfig(ExecutionEngineConfig):
name: str = 'py'
@dataclass(init=False)
class OneshotEngineConfig(ExecutionEngineConfig):
name: str = 'oneshot'
@dataclass(init=False)
class BaseEngineConfig(ExecutionEngineConfig):
name: str = 'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class CgoEngineConfig(ExecutionEngineConfig):
name: str = 'cgo'
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class BenchmarkEngineConfig(ExecutionEngineConfig):
name: str = 'benchmark'
benchmark: Optional[str] = None

Просмотреть файл

@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, Union, Optional
from nni.experiment.config import utils, ExperimentConfig
from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name):
# FIXME: may move this function to experiment utils in future
cls = _get_ee_config_class(engine_name)
if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls()
def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__():
if cls.name == engine_name:
return cls
return None
@dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space: Any = ''
trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved'
# new config field for NAS
execution_engine: Union[str, ExecutionEngineConfig]
# Internal: to support customized fields in trial command
# Useful when customized python / environment variables are needed
_trial_command_params: Optional[Dict[str, Any]] = None
def __init__(self, training_service_platform: Union[str, None] = None,
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs):
super().__init__(training_service_platform, **kwargs)
self.execution_engine = execution_engine
def _canonicalize(self, _parents):
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
trial_command_tmpl = '{envs} {python} -m nni.retiarii.trial_entry {execution_engine}'
if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command:
raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine)
_trial_command_params = {
# Default variables
'envs': '',
# TODO: maybe use sys.executable rendered in trial side (e.g., trial_runner)
'python': sys.executable,
'execution_engine': self.execution_engine.name,
# This should override the parameters above.
**(self._trial_command_params or {})
}
self.trial_command = trial_command_tmpl.format(**_trial_command_params).strip()
super()._canonicalize([self])

Просмотреть файл

@ -0,0 +1,361 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment', 'preprocess_model', 'debug_mutated_model']
import logging
import warnings
from threading import Thread
from typing import Any, List, cast
import colorama
import torch
import torch.nn as nn
from nni.experiment import Experiment, RunMode
from nni.experiment.config.training_services import RemoteConfig
from nni.nas.execution import list_models, set_execution_engine
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict
from nni.nas.execution.pytorch.codegen import model_to_pytorch_script
from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.execution.pytorch.converter.graph_gen import GraphConverterWithShape
from nni.nas.evaluator import Evaluator
from nni.nas.mutable import Mutator
from nni.nas.nn.pytorch.mutator import (
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
)
from nni.nas.utils import is_model_wrapped
from nni.nas.strategy import BaseStrategy
from nni.nas.strategy.utils import dry_run_for_formatted_search_space
from .config import (
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
)
_logger = logging.getLogger(__name__)
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
# TODO: this logic might need to be refactored into execution engine
if oneshot:
base_model_ir, mutators = process_oneshot_mutations(base_model, evaluator)
elif full_ir:
try:
script_module = torch.jit.script(base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
if dummy_input is not None:
# FIXME: this is a workaround as full tensor is not supported in configs
dummy_input = torch.randn(*dummy_input)
converter = GraphConverterWithShape()
base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input)
else:
base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
else:
base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
base_model_ir.evaluator = evaluator
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
# Add mutations on evaluators
applied_mutators += process_evaluator_mutations(evaluator, applied_mutators)
return base_model_ir, applied_mutators
def debug_mutated_model(base_model, evaluator, applied_mutators):
"""
Locally run only one trial without launching an experiment for debug purpose, then exit.
For example, it can be used to quickly check shape mismatch.
Specifically, it applies mutators (default to choose the first candidate for the choices)
to generate a new model, then run this model locally.
The model will be parsed with graph execution engine.
Parameters
----------
base_model : nni.retiarii.nn.pytorch.nn.Module
the base model
evaluator : nni.retiarii.graph.Evaluator
the training class of the generated models
applied_mutators : list
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators)
from nni.nas.strategy.debug import _LocalDebugStrategy
strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators)
_logger.info('local debug completed!')
class RetiariiExperiment(Experiment):
"""
The entry for a NAS experiment.
Users can use this class to start/stop or inspect an experiment, like exporting the results.
Experiment is a sub-class of :class:`nni.experiment.Experiment`, there are many similarities such as
configurable training service to distributed running the experiment on remote server.
But unlike :class:`nni.experiment.Experiment`, RetiariiExperiment doesn't support configure:
- ``trial_code_directory``, which can only be current working directory.
- ``search_space``, which is auto-generated in NAS.
- ``trial_command``, which must be ``python -m nni.retiarii.trial_entry`` to launch the modulized trial code.
RetiariiExperiment also doesn't have tuner/assessor/advisor, because they are also implemented in strategy.
Also, unlike :class:`nni.experiment.Experiment` which is bounded to a node server,
RetiariiExperiment optionally starts a node server to schedule the trials, when the strategy is a multi-trial strategy.
When the strategy is one-shot, the step of launching node server is omitted, and the experiment is run locally by default.
Configurations of experiments, such as execution engine, number of GPUs allocated,
should be put into a :class:`RetiariiExeConfig` and used as an argument of :meth:`RetiariiExperiment.run`.
Parameters
----------
base_model : nn.Module
The model defining the search space / base skeleton without mutation.
It should be wrapped by decorator ``nni.retiarii.model_wrapper``.
evaluator : nni.retiarii.Evaluator, default = None
Evaluator for the experiment.
If you are using a one-shot trainer, it should be placed here, although this usage is deprecated.
applied_mutators : list of nni.retiarii.Mutator, default = None
Mutators os mutate the base model. If none, mutators are skipped.
Note that when ``base_model`` uses inline mutations (e.g., LayerChoice), ``applied_mutators`` must be empty / none.
strategy : nni.retiarii.strategy.BaseStrategy, default = None
Exploration strategy. Can be multi-trial or one-shot.
trainer : BaseOneShotTrainer
Kept for compatibility purposes.
Examples
--------
Multi-trial NAS:
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
>>> exp = RetiariiExperiment(base_model, model_evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig('local')
>>> exp_config.trial_concurrency = 2
>>> exp_config.max_trial_number = 20
>>> exp_config.training_service.use_active_gpu = False
>>> exp.run(exp_config, 8081)
One-shot NAS:
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
>>> exp = RetiariiExperiment(base_model, evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig()
>>> exp_config.execution_engine = 'oneshot' # must be set of one-shot strategy
>>> exp.run(exp_config)
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module,
evaluator: Evaluator = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: Any = None):
super().__init__(None)
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
evaluator = trainer
if evaluator is None:
raise ValueError('Evaluator should not be none.')
self.base_model = base_model
self.evaluator: Evaluator = evaluator
self.applied_mutators = applied_mutators
self.strategy = strategy
self._dispatcher = None
self._dispatcher_thread = None
# check for sanity
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
def _run_strategy(self, config: RetiariiExeConfig):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators,
full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=config.execution_engine.dummy_input
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
)
_logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
self.update_search_space(search_space)
self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI
def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
#TODO: we will probably need a execution engine factory to make this clean and elegant
if isinstance(config.execution_engine, BaseEngineConfig):
from nni.nas.execution.pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from nni.nas.execution.pytorch.cgo import CGOExecutionEngine
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert config.execution_engine.batch_waiting_time is not None \
and config.execution_engine.max_concurrency_cgo is not None
engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=config.execution_engine.batch_waiting_time,
rest_port=self.port,
rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from nni.nas.execution.pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from nni.nas.execution.pytorch.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
else:
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
set_execution_engine(engine)
def start(self, *args, **kwargs) -> None:
"""
By design, the only different between `start` and `run` is that `start` is asynchronous,
while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment
to complete as strategy runs in foreground.
"""
raise NotImplementedError('RetiariiExperiment is not supposed to provide `start` method')
def run(self,
config: RetiariiExeConfig | None = None,
port: int = 8080,
debug: bool = False) -> None:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer):
warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
'We will try to convert this trainer to our new implementation to run the algorithm. '
'In case you want to stick to the old implementation, '
'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self.evaluator.fit()
return
if config is None:
warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, '
'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning)
self.config = RetiariiExeConfig()
self.config.execution_engine = OneshotEngineConfig()
else:
self.config = config
if isinstance(self.config.execution_engine, OneshotEngineConfig) \
or (isinstance(self.config.execution_engine, str) and self.config.execution_engine == 'oneshot'):
# this is hacky, will be refactored when oneshot can run on training services
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True)
self.strategy.run(base_model_ir, self.applied_mutators)
else:
ws_url = f'ws://localhost:{port}/tuner'
canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
canonicalized_config = cast(RetiariiExeConfig, canonicalized_config)
self._dispatcher = RetiariiAdvisor(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
self._dispatcher_thread.start()
# FIXME: engine cannot be created twice
self._create_execution_engine(canonicalized_config)
try:
self._run_strategy(canonicalized_config)
# FIXME: move this logic to strategy with a new API provided by execution engine
self._wait_completion()
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
self.stop()
_logger.info('Search process is done, the experiment is still alive, `stop()` can terminate the experiment.')
def stop(self) -> None:
"""
Stop background experiment.
"""
_logger.info('Stopping experiment, please wait...')
self._stop_impl()
if self._dispatcher_thread:
self._dispatcher_thread.join()
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None
_logger.info('Experiment stopped')
def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'dict') -> Any:
"""
Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
available for customization.
The concrete behavior of export depends on each strategy.
See the documentation of each strategy for detailed specifications.
Parameters
----------
top_k : int
How many models are intended to be exported.
optimize_mode : str
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Support ``code`` and ``dict``. Not supported by one-shot algorithms.
If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned.
"""
# TODO: the base class may also need this method
if formatter == 'code':
config = self.config.canonical_copy()
assert not isinstance(config.execution_engine, PyEngineConfig), \
'You should use `dict` formatter when using Python execution engine.'
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.evaluator.export()
try:
# this currently works for one-shot algorithms
return self.strategy.export_top_models(top_k=top_k)
except NotImplementedError:
# when strategy hasn't implemented its own export logic
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: cast(float, m.metric), reverse=optimize_mode == 'maximize')
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]]

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

43
nni/nas/fixed.py Normal file
Просмотреть файл

@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from pathlib import Path
from typing import Union, Dict, Any
from .utils import ContextStack
_logger = logging.getLogger(__name__)
def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,
.. code-block:: python
with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)
Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""
if isinstance(fixed_arch, (str, Path)):
with open(fixed_arch) as f:
fixed_arch = json.load(f)
if verbose:
_logger.info(f'Fixed architecture: %s', fixed_arch)
return ContextStack('fixed', fixed_arch)

8
nni/nas/hub/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework

Просмотреть файл

@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mobilenetv3 import MobileNetV3Space
from .nasbench101 import NasBench101
from .nasbench201 import NasBench201
from .nasnet import NDS, NASNet, ENAS, AmoebaNet, PNAS, DARTS
from .proxylessnas import ProxylessNAS
from .shufflenet import ShuffleNetSpace
from .autoformer import AutoformerSpace

Просмотреть файл

@ -0,0 +1,469 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Tuple, cast, Any, Dict
import torch
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper, basic_unit
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, MaybeWeighted as _W
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
class RelativePosition2D(nn.Module):
def __init__(self, head_embed_dim, length=14,) -> None:
super().__init__()
self.head_embed_dim = head_embed_dim
self.legnth = length
self.embeddings_table_v = nn.Parameter(torch.randn(length * 2 + 2, head_embed_dim))
self.embeddings_table_h = nn.Parameter(torch.randn(length * 2 + 2, head_embed_dim))
trunc_normal_(self.embeddings_table_v, std=.02)
trunc_normal_(self.embeddings_table_h, std=.02)
def forward(self, length_q, length_k):
# remove the first cls token distance computation
length_q = length_q - 1
length_k = length_k - 1
# init in the device directly, rather than move to device
range_vec_q = torch.arange(length_q, device=self.embeddings_table_v.device)
range_vec_k = torch.arange(length_k, device=self.embeddings_table_v.device)
# compute the row and column distance
length_q_sqrt = int(length_q ** 0.5)
distance_mat_v = (range_vec_k[None, :] // length_q_sqrt - range_vec_q[:, None] // length_q_sqrt)
distance_mat_h = (range_vec_k[None, :] % length_q_sqrt - range_vec_q[:, None] % length_q_sqrt)
# clip the distance to the range of [-legnth, legnth]
distance_mat_clipped_v = torch.clamp(distance_mat_v, - self.legnth, self.legnth)
distance_mat_clipped_h = torch.clamp(distance_mat_h, - self.legnth, self.legnth)
# translate the distance from [1, 2 * legnth + 1], 0 is for the cls token
final_mat_v = distance_mat_clipped_v + self.legnth + 1
final_mat_h = distance_mat_clipped_h + self.legnth + 1
# pad the 0 which represent the cls token
final_mat_v = F.pad(final_mat_v, (1, 0, 1, 0), "constant", 0)
final_mat_h = F.pad(final_mat_h, (1, 0, 1, 0), "constant", 0)
final_mat_v = final_mat_v.long()
final_mat_h = final_mat_h.long()
# get the embeddings with the corresponding distance
embeddings = self.embeddings_table_v[final_mat_v] + self.embeddings_table_h[final_mat_h]
return embeddings
class RelativePositionAttention(nn.Module):
"""
This class is designed to support the relative position in attention.
The pytorch built-in nn.MultiheadAttention() does not support relative position embedding.
Different from the absolute position embedding, the relative position embedding considers
encode the relative distance between input tokens and learn the pairwise relations of them.
It is commonly calculated via a look-up table with learnable parameters interacting with queries
and keys in self-attention modules.
"""
def __init__(
self, embed_dim, num_heads,
attn_drop=0., proj_drop=0.,
qkv_bias=False, qk_scale=None,
rpe_length=14, rpe=False,
head_dim=64):
super().__init__()
self.num_heads = num_heads
# head_dim is fixed 64 in official autoformer. set head_dim = None to use flex head dim.
self.head_dim = head_dim or (embed_dim // num_heads)
self.scale = qk_scale or head_dim ** -0.5
# Please refer to MixedMultiheadAttention for details.
self.q = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.k = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.v = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(head_dim * num_heads, embed_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rpe = rpe
if rpe:
self.rel_pos_embed_k = RelativePosition2D(head_dim, rpe_length)
self.rel_pos_embed_v = RelativePosition2D(head_dim, rpe_length)
def forward(self, x):
B, N, _ = x.shape
head_dim = self.head_dim
# num_heads can not get from self.num_heads directly,
# use -1 to compute implicitly.
num_heads = -1
q = self.q(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
k = self.k(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
v = self.v(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
num_heads = q.size(1)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.rpe:
r_p_k = self.rel_pos_embed_k(N, N)
attn = attn + (
q.permute(2, 0, 1, 3).reshape(N, num_heads * B, head_dim) @ r_p_k.transpose(2, 1)
).transpose(1, 0).reshape(B, num_heads, N, N) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, num_heads * head_dim)
if self.rpe:
attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * num_heads, N)
r_p_v = self.rel_pos_embed_v(N, N)
# The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with
# the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the
# same size as x (B, num_heads, N, hidden_dim)
x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(B, num_heads, N, head_dim).transpose(2, 1).reshape(B, N, num_heads * head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(nn.Module):
"""
This class is designed to support the RelativePositionAttention().
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
"""
def __init__(
self, embed_dim, num_heads, mlp_ratio=4.,
qkv_bias=False, qk_scale=None, rpe=False,
drop_rate=0., attn_drop=0., proj_drop=0., drop_path=0.,
pre_norm=True, rpe_length=14, head_dim=64
):
super().__init__()
self.normalize_before = pre_norm
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.dropout = drop_rate
self.attn = RelativePositionAttention(
embed_dim=embed_dim,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop,
rpe=rpe,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
rpe_length=rpe_length,
head_dim=head_dim
)
self.attn_layer_norm = nn.LayerNorm(embed_dim)
self.ffn_layer_norm = nn.LayerNorm(embed_dim)
self.activation_fn = nn.GELU()
self.fc1 = nn.Linear(
cast(int, embed_dim),
cast(int, nn.ValueChoice.to_int(embed_dim * mlp_ratio))
)
self.fc2 = nn.Linear(
cast(int, nn.ValueChoice.to_int(embed_dim * mlp_ratio)),
cast(int, embed_dim)
)
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def forward(self, x):
"""
Args:
x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)`
Returns:
encoded output of shape `(batch, patch_num, sample_embed_dim)`
"""
residual = x
x = self.maybe_layer_norm(self.attn_layer_norm, x, before=True)
x = self.attn(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.drop_path(x)
x = residual + x
x = self.maybe_layer_norm(self.attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.ffn_layer_norm, x, before=True)
x = self.fc1(x)
x = self.activation_fn(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.drop_path(x)
x = residual + x
x = self.maybe_layer_norm(self.ffn_layer_norm, x, after=True)
return x
@basic_unit
class ClsToken(nn.Module):
""" Concat class token with dim=embed_dim before patch embedding.
"""
def __init__(self, embed_dim: int):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
trunc_normal_(self.cls_token, std=.02)
def forward(self, x):
return torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
class MixedClsToken(MixedOperation, ClsToken):
""" Mixed class token concat operation.
Supported arguments are:
- ``embed_dim``
Prefix of cls_token will be sliced.
"""
bound_type = ClsToken
argument_list = ['embed_dim']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self, embed_dim,
inputs: torch.Tensor) -> torch.Tensor:
embed_dim_ = _W(embed_dim)
cls_token = _S(self.cls_token)[..., :embed_dim_]
return torch.cat((cls_token.expand(inputs.shape[0], -1, -1), inputs), dim=1)
@basic_unit
class AbsPosEmbed(nn.Module):
""" Add absolute position embedding on patch embedding.
"""
def __init__(self, length: int, embed_dim: int):
super().__init__()
self.pos_embed = nn.Parameter(torch.zeros(1, length, embed_dim))
trunc_normal_(self.pos_embed, std=.02)
def forward(self, x):
return x + self.pos_embed
class MixedAbsPosEmbed(MixedOperation, AbsPosEmbed):
""" Mixed absolute position embedding add operation.
Supported arguments are:
- ``embed_dim``
Prefix of pos_embed will be sliced.
"""
bound_type = AbsPosEmbed
argument_list = ['embed_dim']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self, embed_dim,
inputs: torch.Tensor) -> torch.Tensor:
embed_dim_ = _W(embed_dim)
pos_embed = _S(self.pos_embed)[..., :embed_dim_]
return inputs + pos_embed
@model_wrapper
class AutoformerSpace(nn.Module):
"""
The search space that is proposed in `Autoformer <https://arxiv.org/abs/2107.00651>`__.
There are four searchable variables: depth, embedding dimension, heads number and MLP ratio.
Parameters
----------
search_embed_dim : list of int
The search space of embedding dimension.
search_mlp_ratio : list of float
The search space of MLP ratio.
search_num_heads : list of int
The search space of number of heads.
search_depth: list of int
The search space of depth.
img_size : int
Size of input image.
patch_size : int
Size of image patch.
in_chans : int
Number of channels of the input image.
num_classes : int
Number of classes for classifier.
qkv_bias : bool
Whether to use bias item in the qkv embedding.
drop_rate : float
Drop rate of the MLP projection in MSA and FFN.
attn_drop_rate : float
Drop rate of attention.
drop_path_rate : float
Drop path rate.
pre_norm : bool
Whether to use pre_norm. Otherwise post_norm is used.
global_pool : bool
Whether to use global pooling to generate the image representation. Otherwise the cls_token is used.
abs_pos : bool
Whether to use absolute positional embeddings.
qk_scale : float
The scaler on score map in self-attention.
rpe : bool
Whether to use relative position encoding.
"""
def __init__(
self,
search_embed_dim: Tuple[int, ...] = (192, 216, 240),
search_mlp_ratio: Tuple[float, ...] = (3.0, 3.5, 4.0),
search_num_heads: Tuple[int, ...] = (3, 4),
search_depth: Tuple[int, ...] = (12, 13, 14),
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
qkv_bias: bool = False,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
pre_norm: bool = True,
global_pool: bool = False,
abs_pos: bool = True,
qk_scale: Optional[float] = None,
rpe: bool = True,
):
super().__init__()
# define search space parameters
embed_dim = nn.ValueChoice(list(search_embed_dim), label="embed_dim")
depth = nn.ValueChoice(list(search_depth), label="depth")
mlp_ratios = [nn.ValueChoice(list(search_mlp_ratio), label=f"mlp_ratio_{i}") for i in range(max(search_depth))]
num_heads = [nn.ValueChoice(list(search_num_heads), label=f"num_head_{i}") for i in range(max(search_depth))]
self.patch_embed = nn.Conv2d(
in_chans, cast(int, embed_dim),
kernel_size = patch_size,
stride = patch_size
)
self.patches_num = int((img_size // patch_size) ** 2)
self.global_pool = global_pool
self.cls_token = ClsToken(cast(int, embed_dim))
self.pos_embed = AbsPosEmbed(self.patches_num+1, cast(int, embed_dim)) if abs_pos else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, max(search_depth))] # stochastic depth decay rule
self.blocks = nn.Repeat(
lambda index: TransformerEncoderLayer(
embed_dim = embed_dim, num_heads = num_heads[index], mlp_ratio=mlp_ratios[index],
qkv_bias = qkv_bias, drop_rate = drop_rate, attn_drop = attn_drop_rate, drop_path=dpr[index],
rpe_length=img_size // patch_size, qk_scale=qk_scale, rpe=rpe, pre_norm=pre_norm, head_dim = 64
), depth
)
self.norm = nn.LayerNorm(cast(int, embed_dim)) if pre_norm else nn.Identity()
self.head = nn.Linear(cast(int, embed_dim), num_classes) if num_classes > 0 else nn.Identity()
@classmethod
def get_extra_mutation_hooks(cls):
return [MixedAbsPosEmbed.mutate, MixedClsToken.mutate]
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
init_kwargs = {'qkv_bias': True, 'drop_rate': 0.0, 'drop_path_rate': 0.1, 'global_pool': True, 'num_classes': 1000}
if name == 'autoformer-tiny':
mlp_ratio = [3.5, 3.5, 3.0, 3.5, 3.0, 3.0, 4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 3.5] + [3.0]
num_head = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3] + [3]
arch: Dict[str, Any] = {
'embed_dim': 192,
'depth': 13
}
for i in range(14):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (240, 216, 192),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (4, 3),
'search_depth': (14, 13, 12),
})
elif name == 'autoformer-small':
mlp_ratio = [3.0, 3.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.5, 4.0] + [3.0]
num_head = [6, 6, 5, 7, 5, 5, 5, 6, 6, 7, 7, 6, 7] + [5]
arch: Dict[str, Any] = {
'embed_dim': 384,
'depth': 13
}
for i in range(14):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (448, 384, 320),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (7, 6, 5),
'search_depth': (14, 13, 12),
})
elif name == 'autoformer-base':
mlp_ratio = [3.5, 3.5, 4.0, 3.5, 4.0, 3.5, 3.5, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 3.5] + [3.0, 3.0]
num_head = [9, 9, 9, 9, 9, 10, 9, 9, 10, 9, 10, 9, 9, 10] + [8, 8]
arch: Dict[str, Any] = {
'embed_dim': 576,
'depth': 14
}
for i in range(16):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (624, 576, 528),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (10, 9, 8),
'search_depth': (16, 15, 14),
})
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = FixedFactory(cls, arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x.permute(0, 2, 3, 1).view(B, self.patches_num, -1)
x = self.cls_token(x)
x = self.pos_embed(x)
x = self.blocks(x)
x = self.norm(x)
if self.global_pool:
x = torch.mean(x[:, 1:], dim=1)
else:
x = x[:, 0]
x = self.head(x)
return x

Просмотреть файл

@ -0,0 +1,664 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Tuple, Optional, Callable, Union, List, Type, cast
import torch
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from nni.typehint import Literal
from .proxylessnas import ConvBNReLU, InvertedResidual, DepthwiseSeparableConv, make_divisible, reset_parameters
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
class SqueezeExcite(nn.Module):
"""Squeeze-and-excite layer.
We can't use the op from ``torchvision.ops`` because it's not (yet) properly wrapped,
and ValueChoice couldn't be processed.
Reference:
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L26
- https://github.com/d-li14/mobilenetv3.pytorch/blob/3e6938cedcbbc5ee5bc50780ea18e644702d85fc/mobilenetv3.py#L53
"""
def __init__(self,
channels: int,
reduction_ratio: float = 0.25,
gate_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None):
super().__init__()
rd_channels = make_divisible(channels * reduction_ratio, 8)
gate_layer = gate_layer or nn.Hardsigmoid
activation_layer = activation_layer or nn.ReLU
self.conv_reduce = nn.Conv2d(channels, rd_channels, 1, bias=True)
self.act1 = activation_layer(inplace=True)
self.conv_expand = nn.Conv2d(rd_channels, channels, 1, bias=True)
self.gate = gate_layer()
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
return x * self.gate(x_se)
def _se_or_skip(hidden_ch: int, input_ch: int, optional: bool, se_from_exp: bool, label: str) -> nn.Module:
ch = hidden_ch if se_from_exp else input_ch
if optional:
return nn.LayerChoice({
'identity': nn.Identity(),
'se': SqueezeExcite(ch)
}, label=label)
else:
return SqueezeExcite(ch)
def _act_fn(act_alias: Literal['hswish', 'swish', 'relu']) -> Type[nn.Module]:
if act_alias == 'hswish':
return nn.Hardswish
elif act_alias == 'swish':
return nn.SiLU
elif act_alias == 'relu':
return nn.ReLU
else:
raise ValueError(f'Unsupported act alias: {act_alias}')
@model_wrapper
class MobileNetV3Space(nn.Module):
"""
MobileNetV3Space implements the largest search space in `TuNAS <https://arxiv.org/abs/2008.06120>`__.
The search dimensions include widths, expand ratios, kernel sizes, SE ratio.
Some of them can be turned off via arguments to narrow down the search space.
Different from ProxylessNAS search space, this space is implemented with :class:`nn.ValueChoice`.
We use the following snipppet as reference.
https://github.com/google-research/google-research/blob/20736344591f774f4b1570af64624ed1e18d2867/tunas/mobile_search_space_v3.py#L728
We have ``num_blocks`` which equals to the length of ``self.blocks`` (the main body of the network).
For simplicity, the following parameter specification assumes ``num_blocks`` equals 8 (body + head).
If a shallower body is intended, arrays including ``base_widths``, ``squeeze_excite``, ``depth_range``,
``stride``, ``activation`` should also be shortened accordingly.
Parameters
----------
num_labels
Dimensions for classification head.
base_widths
Widths of each stage, from stem, to body, to head.
Length should be 9, i.e., ``num_blocks + 1`` (because there is a stem width in front).
width_multipliers
A range of widths multiplier to choose from. The choice is independent for each stage.
Or it can be a fixed float. This will be applied on ``base_widths``,
and we would also make sure that widths can be divided by 8.
expand_ratios
A list of expand ratios to choose from. Independent for every **block**.
squeeze_excite
Indicating whether the current stage can have an optional SE layer.
Expect array of length 6 for stage 0 to 5. Each element can be one of ``force``, ``optional``, ``none``.
depth_range
A range (e.g., ``(1, 4)``),
or a list of range (e.g., ``[(1, 3), (1, 4), (1, 4), (1, 3), (0, 2)]``).
If a list, the length should be 5. The depth are specified for stage 1 to 5.
stride
Stride for all stages (including stem and head). Length should be same as ``base_widths``.
activation
Activation (class) for all stages. Length is same as ``base_widths``.
se_from_exp
Calculate SE channel reduction from expanded (mid) channels.
dropout_rate
Dropout rate at classification head.
bn_eps
Epsilon of batch normalization.
bn_momentum
Momentum of batch normalization.
"""
widths: List[Union[nn.ChoiceOf[int], int]]
depth_range: List[Tuple[int, int]]
def __init__(
self, num_labels: int = 1000,
base_widths: Tuple[int, ...] = (16, 16, 16, 32, 64, 128, 256, 512, 1024),
width_multipliers: Union[Tuple[float, ...], float] = (0.5, 0.625, 0.75, 1.0, 1.25, 1.5, 2.0),
expand_ratios: Tuple[float, ...] = (1., 2., 3., 4., 5., 6.),
squeeze_excite: Tuple[Literal['force', 'optional', 'none'], ...] = (
'none', 'none', 'optional', 'none', 'optional', 'optional'
),
depth_range: Union[List[Tuple[int, int]], Tuple[int, int]] = (1, 4),
stride: Tuple[int, ...] = (2, 1, 2, 2, 2, 1, 2, 1, 1),
activation: Tuple[Literal['hswish', 'swish', 'relu'], ...] = (
'hswish', 'relu', 'relu', 'relu', 'hswish', 'hswish', 'hswish', 'hswish', 'hswish'
),
se_from_exp: bool = True,
dropout_rate: float = 0.2,
bn_eps: float = 1e-3,
bn_momentum: float = 0.1
):
super().__init__()
self.num_blocks = len(base_widths) - 1 # without stem, equal to len(self.blocks)
assert self.num_blocks >= 4
assert len(base_widths) == len(stride) == len(activation) == self.num_blocks + 1
# The final two blocks can't have SE
assert len(squeeze_excite) == self.num_blocks - 2 and all(se in ['force', 'optional', 'none'] for se in squeeze_excite)
# The first and final two blocks can't have variational depth
if isinstance(depth_range[0], int):
depth_range = cast(Tuple[int, int], depth_range)
assert len(depth_range) == 2 and depth_range[1] >= depth_range[0] >= 1
self.depth_range = [depth_range] * (self.num_blocks - 3)
else:
assert len(depth_range) == self.num_blocks - 3
self.depth_range = cast(List[Tuple[int, int]], depth_range)
for d in self.depth_range:
d = cast(Tuple[int, int], d)
# pylint: disable=unsubscriptable-object
assert len(d) == 2 and d[1] >= d[0] >= 1, f'{d} does not satisfy depth constraints'
self.widths = []
for i, base_width in enumerate(base_widths):
if isinstance(width_multipliers, float):
self.widths.append(make_divisible(base_width * width_multipliers, 8))
else:
self.widths.append(
# According to tunas, stem and stage 0 share one width multiplier
# https://github.com/google-research/google-research/blob/20736344/tunas/mobile_search_space_v3.py#L791
make_divisible(
nn.ValueChoice(list(width_multipliers), label=f's{max(i - 1, 0)}_width_mult') * base_width, 8
)
)
self.expand_ratios = expand_ratios
self.se_from_exp = se_from_exp
# NOTE: The built-in hardswish produces slightly different output from 3rd-party implementation
# But I guess it doesn't really matter.
# https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/layers/activations.py#L79
self.stem = ConvBNReLU(
3, self.widths[0],
nn.ValueChoice([3, 5], label=f'stem_ks'),
stride=stride[0], activation_layer=_act_fn(activation[0])
)
blocks: List[nn.Module] = [
# Stage 0
# FIXME: this should be an optional layer.
# https://github.com/google-research/google-research/blob/20736344/tunas/mobile_search_space_v3.py#L791
DepthwiseSeparableConv(
self.widths[0], self.widths[1],
nn.ValueChoice([3, 5, 7], label=f's0_i0_ks'),
stride=stride[1],
squeeze_excite=cast(Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module], partial(
_se_or_skip, optional=squeeze_excite[0] == 'optional', se_from_exp=self.se_from_exp, label=f's0_i0_se'
)) if squeeze_excite[0] != 'none' else None,
activation_layer=_act_fn(activation[1])
),
]
blocks += [
# Stage 1-5 (by default)
self._make_stage(i, self.widths[i], self.widths[i + 1], squeeze_excite[i], stride[i + 1], _act_fn(activation[i + 1]))
for i in range(1, self.num_blocks - 2)
]
# Head
blocks += [
ConvBNReLU(
self.widths[self.num_blocks - 2],
self.widths[self.num_blocks - 1],
kernel_size=1,
stride=stride[self.num_blocks - 1],
activation_layer=_act_fn(activation[self.num_blocks - 1])
),
nn.AdaptiveAvgPool2d(1),
# In some implementation, this is a linear instead.
# Should be equivalent.
ConvBNReLU(
self.widths[self.num_blocks - 1],
self.widths[self.num_blocks],
kernel_size=1,
stride=stride[self.num_blocks],
norm_layer=nn.Identity,
activation_layer=_act_fn(activation[self.num_blocks])
)
]
self.blocks = nn.Sequential(*blocks)
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(cast(int, self.widths[self.num_blocks]), num_labels),
)
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def _make_stage(self, stage_idx, inp, oup, se, stride, act):
def layer_builder(idx):
exp = nn.ValueChoice(list(self.expand_ratios), label=f's{stage_idx}_i{idx}_exp')
ks = nn.ValueChoice([3, 5, 7], label=f's{stage_idx}_i{idx}_ks')
# if SE is true, assign a layer choice to SE
se_or_skip = cast(Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module], partial(
_se_or_skip, optional=se == 'optional', se_from_exp=self.se_from_exp, label=f's{stage_idx}_i{idx}_se'
)) if se != 'none' else None
return InvertedResidual(
inp if idx == 0 else oup,
oup, exp, ks,
stride=stride if idx == 0 else 1, # only the first layer in each stage can have stride > 1
squeeze_excite=se_or_skip,
activation_layer=act,
)
# mutable depth
min_depth, max_depth = self.depth_range[stage_idx - 1]
if stride != 1:
min_depth = max(min_depth, 1)
return nn.Repeat(layer_builder, depth=(min_depth, max_depth), label=f's{stage_idx}_depth')
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
return FixedFactory(cls, arch)
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
init_kwargs = {} # all default
if name == 'mobilenetv3-large-100':
# NOTE: Use bicsubic interpolation to evaluate this
# With default interpolation, it yields top-1 75.722
arch = {
'stem_ks': 3,
's0_i0_ks': 3,
's1_depth': 2,
's1_i0_exp': 4,
's1_i0_ks': 3,
's1_i1_exp': 3,
's1_i1_ks': 3,
's2_depth': 3,
's2_i0_exp': 3,
's2_i0_ks': 5,
's2_i1_exp': 3,
's2_i1_ks': 5,
's2_i2_exp': 3,
's2_i2_ks': 5,
's3_depth': 4,
's3_i0_exp': 6,
's3_i0_ks': 3,
's3_i1_exp': 2.5,
's3_i1_ks': 3,
's3_i2_exp': 2.3,
's3_i2_ks': 3,
's3_i3_exp': 2.3,
's3_i3_ks': 3,
's4_depth': 2,
's4_i0_exp': 6,
's4_i0_ks': 3,
's4_i1_exp': 6,
's4_i1_ks': 3,
's5_depth': 3,
's5_i0_exp': 6,
's5_i0_ks': 5,
's5_i1_exp': 6,
's5_i1_ks': 5,
's5_i2_exp': 6,
's5_i2_ks': 5,
}
init_kwargs.update(
base_widths=[16, 16, 24, 40, 80, 112, 160, 960, 1280],
expand_ratios=[1.0, 2.0, 2.3, 2.5, 3.0, 4.0, 6.0],
bn_eps=1e-5,
bn_momentum=0.1,
width_multipliers=1.0,
squeeze_excite=['none', 'none', 'force', 'none', 'force', 'force']
)
elif name.startswith('mobilenetv3-small-'):
# Evaluate with bicubic interpolation
multiplier = int(name.split('-')[-1]) / 100
widths = [16, 16, 24, 40, 48, 96, 576, 1024]
for i in range(7):
if i > 0 or multiplier >= 0.75:
# fix_stem = True when multiplier < 0.75
# https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/mobilenetv3.py#L421
widths[i] = make_divisible(widths[i] * multiplier, 8)
init_kwargs.update(
base_widths=widths,
width_multipliers=1.0,
expand_ratios=[3.0, 3.67, 4.0, 4.5, 6.0],
bn_eps=1e-05,
bn_momentum=0.1,
squeeze_excite=['force', 'none', 'force', 'force', 'force'],
activation=['hswish', 'relu', 'relu', 'hswish', 'hswish', 'hswish', 'hswish', 'hswish'],
stride=[2, 2, 2, 2, 1, 2, 1, 1],
depth_range=(1, 2),
)
arch = {
'stem_ks': 3,
's0_i0_ks': 3,
's1_depth': 2,
's1_i0_exp': 4.5,
's1_i0_ks': 3,
's1_i1_exp': 3.67,
's1_i1_ks': 3,
's2_depth': 3,
's2_i0_exp': 4.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 5,
's2_i2_exp': 6.0,
's2_i2_ks': 5,
's3_depth': 2,
's3_i0_exp': 3.0,
's3_i0_ks': 5,
's3_i1_exp': 3.0,
's3_i1_ks': 5,
's4_depth': 3,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5
}
elif name.startswith('cream'):
# https://github.com/microsoft/Cream/tree/main/Cream
# bilinear interpolation
level = name.split('-')[-1]
# region cream arch specification
if level == '014':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 1,
's1_i0_exp': 4.0,
's1_i0_ks': 3,
's2_depth': 2,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 5,
's3_depth': 2,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 5,
's4_depth': 1,
's4_i0_exp': 6.0,
's4_i0_ks': 3,
's5_depth': 1,
's5_i0_exp': 6.0,
's5_i0_ks': 5
}
elif level == '043':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 1,
's1_i0_exp': 4.0,
's1_i0_ks': 3,
's2_depth': 2,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 3,
's3_depth': 2,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 3,
's4_depth': 3,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5,
's5_depth': 2,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5
}
elif level == '114':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 1,
's1_i0_exp': 4.0,
's1_i0_ks': 3,
's2_depth': 2,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 5,
's3_depth': 2,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 5,
's4_depth': 3,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5,
's5_depth': 2,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5
}
elif level == '287':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 1,
's1_i0_exp': 4.0,
's1_i0_ks': 3,
's2_depth': 2,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 5,
's3_depth': 3,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 3,
's3_i2_exp': 6.0,
's3_i2_ks': 5,
's4_depth': 4,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5,
's4_i3_exp': 6.0,
's4_i3_ks': 5,
's5_depth': 3,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5,
's5_i2_exp': 6.0,
's5_i2_ks': 5
}
elif level == '481':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 4,
's1_i0_exp': 6.0,
's1_i0_ks': 5,
's1_i1_exp': 4.0,
's1_i1_ks': 7,
's1_i2_exp': 6.0,
's1_i2_ks': 5,
's1_i3_exp': 6.0,
's1_i3_ks': 3,
's2_depth': 4,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 4.0,
's2_i1_ks': 5,
's2_i2_exp': 6.0,
's2_i2_ks': 5,
's2_i3_exp': 4.0,
's2_i3_ks': 3,
's3_depth': 5,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 5,
's3_i2_exp': 6.0,
's3_i2_ks': 5,
's3_i3_exp': 6.0,
's3_i3_ks': 3,
's3_i4_exp': 6.0,
's3_i4_ks': 3,
's4_depth': 4,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5,
's4_i3_exp': 6.0,
's4_i3_ks': 5,
's5_depth': 4,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5,
's5_i2_exp': 6.0,
's5_i2_ks': 5,
's5_i3_exp': 6.0,
's5_i3_ks': 5
}
elif level == '604':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 5,
's1_i0_exp': 6.0,
's1_i0_ks': 5,
's1_i1_exp': 6.0,
's1_i1_ks': 5,
's1_i2_exp': 4.0,
's1_i2_ks': 5,
's1_i3_exp': 6.0,
's1_i3_ks': 5,
's1_i4_exp': 6.0,
's1_i4_ks': 5,
's2_depth': 5,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 4.0,
's2_i1_ks': 5,
's2_i2_exp': 6.0,
's2_i2_ks': 5,
's2_i3_exp': 4.0,
's2_i3_ks': 5,
's2_i4_exp': 6.0,
's2_i4_ks': 5,
's3_depth': 5,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 4.0,
's3_i1_ks': 5,
's3_i2_exp': 6.0,
's3_i2_ks': 5,
's3_i3_exp': 4.0,
's3_i3_ks': 5,
's3_i4_exp': 6.0,
's3_i4_ks': 5,
's4_depth': 6,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 4.0,
's4_i2_ks': 5,
's4_i3_exp': 4.0,
's4_i3_ks': 5,
's4_i4_exp': 6.0,
's4_i4_ks': 5,
's4_i5_exp': 6.0,
's4_i5_ks': 5,
's5_depth': 6,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5,
's5_i2_exp': 4.0,
's5_i2_ks': 5,
's5_i3_exp': 6.0,
's5_i3_ks': 5,
's5_i4_exp': 6.0,
's5_i4_ks': 5,
's5_i5_exp': 6.0,
's5_i5_ks': 5
}
else:
raise ValueError(f'Unsupported cream model level: {level}')
# endregion
init_kwargs.update(
base_widths=[16, 16, 24, 40, 80, 96, 192, 320, 1280],
width_multipliers=1.0,
expand_ratios=[4.0, 6.0],
bn_eps=1e-5,
bn_momentum=0.1,
squeeze_excite=['force'] * 6,
activation=['swish'] * 9
)
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = cls.fixed_arch(arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Famous building blocks of search spaces."""
from .autoactivation import *
from .nasbench101 import *
from .nasbench201 import *

Просмотреть файл

@ -0,0 +1,259 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from packaging.version import Version
import torch
import torch.nn as nn
from nni.nas.utils import basic_unit
from nni.nas.nn.pytorch import LayerChoice
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
__all__ = ['AutoActivation']
TorchVersion = '1.5.0'
# ============== unary function modules ==============
@basic_unit
class UnaryIdentity(nn.Module):
def forward(self, x):
return x
@basic_unit
class UnaryNegative(nn.Module):
def forward(self, x):
return -x
@basic_unit
class UnaryAbs(nn.Module):
def forward(self, x):
return torch.abs(x)
@basic_unit
class UnarySquare(nn.Module):
def forward(self, x):
return torch.square(x)
@basic_unit
class UnaryPow(nn.Module):
def forward(self, x):
return torch.pow(x, 3)
@basic_unit
class UnarySqrt(nn.Module):
def forward(self, x):
return torch.sqrt(x)
@basic_unit
class UnaryMul(nn.Module):
def __init__(self):
super().__init__()
# element-wise for now, will change to per-channel trainable parameter
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return x * self.beta
@basic_unit
class UnaryAdd(nn.Module):
def __init__(self):
super().__init__()
# element-wise for now, will change to per-channel trainable parameter
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return x + self.beta
@basic_unit
class UnaryLogAbs(nn.Module):
def forward(self, x):
return torch.log(torch.abs(x) + 1e-7)
@basic_unit
class UnaryExp(nn.Module):
def forward(self, x):
return torch.exp(x)
@basic_unit
class UnarySin(nn.Module):
def forward(self, x):
return torch.sin(x)
@basic_unit
class UnaryCos(nn.Module):
def forward(self, x):
return torch.cos(x)
@basic_unit
class UnarySinh(nn.Module):
def forward(self, x):
return torch.sinh(x)
@basic_unit
class UnaryCosh(nn.Module):
def forward(self, x):
return torch.cosh(x)
@basic_unit
class UnaryTanh(nn.Module):
def forward(self, x):
return torch.tanh(x)
if not Version(torch.__version__) >= Version(TorchVersion):
@basic_unit
class UnaryAsinh(nn.Module):
def forward(self, x):
return torch.asinh(x)
@basic_unit
class UnaryAtan(nn.Module):
def forward(self, x):
return torch.atan(x)
if not Version(torch.__version__) >= Version(TorchVersion):
@basic_unit
class UnarySinc(nn.Module):
def forward(self, x):
return torch.sinc(x)
@basic_unit
class UnaryMax(nn.Module):
def forward(self, x):
return torch.max(x, torch.zeros_like(x))
@basic_unit
class UnaryMin(nn.Module):
def forward(self, x):
return torch.min(x, torch.zeros_like(x))
@basic_unit
class UnarySigmoid(nn.Module):
def forward(self, x):
return torch.sigmoid(x)
@basic_unit
class UnaryLogExp(nn.Module):
def forward(self, x):
return torch.log(1 + torch.exp(x))
@basic_unit
class UnaryExpSquare(nn.Module):
def forward(self, x):
return torch.exp(-torch.square(x))
@basic_unit
class UnaryErf(nn.Module):
def forward(self, x):
return torch.erf(x)
unary_modules = ['UnaryIdentity', 'UnaryNegative', 'UnaryAbs', 'UnarySquare', 'UnaryPow',
'UnarySqrt', 'UnaryMul', 'UnaryAdd', 'UnaryLogAbs', 'UnaryExp', 'UnarySin', 'UnaryCos',
'UnarySinh', 'UnaryCosh', 'UnaryTanh', 'UnaryAtan', 'UnaryMax',
'UnaryMin', 'UnarySigmoid', 'UnaryLogExp', 'UnaryExpSquare', 'UnaryErf']
if not Version(torch.__version__) >= Version(TorchVersion):
unary_modules.append('UnaryAsinh')
unary_modules.append('UnarySinc')
# ============== binary function modules ==============
@basic_unit
class BinaryAdd(nn.Module):
def forward(self, x):
return x[0] + x[1]
@basic_unit
class BinaryMul(nn.Module):
def forward(self, x):
return x[0] * x[1]
@basic_unit
class BinaryMinus(nn.Module):
def forward(self, x):
return x[0] - x[1]
@basic_unit
class BinaryDivide(nn.Module):
def forward(self, x):
return x[0] / (x[1] + 1e-7)
@basic_unit
class BinaryMax(nn.Module):
def forward(self, x):
return torch.max(x[0], x[1])
@basic_unit
class BinaryMin(nn.Module):
def forward(self, x):
return torch.min(x[0], x[1])
@basic_unit
class BinarySigmoid(nn.Module):
def forward(self, x):
return torch.sigmoid(x[0]) * x[1]
@basic_unit
class BinaryExpSquare(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return torch.exp(-self.beta * torch.square(x[0] - x[1]))
@basic_unit
class BinaryExpAbs(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return torch.exp(-self.beta * torch.abs(x[0] - x[1]))
@basic_unit
class BinaryParamAdd(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return self.beta * x[0] + (1 - self.beta) * x[1]
binary_modules = ['BinaryAdd', 'BinaryMul', 'BinaryMinus', 'BinaryDivide', 'BinaryMax',
'BinaryMin', 'BinarySigmoid', 'BinaryExpSquare', 'BinaryExpAbs', 'BinaryParamAdd']
class AutoActivation(nn.Module):
"""
This module is an implementation of the paper `Searching for Activation Functions <https://arxiv.org/abs/1710.05941>`__.
Parameters
----------
unit_num : int
the number of core units
Notes
-----
Current `beta` is not per-channel parameter.
"""
def __init__(self, unit_num: int = 1, label: str | None = None):
super().__init__()
self._label = generate_new_label(label)
self.unaries = nn.ModuleList()
self.binaries = nn.ModuleList()
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_0')
for i in range(unit_num):
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_{i+1}')
self.unaries.append(one_unary)
for i in range(unit_num):
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules], label = f'{self.label}__binary_{i}')
self.binaries.append(one_binary)
@property
def label(self):
return self._label
def forward(self, x):
out = self.first_unary(x)
for unary, binary in zip(self.unaries, self.binaries):
out = binary(torch.stack([out, unary(x)]))
return out

Просмотреть файл

@ -0,0 +1,418 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['NasBench101Cell', 'NasBench101Mutator']
import logging
from collections import OrderedDict
from typing import Callable, List, Optional, Union, Dict, Tuple, cast
import numpy as np
import torch
import torch.nn as nn
from nni.nas.mutable import InvalidMutation, Mutator
from nni.nas.execution.common import Model
from nni.nas.nn.pytorch import InputChoice, ValueChoice, LayerChoice
from nni.nas.nn.pytorch.mutation_utils import Mutable, generate_new_label, get_fixed_dict
_logger = logging.getLogger(__name__)
def compute_vertex_channels(input_channels, output_channels, matrix):
"""
This is (almost) copied from the original NAS-Bench-101 implementation.
Computes the number of channels at every vertex.
Given the input channels and output channels, this calculates the number of channels at each interior vertex.
Interior vertices have the same number of channels as the max of the channels of the vertices it feeds into.
The output channels are divided amongst the vertices that are directly connected to it.
When the division is not even, some vertices may receive an extra channel to compensate.
Parameters
----------
in_channels : int
input channels count.
output_channels : int
output channel count.
matrix : np.ndarray
adjacency matrix for the module (pruned by model_spec).
Returns
-------
list of int
list of channel counts, in order of the vertices.
"""
num_vertices = np.shape(matrix)[0]
vertex_channels = [0] * num_vertices
vertex_channels[0] = input_channels
vertex_channels[num_vertices - 1] = output_channels
if num_vertices == 2:
# Edge case where module only has input and output vertices
return vertex_channels
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
in_degree = np.sum(matrix[1:], axis=0)
interior_channels = output_channels // in_degree[num_vertices - 1]
correction = output_channels % in_degree[num_vertices - 1] # Remainder to add
# Set channels of vertices that flow directly to output
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
vertex_channels[v] = interior_channels
if correction:
vertex_channels[v] += 1
correction -= 1
# Set channels for all other vertices to the max of the out edges, going backwards.
# (num_vertices - 2) index skipped because it only connects to output.
for v in range(num_vertices - 3, 0, -1):
if not matrix[v, num_vertices - 1]:
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
assert vertex_channels[v] > 0
_logger.debug('vertex_channels: %s', str(vertex_channels))
# Sanity check, verify that channels never increase and final channels add up.
final_fan_in = 0
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
final_fan_in += vertex_channels[v]
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
assert vertex_channels[v] >= vertex_channels[dst]
assert final_fan_in == output_channels or num_vertices == 2
# num_vertices == 2 means only input/output nodes, so 0 fan-in
return vertex_channels
def prune(matrix, ops) -> Tuple[np.ndarray, List[Union[str, Callable[[int], nn.Module]]]]:
"""
Prune the extraneous parts of the graph.
General procedure:
1. Remove parts of graph not connected to input.
2. Remove parts of graph not connected to output.
3. Reorder the vertices so that they are consecutive after steps 1 and 2.
These 3 steps can be combined by deleting the rows and columns of the
vertices that are not reachable from both the input and output (in reverse).
"""
num_vertices = np.shape(matrix)[0]
# calculate the connection matrix within V number of steps.
connections = np.linalg.matrix_power(matrix + np.eye(num_vertices), num_vertices)
visited_from_input = set([i for i in range(num_vertices) if connections[0, i]])
visited_from_output = set([i for i in range(num_vertices) if connections[i, -1]])
# Any vertex that isn't connected to both input and output is extraneous to the computation graph.
extraneous = set(range(num_vertices)).difference(
visited_from_input.intersection(visited_from_output))
if len(extraneous) > num_vertices - 2:
raise InvalidMutation('Non-extraneous graph is less than 2 vertices, '
'the input is not connected to the output and the spec is invalid.')
matrix = np.delete(matrix, list(extraneous), axis=0)
matrix = np.delete(matrix, list(extraneous), axis=1)
for index in sorted(extraneous, reverse=True):
del ops[index]
return matrix, ops
def truncate(inputs, channels):
input_channels = inputs.size(1)
if input_channels < channels:
raise ValueError('input channel < output channels for truncate')
elif input_channels == channels:
return inputs # No truncation necessary
else:
# Truncation should only be necessary when channel division leads to
# vertices with +1 channels. The input vertex should always be projected to
# the minimum channel count.
assert input_channels - channels == 1
return inputs[:, :channels]
class _NasBench101CellFixed(nn.Module):
"""
The fixed version of NAS-Bench-101 Cell, used in python-version execution engine.
"""
def __init__(self, operations: List[Callable[[int], nn.Module]],
adjacency_list: List[List[int]],
in_features: int, out_features: int, num_nodes: int,
projection: Callable[[int, int], nn.Module]):
super().__init__()
assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1
raw_operations: List[Union[str, Callable[[int], nn.Module]]] = list(operations)
del operations # operations is no longer needed. Delete it to avoid misuse
# add psuedo nodes
raw_operations.insert(0, 'IN')
raw_operations.append('OUT')
self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes)
del num_nodes # raw number of nodes is no longer used
self.connection_matrix, self.operations = prune(self.connection_matrix, raw_operations)
self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix)
self.num_nodes = len(self.connection_matrix)
self.in_features = in_features
self.out_features = out_features
_logger.info('Prund number of nodes: %d', self.num_nodes)
_logger.info('Pruned connection matrix: %s', str(self.connection_matrix))
self.projections = nn.ModuleList([nn.Identity()])
self.ops = nn.ModuleList([nn.Identity()])
for i in range(1, self.num_nodes):
self.projections.append(projection(in_features, self.hidden_features[i]))
for i in range(1, self.num_nodes - 1):
operation = cast(Callable[[int], nn.Module], self.operations[i])
self.ops.append(operation(self.hidden_features[i]))
@staticmethod
def build_connection_matrix(adjacency_list, num_nodes):
adjacency_list = [[]] + adjacency_list # add adjacency for first node
connections = np.zeros((num_nodes, num_nodes), dtype='int')
for i, lst in enumerate(adjacency_list):
assert all([0 <= k < i for k in lst])
for k in lst:
connections[k, i] = 1
return connections
def forward(self, inputs):
tensors = [inputs]
for t in range(1, self.num_nodes - 1):
# Create interior connections, truncating if necessary
add_in = [truncate(tensors[src], self.hidden_features[t])
for src in range(1, t) if self.connection_matrix[src, t]]
# Create add connection from projected input
if self.connection_matrix[0, t]:
add_in.append(self.projections[t](tensors[0]))
if len(add_in) == 1:
vertex_input = add_in[0]
else:
vertex_input = sum(add_in)
# Perform op at vertex t
vertex_out = self.ops[t](vertex_input)
tensors.append(vertex_out)
# Construct final output tensor by concating all fan-in and adding input.
if np.sum(self.connection_matrix[:, -1]) == 1:
src = np.where(self.connection_matrix[:, -1] == 1)[0][0]
return self.projections[-1](tensors[0]) if src == 0 else tensors[src]
outputs = torch.cat([tensors[src] for src in range(1, self.num_nodes - 1) if self.connection_matrix[src, -1]], 1)
if self.connection_matrix[0, -1]:
outputs += self.projections[-1](tensors[0])
assert outputs.size(1) == self.out_features
return outputs
class NasBench101Cell(Mutable):
"""
Cell structure that is proposed in NAS-Bench-101.
Proposed by `NAS-Bench-101: Towards Reproducible Neural Architecture Search <http://proceedings.mlr.press/v97/ying19a/ying19a.pdf>`__.
This cell is usually used in evaluation of NAS algorithms because there is a "comprehensive analysis" of this search space
available, which includes a full architecture-dataset that "maps 423k unique architectures to metrics
including run time and accuracy". You can also use the space in your own space design, in which scenario it should be possible
to leverage results in the benchmark to narrow the huge space down to a few efficient architectures.
The space of this cell architecture consists of all possible directed acyclic graphs on no more than ``max_num_nodes`` nodes,
where each possible node (other than IN and OUT) has one of ``op_candidates``, representing the corresponding operation.
Edges connecting the nodes can be no more than ``max_num_edges``.
To align with the paper settings, two vertices specially labeled as operation IN and OUT, are also counted into
``max_num_nodes`` in our implementaion, the default value of ``max_num_nodes`` is 7 and ``max_num_edges`` is 9.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. The shape
of each hidden nodes will be first automatically computed, depending on the cell structure. Each of the ``op_candidates``
should be a callable that accepts computed ``num_features`` and returns a ``Module``. For example,
.. code-block:: python
def conv_bn_relu(num_features):
return nn.Sequential(
nn.Conv2d(num_features, num_features, 1),
nn.BatchNorm2d(num_features),
nn.ReLU()
)
The output of each node is the sum of its input node feed into its operation, except for the last node (output node),
which is the concatenation of its input *hidden* nodes, adding the *IN* node (if IN and OUT are connected).
When input tensor is added with any other tensor, there could be shape mismatch. Therefore, a projection transformation
is needed to transform the input tensor. In paper, this is simply a Conv1x1 followed by BN and ReLU. The ``projection``
parameters accepts ``in_features`` and ``out_features``, returns a ``Module``. This parameter has no default value,
as we hold no assumption that users are dealing with images. An example for this parameter is,
.. code-block:: python
def projection_fn(in_features, out_features):
return nn.Conv2d(in_features, out_features, 1)
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts number of feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
projection : callable
Projection module that is used to preprocess the input tensor of the whole cell.
A callable that accept input feature and output feature, returning nn.Module.
max_num_nodes : int
Maximum number of nodes in the cell, input and output included. At least 2. Default: 7.
max_num_edges : int
Maximum number of edges in the cell. Default: 9.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
Warnings
--------
:class:`NasBench101Cell` is not supported in :ref:`graph-based execution engine <graph-based-execution-engine>`.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
@classmethod
def create_fixed_module(cls, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
def make_list(x): return x if isinstance(x, list) else [x]
label, selected = get_fixed_dict(label)
op_candidates = cls._make_dict(op_candidates)
num_nodes = selected[f'{label}/num_nodes']
adjacency_list = [make_list(selected[f'{label}/input{i}']) for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
return _NasBench101CellFixed(
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection)
# FIXME: weight inheritance on nasbench101 is not supported yet
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
num_vertices_prior = [2 ** i for i in range(2, max_num_nodes + 1)]
num_vertices_prior = (np.array(num_vertices_prior) / sum(num_vertices_prior)).tolist()
self.num_nodes = ValueChoice(list(range(2, max_num_nodes + 1)),
prior=num_vertices_prior,
label=f'{self._label}/num_nodes')
self.max_num_nodes = max_num_nodes
self.max_num_edges = max_num_edges
op_candidates = self._make_dict(op_candidates)
# this is only for input validation and instantiating enough layer choice and input choice
self.hidden_features = out_features
self.projections = nn.ModuleList([nn.Identity()])
self.ops = nn.ModuleList([nn.Identity()])
self.inputs = nn.ModuleList([nn.Identity()])
for _ in range(1, max_num_nodes):
self.projections.append(projection(in_features, self.hidden_features))
for i in range(1, max_num_nodes):
if i < max_num_nodes - 1:
self.ops.append(LayerChoice(OrderedDict([(k, op(self.hidden_features)) for k, op in op_candidates.items()]),
label=f'{self._label}/op{i}'))
self.inputs.append(InputChoice(i, None, label=f'{self._label}/input{i}'))
@property
def label(self):
return self._label
def forward(self, x):
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors = [x]
for i in range(1, self.max_num_nodes):
node_input = self.inputs[i]([self.projections[i](tensors[0])] + [t for t in tensors[1:]])
if i < self.max_num_nodes - 1:
node_output = self.ops[i](node_input)
else:
node_output = node_input
tensors.append(node_output)
return tensors[-1]
class NasBench101Mutator(Mutator):
# for validation purposes
# for python execution engine
def __init__(self, label: str):
super().__init__(label=label)
@staticmethod
def candidates(node):
if 'n_candidates' in node.operation.parameters:
return list(range(node.operation.parameters['n_candidates']))
else:
return node.operation.parameters['candidates']
@staticmethod
def number_of_chosen(node):
if 'n_chosen' in node.operation.parameters:
return node.operation.parameters['n_chosen']
return 1
def mutate(self, model: Model):
max_num_edges = cast(int, None)
for node in model.get_nodes_by_label(self.label):
max_num_edges = node.operation.parameters['max_num_edges']
break
assert max_num_edges is not None
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
matrix = _NasBench101CellFixed.build_connection_matrix(adjacency_list, num_nodes)
operations = ['IN'] + [mutation_dict[f'{self.label}/op{i}'][0] for i in range(1, num_nodes - 1)] + ['OUT']
assert len(operations) == len(matrix)
matrix, operations = prune(matrix, operations) # possible to raise InvalidMutation inside
# NOTE: a hack to maintain a clean copy of what nasbench101 cell looks like
self._cur_samples = {}
for i in range(1, len(matrix)):
if i + 1 < len(matrix):
self._cur_samples[f'op{i}'] = operations[i]
self._cur_samples[f'input{i}'] = [k for k in range(i) if matrix[k, i]]
self._cur_samples = [self._cur_samples] # by design, _cur_samples is a list of samples
def dry_run(self, model):
return [], model

Просмотреть файл

@ -0,0 +1,87 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['NasBench201Cell']
from collections import OrderedDict
from typing import Callable, List, Dict, Union, Optional
import torch
import torch.nn as nn
from nni.nas.nn.pytorch import LayerChoice
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
class NasBench201Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-201.
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.layers = nn.ModuleList()
self.in_features = in_features
self.out_features = out_features
self.num_tensors = num_tensors
op_candidates = self._make_dict(op_candidates)
for tid in range(1, num_tensors):
node_ops = nn.ModuleList()
for j in range(tid):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors: List[torch.Tensor] = [inputs]
for layer in self.layers:
current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) # type: ignore
tensors.append(torch.sum(torch.stack(current_tensor), 0))
return tensors[-1]

Просмотреть файл

@ -0,0 +1,121 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import torch
import torch.nn as nn
from nni.nas import model_wrapper
from .modules.nasbench101 import NasBench101Cell
__all__ = ['NasBench101']
def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
class ConvBNReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super(ConvBNReLU, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
truncated_normal_(m.weight.data, mean=0., std=math.sqrt(1. / fan_in))
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
return self.conv_bn_relu(x)
class Conv3x3BNReLU(ConvBNReLU):
def __init__(self, in_channels, out_channels):
super(Conv3x3BNReLU, self).__init__(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
class Conv1x1BNReLU(ConvBNReLU):
def __init__(self, in_channels, out_channels):
super(Conv1x1BNReLU, self).__init__(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
Projection = Conv1x1BNReLU
@model_wrapper
class NasBench101(nn.Module):
"""The full search space, proposed by `NAS-Bench-101 <http://proceedings.mlr.press/v97/ying19a/ying19a.pdf>`__.
It's simply a stack of :class:`NasBench101Cell`. Operations are conv3x3, conv1x1 and maxpool respectively.
"""
def __init__(self,
stem_out_channels: int = 128,
num_stacks: int = 3,
num_modules_per_stack: int = 3,
max_num_vertices: int = 7,
max_num_edges: int = 9,
num_labels: int = 10,
bn_eps: float = 1e-5,
bn_momentum: float = 0.003):
super().__init__()
op_candidates = {
'conv3x3-bn-relu': lambda num_features: Conv3x3BNReLU(num_features, num_features),
'conv1x1-bn-relu': lambda num_features: Conv1x1BNReLU(num_features, num_features),
'maxpool3x3': lambda num_features: nn.MaxPool2d(3, 1, 1)
}
# initial stem convolution
self.stem_conv = Conv3x3BNReLU(3, stem_out_channels)
layers = []
in_channels = out_channels = stem_out_channels
for stack_num in range(num_stacks):
if stack_num > 0:
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
layers.append(downsample)
out_channels *= 2
for _ in range(num_modules_per_stack):
cell = NasBench101Cell(op_candidates, in_channels, out_channels,
lambda cin, cout: Projection(cin, cout),
max_num_vertices, max_num_edges, label='cell')
layers.append(cell)
in_channels = out_channels
self.features = nn.ModuleList(layers)
self.gap = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(out_channels, num_labels)
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = bn_eps
module.momentum = bn_momentum
def forward(self, x):
bs = x.size(0)
out = self.stem_conv(x)
for layer in self.features:
out = layer(out)
out = self.gap(out).view(bs, -1)
out = self.classifier(out)
return out

Просмотреть файл

@ -0,0 +1,205 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, Dict
import torch
import torch.nn as nn
from nni.nas import model_wrapper
from .modules.nasbench201 import NasBench201Cell
__all__ = ['NasBench201']
OPS_WITH_STRIDE = {
'none': lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'avg'),
'max_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'max'),
'conv_3x3': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3, 3), (stride, stride), (1, 1), (1, 1)),
'conv_1x1': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1, 1), (stride, stride), (0, 0), (1, 1)),
'skip_connect': lambda C_in, C_out, stride: nn.Identity() if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride),
}
PRIMITIVES = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3']
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out),
)
def forward(self, x):
return self.op(x)
class Pooling(nn.Module):
def __init__(self, C_in, C_out, stride, mode):
super(Pooling, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1)
if mode == 'avg':
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max':
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else:
raise ValueError('Invalid mode={:} in Pooling'.format(mode))
def forward(self, x):
if self.preprocess:
x = self.preprocess(x)
return self.op(x)
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1:
return x.mul(0.)
else:
return x[:, :, ::self.stride, ::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
inputs = self.downsample(inputs) # residual
return inputs + basicblock
@model_wrapper
class NasBench201(nn.Module):
"""The full search space proposed by `NAS-Bench-201 <https://arxiv.org/abs/2001.00326>`__.
It's a stack of :class:`NasBench201Cell`.
"""
def __init__(self,
stem_out_channels: int = 16,
num_modules_per_stack: int = 5,
num_labels: int = 10):
super().__init__()
self.channels = C = stem_out_channels
self.num_modules = N = num_modules_per_stack
self.num_labels = num_labels
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for C_curr, reduction in zip(layer_channels, layer_reductions):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
}
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell)
C_prev = C_curr
self.lastact = nn.Sequential(
nn.BatchNorm2d(C_prev),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels)
def forward(self, inputs):
feature = self.stem(inputs)
for cell in self.cells:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits

Просмотреть файл

@ -0,0 +1,862 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""File containing NASNet-series search space.
The implementation is based on NDS.
It's called ``nasnet.py`` simply because NASNet is the first to propose such structure.
"""
from collections import OrderedDict
from functools import partial
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from nni.nas.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat
from nni.nas.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
# the following are NAS operations from
# https://github.com/facebookresearch/unnas/blob/main/pycls/models/nas/operations.py
OPS = {
'none': lambda C, stride, affine:
Zero(stride),
'avg_pool_2x2': lambda C, stride, affine:
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
'avg_pool_3x3': lambda C, stride, affine:
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'avg_pool_5x5': lambda C, stride, affine:
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
'max_pool_2x2': lambda C, stride, affine:
nn.MaxPool2d(2, stride=stride, padding=0),
'max_pool_3x3': lambda C, stride, affine:
nn.MaxPool2d(3, stride=stride, padding=1),
'max_pool_5x5': lambda C, stride, affine:
nn.MaxPool2d(5, stride=stride, padding=2),
'max_pool_7x7': lambda C, stride, affine:
nn.MaxPool2d(7, stride=stride, padding=3),
'skip_connect': lambda C, stride, affine:
nn.Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'conv_1x1': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_3x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'sep_conv_3x3': lambda C, stride, affine:
SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine:
SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine:
SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine:
DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5': lambda C, stride, affine:
DilConv(C, C, 5, stride, 4, 2, affine=affine),
'dil_sep_conv_3x3': lambda C, stride, affine:
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
'conv_3x1_1x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, stride), padding=(0, 1), bias=False),
nn.Conv2d(C, C, (3, 1), stride=(stride, 1), padding=(1, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_7x1_1x7': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
}
class ReLUConvBN(nn.Sequential):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False
),
nn.BatchNorm2d(C_out, affine=affine)
)
class DilConv(nn.Sequential):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
class SepConv(nn.Sequential):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
class DilSepConv(nn.Sequential):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
class Zero(nn.Module):
def __init__(self, stride):
super().__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super().__init__()
if isinstance(C_out, int):
assert C_out % 2 == 0
else: # is a value choice
assert all(c % 2 == 0 for c in C_out.all_options())
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.conv_1(x), self.conv_2(y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
class DropPath_(nn.Module):
# https://github.com/khanrc/pt.darts/blob/0.1/models/ops.py
def __init__(self, drop_prob=0.):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.training and self.drop_prob > 0.:
keep_prob = 1. - self.drop_prob
mask = torch.zeros((x.size(0), 1, 1, 1), dtype=torch.float, device=x.device).bernoulli_(keep_prob)
return x.div(keep_prob).mul(mask)
return x
class AuxiliaryHead(nn.Module):
def __init__(self, C: int, num_labels: int, dataset: Literal['imagenet', 'cifar']):
super().__init__()
if dataset == 'imagenet':
# assuming input size 14x14
stride = 2
elif dataset == 'cifar':
stride = 3
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=stride, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_labels)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class SequentialBreakdown(nn.Sequential):
"""Return all layers of a sequential."""
def __init__(self, sequential: nn.Sequential):
super().__init__(OrderedDict(sequential.named_children()))
def forward(self, inputs):
result = []
for module in self:
inputs = module(inputs)
result.append(inputs)
return result
class CellPreprocessor(nn.Module):
"""
Aligning the shape of predecessors.
If the last cell is a reduction cell, ``pre0`` should be ``FactorizedReduce`` instead of ``ReLUConvBN``.
See :class:`CellBuilder` on how to calculate those channel numbers.
"""
def __init__(self, C_pprev: nn.MaybeChoice[int], C_prev: nn.MaybeChoice[int], C: nn.MaybeChoice[int], last_cell_reduce: bool) -> None:
super().__init__()
if last_cell_reduce:
self.pre0 = FactorizedReduce(cast(int, C_pprev), cast(int, C))
else:
self.pre0 = ReLUConvBN(cast(int, C_pprev), cast(int, C), 1, 1, 0)
self.pre1 = ReLUConvBN(cast(int, C_prev), cast(int, C), 1, 1, 0)
def forward(self, cells):
assert len(cells) == 2
pprev, prev = cells
pprev = self.pre0(pprev)
prev = self.pre1(prev)
return [pprev, prev]
class CellPostprocessor(nn.Module):
"""
The cell outputs previous cell + this cell, so that cells can be directly chained.
"""
def forward(self, this_cell, previous_cells):
return [previous_cells[-1], this_cell]
class CellBuilder:
"""The cell builder is used in Repeat.
Builds an cell each time it's "called".
Note that the builder is ephemeral, it can only be called once for every index.
"""
def __init__(self, op_candidates: List[str],
C_prev_in: nn.MaybeChoice[int],
C_in: nn.MaybeChoice[int],
C: nn.MaybeChoice[int],
num_nodes: int,
merge_op: Literal['all', 'loose_end'],
first_cell_reduce: bool, last_cell_reduce: bool):
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
self.C_in = C_in # This is the out channesl of last cell.
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self.op_candidates = op_candidates
self.num_nodes = num_nodes
self.merge_op: Literal['all', 'loose_end'] = merge_op
self.first_cell_reduce = first_cell_reduce
self.last_cell_reduce = last_cell_reduce
self._expect_idx = 0
# It takes an index that is the index in the repeat.
# Number of predecessors for each cell is fixed to 2.
self.num_predecessors = 2
# Number of ops per node is fixed to 2.
self.num_ops_per_node = 2
def op_factory(self, node_index: int, op_index: int, input_index: Optional[int], *,
op: str, channels: int, is_reduction_cell: bool):
if is_reduction_cell and (
input_index is None or input_index < self.num_predecessors
): # could be none when constructing search sapce
stride = 2
else:
stride = 1
return OPS[op](channels, stride, True)
def __call__(self, repeat_idx: int):
if self._expect_idx != repeat_idx:
raise ValueError(f'Expect index {self._expect_idx}, found {repeat_idx}')
# Reduction cell means stride = 2 and channel multiplied by 2.
is_reduction_cell = repeat_idx == 0 and self.first_cell_reduce
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce)
ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {}
for op in self.op_candidates:
ops_factory[op] = partial(self.op_factory, op=op, channels=cast(int, self.C), is_reduction_cell=is_reduction_cell)
cell = nn.Cell(ops_factory, self.num_nodes, self.num_ops_per_node, self.num_predecessors, self.merge_op,
preprocessor=preprocessor, postprocessor=CellPostprocessor(),
label='reduce' if is_reduction_cell else 'normal')
# update state
self.C_prev_in = self.C_in
self.C_in = self.C * len(cell.output_node_indices)
self.last_cell_reduce = is_reduction_cell
self._expect_idx += 1
return cell
class NDSStage(nn.Repeat):
"""This class defines NDSStage, a special type of Repeat, for isinstance check, and shape alignment.
In NDS, we can't simply use Repeat to stack the blocks,
because the output shape of each stacked block can be different.
This is a problem for one-shot strategy because they assume every possible candidate
should return values of the same shape.
Therefore, we need :class:`NDSStagePathSampling` and :class:`NDSStageDifferentiable`
to manually align the shapes -- specifically, to transform the first block in each stage.
This is not required though, when depth is not changing, or the mutable depth causes no problem
(e.g., when the minimum depth is large enough).
.. attention::
Assumption: Loose end is treated as all in ``merge_op`` (the case in one-shot),
which enforces reduction cell and normal cells in the same stage to have the exact same output shape.
"""
estimated_out_channels_prev: int
"""Output channels of cells in last stage."""
estimated_out_channels: int
"""Output channels of this stage. It's **estimated** because it assumes ``all`` as ``merge_op``."""
downsampling: bool
"""This stage has downsampling"""
def first_cell_transformation_factory(self) -> Optional[nn.Module]:
"""To make the "previous cell" in first cell's output have the same shape as cells in this stage."""
if self.downsampling:
return FactorizedReduce(self.estimated_out_channels_prev, self.estimated_out_channels)
elif self.estimated_out_channels_prev is not self.estimated_out_channels:
# Can't use != here, ValueChoice doesn't support
return ReLUConvBN(self.estimated_out_channels_prev, self.estimated_out_channels, 1, 1, 0)
return None
class NDSStagePathSampling(PathSamplingRepeat):
"""The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
module.depth_choice
)
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_cell_transformation = first_cell_transformation
def reduction(self, items: List[Tuple[torch.Tensor, torch.Tensor]], sampled: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
if 1 not in sampled or self.first_cell_transformation is None:
return super().reduction(items, sampled)
# items[0] must be the result of first cell
assert len(items[0]) == 2
# Only apply the transformation on "prev" output.
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
return super().reduction(items, sampled)
class NDSStageDifferentiable(DifferentiableMixedRepeat):
"""The differentiable implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
module.depth_choice,
softmax,
memo
)
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_cell_transformation = first_cell_transformation
def reduction(
self, items: List[Tuple[torch.Tensor, torch.Tensor]], weights: List[float], depths: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
if 1 not in depths or self.first_cell_transformation is None:
return super().reduction(items, weights, depths)
# Same as NDSStagePathSampling
assert len(items[0]) == 2
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
return super().reduction(items, weights, depths)
_INIT_PARAMETER_DOCS = """
Parameters
----------
width : int or tuple of int
A fixed initial width or a tuple of widths to choose from.
num_cells : int or tuple of int
A fixed number of cells (depths) to stack, or a tuple of depths to choose from.
dataset : "cifar" | "imagenet"
The essential differences are in "stem" cells, i.e., how they process the raw image input.
Choosing "imagenet" means more downsampling at the beginning of the network.
auxiliary_loss : bool
If true, another auxiliary classification head will produce the another prediction.
This makes the output of network two logits in the training phase.
"""
class NDS(nn.Module):
__doc__ = """
The unified version of NASNet search space.
We follow the implementation in
`unnas <https://github.com/facebookresearch/unnas/blob/main/pycls/models/nas/nas.py>`__.
See `On Network Design Spaces for Visual Recognition <https://arxiv.org/abs/1905.13214>`__ for details.
Different NAS papers usually differ in the way that they specify ``op_candidates`` and ``merge_op``.
``dataset`` here is to give a hint about input resolution, so as to create reasonable stem and auxiliary heads.
NDS has a speciality that it has mutable depths/widths.
This is implemented by accepting a list of int as ``num_cells`` / ``width``.
""" + _INIT_PARAMETER_DOCS + """
op_candidates : list of str
List of operator candidates. Must be from ``OPS``.
merge_op : ``all`` or ``loose_end``
See :class:`~nni.retiarii.nn.pytorch.Cell`.
num_nodes_per_cell : int
See :class:`~nni.retiarii.nn.pytorch.Cell`.
"""
def __init__(self,
op_candidates: List[str],
merge_op: Literal['all', 'loose_end'] = 'all',
num_nodes_per_cell: int = 4,
width: Union[Tuple[int, ...], int] = 16,
num_cells: Union[Tuple[int, ...], int] = 20,
dataset: Literal['cifar', 'imagenet'] = 'imagenet',
auxiliary_loss: bool = False):
super().__init__()
self.dataset = dataset
self.num_labels = 10 if dataset == 'cifar' else 1000
self.auxiliary_loss = auxiliary_loss
# preprocess the specified width and depth
if isinstance(width, Iterable):
C = nn.ValueChoice(list(width), label='width')
else:
C = width
self.num_cells: nn.MaybeChoice[int] = cast(int, num_cells)
if isinstance(num_cells, Iterable):
self.num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [(i + 1) * self.num_cells // 3 - i * self.num_cells // 3 for i in range(3)]
# auxiliary head is different for network targetted at different datasets
if dataset == 'imagenet':
self.stem0 = nn.Sequential(
nn.Conv2d(3, cast(int, C // 2), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(cast(int, C // 2)),
nn.ReLU(inplace=True),
nn.Conv2d(cast(int, C // 2), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(cast(int, C), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_pprev = C_prev = C_curr = C
last_cell_reduce = True
elif dataset == 'cifar':
self.stem = nn.Sequential(
nn.Conv2d(3, cast(int, 3 * C), 3, padding=1, bias=False),
nn.BatchNorm2d(cast(int, 3 * C))
)
C_pprev = C_prev = 3 * C
C_curr = C
last_cell_reduce = False
else:
raise ValueError(f'Unsupported dataset: {dataset}')
self.stages = nn.ModuleList()
for stage_idx in range(3):
if stage_idx > 0:
C_curr *= 2
# For a stage, we get C_in, C_curr, and C_out.
# C_in is only used in the first cell.
# C_curr is number of channels for each operator in current stage.
# C_out is usually `C * num_nodes_per_cell` because of concat operator.
cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell,
merge_op, stage_idx > 0, last_cell_reduce)
stage: Union[NDSStage, nn.Sequential] = NDSStage(cell_builder, num_cells_per_stage[stage_idx])
if isinstance(stage, NDSStage):
stage.estimated_out_channels_prev = cast(int, C_prev)
stage.estimated_out_channels = cast(int, C_curr * num_nodes_per_cell)
stage.downsampling = stage_idx > 0
self.stages.append(stage)
# NOTE: output_node_indices will be computed on-the-fly in trial code.
# When constructing model space, it's just all the nodes in the cell,
# which happens to be the case of one-shot supernet.
# C_pprev is output channel number of last second cell among all the cells already built.
if len(stage) > 1:
# Contains more than one cell
C_pprev = len(cast(nn.Cell, stage[-2]).output_node_indices) * C_curr
else:
# Look up in the out channels of last stage.
C_pprev = C_prev
# This was originally,
# C_prev = num_nodes_per_cell * C_curr.
# but due to loose end, it becomes,
C_prev = len(cast(nn.Cell, stage[-1]).output_node_indices) * C_curr
# Useful in aligning the pprev and prev cell.
last_cell_reduce = cell_builder.last_cell_reduce
if stage_idx == 2:
C_to_auxiliary = C_prev
if auxiliary_loss:
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(cast(int, C_prev), self.num_labels)
def forward(self, inputs):
if self.dataset == 'imagenet':
s0 = self.stem0(inputs)
s1 = self.stem1(s0)
else:
s0 = s1 = self.stem(inputs)
for stage_idx, stage in enumerate(self.stages):
if stage_idx == 2 and self.auxiliary_loss:
s = list(stage([s0, s1]).values())
s0, s1 = s[-1]
if self.training:
# auxiliary loss is attached to the first cell of the last stage.
logits_aux = self.auxiliary_head(s[0][1])
else:
s0, s1 = stage([s0, s1])
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
if self.training and self.auxiliary_loss:
return logits, logits_aux # type: ignore
else:
return logits
def set_drop_path_prob(self, drop_prob):
"""
Set the drop probability of Drop-path in the network.
Reference: `FractalNet: Ultra-Deep Neural Networks without Residuals <https://arxiv.org/pdf/1605.07648v4.pdf>`__.
"""
for module in self.modules():
if isinstance(module, DropPath_):
module.drop_prob = drop_prob
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
return FixedFactory(cls, arch)
@model_wrapper
class NASNet(NDS):
__doc__ = """
Search space proposed in `Learning Transferable Architectures for Scalable Image Recognition <https://arxiv.org/abs/1707.07012>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~NASNet.NASNET_OPS`.
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
""" + _INIT_PARAMETER_DOCS
NASNET_OPS = [
'skip_connect',
'conv_3x1_1x3',
'conv_7x1_1x7',
'dil_conv_3x3',
'avg_pool_3x3',
'max_pool_3x3',
'max_pool_5x5',
'max_pool_7x7',
'conv_1x1',
'conv_3x3',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.NASNET_OPS,
merge_op='loose_end',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@model_wrapper
class ENAS(NDS):
__doc__ = """Search space proposed in `Efficient neural architecture search via parameter sharing <https://arxiv.org/abs/1802.03268>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~ENAS.ENAS_OPS`.
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
""" + _INIT_PARAMETER_DOCS
ENAS_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'avg_pool_3x3',
'max_pool_3x3',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.ENAS_OPS,
merge_op='loose_end',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@model_wrapper
class AmoebaNet(NDS):
__doc__ = """Search space proposed in
`Regularized evolution for image classifier architecture search <https://arxiv.org/abs/1802.01548>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~AmoebaNet.AMOEBA_OPS`.
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
""" + _INIT_PARAMETER_DOCS
AMOEBA_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'avg_pool_3x3',
'max_pool_3x3',
'dil_sep_conv_3x3',
'conv_7x1_1x7',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.AMOEBA_OPS,
merge_op='loose_end',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@model_wrapper
class PNAS(NDS):
__doc__ = """Search space proposed in
`Progressive neural architecture search <https://arxiv.org/abs/1712.00559>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~PNAS.PNAS_OPS`.
It has 5 nodes per cell, and the output is concatenation of all nodes in the cell.
""" + _INIT_PARAMETER_DOCS
PNAS_OPS = [
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
'skip_connect',
'avg_pool_3x3',
'max_pool_3x3',
'dil_conv_3x3',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.PNAS_OPS,
merge_op='all',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@model_wrapper
class DARTS(NDS):
__doc__ = """Search space proposed in `Darts: Differentiable architecture search <https://arxiv.org/abs/1806.09055>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~DARTS.DARTS_OPS`.
It has 4 nodes per cell, and the output is concatenation of all nodes in the cell.
""" + _INIT_PARAMETER_DOCS
DARTS_OPS = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.DARTS_OPS,
merge_op='all',
num_nodes_per_cell=4,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
init_kwargs = {} # all default
if name == 'darts-v2':
init_kwargs.update(
num_cells=20,
width=36,
)
arch = {
'normal/op_2_0': 'sep_conv_3x3',
'normal/op_2_1': 'sep_conv_3x3',
'normal/input_2_0': 0,
'normal/input_2_1': 1,
'normal/op_3_0': 'sep_conv_3x3',
'normal/op_3_1': 'sep_conv_3x3',
'normal/input_3_0': 0,
'normal/input_3_1': 1,
'normal/op_4_0': 'sep_conv_3x3',
'normal/op_4_1': 'skip_connect',
'normal/input_4_0': 1,
'normal/input_4_1': 0,
'normal/op_5_0': 'skip_connect',
'normal/op_5_1': 'dil_conv_3x3',
'normal/input_5_0': 0,
'normal/input_5_1': 2,
'reduce/op_2_0': 'max_pool_3x3',
'reduce/op_2_1': 'max_pool_3x3',
'reduce/input_2_0': 0,
'reduce/input_2_1': 1,
'reduce/op_3_0': 'skip_connect',
'reduce/op_3_1': 'max_pool_3x3',
'reduce/input_3_0': 2,
'reduce/input_3_1': 1,
'reduce/op_4_0': 'max_pool_3x3',
'reduce/op_4_1': 'skip_connect',
'reduce/input_4_0': 0,
'reduce/input_4_1': 2,
'reduce/op_5_0': 'skip_connect',
'reduce/op_5_1': 'max_pool_3x3',
'reduce/input_5_0': 2,
'reduce/input_5_1': 1
}
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = cls.fixed_arch(arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model

Просмотреть файл

@ -0,0 +1,538 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload
import torch
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
@overload
def make_divisible(v: Union[int, float], divisor, min_val=None) -> int:
...
@overload
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float]], divisor, min_val=None) -> nn.ChoiceOf[int]:
...
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float], int, float], divisor, min_val=None) -> nn.MaybeChoice[int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_val is None:
min_val = divisor
# This should work for both value choices and constants.
new_v = nn.ValueChoice.max(min_val, round(v + divisor // 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
return nn.ValueChoice.condition(new_v < 0.9 * v, new_v + divisor, new_v)
def simplify_sequential(sequentials: List[nn.Module]) -> Iterator[nn.Module]:
"""
Flatten the sequential blocks so that the hierarchy looks better.
Eliminate identity modules automatically.
"""
for module in sequentials:
if isinstance(module, nn.Sequential):
for submodule in module.children():
# no recursive expansion
if not isinstance(submodule, nn.Identity):
yield submodule
else:
if not isinstance(module, nn.Identity):
yield module
class ConvBNReLU(nn.Sequential):
"""
The template for a conv-bn-relu block.
"""
def __init__(
self,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
groups: nn.MaybeChoice[int] = 1,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
padding = (kernel_size - 1) // 2 * dilation
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation_layer is None:
activation_layer = nn.ReLU6
# If no normalization is used, set bias to True
# https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L194
norm = norm_layer(cast(int, out_channels))
no_normalization = isinstance(norm, nn.Identity)
blocks: List[nn.Module] = [
nn.Conv2d(
cast(int, in_channels),
cast(int, out_channels),
cast(int, kernel_size),
stride,
cast(int, padding),
dilation=dilation,
groups=cast(int, groups),
bias=no_normalization
),
# Normalization, regardless of batchnorm or identity
norm,
# One pytorch implementation as an SE here, to faithfully reproduce paper
# We follow a more accepted approach to put SE outside
# Reference: https://github.com/d-li14/mobilenetv3.pytorch/issues/18
activation_layer(inplace=True)
]
super().__init__(*simplify_sequential(blocks))
class DepthwiseSeparableConv(nn.Sequential):
"""
In the original MobileNetV2 implementation, this is InvertedResidual when expand ratio = 1.
Residual connection is added if input and output shape are the same.
References:
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L90
- https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L433
- https://github.com/ultmaster/AceNAS/blob/46c8895f/searchspace/proxylessnas/utils.py#L100
"""
def __init__(
self,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
squeeze_excite: Optional[Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
blocks = [
# dw
ConvBNReLU(in_channels, in_channels, stride=stride, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=activation_layer),
# optional se
squeeze_excite(in_channels, in_channels) if squeeze_excite else nn.Identity(),
# pw-linear
ConvBNReLU(in_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity)
]
super().__init__(*simplify_sequential(blocks))
# NOTE: "is" is used here instead of "==" to avoid creating a new value choice.
self.has_skip = stride == 1 and in_channels is out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.has_skip:
return x + super().forward(x)
else:
return super().forward(x)
class InvertedResidual(nn.Sequential):
"""
An Inverted Residual Block, sometimes called an MBConv Block, is a type of residual block used for image models
that uses an inverted structure for efficiency reasons.
It was originally proposed for the `MobileNetV2 <https://arxiv.org/abs/1801.04381>`__ CNN architecture.
It has since been reused for several mobile-optimized CNNs.
It follows a narrow -> wide -> narrow approach, hence the inversion.
It first widens with a 1x1 convolution, then uses a 3x3 depthwise convolution (which greatly reduces the number of parameters),
then a 1x1 convolution is used to reduce the number of channels so input and output can be added.
This implementation is sort of a mixture between:
- https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L453
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L134
"""
def __init__(
self,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
expand_ratio: nn.MaybeChoice[float],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
squeeze_excite: Optional[Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
self.stride = stride
self.out_channels = out_channels
assert stride in [1, 2]
hidden_ch = cast(int, make_divisible(in_channels * expand_ratio, 8))
# NOTE: this equivalence check (==) does NOT work for ValueChoice, need to use "is"
self.has_skip = stride == 1 and in_channels is out_channels
layers: List[nn.Module] = [
# point-wise convolution
# NOTE: some paper omit this point-wise convolution when stride = 1.
# In our implementation, if this pw convolution is intended to be omitted,
# please use SepConv instead.
ConvBNReLU(in_channels, hidden_ch, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer),
# depth-wise
ConvBNReLU(hidden_ch, hidden_ch, stride=stride, kernel_size=kernel_size, groups=hidden_ch,
norm_layer=norm_layer, activation_layer=activation_layer),
# SE
squeeze_excite(
cast(int, hidden_ch),
cast(int, in_channels)
) if squeeze_excite is not None else nn.Identity(),
# pw-linear
ConvBNReLU(hidden_ch, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity),
]
super().__init__(*simplify_sequential(layers))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.has_skip:
return x + super().forward(x)
else:
return super().forward(x)
def inverted_residual_choice_builder(
expand_ratios: List[int],
kernel_sizes: List[int],
downsample: bool,
stage_input_width: int,
stage_output_width: int,
label: str
):
def builder(index):
stride = 1
inp = stage_output_width
if index == 0:
# first layer in stage
# do downsample and width reshape
inp = stage_input_width
if downsample:
stride = 2
oup = stage_output_width
op_choices = {}
for exp_ratio in expand_ratios:
for kernel_size in kernel_sizes:
op_choices[f'k{kernel_size}e{exp_ratio}'] = InvertedResidual(inp, oup, exp_ratio, kernel_size, stride)
# It can be implemented with ValueChoice, but we use LayerChoice here
# to be aligned with the intention of the original ProxylessNAS.
return nn.LayerChoice(op_choices, label=f'{label}_i{index}')
return builder
@model_wrapper
class ProxylessNAS(nn.Module):
"""
The search space proposed by `ProxylessNAS <https://arxiv.org/abs/1812.00332>`__.
Following the official implementation, the inverted residual with kernel size / expand ratio variations in each layer
is implemented with a :class:`nn.LayerChoice` with all-combination candidates. That means,
when used in weight sharing, these candidates will be treated as separate layers, and won't be fine-grained shared.
We note that :class:`MobileNetV3Space` is different in this perspective.
This space can be implemented as part of :class:`MobileNetV3Space`, but we separate those following conventions.
"""
def __init__(self, num_labels: int = 1000,
base_widths: Tuple[int, ...] = (32, 16, 32, 40, 80, 96, 192, 320, 1280),
dropout_rate: float = 0.,
width_mult: float = 1.0,
bn_eps: float = 1e-3,
bn_momentum: float = 0.1):
super().__init__()
assert len(base_widths) == 9
# include the last stage info widths here
widths = [make_divisible(width * width_mult, 8) for width in base_widths]
downsamples = [True, False, True, True, True, False, True, False]
self.num_labels = num_labels
self.dropout_rate = dropout_rate
self.bn_eps = bn_eps
self.bn_momentum = bn_momentum
self.stem = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
blocks: List[nn.Module] = [
# first stage is fixed
DepthwiseSeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
]
# https://github.com/ultmaster/AceNAS/blob/46c8895fd8a05ffbc61a6b44f1e813f64b4f66b7/searchspace/proxylessnas/__init__.py#L21
for stage in range(2, 8):
# Rather than returning a fixed module here,
# we return a builder that dynamically creates module for different `repeat_idx`.
builder = inverted_residual_choice_builder(
[3, 6], [3, 5, 7], downsamples[stage], widths[stage - 1], widths[stage], f's{stage}')
if stage < 7:
blocks.append(nn.Repeat(builder, (1, 4), label=f's{stage}_depth'))
else:
# No mutation for depth in the last stage.
# Directly call builder to initiate one block
blocks.append(builder(0))
self.blocks = nn.Sequential(*blocks)
# final layers
self.feature_mix_layer = ConvBNReLU(widths[7], widths[8], kernel_size=1, norm_layer=nn.BatchNorm2d)
self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
self.dropout_layer = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(widths[-1], num_labels)
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.feature_mix_layer(x)
x = self.global_avg_pooling(x)
x = x.view(x.size(0), -1) # flatten
x = self.dropout_layer(x)
x = self.classifier(x)
return x
def no_weight_decay(self):
# this is useful for timm optimizer
# no regularizer to linear layer
if hasattr(self, 'classifier'):
return {'classifier.weight', 'classifier.bias'}
return set()
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
return FixedFactory(cls, arch)
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
init_kwargs = {} # all default
if name == 'acenas-m1':
arch = {
's2_depth': 2,
's2_i0': 'k3e6',
's2_i1': 'k3e3',
's3_depth': 3,
's3_i0': 'k5e3',
's3_i1': 'k3e3',
's3_i2': 'k5e3',
's4_depth': 2,
's4_i0': 'k3e6',
's4_i1': 'k5e3',
's5_depth': 4,
's5_i0': 'k7e6',
's5_i1': 'k3e6',
's5_i2': 'k3e6',
's5_i3': 'k7e3',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k7e6',
's6_i2': 'k7e3',
's6_i3': 'k7e3',
's7_depth': 1,
's7_i0': 'k7e6'
}
elif name == 'acenas-m2':
arch = {
's2_depth': 1,
's2_i0': 'k5e3',
's3_depth': 3,
's3_i0': 'k3e6',
's3_i1': 'k3e3',
's3_i2': 'k5e3',
's4_depth': 2,
's4_i0': 'k7e6',
's4_i1': 'k5e6',
's5_depth': 4,
's5_i0': 'k5e6',
's5_i1': 'k5e3',
's5_i2': 'k5e6',
's5_i3': 'k3e6',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k5e6',
's6_i2': 'k5e3',
's6_i3': 'k5e6',
's7_depth': 1,
's7_i0': 'k7e6'
}
elif name == 'acenas-m3':
arch = {
's2_depth': 2,
's2_i0': 'k3e3',
's2_i1': 'k3e6',
's3_depth': 2,
's3_i0': 'k5e3',
's3_i1': 'k3e3',
's4_depth': 3,
's4_i0': 'k5e6',
's4_i1': 'k7e6',
's4_i2': 'k3e6',
's5_depth': 4,
's5_i0': 'k7e6',
's5_i1': 'k7e3',
's5_i2': 'k7e3',
's5_i3': 'k5e3',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k7e3',
's6_i2': 'k7e6',
's6_i3': 'k3e3',
's7_depth': 1,
's7_i0': 'k5e6'
}
elif name == 'proxyless-cpu':
arch = {
's2_depth': 4,
's2_i0': 'k3e6',
's2_i1': 'k3e3',
's2_i2': 'k3e3',
's2_i3': 'k3e3',
's3_depth': 4,
's3_i0': 'k3e6',
's3_i1': 'k3e3',
's3_i2': 'k3e3',
's3_i3': 'k5e3',
's4_depth': 2,
's4_i0': 'k3e6',
's4_i1': 'k3e3',
's5_depth': 4,
's5_i0': 'k5e6',
's5_i1': 'k3e3',
's5_i2': 'k3e3',
's5_i3': 'k3e3',
's6_depth': 4,
's6_i0': 'k5e6',
's6_i1': 'k5e3',
's6_i2': 'k5e3',
's6_i3': 'k3e3',
's7_depth': 1,
's7_i0': 'k5e6'
}
init_kwargs['base_widths'] = [40, 24, 32, 48, 88, 104, 216, 360, 1432]
elif name == 'proxyless-gpu':
arch = {
's2_depth': 1,
's2_i0': 'k5e3',
's3_depth': 2,
's3_i0': 'k7e3',
's3_i1': 'k3e3',
's4_depth': 2,
's4_i0': 'k7e6',
's4_i1': 'k5e3',
's5_depth': 3,
's5_i0': 'k5e6',
's5_i1': 'k3e3',
's5_i2': 'k5e3',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k7e6',
's6_i2': 'k7e6',
's6_i3': 'k5e6',
's7_depth': 1,
's7_i0': 'k7e6'
}
init_kwargs['base_widths'] = [40, 24, 32, 56, 112, 128, 256, 432, 1728]
elif name == 'proxyless-mobile':
arch = {
's2_depth': 2,
's2_i0': 'k5e3',
's2_i1': 'k3e3',
's3_depth': 4,
's3_i0': 'k7e3',
's3_i1': 'k3e3',
's3_i2': 'k5e3',
's3_i3': 'k5e3',
's4_depth': 4,
's4_i0': 'k7e6',
's4_i1': 'k5e3',
's4_i2': 'k5e3',
's4_i3': 'k5e3',
's5_depth': 4,
's5_i0': 'k5e6',
's5_i1': 'k5e3',
's5_i2': 'k5e3',
's5_i3': 'k5e3',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k7e6',
's6_i2': 'k7e3',
's6_i3': 'k7e3',
's7_depth': 1,
's7_i0': 'k7e6'
}
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = cls.fixed_arch(arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model
def reset_parameters(model, model_init='he_fout', init_div_groups=False,
bn_momentum=0.1, bn_eps=1e-5):
for m in model.modules():
if isinstance(m, nn.Conv2d):
if model_init == 'he_fout':
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if init_div_groups:
n /= m.groups
m.weight.data.normal_(0, math.sqrt(2. / n))
elif model_init == 'he_fin':
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
if init_div_groups:
n /= m.groups
m.weight.data.normal_(0, math.sqrt(2. / n))
else:
raise NotImplementedError
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.momentum = bn_momentum
m.eps = bn_eps
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()

Просмотреть файл

@ -0,0 +1,297 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
class ShuffleNetBlock(nn.Module):
"""
Describe the basic building block of shuffle net, as described in
`ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices <https://arxiv.org/pdf/1707.01083.pdf>`__.
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *,
kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
super().__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5, 7]
self.channels = in_channels // 2 if stride == 1 else in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.mid_channels = mid_channels
self.kernel_size = kernel_size
self.stride = stride
self.pad = kernel_size // 2
self.oup_main = out_channels - self.channels
self.affine = affine
assert self.oup_main > 0
self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))
if stride == 2:
self.branch_proj = nn.Sequential(
# dw
nn.Conv2d(self.channels, self.channels, kernel_size, stride, self.pad,
groups=self.channels, bias=False),
nn.BatchNorm2d(self.channels, affine=affine),
# pw-linear
nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.channels, affine=affine),
nn.ReLU(inplace=True)
)
else:
# empty block to be compatible with torchscript
self.branch_proj = nn.Sequential()
def forward(self, x):
if self.stride == 2:
x_proj, x = self.branch_proj(x), x
else:
x_proj, x = self._channel_shuffle(x)
return torch.cat((x_proj, self.branch_main(x)), 1)
def _decode_point_depth_conv(self, sequence):
result = []
first_depth = first_point = True
pc: int = self.channels
c: int = self.channels
for i, token in enumerate(sequence):
# compute output channels of this conv
if i + 1 == len(sequence):
assert token == "p", "Last conv must be point-wise conv."
c = self.oup_main
elif token == "p" and first_point:
c = cast(int, self.mid_channels)
if token == "d":
# depth-wise conv
if isinstance(pc, int) and isinstance(c, int):
# check can only be done for static channels
assert pc == c, "Depth-wise conv must not change channels."
result.append(nn.Conv2d(pc, c, self.kernel_size, self.stride if first_depth else 1, self.pad,
groups=c, bias=False))
result.append(nn.BatchNorm2d(c, affine=self.affine))
first_depth = False
elif token == "p":
# point-wise conv
result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
result.append(nn.BatchNorm2d(c, affine=self.affine))
result.append(nn.ReLU(inplace=True))
first_point = False
else:
raise ValueError("Conv sequence must be d and p.")
pc = c
return result
def _channel_shuffle(self, x):
bs, num_channels, height, width = x.size()
# NOTE: this line is commented for torchscript
# assert (num_channels % 4 == 0)
x = x.reshape(bs * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width)
return x[0], x[1]
class ShuffleXceptionBlock(ShuffleNetBlock):
"""
The ``choice_x`` version of shuffle net block, described in
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *, stride: int, affine: bool = True):
super().__init__(in_channels, out_channels, mid_channels,
kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)
@model_wrapper
class ShuffleNetSpace(nn.Module):
"""
The search space proposed in `Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
The basic building block design is inspired by a state-of-the-art manually-designed network --
`ShuffleNetV2 <https://openaccess.thecvf.com/content_ECCV_2018/html/Ningning_Light-weight_CNN_Architecture_ECCV_2018_paper.html>`__.
There are 20 choice blocks in total. Each choice block has 4 candidates, namely ``choice 3``, ``choice 5``,
``choice_7`` and ``choice_x`` respectively. They differ in kernel sizes and the number of depthwise convolutions.
The size of the search space is :math:`4^{20}`.
Parameters
----------
num_labels : int
Number of classes for the classification head. Default: 1000.
channel_search : bool
If true, for each building block, the number of ``mid_channels``
(output channels of the first 1x1 conv in each building block) varies from 0.2x to 1.6x (quantized to multiple of 0.2).
Here, "k-x" means k times the number of default channels.
Otherwise, 1.0x is used by default. Default: false.
affine : bool
Apply affine to all batch norm. Default: true.
"""
def __init__(self,
num_labels: int = 1000,
channel_search: bool = False,
affine: bool = True):
super().__init__()
self.num_labels = num_labels
self.channel_search = channel_search
self.affine = affine
# the block number in each stage. 4 stages in total. 20 blocks in total.
self.stage_repeats = [4, 4, 8, 4]
# output channels for all stages, including the very first layer and the very last layer
self.stage_out_channels = [-1, 16, 64, 160, 320, 640, 1024]
# building first layer
out_channels = self.stage_out_channels[1]
self.first_conv = nn.Sequential(
nn.Conv2d(3, out_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
feature_blocks = []
global_block_idx = 0
for stage_idx, num_repeat in enumerate(self.stage_repeats):
for block_idx in range(num_repeat):
# count global index to give names to choices
global_block_idx += 1
# get ready for input and output
in_channels = out_channels
out_channels = self.stage_out_channels[stage_idx + 2]
stride = 2 if block_idx == 0 else 1
# mid channels can be searched
base_mid_channels = out_channels // 2
if self.channel_search:
k_choice_list = [int(base_mid_channels * (.2 * k)) for k in range(1, 9)]
mid_channels = nn.ValueChoice(k_choice_list, label=f'channel_{global_block_idx}')
else:
mid_channels = int(base_mid_channels)
mid_channels = cast(nn.MaybeChoice[int], mid_channels)
choice_block = nn.LayerChoice(dict(
k3=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine),
k5=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine),
k7=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine),
xcep=ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine)
), label=f'layer_{global_block_idx}')
feature_blocks.append(choice_block)
self.features = nn.Sequential(*feature_blocks)
# final layers
last_conv_channels = self.stage_out_channels[-1]
self.conv_last = nn.Sequential(
nn.Conv2d(out_channels, last_conv_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(last_conv_channels, affine=affine),
nn.ReLU(inplace=True),
)
self.globalpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Sequential(
nn.Linear(last_conv_channels, num_labels, bias=False),
)
self._initialize_weights()
def forward(self, x):
x = self.first_conv(x)
x = self.features(x)
x = self.conv_last(x)
x = self.globalpool(x)
x = self.dropout(x)
x = x.contiguous().view(-1, self.stage_out_channels[-1])
x = self.classifier(x)
return x
def _initialize_weights(self):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'first' in name:
torch.nn.init.normal_(m.weight, 0, 0.01)
else:
torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
return FixedFactory(cls, arch)
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
if name == 'spos':
# NOTE: Need BGR tensor, with no normalization
# https://github.com/ultmaster/spacehub-conversion/blob/371a4fd6646b4e11eda3f61187f7c9a1d484b1ca/cutils.py#L63
arch = {
'layer_1': 'k7',
'layer_2': 'k5',
'layer_3': 'k3',
'layer_4': 'k5',
'layer_5': 'k7',
'layer_6': 'k3',
'layer_7': 'k7',
'layer_8': 'k3',
'layer_9': 'k7',
'layer_10': 'k3',
'layer_11': 'k7',
'layer_12': 'xcep',
'layer_13': 'k3',
'layer_14': 'k3',
'layer_15': 'k3',
'layer_16': 'k3',
'layer_17': 'xcep',
'layer_18': 'k7',
'layer_19': 'xcep',
'layer_20': 'xcep'
}
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = cls.fixed_arch(arch)
model = model_factory()
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model

Просмотреть файл

Просмотреть файл

@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""This file should be merged to nni/nas/fixed.py"""
from typing import Type
from nni.nas.utils import ContextStack
class FixedFactory:
"""Make a model space ready to create a fixed model.
Examples
--------
>>> factory = FixedFactory(ModelSpaceClass, {"choice1": 3})
>>> model = factory(channels=16, classes=10)
"""
# TODO: mutations on ``init_args`` and ``init_kwargs`` themselves are not supported.
def __init__(self, cls: Type, arch: dict):
self.cls = cls
self.arch = arch
def __call__(self, *init_args, **init_kwargs):
with ContextStack('fixed', self.arch):
return self.cls(*init_args, **init_kwargs)
def __repr__(self):
return f'FixedFactory(class={self.cls}, arch={self.arch})'

Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Weights available in this file are processed with scripts in https://github.com/ultmaster/spacehub-conversion,
and uploaded with :func:`nni.common.blob_utils.upload_file`.
"""
import os
from nni.common.blob_utils import NNI_BLOB, nni_cache_home, load_or_download_file
PRETRAINED_WEIGHT_URLS = {
# proxylessnas
'acenas-m1': f'{NNI_BLOB}/nashub/acenas-m1-e215f1b8.pth',
'acenas-m2': f'{NNI_BLOB}/nashub/acenas-m2-a8ee9e8f.pth',
'acenas-m3': f'{NNI_BLOB}/nashub/acenas-m3-66a5ed7b.pth',
'proxyless-cpu': f'{NNI_BLOB}/nashub/proxyless-cpu-2df03430.pth',
'proxyless-gpu': f'{NNI_BLOB}/nashub/proxyless-gpu-dbe6dd15.pth',
'proxyless-mobile': f'{NNI_BLOB}/nashub/proxyless-mobile-8668a978.pth',
# mobilenetv3
'mobilenetv3-large-100': f'{NNI_BLOB}/nashub/mobilenetv3-large-100-420e040a.pth',
'mobilenetv3-small-050': f'{NNI_BLOB}/nashub/mobilenetv3-small-050-05cb7a80.pth',
'mobilenetv3-small-075': f'{NNI_BLOB}/nashub/mobilenetv3-small-075-c87d8acb.pth',
'mobilenetv3-small-100': f'{NNI_BLOB}/nashub/mobilenetv3-small-100-8332faac.pth',
'cream-014': f'{NNI_BLOB}/nashub/cream-014-060aea24.pth',
'cream-043': f'{NNI_BLOB}/nashub/cream-043-bec949e1.pth',
'cream-114': f'{NNI_BLOB}/nashub/cream-114-fc272590.pth',
'cream-287': f'{NNI_BLOB}/nashub/cream-287-a0fcba33.pth',
'cream-481': f'{NNI_BLOB}/nashub/cream-481-d85779b6.pth',
'cream-604': f'{NNI_BLOB}/nashub/cream-604-9ee425f7.pth',
# nasnet
'darts-v2': f'{NNI_BLOB}/nashub/darts-v2-5465b0d2.pth',
# spos
'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth',
# autoformer
'autoformer-tiny': f'{NNI_BLOB}/nashub/autoformer-searched-tiny-1e90ebc1.pth',
'autoformer-small': f'{NNI_BLOB}/nashub/autoformer-searched-small-4bc5d4e5.pth',
'autoformer-base': f'{NNI_BLOB}/nashub/autoformer-searched-base-c417590a.pth'
}
def load_pretrained_weight(name: str, **kwargs) -> str:
if name not in PRETRAINED_WEIGHT_URLS:
raise ValueError(f'"{name}" do not have a valid pretrained weight file.')
url = PRETRAINED_WEIGHT_URLS[name]
local_path = os.path.join(nni_cache_home(), 'nashub', url.split('/')[-1])
load_or_download_file(local_path, url, **kwargs)
return local_path

Просмотреть файл

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import *

124
nni/nas/mutable/mutator.py Normal file
Просмотреть файл

@ -0,0 +1,124 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast)
from nni.nas.execution import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator', 'InvalidMutation']
Choice = Any
class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError()
def mutation_start(self, mutator: 'Mutator', model: Model) -> None:
pass
def mutation_end(self, mutator: 'Mutator', model: Model) -> None:
pass
class Mutator:
"""
Mutates graphs in model to generate new model.
`Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)):
self.sampler: Optional[Sampler] = sampler
if label is None:
warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning)
self.label: str = label
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None
def bind_sampler(self, sampler: Sampler) -> 'Mutator':
"""
Set the sampler which will handle `Mutator.choice` calls.
"""
self.sampler = sampler
return self
def apply(self, model: Model) -> Model:
"""
Apply this mutator on a model.
Returns mutated model.
The model will be copied before mutation and the original model will not be modified.
"""
assert self.sampler is not None
copy = model.fork()
self._cur_model = copy
self._cur_choice_idx = 0
self._cur_samples = []
self.sampler.mutation_start(self, copy)
self.mutate(copy)
self.sampler.mutation_end(self, copy)
copy.history.append(Mutation(self, self._cur_samples, model, copy))
copy.status = ModelStatus.Frozen
self._cur_model = None
self._cur_choice_idx = None
return copy
def dry_run(self, model: Model) -> Tuple[List[List[Choice]], Model]:
"""
Dry run mutator on a model to collect choice candidates.
If you invoke this method multiple times on same or different models,
it may or may not return identical results, depending on how the subclass implements `Mutator.mutate()`.
"""
sampler_backup = self.sampler
recorder = _RecorderSampler()
self.sampler = recorder
new_model = self.apply(model)
self.sampler = sampler_backup
return recorder.recorded_candidates, new_model
def mutate(self, model: Model) -> None:
"""
Abstract method to be implemented by subclass.
Mutate a model in place.
"""
raise NotImplementedError()
def choice(self, candidates: Iterable[Choice]) -> Choice:
"""
Ask sampler to make a choice.
"""
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
self._cur_samples.append(ret)
self._cur_choice_idx += 1
return ret
class _RecorderSampler(Sampler):
def __init__(self):
self.recorded_candidates: List[List[Choice]] = []
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]
class InvalidMutation(Exception):
pass

8
nni/nas/nn/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework

1
nni/nas/nn/pytorch/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1 @@
_layers.py

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше